# Theory Cotree

(* Author: Andreas Lochbihler, ETH Zurich
Author: Peter Gammie *)

section ‹A codatatype of infinite binary trees›

theory Cotree imports
Main
Applicative_Lifting.Applicative
"HOL-Library.BNF_Corec"
begin

context notes [[bnf_internals]]
begin
codatatype 'a tree = Node (root: 'a) (left: "'a tree") (right: "'a tree")
end

lemma rel_treeD:
assumes "rel_tree A x y"
shows rel_tree_rootD: "A (root x) (root y)"
and rel_tree_leftD: "rel_tree A (left x) (left y)"
and rel_tree_rightD: "rel_tree A (right x) (right y)"
using assms
by(cases x y rule: tree.exhaust[case_product tree.exhaust], simp_all)+

lemmas [simp] = tree.map_sel tree.map_comp

lemma set_tree_induct[consumes 1, case_names root left right]:
assumes x: "x ∈ set_tree t"
and root: "⋀t. P (root t) t"
and left: "⋀x t. ⟦ x ∈ set_tree (left t); P x (left t) ⟧ ⟹ P x t"
and right: "⋀x t. ⟦ x ∈ set_tree (right t); P x (right t) ⟧ ⟹ P x t"
shows "P x t"
using x
proof(rule tree.set_induct)
fix l x r
from root[of "Node x l r"] show "P x (Node x l r)" by simp
qed(auto intro: left right)

lemma corec_tree_cong:
assumes "⋀x. stopL x ⟹ STOPL x = STOPL' x"
and "⋀x. ~ stopL x ⟹ LEFT x = LEFT' x"
and "⋀x. stopR x ⟹ STOPR x = STOPR' x"
and "⋀x. ¬ stopR x ⟹ RIGHT x = RIGHT' x"
shows "corec_tree ROOT stopL STOPL LEFT stopR STOPR RIGHT =
corec_tree ROOT stopL STOPL' LEFT' stopR STOPR' RIGHT'"
(is "?lhs = ?rhs")
proof
fix x
show "?lhs x = ?rhs x"
by(coinduction arbitrary: x rule: tree.coinduct_strong)(auto simp add: assms)
qed

context
fixes g1 :: "'a ⇒ 'b"
and g22 :: "'a ⇒ 'a"
and g32 :: "'a ⇒ 'a"
begin

corec unfold_tree :: "'a ⇒ 'b tree"
where "unfold_tree a = Node (g1 a) (unfold_tree (g22 a)) (unfold_tree (g32 a))"

lemma unfold_tree_simps [simp]:
"root (unfold_tree a) = g1 a"
"left (unfold_tree a) = unfold_tree (g22 a)"
"right (unfold_tree a) = unfold_tree (g32 a)"
by(subst unfold_tree.code; simp; fail)+

end

lemma unfold_tree_unique:
assumes "⋀s. root (f s) = ROOT s"
and "⋀s. left (f s) = f (LEFT s)"
and "⋀s. right (f s) = f (RIGHT s)"
shows "f s = unfold_tree ROOT LEFT RIGHT s"
by(rule unfold_tree.unique[THEN fun_cong])(auto simp add: fun_eq_iff assms intro: tree.expand)

subsection ‹Applicative functor for @{typ "'a tree"}›

context fixes x :: "'a" begin
corec pure_tree :: "'a tree"
where "pure_tree = Node x pure_tree pure_tree"
end

lemmas pure_tree_unfold = pure_tree.code

lemma pure_tree_simps [simp]:
"root (pure_tree x) = x"
"left (pure_tree x) = pure_tree x"
"right (pure_tree x) = pure_tree x"
by(subst pure_tree_unfold; simp; fail)+

lemma pure_tree_parametric [transfer_rule]: "(rel_fun A (rel_tree A)) pure pure"
by(rule rel_funI)(coinduction, auto)

lemma map_pure_tree [simp]: "map_tree f (pure x) = pure (f x)"
by(coinduction arbitrary: x) auto

lemmas pure_tree_unique = pure_tree.unique

primcorec (transfer) ap_tree :: "('a ⇒ 'b) tree ⇒ 'a tree ⇒ 'b tree"
where
"root (ap_tree f x) = root f (root x)"
| "left (ap_tree f x) = ap_tree (left f) (left x)"
| "right (ap_tree f x) = ap_tree (right f) (right x)"

unbundle applicative_syntax

lemma ap_tree_pure_Node [simp]:
"pure f ⋄ Node x l r = Node (f x) (pure f ⋄ l) (pure f ⋄ r)"
by(rule tree.expand) auto

lemma ap_tree_Node_Node [simp]:
"Node f fl fr ⋄ Node x l r = Node (f x) (fl ⋄ l) (fr ⋄ r)"
by(rule tree.expand) auto

text ‹Applicative functor laws›

lemma map_tree_ap_tree_pure_tree:
"pure f ⋄ u = map_tree f u"
by(coinduction arbitrary: u) auto

lemma ap_tree_identity: "pure id ⋄ t = t"

lemma ap_tree_composition:
"pure (∘) ⋄ r1 ⋄ r2 ⋄ r3 = r1 ⋄ (r2 ⋄ r3)"
by(coinduction arbitrary: r1 r2 r3) auto

lemma ap_tree_homomorphism:
"pure f ⋄ pure x = pure (f x)"

lemma ap_tree_interchange:
"t ⋄ pure x = pure (λf. f x) ⋄ t"
by(coinduction arbitrary: t) auto

lemma ap_tree_K_tree: "pure (λx y. x) ⋄ u ⋄ v = u"
by(coinduction arbitrary: u v)(auto)

lemma ap_tree_C_tree: "pure (λf x y. f y x) ⋄ u ⋄ v ⋄ w = u ⋄ w ⋄ v"
by(coinduction arbitrary: u v w)(auto)

lemma ap_tree_W_tree: "pure (λf x. f x x) ⋄ f ⋄ x = f ⋄ x ⋄ x"
by(coinduction arbitrary: f x)(auto)

applicative tree (K, W) for
pure: pure_tree
ap: ap_tree
rel: rel_tree
set: set_tree
proof -
fix R :: "'b ⇒ 'c ⇒ bool" and f :: "('a ⇒ 'b) tree" and g x
assume [transfer_rule]: "rel_tree (rel_fun (eq_on (set_tree x)) R) f g"
have [transfer_rule]: "rel_tree (eq_on (set_tree x)) x x" by(rule tree.rel_refl_strong) simp
show "rel_tree R (f ⋄ x) (g ⋄ x)" by transfer_prover
qed(rule ap_tree_homomorphism ap_tree_composition[unfolded o_def[abs_def]] ap_tree_K_tree ap_tree_W_tree ap_tree_interchange pure_tree_parametric)+

declare map_tree_ap_tree_pure_tree[symmetric, applicative_unfold]

lemma ap_tree_strong_extensional:
"(⋀x. f ⋄ pure x = g ⋄ pure x) ⟹ f = g"
proof(coinduction arbitrary: f g)
case [rule_format]: (Eq_tree f g)
have "root f = root g"
proof
fix x
show "root f x = root g x"
using Eq_tree[of x] by(subst (asm) (1 2) ap_tree.ctr) simp
qed
moreover {
fix x
have "left f ⋄ pure x = left g ⋄ pure x"
using Eq_tree[of x] by(subst (asm) (1 2) ap_tree.ctr) simp
} moreover {
fix x
have "right f ⋄ pure x = right g ⋄ pure x"
using Eq_tree[of x] by(subst (asm) (1 2) ap_tree.ctr) simp
} ultimately show ?case by simp
qed

lemma ap_tree_extensional:
"(⋀x. f ⋄ x = g ⋄ x) ⟹ f = g"
by(rule ap_tree_strong_extensional) simp

subsection ‹Standard tree combinators›

subsubsection ‹Recurse combinator›

text ‹
This will be the main combinator to define trees recursively

Uniqueness for this gives us the unique fixed-point theorem for guarded recursive definitions.
›
lemma map_unfold_tree [simp]: fixes l r x
defines "unf ≡ unfold_tree (λf. f x) (λf. f ∘ l) (λf. f ∘ r)"
shows "map_tree G (unf F) = unf (G ∘ F)"
by(coinduction arbitrary: F G)(auto 4 3 simp add: unf_def o_assoc)

friend_of_corec map_tree :: "('a ⇒ 'a) ⇒ 'a tree ⇒ 'a tree" where
"map_tree f t = Node (f (root t)) (map_tree f (left t)) (map_tree f (right t))"
subgoal by (rule tree.expand; simp)
subgoal by (fold relator_eq; transfer_prover)
done

context fixes l :: "'a ⇒ 'a" and r :: "'a ⇒ 'a" and x :: "'a" begin
corec tree_recurse :: "'a tree"
where "tree_recurse = Node x (map_tree l tree_recurse) (map_tree r tree_recurse)"
end

lemma tree_recurse_simps [simp]:
"root (tree_recurse l r x) = x"
"left (tree_recurse l r x) = map_tree l (tree_recurse l r x)"
"right (tree_recurse l r x) = map_tree r (tree_recurse l r x)"
by(subst tree_recurse.code; simp; fail)+

lemma tree_recurse_unfold:
"tree_recurse l r x = Node x (map_tree l (tree_recurse l r x)) (map_tree r (tree_recurse l r x))"
by(fact tree_recurse.code)

lemma tree_recurse_fusion:
assumes "h ∘ l = l' ∘ h" and "h ∘ r = r' ∘ h"
shows "map_tree h (tree_recurse l r x) = tree_recurse l' r' (h x)"

subsubsection ‹Tree iteration›

context fixes l :: "'a ⇒ 'a" and r :: "'a ⇒ 'a" begin
primcorec tree_iterate :: " 'a ⇒ 'a tree"
where "tree_iterate s = Node s (tree_iterate (l s)) (tree_iterate (r s))"
end

lemma unfold_tree_tree_iterate:
"unfold_tree out l r = map_tree out ∘ tree_iterate l r"
by(rule ext)(rule unfold_tree_unique[symmetric]; simp)

lemma tree_iterate_fusion:
assumes "h ∘ l = l' ∘ h"
assumes "h ∘ r = r' ∘ h"
shows "map_tree h (tree_iterate l r x) = tree_iterate l' r' (h x)"
apply(coinduction arbitrary: x)
using assms by(auto simp add: fun_eq_iff)

subsubsection ‹Tree traversal›

datatype dir = L | R
type_synonym path = "dir list"

definition traverse_tree :: "path ⇒ 'a tree ⇒ 'a tree"
where "traverse_tree path ≡ foldr (λd f. f ∘ case_dir left right d) path id"

lemma traverse_tree_simps[simp]:
"traverse_tree [] = id"
"traverse_tree (d # path) = traverse_tree path ∘ (case d of L ⇒ left | R ⇒ right)"

lemma traverse_tree_map_tree [simp]:
"traverse_tree path (map_tree f t) = map_tree f (traverse_tree path t)"
by (induct path arbitrary: t) (simp_all split: dir.splits)

lemma traverse_tree_append [simp]:
"traverse_tree (path @ ext) t = traverse_tree ext (traverse_tree path t)"
by (induct path arbitrary: t) simp_all

text‹@{const "traverse_tree"} is an applicative-functor homomorphism.›

lemma traverse_tree_pure_tree [simp]:
"traverse_tree path (pure x) = pure x"
by (induct path arbitrary: x) (simp_all split: dir.splits)

lemma traverse_tree_ap [simp]:
"traverse_tree path (f ⋄ x) = traverse_tree path f ⋄ traverse_tree path x"
by (induct path arbitrary: f x) (simp_all split: dir.splits)

context fixes l r :: "'a ⇒ 'a" begin

primrec traverse_dir :: "dir ⇒ 'a ⇒ 'a"
where
"traverse_dir L = l"
| "traverse_dir R = r"

abbreviation traverse_path :: "path ⇒ 'a ⇒ 'a"
where "traverse_path ≡ fold traverse_dir"

end

lemma traverse_tree_tree_iterate:
"traverse_tree path (tree_iterate l r s) =
tree_iterate l r (traverse_path l r path s)"
by (induct path arbitrary: s) (simp_all split: dir.splits)

text‹

\citeauthor{DBLP:journals/jfp/Hinze09} shows that if the tree
construction function is suitably monoidal then recursion and
iteration define the same tree.

›

lemma tree_recurse_iterate:
assumes monoid:
"⋀x y z. f (f x y) z = f x (f y z)"
"⋀x. f x ε = x"
"⋀x. f ε x = x"
shows "tree_recurse (f l) (f r) ε = tree_iterate (λx. f x l) (λx. f x r) ε"
apply(rule tree_recurse.unique[symmetric])
apply(rule tree.expand)
apply(simp add: tree_iterate_fusion[where r'="λx. f x r" and l'="λx. f x l"] fun_eq_iff monoid)
done

subsubsection ‹Mirroring›

primcorec mirror :: "'a tree ⇒ 'a tree"
where
"root (mirror t) = root t"
| "left (mirror t) = mirror (right t)"
| "right (mirror t) = mirror (left t)"

lemma mirror_unfold: "mirror (Node x l r) = Node x (mirror r) (mirror l)"
by(rule tree.expand) simp

lemma mirror_pure: "mirror (pure x) = pure x"
by(coinduction rule: tree.coinduct) simp

lemma mirror_ap_tree: "mirror (f ⋄ x) = mirror f ⋄ mirror x"
by(coinduction arbitrary: f x) auto

end


# Theory Cotree_Algebra

(*  Author: Andreas Lochbihler, ETH Zurich
Author: Joshua Schneider, ETH Zurich
*)

subsection ‹Pointwise arithmetic on infinite binary trees›

theory Cotree_Algebra
imports Cotree
begin

subsubsection ‹Constants and operators›

instantiation tree :: (zero) zero begin
definition [applicative_unfold]: "0 = pure_tree 0"
instance ..
end

instantiation tree :: (one) one begin
definition [applicative_unfold]: "1 = pure_tree 1"
instance ..
end

instantiation tree :: (plus) plus begin
definition [applicative_unfold]: "plus x y = pure (+) ⋄ x ⋄ (y :: 'a tree)"
instance ..
end

lemma plus_tree_simps [simp]:
"root (t + t') = root t + root t'"
"left (t + t') = left t + left t'"
"right (t + t') = right t + right t'"

friend_of_corec plus where "t + t' = Node (root t + root t') (left t + left t') (right t + right t')"
subgoal by(rule tree.expand; simp)
subgoal by transfer_prover
done

instantiation tree :: (minus) minus begin
definition [applicative_unfold]: "minus x y = pure (-) ⋄ x ⋄ (y :: 'a tree)"
instance ..
end

lemma minus_tree_simps [simp]:
"root (t - t') = root t - root t'"
"left (t - t') = left t - left t'"
"right (t - t') = right t - right t'"

instantiation tree :: (uminus) uminus begin
definition [applicative_unfold tree]: "uminus = ((⋄) (pure uminus) :: 'a tree ⇒ 'a tree)"
instance ..
end

instantiation tree :: (times) times begin
definition [applicative_unfold]: "times x y = pure (*) ⋄ x ⋄ (y :: 'a tree)"
instance ..
end

lemma times_tree_simps [simp]:
"root (t * t') = root t * root t'"
"left (t * t') = left t * left t'"
"right (t * t') = right t * right t'"

instance tree :: (Rings.dvd) Rings.dvd ..

instantiation tree :: (modulo) modulo begin
definition [applicative_unfold]: "x div y = pure_tree (div) ⋄ x ⋄ (y :: 'a tree)"
definition [applicative_unfold]: "x mod y = pure_tree (mod) ⋄ x ⋄ (y :: 'a tree)"
instance ..
end

lemma mod_tree_simps [simp]:
"root (t mod t') = root t mod root t'"
"left (t mod t') = left t mod left t'"
"right (t mod t') = right t mod right t'"

subsubsection ‹Algebraic instances›

instance tree :: (semigroup_mult) semigroup_mult
using mult.assoc by intro_classes applicative_lifting

instance tree :: (ab_semigroup_mult) ab_semigroup_mult
using mult.commute by intro_classes applicative_lifting

by intro_classes (applicative_lifting, simp)+

by intro_classes (applicative_lifting, simp)

instance tree :: (comm_monoid_diff) comm_monoid_diff

instance tree :: (monoid_mult) monoid_mult
by intro_classes (applicative_lifting, simp)+

instance tree :: (comm_monoid_mult) comm_monoid_mult
by intro_classes (applicative_lifting, simp)

proof
fix a b c :: "'a tree"
assume "a + b = a + c"
thus "b = c"
proof (coinduction arbitrary: a b c)
case (Eq_tree a b c)
hence "root (a + b) = root (a + c)"
"left (a + b) = left (a + c)"
"right (a + b) = right (a + c)"
by simp_all
thus ?case by (auto)
qed
next
fix a b c :: "'a tree"
assume "b + a = c + a"
thus "b = c"
proof (coinduction arbitrary: a b c)
case (Eq_tree a b c)
hence "root (b + a) = root (c + a)"
"left (b + a) = left (c + a)"
"right (b + a) = right (c + a)"
by simp_all
thus ?case by (auto)
qed
qed

by intro_classes (applicative_lifting, simp add: diff_diff_eq)+

by intro_classes (applicative_lifting, simp)+

by intro_classes (applicative_lifting, simp)+

instance tree :: (semiring) semiring
by intro_classes (applicative_lifting, simp add: ring_distribs)+

instance tree :: (mult_zero) mult_zero
by intro_classes (applicative_lifting, simp)+

instance tree :: (semiring_0) semiring_0 ..

instance tree :: (semiring_0_cancel) semiring_0_cancel ..

instance tree :: (comm_semiring) comm_semiring
by intro_classes(rule distrib_right)

instance tree :: (comm_semiring_0) comm_semiring_0 ..

instance tree :: (comm_semiring_0_cancel) comm_semiring_0_cancel ..

lemma pure_tree_inject[simp]: "pure_tree x = pure_tree y ⟷ x = y"
proof
assume "pure_tree x = pure_tree y"
hence "root (pure_tree x) = root (pure_tree y)" by simp
thus "x = y" by simp
qed simp

instance tree :: (zero_neq_one) zero_neq_one
by intro_classes (applicative_unfold tree)

instance tree :: (semiring_1) semiring_1 ..

instance tree :: (comm_semiring_1) comm_semiring_1 ..

instance tree :: (semiring_1_cancel) semiring_1_cancel ..

instance tree :: (comm_semiring_1_cancel) comm_semiring_1_cancel
by(intro_classes; applicative_lifting, rule right_diff_distrib')

instance tree :: (ring) ring ..

instance tree :: (comm_ring) comm_ring ..

instance tree :: (ring_1) ring_1 ..

instance tree :: (comm_ring_1) comm_ring_1 ..

instance tree :: (numeral) numeral ..

instance tree :: (neg_numeral) neg_numeral ..

instance tree :: (semiring_numeral) semiring_numeral ..

lemma of_nat_tree: "of_nat n = pure_tree (of_nat n)"
proof (induction n)
case 0 show ?case by (simp add: zero_tree_def)
next
case (Suc n)
have "1 + pure (of_nat n) = pure (1 + of_nat n)" by applicative_nf rule
with Suc.IH show ?case by simp
qed

instance tree :: (semiring_char_0) semiring_char_0
by intro_classes (simp add: inj_on_def of_nat_tree)

lemma numeral_tree_simps [simp]:
"root (numeral n) = numeral n"
"left (numeral n) = numeral n"
"right (numeral n) = numeral n"
by(induct n)(auto simp add: numeral.simps plus_tree_def one_tree_def)

lemma numeral_tree_conv_pure [applicative_unfold]: "numeral n = pure (numeral n)"
by(rule pure_tree_unique)(rule tree.expand; simp)

instance tree :: (ring_char_0) ring_char_0 ..

end


# Theory Stern_Brocot_Tree

(* Author: Peter Gammie
Author: Andreas Lochbihler, ETH Zurich *)

section ‹The Stern-Brocot Tree›

theory Stern_Brocot_Tree
imports
HOL.Rat
"HOL-Library.Sublist"
Cotree_Algebra
Applicative_Lifting.Stream_Algebra
begin

text‹
The Stern-Brocot tree is discussed at length by \citet[\S4.5]{GrahamKnuthPatashnik1994CM}.
In essence the tree enumerates the rational numbers in their lowest terms by constructing the
‹mediant› of two bounding fractions.
›

type_synonym fraction = "nat × nat"

definition mediant :: "fraction × fraction ⇒ fraction"
where "mediant ≡ λ((a, c), (b, d)). (a + b, c + d)"

definition stern_brocot :: "fraction tree"
where
"stern_brocot = unfold_tree
(λ(lb, ub). mediant (lb, ub))
(λ(lb, ub). (lb, mediant (lb, ub)))
(λ(lb, ub). (mediant (lb, ub), ub))
((0, 1), (1, 0))"

text‹
This process is visualised in Figure~\ref{fig:stern-brocot-iterate}.
Intuitively each node is labelled with the mediant of it's rightmost and leftmost ancestors.

\begin{figure}
\centering
\begin{tikzpicture}[auto,thick,node distance=3cm,main node/.style={circle,draw,font=\sffamily\Large\bfseries}]
\node[main node] (0) at (0, 0) {$\frac{1}{1}$};
\node[main node] (1) at (-4, 1) {$\frac{0}{1}$};
\node[main node] (2) at (4, 1) {$\frac{1}{0}$};
\node[main node] (3) at (-2, -1) {$\frac{1}{2}$};
\node[main node] (4) at (2, -1) {$\frac{2}{1}$};
\node[main node] (5) at (-3, -2) {$\frac{1}{3}$};
\node[main node] (6) at (3, -2) {$\frac{3}{1}$};
\node[main node] (7) at (-1, -2) {$\frac{2}{3}$};
\node[main node] (8) at (1, -2) {$\frac{3}{2}$};
\node (9) at (-3.5, -3) {};
\node (10) at (-2.5, -3) {};
\node (11) at (-1.5, -3) {};
\node (12) at (-0.5, -3) {};
\node (13) at (0.5, -3) {};
\node (14) at (1.5, -3) {};
\node (15) at (2.5, -3) {};
\node (16) at (3.5, -3) {};
\path
(1) edge[dashed] (0)
(2) edge[dashed] (0)
(0) edge (3)
(0) edge (4)
(3) edge (5)
(3) edge (7)
(4) edge (6)
(4) edge (8)
(5) edge[dotted] (9)
(5) edge[dotted] (10)
(6) edge[dotted] (15)
(6) edge[dotted] (16)
(7) edge[dotted] (11)
(7) edge[dotted] (12)
(8) edge[dotted] (13)
(8) edge[dotted] (14);
\end{tikzpicture}
\label{fig:stern-brocot-iterate}
\caption{Constructing the Stern-Brocot tree iteratively.}
\end{figure}

Our ultimate goal is to show that the Stern-Brocot tree contains all rationals (in lowest terms),
and that each occurs exactly once in the tree. A proof is sketched in \citet[\S4.5]{GrahamKnuthPatashnik1994CM}.
›

subsection ‹Specification via a recursion equation›

text ‹
\cite{Hinze2009JFP} derives the following recurrence relation for the Stern-Brocot tree.
We will show in \S\ref{section:eq:rec:iterative} that his derivation is sound with respect to the
standard iterative definition of the tree shown above.
›

abbreviation succ :: "fraction ⇒ fraction"
where "succ ≡ λ(m, n). (m + n, n)"

abbreviation recip :: "fraction ⇒ fraction"
where "recip ≡ λ(m, n). (n, m)"

corec stern_brocot_recurse :: "fraction tree"
where
"stern_brocot_recurse =
Node (1, 1)
(map_tree recip (map_tree succ (map_tree recip stern_brocot_recurse)))
(map_tree succ stern_brocot_recurse)"

text ‹Actually, we would like to write the specification below, but ‹(⋄)› cannot be registered as friendly due to varying type parameters›
lemma stern_brocot_unfold:
"stern_brocot_recurse =
Node (1, 1)
(pure recip ⋄ (pure succ ⋄ (pure recip ⋄ stern_brocot_recurse)))
(pure succ ⋄ stern_brocot_recurse)"
by(fact stern_brocot_recurse.code[unfolded map_tree_ap_tree_pure_tree[symmetric]])

lemma stern_brocot_simps [simp]:
"root stern_brocot_recurse = (1, 1)"
"left stern_brocot_recurse = pure recip ⋄ (pure succ ⋄ (pure recip ⋄ stern_brocot_recurse))"
"right stern_brocot_recurse = pure succ ⋄ stern_brocot_recurse"
by (subst stern_brocot_unfold, simp)+

lemma stern_brocot_conv:
"stern_brocot_recurse = tree_recurse (recip ∘ succ ∘ recip) succ (1, 1)"
apply(rule tree_recurse.unique)
apply(subst stern_brocot_unfold)
apply(rule conjI; applicative_nf; simp)
done

subsection ‹Basic properties›

text ‹
The recursive definition is useful for showing some basic properties of the tree,
such as that the pairs of numbers at each node are coprime, and have non-zero denominators.
Both are simple inductions on the path.
›

lemma stern_brocot_denominator_non_zero:
"case root (traverse_tree path stern_brocot_recurse) of (m, n) ⇒ m > 0 ∧ n > 0"
by(induct path)(auto split: dir.splits)

lemma stern_brocot_coprime:
"case root (traverse_tree path stern_brocot_recurse) of (m, n) ⇒ coprime m n"
by (induct path) (auto split: dir.splits simp add: coprime_iff_gcd_eq_1, metis gcd.commute gcd_add1)

subsection ‹All the rationals›

text‹
For every pair of positive naturals, we can construct a path into the Stern-Brocot tree such that the naturals at the end of the path define the same rational as the pair we started with.
Intuitively, the choices made by Euclid's algorithm define this path.
›

function mk_path :: "nat ⇒ nat ⇒ path" where
"m = n ⟹ mk_path (Suc m) (Suc n) = []"
| "m < n ⟹ mk_path (Suc m) (Suc n) = L # mk_path (Suc m) (n - m)"
| "m > n ⟹ mk_path (Suc m) (Suc n) = R # mk_path (m - n) (Suc n)"
| "mk_path 0 _ = undefined"
| "mk_path _ 0 = undefined"
by atomize_elim(auto, arith)
termination mk_path by lexicographic_order

lemmas mk_path_induct[case_names equal less greater] = mk_path.induct

abbreviation rat_of :: "fraction ⇒ rat"
where "rat_of ≡ λ(x, y). Fract (int x) (int y)"

theorem stern_brocot_rationals:
"⟦ m > 0; n > 0 ⟧ ⟹
root (traverse_tree (mk_path m n) (pure rat_of ⋄ stern_brocot_recurse)) = Fract (int m) (int n)"
proof(induction m n rule: mk_path_induct)
case (less m n)
with stern_brocot_denominator_non_zero[where path="mk_path (Suc m) (n - m)"]
show ?case
by (simp add: eq_rat field_simps of_nat_diff split: prod.split_asm)
next
case (greater m n)
with stern_brocot_denominator_non_zero[where path="mk_path (m - n) (Suc n)"]
show ?case
by (simp add: eq_rat field_simps of_nat_diff split: prod.split_asm)

subsection ‹No repetitions›

text ‹
We establish that the Stern-Brocot tree does not contain repetitions, i.e.,
that each rational number appears at most once in it.
Note that this property is stronger than merely requiring that pairs of naturals not be repeated,
though it is implied by that property and @{thm [source] "stern_brocot_coprime"}.

Intuitively, the tree enjoys the \emph{binary search tree} ordering property when we map our
pairs of naturals into rationals. This suffices to show that each rational appears at most once
in the tree. To establish this seems to require more structure than is present in the recursion
equations, and so we follow \citet{BackhouseFerreira2008MPC} and \citet{Hinze2009JFP} by
introducing another definition of the tree, which summarises the path to each node using a matrix.

We then derive an iterative version and use invariant reasoning on that.
We begin by defining some matrix machinery.
This is all elementary and primitive (we do not need much algebra).
›

type_synonym matrix = "fraction × fraction"
type_synonym vector = fraction

definition times_matrix :: "matrix ⇒ matrix ⇒ matrix" (infixl "⊗" 70)
where "times_matrix = (λ((a, c), (b, d)) ((a', c'), (b', d')).
((a * a' + b * c', c * a' + d * c'),
(a * b' + b * d', c * b' + d * d')))"

definition times_vector :: "matrix ⇒ vector ⇒ vector" (infixr "⊙" 70)
where "times_vector = (λ((a, c), (b, d)) (a', c'). (a * a' + b * c', c * a' + d * c'))"

context begin

private definition F :: matrix where "F = ((0, 1), (1, 0))"
private definition I :: matrix where "I = ((1, 0), (0, 1))"
private definition LL :: matrix where "LL = ((1, 1), (0, 1))"
private definition UR :: matrix where "UR = ((1, 0), (1, 1))"

definition Det :: "matrix ⇒ nat" where "Det ≡ λ((a, c), (b, d)). a * d - b * c"

lemma Dets [iff]:
"Det I = 1"
"Det LL = 1"
"Det UR = 1"
unfolding Det_def I_def LL_def UR_def by simp_all

lemma LL_UR_Det:
"Det m = 1 ⟹ Det (m ⊗ LL) = 1"
"Det m = 1 ⟹ Det (LL ⊗ m) = 1"
"Det m = 1 ⟹ Det (m ⊗ UR) = 1"
"Det m = 1 ⟹ Det (UR ⊗ m) = 1"
by (cases m, simp add: Det_def LL_def UR_def times_matrix_def split_def field_simps)+

lemma mediant_I_F [simp]:
"mediant F = (1, 1)"
"mediant I = (1, 1)"
by (simp_all add: F_def I_def mediant_def)

lemma times_matrix_I [simp]:
"I ⊗ x = x"
"x ⊗ I = x"
by (simp_all add: times_matrix_def I_def split_def)

lemma times_matrix_assoc [simp]:
"(x ⊗ y) ⊗ z = x ⊗ (y ⊗ z)"
by (simp add: times_matrix_def field_simps split_def)

lemma LL_UR_pos:
"0 < snd (mediant m) ⟹ 0 < snd (mediant (m ⊗ LL))"
"0 < snd (mediant m) ⟹ 0 < snd (mediant (m ⊗ UR))"
by (cases m) (simp_all add: LL_def UR_def times_matrix_def split_def field_simps mediant_def)

lemma recip_succ_recip: "recip ∘ succ ∘ recip = (λ(x, y). (x, x + y))"
by (clarsimp simp: fun_eq_iff)

text ‹
\citeauthor{BackhouseFerreira2008MPC} work with the identity matrix @{const "I"} at the root.
This has the advantage that all relevant matrices have determinants of @{term "1 :: nat"}.
›

definition stern_brocot_iterate_aux :: "matrix ⇒ matrix tree"
where "stern_brocot_iterate_aux ≡ tree_iterate (λs. s ⊗ LL) (λs. s ⊗ UR)"

definition stern_brocot_iterate :: "fraction tree"
where "stern_brocot_iterate ≡ map_tree mediant (stern_brocot_iterate_aux I)"

lemma stern_brocot_recurse_iterate: "stern_brocot_recurse = stern_brocot_iterate" (is "?lhs = ?rhs")
proof -
have "?rhs = map_tree mediant (tree_recurse ((⊗) LL) ((⊗) UR) I)"
using tree_recurse_iterate[where f="(⊗)" and l="LL" and r="UR" and ε="I"]
also have "… = tree_recurse ((⊙) LL) ((⊙) UR) (1, 1)"
unfolding mediant_I_F(2)[symmetric]
by (rule tree_recurse_fusion)(simp_all add: fun_eq_iff mediant_def times_matrix_def times_vector_def LL_def UR_def)[2]
also have "… = ?lhs"
by (simp add: stern_brocot_conv recip_succ_recip times_vector_def LL_def UR_def)
finally show ?thesis by simp
qed

text‹
The following are the key ordering properties derived by \citet{BackhouseFerreira2008MPC}.
They hinge on the matrices containing only natural numbers.
›

lemma tree_ordering_left:
assumes DX: "Det X = 1"
assumes DY: "Det Y = 1"
assumes MX: "0 < snd (mediant X)"
shows "rat_of (mediant (X ⊗ LL ⊗ Y)) < rat_of (mediant X)"
proof -
from DX DY have F: "0 < snd (mediant (X ⊗ LL ⊗ Y))"
by (auto simp: Det_def times_matrix_def LL_def split_def mediant_def)
obtain x11 x12 x21 x22 where X: "X = ((x11, x12), (x21, x22))" by(cases X) auto
obtain y11 y12 y21 y22 where Y: "Y = ((y11, y12), (y21, y22))" by(cases Y) auto
from DX DY have *: "(x12 * x21) * (y12 + y22) < (x11 * x22) * (y12 + y22)"
from DX DY MX F show ?thesis
apply (simp add: split_def X Y of_nat_mult [symmetric] del: of_nat_mult)
apply (clarsimp simp: Det_def times_matrix_def LL_def UR_def mediant_def split_def)
using * by (simp add: field_simps)
qed

lemma tree_ordering_right:
assumes DX: "Det X = 1"
assumes DY: "Det Y = 1"
assumes MX: "0 < snd (mediant X)"
shows "rat_of (mediant X) < rat_of (mediant (X ⊗ UR ⊗ Y))"
proof -
from DX DY have F: "0 < snd (mediant (X ⊗ UR ⊗ Y))"
by (auto simp: Det_def times_matrix_def UR_def split_def mediant_def)
obtain x11 x12 x21 x22 where X: "X = ((x11, x12), (x21, x22))" by(cases X) auto
obtain y11 y12 y21 y22 where Y: "Y = ((y11, y12), (y21, y22))" by(cases Y) auto
show ?thesis using DX DY MX F
apply (simp add: X Y split_def of_nat_mult [symmetric] del: of_nat_mult)
apply (simp add: Det_def times_matrix_def LL_def UR_def mediant_def split_def algebra_simps)
apply (cases y21; simp)
done
qed

lemma stern_brocot_iterate_aux_Det:
assumes "Det m = 1" "0 < snd (mediant m)"
shows "Det (root (traverse_tree path (stern_brocot_iterate_aux m))) = 1"
and "0 < snd (mediant (root (traverse_tree path (stern_brocot_iterate_aux m))))"
using assms
by (induct path arbitrary: m)
(simp_all add: stern_brocot_iterate_aux_def LL_UR_Det LL_UR_pos split: dir.splits)

lemma stern_brocot_iterate_aux_decompose:
"∃m''. m ⊗ m'' = root (traverse_tree path (stern_brocot_iterate_aux m)) ∧ Det m'' = 1"
proof(induction path arbitrary: m)
case Nil show ?case
by (auto simp add: stern_brocot_iterate_aux_def intro: exI[where x=I] simp del: split_paired_Ex)
next
case (Cons d ds m)
from Cons.IH[where m="m ⊗ UR"] Cons.IH[where m="m ⊗ LL"] show ?case
by(simp add: stern_brocot_iterate_aux_def split: dir.splits del: split_paired_Ex)(fastforce simp: LL_UR_Det)
qed

lemma stern_brocot_fractions_not_repeated_strict_prefix:
assumes "root (traverse_tree path stern_brocot_iterate) = root (traverse_tree path' stern_brocot_iterate)"
assumes pp': "strict_prefix path path'"
shows False
proof -
from pp' obtain d ds where pp': "path' = path @ [d] @ ds" by (auto elim!: strict_prefixE')
define m where "m = root (traverse_tree path (stern_brocot_iterate_aux I))"
then have Dm: "Det m = 1" and Pm: "0 < snd (mediant m)"
using stern_brocot_iterate_aux_Det[where path="path" and m="I"] by simp_all
define m' where "m' = root (traverse_tree path' (stern_brocot_iterate_aux I))"
then have Dm': "Det m' = 1"
using stern_brocot_iterate_aux_Det[where path=path' and m="I"] by simp
let ?M = "case d of L ⇒ m ⊗ LL | R ⇒ m ⊗ UR"
from pp' have "root (traverse_tree ds (stern_brocot_iterate_aux ?M)) = m'"
by(simp add: m_def m'_def stern_brocot_iterate_aux_def traverse_tree_tree_iterate split: dir.splits)
then obtain m'' where mm'm'': "?M ⊗ m''= m'" and Dm'': "Det m'' = 1"
using stern_brocot_iterate_aux_decompose[where path="ds" and m="?M"] by clarsimp
hence "case d of L ⇒ rat_of (mediant m') < rat_of (mediant m) | R ⇒ rat_of (mediant m) < rat_of (mediant m')"
using tree_ordering_left[OF Dm Dm'' Pm] tree_ordering_right[OF Dm Dm'' Pm]
by (simp split: dir.splits)
with assms show False
by (simp add: stern_brocot_iterate_def m_def m'_def split: dir.splits)
qed

lemma stern_brocot_fractions_not_repeated_parallel:
assumes "root (traverse_tree path stern_brocot_iterate) = root (traverse_tree path' stern_brocot_iterate)"
assumes p: "path = pref @ d # ds"
assumes p': "path' = pref @ d' # ds'"
assumes dd': "d ≠ d'"
shows False
proof -
define m where "m = root (traverse_tree pref (stern_brocot_iterate_aux I))"
then have Dm: "Det m = 1" and Pm: "0 < snd (mediant m)"
using stern_brocot_iterate_aux_Det[where path="pref" and m="I"] by simp_all
define pm where "pm = root (traverse_tree path (stern_brocot_iterate_aux I))"
then have Dpm: "Det pm = 1"
using stern_brocot_iterate_aux_Det[where path=path and m="I"] by simp
let ?M = "case d of L ⇒ m ⊗ LL | R ⇒ m ⊗ UR"
from p
have "root (traverse_tree ds (stern_brocot_iterate_aux ?M)) = pm"
by(simp add: stern_brocot_iterate_aux_def m_def pm_def traverse_tree_tree_iterate split: dir.splits)
then obtain pm'
where pm': "?M ⊗ pm'= pm" and Dpm': "Det pm' = 1"
using stern_brocot_iterate_aux_decompose[where path="ds" and m="?M"] by clarsimp
hence "case d of L ⇒ rat_of (mediant pm) < rat_of (mediant m) | R ⇒ rat_of (mediant m) < rat_of (mediant pm)"
using tree_ordering_left[OF Dm Dpm' Pm, unfolded pm']
tree_ordering_right[OF Dm Dpm' Pm, unfolded pm']
by (simp split: dir.splits)
moreover
define p'm where "p'm = root (traverse_tree path' (stern_brocot_iterate_aux I))"
then have Dp'm: "Det p'm = 1"
using stern_brocot_iterate_aux_Det[where path=path' and m="I"] by simp
let ?M' = "case d' of L ⇒ m ⊗ LL | R ⇒ m ⊗ UR"
from p'
have "root (traverse_tree ds' (stern_brocot_iterate_aux ?M')) = p'm"
by(simp add: stern_brocot_iterate_aux_def m_def p'm_def traverse_tree_tree_iterate split: dir.splits)
then obtain p'm'
where p'm': "?M' ⊗ p'm' = p'm" and Dp'm': "Det p'm' = 1"
using stern_brocot_iterate_aux_decompose[where path="ds'" and m="?M'"] by clarsimp
hence "case d' of L ⇒ rat_of (mediant p'm) < rat_of (mediant m) | R ⇒ rat_of (mediant m) < rat_of (mediant p'm)"
using tree_ordering_left[OF Dm Dp'm' Pm, unfolded pm']
tree_ordering_right[OF Dm Dp'm' Pm, unfolded pm']
by (simp split: dir.splits)
ultimately show False using pm' p'm' assms
by(simp add: m_def pm_def p'm_def stern_brocot_iterate_def split: dir.splits)
qed

lemma lists_not_eq:
assumes "xs ≠ ys"
obtains
(c1) "strict_prefix xs ys"
| (c2) "strict_prefix ys xs"
| (c3) ps x y xs' ys'
where "xs = ps @ x # xs'" and "ys = ps @ y # ys'" and "x ≠ y"
using assms
by (cases xs ys rule: prefix_cases)
(blast dest: parallel_decomp prefix_order.neq_le_trans)+

lemma stern_brocot_fractions_not_repeated:
assumes "root (traverse_tree path stern_brocot_iterate) = root (traverse_tree path' stern_brocot_iterate)"
shows "path = path'"
proof(rule ccontr)
assume "path ≠ path'"
then show False using assms
by (cases path path' rule: lists_not_eq)
(blast intro: stern_brocot_fractions_not_repeated_strict_prefix sym
stern_brocot_fractions_not_repeated_parallel)+
qed

text ‹ The function @{const Fract} is injective under certain conditions. ›

lemma rat_inv_eq:
assumes "Fract a b = Fract c d"
assumes "b > 0"
assumes "d > 0"
assumes "coprime a b"
assumes "coprime c d"
shows "a = c ∧ b = d"
proof -
from ‹b > 0› ‹d > 0› ‹Fract a b = Fract c d›
have *: "a * d = c * b" by (simp add: eq_rat)
from arg_cong[where f=sgn, OF this] ‹b > 0› ‹d > 0›
have "sgn a = sgn c" by (simp add: sgn_mult)
with * show ?thesis
using ‹b > 0› ‹d > 0› coprime_crossproduct_int[OF ‹coprime a b› ‹coprime c d›]
qed

theorem stern_brocot_rationals_not_repeated:
assumes "root (traverse_tree path (pure rat_of ⋄ stern_brocot_recurse))
= root (traverse_tree path' (pure rat_of ⋄ stern_brocot_recurse))"
shows "path = path'"
using assms
using stern_brocot_coprime[where path=path]
stern_brocot_coprime[where path=path']
stern_brocot_denominator_non_zero[where path=path]
stern_brocot_denominator_non_zero[where path=path']
by(auto simp: gcd_int_def dest!: rat_inv_eq intro: stern_brocot_fractions_not_repeated simp add: stern_brocot_recurse_iterate[symmetric] split: prod.splits)

subsection ‹Equivalence of recursive and iterative version \label{section:eq:rec:iterative}›

text ‹
\citeauthor{Hinze2009JFP} shows that it does not matter whether we use @{const I} or
@{const "F"} at the root provided we swap the left and right matrices too.
›

definition stern_brocot_Hinze_iterate :: "fraction tree"
where "stern_brocot_Hinze_iterate = map_tree mediant (tree_iterate (λs. s ⊗ UR) (λs. s ⊗ LL) F)"

lemma mediant_times_F: "mediant ∘ (λs. s ⊗ F) = mediant"

lemma stern_brocot_iterate: "stern_brocot = stern_brocot_iterate"
proof -
have "stern_brocot = stern_brocot_Hinze_iterate"
unfolding stern_brocot_def stern_brocot_Hinze_iterate_def
by(subst unfold_tree_tree_iterate)(simp add: F_def times_matrix_def mediant_def UR_def LL_def split_def)
also have "… = map_tree mediant (map_tree (λs. s ⊗ F) (tree_iterate (λs. s ⊗ LL) (λs. s ⊗ UR) I))"
unfolding stern_brocot_Hinze_iterate_def
by(subst tree_iterate_fusion[where l'="λs. s ⊗ UR" and r'="λs. s ⊗ LL"])
(simp_all add: fun_eq_iff times_matrix_def UR_def LL_def F_def I_def)
also have "… = stern_brocot_iterate"
by(simp only: tree.map_comp mediant_times_F stern_brocot_iterate_def stern_brocot_iterate_aux_def)
finally show ?thesis .
qed

theorem stern_brocot_mediant_recurse: "stern_brocot = stern_brocot_recurse"

end

no_notation times_matrix (infixl "⊗" 70)
and times_vector (infixl "⊙" 70)

section ‹Linearising the Stern-Brocot Tree›

subsection ‹Turning a tree into a stream›

corec tree_chop :: "'a tree ⇒ 'a tree"
where "tree_chop t = Node (root (left t)) (right t) (tree_chop (left t))"

lemma tree_chop_sel [simp]:
"root (tree_chop t) = root (left t)"
"left (tree_chop t) = right t"
"right (tree_chop t) = tree_chop (left t)"
by(subst tree_chop.code; simp; fail)+

text ‹@{const tree_chop} is a idiom homomorphism›

lemma tree_chop_pure_tree [simp]:
"tree_chop (pure x) = pure x"
by(coinduction rule: tree.coinduct_strong) auto

lemma tree_chop_ap_tree [simp]:
"tree_chop (f ⋄ x) = tree_chop f ⋄ tree_chop x"
by(coinduction arbitrary: f x rule: tree.coinduct_strong) auto

lemma tree_chop_plus: "tree_chop (t + t') = tree_chop t + tree_chop t'"

corec stream :: "'a tree ⇒ 'a stream"
where "stream t = root t ## stream (tree_chop t)"

lemma stream_sel [simp]:
"shd (stream t) = root t"
"stl (stream t) = stream (tree_chop t)"
by(subst stream.code; simp; fail)+

text‹@{const "stream"} is an idiom homomorphism.›

lemma stream_pure [simp]: "stream (pure x) = pure x"
by coinduction auto

lemma stream_ap [simp]: "stream (f ⋄ x) = stream f ⋄ stream x"
by(coinduction arbitrary: f x) auto

lemma stream_plus [simp]: "stream (t + t') = stream t + stream t'"

lemma stream_minus [simp]: "stream (t - t') = stream t - stream t'"

lemma stream_times [simp]: "stream (t * t') = stream t * stream t'"

lemma stream_mod [simp]: "stream (t mod t') = stream t mod stream t'"

lemma stream_1 [simp]: "stream 1 = 1"

lemma stream_numeral [simp]: "stream (numeral n) = numeral n"
by(induct n)(simp_all only: numeral.simps stream_plus stream_1)

subsection ‹Split the Stern-Brocot tree into numerators and denumerators›

corec num_den :: "bool ⇒ nat tree"
where
"num_den x =
Node 1
(if x then num_den True else num_den True + num_den False)
(if x then num_den True + num_den False else num_den False)"

abbreviation num where "num ≡ num_den True"
abbreviation den where "den ≡ num_den False"

lemma num_unfold: "num = Node 1 num (num + den)"
by(subst num_den.code; simp)

lemma den_unfold: "den = Node 1 (num + den) den"
by(subst num_den.code; simp)

lemma num_simps [simp]:
"root num = 1"
"left num = num"
"right num = num + den"
by(subst num_unfold, simp)+

lemma den_simps [simp]:
"root den = 1"
"left den = num + den"
"right den = den"
by (subst den_unfold, simp)+

lemma stern_brocot_num_den:
"pure_tree Pair ⋄ num ⋄ den = stern_brocot_recurse"
apply(rule stern_brocot_recurse.unique)
apply(subst den_unfold)
apply(subst num_unfold)
apply(simp; intro conjI)
apply(applicative_lifting; simp)+
done

lemma den_eq_chop_num: "den = tree_chop num"
by(coinduction rule: tree.coinduct_strong) simp

lemma num_conv: "num = pure fst ⋄ stern_brocot_recurse"
unfolding stern_brocot_num_den[symmetric]
apply(applicative_lifting; simp)
done

lemma den_conv: "den = pure snd ⋄ stern_brocot_recurse"
unfolding stern_brocot_num_den[symmetric]
apply(applicative_lifting; simp)
done

corec num_mod_den :: "nat tree"
where "num_mod_den = Node 0 num num_mod_den"

lemma num_mod_den_simps [simp]:
"root num_mod_den = 0"
"left num_mod_den = num"
"right num_mod_den = num_mod_den"
by(subst num_mod_den.code; simp; fail)+

text‹
The arithmetic transformations need the precondition that @{const den} contains only
positive numbers, no @{term "0 :: nat"}. \citet[p502]{Hinze2009JFP} gets a bit sloppy here; it is
not straightforward to adapt his lifting framework \cite{Hinze2010Lifting} to conditional equations.
›

lemma mod_tree_lemma1:
fixes x :: "nat tree"
assumes "∀i∈set_tree y. 0 < i"
shows "x mod (x + y) = x"
proof -
have "rel_tree (=) (x mod (x + y)) x" by applicative_lifting(simp add: assms)
thus ?thesis by(unfold tree.rel_eq)
qed

lemma mod_tree_lemma2:
fixes x y :: "'a :: unique_euclidean_semiring tree"
shows "(x + y) mod y = x mod y"
by applicative_lifting simp

lemma set_tree_pathD: "x ∈ set_tree t ⟹ ∃p. x = root (traverse_tree p t)"
by(induct rule: set_tree_induct)(auto intro: exI[where x="[]"] exI[where x="L # p" for p] exI[where x="R # p" for p])

lemma den_gt_0: "0 < x" if "x ∈ set_tree den"
proof -
from that obtain p where "x = root (traverse_tree p den)" by(blast dest: set_tree_pathD)
with stern_brocot_denominator_non_zero[of p] show "0 < x" by(simp add: den_conv split_beta)
qed

lemma num_mod_den: "num mod den = num_mod_den"
by(rule num_mod_den.unique)(rule tree.expand, simp add: mod_tree_lemma2 mod_tree_lemma1 den_gt_0)

lemma tree_chop_den: "tree_chop den = num + den - 2 * (num mod den)"
proof -
have le: "0 < y ⟹ 2 * (x mod y) ≤ x + y" for x y :: nat

text ‹We switch to @{typ int} such that all cancellation laws are available.›
define den' where "den' = pure int ⋄ den"
define num' where "num' = pure int ⋄ num"
define num_mod_den' where "num_mod_den' = pure int ⋄ num_mod_den"

have [simp]: "root num' = 1" "left num' = num'" unfolding den'_def num'_def by simp_all
have [simp]: "right num' = num' + den'" unfolding den'_def num'_def ap_tree.sel pure_tree_simps num_simps
by applicative_lifting simp

have num_mod_den'_simps [simp]: "root num_mod_den' = 0" "left num_mod_den' = num'" "right num_mod_den' = num_mod_den'"
have den'_eq_chop_num': "den' = tree_chop num'" by(simp add: den'_def num'_def den_eq_chop_num)
have num_mod_den'2_unique: "⋀x. x = Node 0 (2 * num') x ⟹ x = 2 * num_mod_den'"
by(corec_unique)(rule tree.expand; simp)
have num'_plus_den'_minus_chop_den': "num' + den' - tree_chop den' = 2 * num_mod_den'"
by(rule num_mod_den'2_unique)(rule tree.expand, simp add: tree_chop_plus den'_eq_chop_num')

have "tree_chop den = pure nat ⋄ (tree_chop den')"
unfolding den_conv tree_chop_ap_tree tree_chop_pure_tree den'_def by applicative_nf simp
also have "tree_chop den' = num' + den' - tree_chop den' + tree_chop den' - 2 * num_mod_den'"
by(subst num'_plus_den'_minus_chop_den') simp
also have "… = num' + den' - 2 * (num' mod den')"
unfolding num_mod_den'_def num'_def den'_def num_mod_den[symmetric]
also have [unfolded tree.rel_eq]: "rel_tree (=) … (pure int ⋄ (num + den - 2 * (num mod den)))"
unfolding num'_def den'_def by(applicative_lifting)(simp add: of_nat_diff zmod_int le den_gt_0)
also have "pure nat ⋄ (pure int ⋄ (num + den - 2 * (num mod den))) = num + den - 2 * (num mod den)" by(applicative_nf) simp
finally show ?thesis .
qed

subsection‹Loopless linearisation of the Stern-Brocot tree.›

text ‹
This is a loopless linearisation of the Stern-Brocot tree that gives Stern's diatomic sequence,
which is also known as Dijkstra's fusc function \cite{Dijkstra1982EWD570,Dijkstra1982EWD578}.
Loopless \a la \cite{Bird2006MPC} means that the first element of the stream can be computed in linear
time and every further element in constant time.
›

friend_of_corec smap :: "('a ⇒ 'a) ⇒ 'a stream ⇒ 'a stream"
where "smap f xs = SCons (f (shd xs)) (smap f (stl xs))"
subgoal by(rule stream.expand) simp
subgoal by(fold relator_eq)(transfer_prover)
done

definition step :: "nat × nat ⇒ nat × nat"
where "step = (λ(n, d). (d, n + d - 2 * (n mod d)))"

corec stern_brocot_loopless :: "fraction stream"
where "stern_brocot_loopless = (1, 1) ## smap step stern_brocot_loopless"

lemmas stern_brocot_loopless_rec = stern_brocot_loopless.code

friend_of_corec plus where "s + s' = (shd s + shd s') ## (stl s + stl s')"
subgoal by (rule stream.expand; simp add: plus_stream_shd plus_stream_stl)
subgoal by transfer_prover
done

friend_of_corec minus where "t - t' = (shd t - shd t') ## (stl t - stl t')"
subgoal by (rule stream.expand; simp add: minus_stream_def)
subgoal by transfer_prover
done

friend_of_corec times where "t * t' = (shd t * shd t') ## (stl t * stl t')"
subgoal by (rule stream.expand; simp add: times_stream_def)
subgoal by transfer_prover
done

friend_of_corec modulo where "t mod t' = (shd t mod shd t') ## (stl t mod stl t')"
subgoal by (rule stream.expand; simp add: modulo_stream_def)
subgoal by transfer_prover
done

corec fusc' :: "nat stream"
where "fusc' = 1 ## (((1 ## fusc') + fusc') - 2 * ((1 ## fusc') mod fusc'))"

definition fusc where "fusc = 1 ## fusc'"

lemma fusc_unfold: "fusc = 1 ## fusc'" by(fact fusc_def)

lemma fusc'_unfold: "fusc' = 1 ## (fusc + fusc' - 2 * (fusc mod fusc'))"

lemma fusc_simps [simp]:
"shd fusc = 1"
"stl fusc = fusc'"

lemma fusc'_simps [simp]:
"shd fusc' = 1"
"stl fusc' = fusc + fusc' - 2 * (fusc mod fusc')"
by(subst fusc'_unfold, simp)+

subsection ‹Equivalence with Dijkstra's fusc function›

lemma stern_brocot_loopless_siterate: "stern_brocot_loopless = siterate step (1, 1)"
by(rule stern_brocot_loopless.unique[symmetric])(rule stream.expand; simp add: smap_siterate[symmetric])

lemma fusc_fusc'_iterate: "pure Pair ⋄ fusc ⋄ fusc' = stern_brocot_loopless"
apply(rule stern_brocot_loopless.unique)
apply(applicative_lifting; simp)
done

theorem stern_brocot_loopless:
"stream stern_brocot_recurse = stern_brocot_loopless" (is "?lhs = ?rhs")
proof(rule stern_brocot_loopless.unique)
have eq: "?lhs = stream (pure_tree Pair ⋄ num ⋄ den)" by (simp only: stern_brocot_num_den)
have num: "stream num = 1 ## stream den"
by (rule stream.expand) (simp add: den_eq_chop_num)
have den: "stream den = 1 ## (stream num + stream den - 2 * (stream num mod stream den))"
show "?lhs = (1, 1) ## smap step ?lhs" unfolding eq
qed

end


# Theory Bird_Tree

(* Author: Andreas Lochbihler, ETH Zurich
Author: Peter Gammie *)

section ‹ The Bird tree ›

text ‹
We define the Bird tree following \cite{Hinze2009JFP} and prove that it is a
permutation of the Stern-Brocot tree. As a corollary, we derive that the Bird tree also
contains all rational numbers in lowest terms exactly once.
›

theory Bird_Tree imports Stern_Brocot_Tree begin

corec bird :: "fraction tree"
where
"bird = Node (1, 1) (map_tree recip (map_tree succ bird)) (map_tree succ (map_tree recip bird))"

lemma bird_unfold:
"bird = Node (1, 1) (pure recip ⋄ (pure succ ⋄ bird)) (pure succ ⋄ (pure recip ⋄ bird))"
using bird.code unfolding map_tree_ap_tree_pure_tree[symmetric] .

lemma bird_simps [simp]:
"root bird = (1, 1)"
"left bird = pure recip ⋄ (pure succ ⋄ bird)"
"right bird = pure succ ⋄ (pure recip ⋄ bird)"
by(subst bird_unfold, simp)+

lemma mirror_bird: "mirror bird = pure recip ⋄ bird" (is "?lhs = ?rhs")
proof(rule sym)
let ?F = "λt. Node (1, 1) (map_tree succ (map_tree recip t)) (map_tree recip (map_tree succ t))"
have *: "mirror bird = ?F (mirror bird)"
by(rule tree.expand; simp add: mirror_ap_tree mirror_pure map_tree_ap_tree_pure_tree[symmetric])
show "t = mirror bird" when "t = ?F t" for t using that by corec_unique (fact *)
show "pure recip ⋄ bird = ?F (pure recip ⋄ bird)"
qed

primcorec even_odd_mirror :: "bool ⇒ 'a tree ⇒ 'a tree"
where
"⋀even. root (even_odd_mirror even t) = root t"
| "⋀even. left (even_odd_mirror even t) = even_odd_mirror (¬ even) (if even then right t else left t)"
| "⋀even. right (even_odd_mirror even t) = even_odd_mirror (¬ even) (if even then left t else right t)"

definition even_mirror :: "'a tree ⇒ 'a tree"
where "even_mirror = even_odd_mirror True"

definition odd_mirror :: "'a tree ⇒ 'a tree"
where "odd_mirror = even_odd_mirror False"

lemma even_mirror_simps [simp]:
"root (even_mirror t) = root t"
"left (even_mirror t) = odd_mirror (right t)"
"right (even_mirror t) = odd_mirror (left t)"
and odd_mirror_simps [simp]:
"root (odd_mirror t) = root t"
"left (odd_mirror t) = even_mirror (left t)"
"right (odd_mirror t) = even_mirror (right t)"

lemma even_odd_mirror_pure [simp]: fixes even shows
"even_odd_mirror even (pure_tree x) = pure_tree x"
by(coinduction arbitrary: even) auto

lemma even_odd_mirror_ap_tree [simp]: fixes even shows
"even_odd_mirror even (f ⋄ x) = even_odd_mirror even f ⋄ even_odd_mirror even x"
by(coinduction arbitrary: even f x) auto

lemma [simp]:
shows even_mirror_pure: "even_mirror (pure_tree x) = pure_tree x"
and odd_mirror_pure: "odd_mirror (pure_tree x) = pure_tree x"

lemma [simp]:
shows even_mirror_ap_tree: "even_mirror (f ⋄ x) = even_mirror f ⋄ even_mirror x"
and odd_mirror_ap_tree: "odd_mirror (f ⋄ x) = odd_mirror f ⋄ odd_mirror x"

fun even_mirror_path :: "path ⇒ path"
and odd_mirror_path :: "path ⇒ path"
where
"even_mirror_path [] = []"
| "even_mirror_path (d # ds) = (case d of L ⇒ R | R ⇒ L) # odd_mirror_path ds"
| "odd_mirror_path [] = []"
| "odd_mirror_path (d # ds) = d # even_mirror_path ds"

lemma even_mirror_traverse_tree [simp]:
"root (traverse_tree path (even_mirror t)) = root (traverse_tree (even_mirror_path path) t)"
and odd_mirror_traverse_tree [simp]:
"root (traverse_tree path (odd_mirror t)) = root (traverse_tree (odd_mirror_path path) t)"
by (induct path arbitrary: t) (simp_all split: dir.splits)

lemma even_odd_mirror_path_involution [simp]:
"even_mirror_path (even_mirror_path path) = path"
"odd_mirror_path (odd_mirror_path path) = path"
by (induct path) (simp_all split: dir.splits)

lemma even_odd_mirror_path_injective [simp]:
"even_mirror_path path = even_mirror_path path' ⟷ path = path'"
"odd_mirror_path path = odd_mirror_path path' ⟷ path = path'"
by (induct path arbitrary: path') (case_tac path', simp_all split: dir.splits)+

lemma odd_mirror_bird_stern_brocot:
"odd_mirror bird = stern_brocot_recurse"
proof -
let ?rsrs = "map_tree (recip ∘ succ ∘ recip ∘ succ)"
let ?rssr = "map_tree (recip ∘ succ ∘ succ ∘ recip)"
let ?srrs = "map_tree (succ ∘ recip ∘ recip ∘ succ)"
let ?srsr = "map_tree (succ ∘ recip ∘ succ ∘ recip)"
let ?R = "λt. Node (1, 1) (Node (1, 2) (?rssr t) (?rsrs t)) (Node (2, 1) (?srsr t) (?srrs t))"

have *: "stern_brocot_recurse = ?R stern_brocot_recurse"
by(rule tree.expand; simp; intro conjI; rule tree.expand; simp; intro conjI) ― ‹Expand the tree twice›
show "f = stern_brocot_recurse" when "f = ?R f" for f using that * by corec_unique
show "odd_mirror bird = ?R (odd_mirror bird)"
by(rule tree.expand; simp; intro conjI; rule tree.expand; simp; intro conjI) ― ‹Expand the tree twice›
(applicative_lifting; simp)+
qed

theorem bird_rationals:
assumes "m > 0" "n > 0"
shows "root (traverse_tree (odd_mirror_path (mk_path m n)) (pure rat_of ⋄ bird)) = Fract (int m) (int n)"
using stern_brocot_rationals[OF assms]
`