Session Stream_Fusion_Code

Theory Stream_Fusion

(* Title: Stream_Fusion.thy
 Authors: Alexandra Maximova, ETH Zurich
          Andreas Lochbihler, ETH Zurich
*)

section ‹Stream fusion implementation›

theory Stream_Fusion
imports
  Main
begin

ML_file ‹stream_fusion.ML›

simproc_setup stream_fusion ("f x") = ‹K Stream_Fusion.fusion_simproc
declare [[simproc del: stream_fusion]]

text ‹Install stream fusion as a simproc in the preprocessor for code equations›
setup Code_Preproc.map_pre (fn ss => ss addsimprocs [@{simproc "stream_fusion"}])

end

File ‹stream_fusion.ML›

(* Title: stream_fusion.ML
  Author: Alexandra Maximova, ETH Zurich,
          Andreas Lochbihler, ETH Zurich 

Implementation of the stream fusion transformation as a simproc for the preprocessor of the
code generator
*)

signature STREAM_FUSION =
sig
  val get_rules: Proof.context -> thm list
  val get_conspats: Proof.context -> (term * thm) list
  val match_consumer: Proof.context -> term -> bool
  val add_fusion_rule: thm -> Context.generic -> Context.generic
  val del_fusion_rule: thm -> Context.generic -> Context.generic
  val add_unstream: string -> Context.generic -> Context.generic
  val del_unstream: string -> Context.generic -> Context.generic
  val get_unstream: Proof.context -> string list
  val fusion_add: attribute
  val fusion_del: attribute
  val fusion_conv: Proof.context -> conv
  val fusion_simproc: Proof.context -> cterm -> thm option
  val trace: bool Config.T
end;

structure Stream_Fusion : STREAM_FUSION = 
struct

type fusion_rules = 
  { rules : thm Item_Net.T,
    conspats : (term * thm) Item_Net.T,
    unstream : string list
  }

fun map_fusion_rules f1 f2 f3
  {rules, conspats, unstream}
  =
  {rules = f1 rules,
   conspats = f2 conspats,
   unstream = f3 unstream};

fun map_rules f = map_fusion_rules f I I;
fun map_conspats f = map_fusion_rules I f I;
fun map_unstream f = map_fusion_rules I I f;


(* producers: theorems about producers, have 'unstream' only on the lhs *)
(* consumers: theorems about consumers, have 'unstream' only on the rhs *)
(* transformers: theorems about transformers, have 'unstream' on both sides *)
(* conspats: patterns of consumers that have matching theorems in consumers *)
structure Fusion_Rules = Generic_Data
(
  type T = fusion_rules;
  val empty = {rules = Thm.full_rules,
               conspats = Item_Net.init (Thm.eq_thm_prop o apply2 snd) (single o fst),
               unstream = []};
  val extend = I;
  fun merge 
    ({rules = r, conspats = cp, unstream = u},
     {rules = r', conspats = cp', unstream = u'}) =
    {rules = Item_Net.merge (r, r'),
     conspats = Item_Net.merge (cp, cp'),
     unstream = Library.merge (op =) (u, u')}
);


val get_rules = Item_Net.content o #rules o Fusion_Rules.get o Context.Proof;
val get_conspats = Item_Net.content o #conspats o Fusion_Rules.get o Context.Proof;
val get_unstream = #unstream o Fusion_Rules.get o Context.Proof;

fun match_consumer ctxt t = 
  Context.Proof ctxt
  |> Fusion_Rules.get
  |> #conspats
  |> (fn net => Item_Net.retrieve_matching net t)
  |> not o null

datatype classification = ProducerTransformer | Consumer of term

(* used to find out if a 'unstream' is present in a term *)
fun occur_in ts ((Const (c, _)) $ t) =
    member (op =) ts c orelse occur_in ts t
  | occur_in ts (op $ (u, t)) = occur_in ts u orelse occur_in ts t
  | occur_in ts (Abs (_, _, t)) = occur_in ts t
  | occur_in _ _ = false;

fun first_depth (t1 $ _) = let val (f,d) = first_depth t1 in (f,d+1) end |
    first_depth t1 = (t1,0)

fun mk_conspat rhs ctxt =
  let
    val (f,d) = first_depth rhs
    val types = binder_types (fastype_of f)
    val (vfixes, ctxt1) = Variable.variant_fixes (replicate d "x") ctxt 
  in
    (hd o Variable.export_terms ctxt1 ctxt o single) (list_comb (f, map Free (vfixes ~~ types)))
  end

fun classify ctxt thm = case Thm.full_prop_of thm
  of (@{const Trueprop} $ (Const (@{const_name "HOL.eq"}, _) $ lhs $ rhs)) =>
    let val unstream = get_unstream ctxt in
      if occur_in unstream lhs then SOME ProducerTransformer
      else if occur_in unstream rhs then SOME (Consumer (mk_conspat rhs ctxt))
      else NONE
    end
  | _ => NONE;

fun sym thm = thm RS @{thm sym}

fun format_error ctxt thm =
  warning (Pretty.string_of (Pretty.block [
    Pretty.str "Wrong format for fusion rule: ",
    Pretty.brk 2,
    Syntax.pretty_term (Context.proof_of ctxt) (Thm.prop_of thm)]))

fun register thm NONE = (fn ctxt =>
  let
    val _ = format_error ctxt thm
  in
    ctxt
  end)
| register thm (SOME ProducerTransformer) = Fusion_Rules.map (
    map_rules (Item_Net.update (sym thm)))
| register thm (SOME (Consumer cp)) = Fusion_Rules.map (
    map_rules (Item_Net.update (sym thm)) o map_conspats (Item_Net.update (cp, thm)));

fun unregister thm NONE = (fn ctxt =>
  let
    val _ = format_error ctxt thm
  in
    ctxt
  end)
| unregister thm (SOME ProducerTransformer) = Fusion_Rules.map (
    map_rules (Item_Net.remove (sym thm)))
| unregister thm (SOME (Consumer cp)) = Fusion_Rules.map (
    map_rules (Item_Net.remove (sym thm)) o map_conspats (Item_Net.remove (cp, thm)));

fun add_fusion_rule thm ctxt = register thm (classify (Context.proof_of ctxt) thm) ctxt
fun del_fusion_rule thm ctxt = unregister thm (classify (Context.proof_of ctxt) thm) ctxt

fun add_unstream c = Fusion_Rules.map (map_unstream (insert (op =) c))
fun del_unstream c = Fusion_Rules.map (map_unstream (remove (op =) c))

(* attributes and setup *)
val fusion_add = Thm.declaration_attribute add_fusion_rule;
val fusion_del = Thm.declaration_attribute del_fusion_rule;

val _ =
  Theory.setup
   (Attrib.setup @{binding "stream_fusion"} (Attrib.add_del fusion_add fusion_del)
      "declaration of a rule for stream fusion" #>
    Global_Theory.add_thms_dynamic
      (@{binding "stream_fusion"}, Item_Net.content o #rules o Fusion_Rules.get));

val trace = Attrib.setup_config_bool @{binding "stream_fusion_trace"} (K false)

fun tracing ctxt msg = if Config.get ctxt trace then Output.tracing (msg ()) else ()

fun fusion_conv ctxt = Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps get_rules ctxt)

fun fusion_simproc ctxt ct =
  let
    val matches = match_consumer ctxt (Thm.term_of ct)
  in
    if matches then 
      let
        val _ = tracing ctxt (fn _ => Pretty.string_of (Pretty.block 
          [Pretty.str "Trying stream fusion on ",
           Pretty.brk 2,
           Syntax.pretty_term ctxt (Thm.term_of ct)]))
        val thm = fusion_conv ctxt ct
        val failed = Thm.is_reflexive thm orelse occur_in (get_unstream ctxt) (Thm.term_of (Thm.rhs_of thm))
        val _ = tracing ctxt (fn _ => Pretty.string_of (Pretty.block 
          [Pretty.str (if failed then "FAILED: " else "SUCCEEDED: "),
           Pretty.brk 2,
           Syntax.pretty_term ctxt (Thm.prop_of thm)]))
      in
        if failed then NONE else SOME thm
      end
    else NONE
  end

end;

Theory Stream_Fusion_List

(* Title: Stream_Fusion_List.thy
 Authors: Alexandra Maximova, ETH Zurich
          Andreas Lochbihler, ETH Zurich
*)

section ‹Stream fusion for finite lists›

theory Stream_Fusion_List
imports Stream_Fusion
begin

lemma map_option_mono [partial_function_mono]: (* To be moved to HOL *)
  "mono_option f  mono_option (λx. map_option g (f x))"
apply (rule monotoneI)
apply (drule (1) monotoneD)
apply (auto simp add: flat_ord_def split: option.split)
done

subsection ‹The type of generators for finite lists›

datatype ('a, 's) step = Done | is_Skip: Skip 's | is_Yield: Yield 'a 's

type_synonym ('a, 's) raw_generator = "'s  ('a,'s) step"

text ‹
  Raw generators may not end in @{const Done}, but may lead to infinitely many @{const Yield}s 
  in a row. Such generators cannot be converted to finite lists, because it corresponds to an
  infinite list. Therefore, we introduce the type of generators that always end in @{const Done}
  after finitely many steps.
›

inductive_set terminates_on :: "('a, 's) raw_generator  's set"
  for g :: "('a, 's) raw_generator"
where
  stop: "g s = Done  s  terminates_on g"
| pause: " g s = Skip s'; s'  terminates_on g   s  terminates_on g"
| unfold: " g s = Yield a s'; s'  terminates_on g   s  terminates_on g"

definition terminates :: "('a, 's) raw_generator  bool"
where "terminates g  (terminates_on g = UNIV)"

lemma terminatesI [intro?]:
  "(s. s  terminates_on g)  terminates g"
by (auto simp add: terminates_def)

lemma terminatesD:
  "terminates g  s  terminates_on g"
by (auto simp add: terminates_def)

lemma terminates_on_stop:
  "terminates_on (λ_. Done) = UNIV"
by (auto intro: terminates_on.stop)

lemma wf_terminates:
  assumes "wf R"
  and skip: "s s'. g s = Skip s'  (s',s)  R"
  and yield: "s s' a. g s = Yield a s'  (s',s)  R"
  shows "terminates g"
proof (rule terminatesI)
  fix s
  from ‹wf R show "s  terminates_on g"
  proof (induction rule: wf_induct [rule_format, consumes 1, case_names wf])
    case (wf s)
    show ?case
    proof (cases "g s")
      case (Skip s')
      hence "(s', s)  R" by (rule skip)
      hence "s'  terminates_on g" by (rule wf.IH)
      with Skip show ?thesis by (rule terminates_on.pause)
    next
      case (Yield a s')
      hence "(s', s)  R" by (rule yield)
      hence "s'  terminates_on g" by (rule wf.IH)
      with Yield show ?thesis by (rule terminates_on.unfold)
    qed (rule terminates_on.stop)
  qed
qed

context fixes g :: "('a, 's) raw_generator" begin

partial_function (option) terminates_within :: "'s  nat option" where
  "terminates_within s = (case g s of
     Done  Some 0
  | Skip s'  map_option (λn. n + 1) (terminates_within s')
  | Yield a s'  map_option (λn. n + 1) (terminates_within s'))"

lemma terminates_on_conv_dom_terminates_within:
  "terminates_on g = dom terminates_within"
proof (rule set_eqI iffI)+
  fix s
  assume "s  terminates_on g"
  hence "n. terminates_within s = Some n"
    by induction (subst terminates_within.simps, simp add: split_beta)+
  then show "s  dom terminates_within" by blast
next
  fix s
  assume "s  dom terminates_within"
  then obtain n where "terminates_within s = Some n" by blast
  then show "s  terminates_on g"
  proof (induction rule: terminates_within.raw_induct[rotated 1, consumes 1])
    case (1 terminates_within s s')
    show ?case
    proof(cases "g s")
      case Done
      thus ?thesis by (simp add: terminates_on.stop)
    next
      case (Skip s')
      hence "s'  terminates_on g" using 1 by(auto)
      thus ?thesis using g s = Skip s' by (simp add: terminates_on.pause)
    next
      case (Yield a s')
      hence "s'  terminates_on g" using 1 by(auto)
      thus ?thesis using g s = Yield a s' by (auto intro: terminates_on.unfold)
    qed
  qed
qed

end

lemma terminates_wfE:
  assumes "terminates g"
  obtains R 
  where "wf R"
    "s s'. (g s = Skip s')  (s',s)  R"
    "s a s'. (g s = Yield a s')  (s',s)  R"
proof -
  let ?R = "measure (λs. the (terminates_within g s)) :: ('a × 'a) set"
  have "wf ?R" by simp
  moreover {
    fix s s'
    assume "g s = Skip s'"
    moreover from assms have "s'  terminates_on g" by (rule terminatesD)
    then obtain n where "terminates_within g s' = Some n"
      unfolding terminates_on_conv_dom_terminates_within by (auto)
    ultimately have "the (terminates_within g s') < the (terminates_within g s)"
      by (simp add: terminates_within.simps)
    hence "(s',s)  ?R" by (auto)
  } moreover {
    fix s s' a
    assume 2: "g s = Yield a s'"
    moreover from assms have "s'  terminates_on g" by (rule terminatesD)
    then obtain n where "terminates_within g s' = Some n"
      unfolding terminates_on_conv_dom_terminates_within by (auto)
    ultimately have "(s',s)  ?R"
      by simp (subst terminates_within.simps, simp add: split_beta)
  } ultimately 
  show thesis by (rule that)
qed

typedef ('a,'s) generator = "{g :: ('a,'s) raw_generator. terminates g}"
  morphisms generator Generator
proof
  show "(λ_. Done)  ?generator"
    by (simp add: terminates_on_stop terminates_def)
qed

setup_lifting type_definition_generator

subsection ‹Conversion to @{typ "'a list"}

context fixes g :: "('a, 's) generator" begin

function unstream :: "'s  'a list"
where
  "unstream s = (case generator g s of
     Done  []
   | Skip s'  unstream s'
   | Yield x s'  x # unstream s')"
by pat_completeness auto
termination
proof -
  have "terminates (generator g)" using generator[of g] by simp
  thus ?thesis by(rule terminates_wfE)(erule "termination")
qed

lemma unstream_simps [simp]:
  "generator g s = Done  unstream s = []"
  "generator g s = Skip s'  unstream s = unstream s'"
  "generator g s = Yield x s'  unstream s = x # unstream s'"
by(simp_all)

declare unstream.simps[simp del]

function force :: "'s  ('a × 's) option"
where
  "force s = (case generator g s of Done  None 
     | Skip s'  force s'
     | Yield x s'  Some (x, s'))"
by pat_completeness auto
termination
proof -
  have "terminates (generator g)" using generator[of g] by simp
  thus ?thesis by(rule terminates_wfE)(rule "termination")
qed

lemma force_simps [simp]:
  "generator g s = Done  force s = None"
  "generator g s = Skip s'  force s = force s'"
  "generator g s = Yield x s'  force s = Some (x, s')"
by(simp_all)

declare force.simps[simp del]

lemma unstream_force_None [simp]: "force s = None  unstream s = []"
proof(induction s rule: force.induct)
  case (1 s)
  thus ?case by(cases "generator g s") simp_all
qed

lemma unstream_force_Some [simp]: "force s = Some (x, s')  unstream s = x # unstream s'"
proof(induction s rule: force.induct)
  case (1 s)
  thus ?case by(cases "generator g s") simp_all
qed

end

setup ‹Context.theory_map (Stream_Fusion.add_unstream @{const_name unstream})

subsection ‹Producers›

subsubsection ‹Conversion to streams›

fun stream_raw :: "'a list  ('a, 'a list) step"
where
  "stream_raw [] = Done"
| "stream_raw (x # xs) = Yield x xs"

lemma terminates_stream_raw: "terminates stream_raw"
proof (rule terminatesI)
  fix s :: "'a list"
  show "s  terminates_on stream_raw"
    by(induction s)(auto intro: terminates_on.intros)
qed

lift_definition stream :: "('a, 'a list) generator" is "stream_raw" by(rule terminates_stream_raw)

lemma unstream_stream: "unstream stream xs = xs"
by(induction xs)(auto simp add: stream.rep_eq)

subsubsection @{const replicate}

fun replicate_raw :: "'a  ('a, nat) raw_generator"
where
  "replicate_raw a 0 = Done"
| "replicate_raw a (Suc n) = Yield a n"
 
lemma terminates_replicate_raw: "terminates (replicate_raw a)"
proof (rule terminatesI)
  fix s :: "nat"
  show "s  terminates_on (replicate_raw a)"
    by(induction s)(auto intro: terminates_on.intros)
qed

lift_definition replicate_prod :: "'a  ('a, nat) generator" is "replicate_raw"
by(rule terminates_replicate_raw)

lemma unstream_replicate_prod [stream_fusion]: "unstream (replicate_prod x) n = replicate n x"
by(induction n)(simp_all add: replicate_prod.rep_eq)

subsubsection @{const upt}

definition upt_raw :: "nat  (nat, nat) raw_generator"
where "upt_raw n m = (if m  n then Done else Yield m (Suc m))"

lemma terminates_upt_raw: "terminates (upt_raw n)"
proof (rule terminatesI)
  fix s :: nat
  show "s  terminates_on (upt_raw n)"
    by(induction "n-s" arbitrary: s rule: nat.induct)(auto 4 3 simp add: upt_raw_def intro: terminates_on.intros)
qed

lift_definition upt_prod :: "nat  (nat, nat) generator" is "upt_raw" by(rule terminates_upt_raw)

lemma unstream_upt_prod [stream_fusion]: "unstream (upt_prod n) m = upt m n"
by(induction "n-m" arbitrary: n m)(simp_all add: upt_prod.rep_eq upt_conv_Cons upt_raw_def unstream.simps)


subsubsection @{const upto}

definition upto_raw :: "int  (int, int) raw_generator"
where "upto_raw n m = (if m  n then Yield m (m + 1) else Done)"

lemma terminates_upto_raw: "terminates (upto_raw n)"
proof (rule terminatesI)
  fix s :: int
  show "s  terminates_on (upto_raw n)"
    by(induction "nat(n-s+1)" arbitrary: s)(auto 4 3 simp add: upto_raw_def intro: terminates_on.intros)
qed

lift_definition upto_prod :: "int  (int, int) generator" is "upto_raw" by (rule terminates_upto_raw)

lemma unstream_upto_prod [stream_fusion]: "unstream (upto_prod n) m = upto m n"
by(induction "nat (n - m + 1)" arbitrary: m)(simp_all add: upto_prod.rep_eq upto.simps upto_raw_def)

subsubsection @{term "[]"}

lift_definition Nil_prod :: "('a, unit) generator" is "λ_. Done"
by(auto simp add: terminates_def intro: terminates_on.intros)

lemma generator_Nil_prod: "generator Nil_prod = (λ_. Done)"
by(fact Nil_prod.rep_eq)

lemma unstream_Nil_prod [stream_fusion]: "unstream Nil_prod () = []"
by(simp add: generator_Nil_prod)

subsection ‹Consumers›

subsubsection @{const nth}

context fixes g :: "('a, 's) generator" begin

definition nth_cons :: "'s  nat  'a" 
where [stream_fusion]: "nth_cons s n = unstream g s ! n"

lemma nth_cons_code [code]:
  "nth_cons s n =
  (case generator g s of Done => undefined n
    | Skip s' => nth_cons s' n
    | Yield x s' => (case n of 0 => x | Suc n' => nth_cons s' n'))"
by(cases "generator g s")(simp_all add: nth_cons_def nth_def split: nat.split)

end

subsubsection @{term length}

context fixes g :: "('a, 's) generator" begin

definition length_cons :: "'s  nat"
where "length_cons s = length (unstream g s)"

lemma length_cons_code [code]:
  "length_cons s =
    (case generator g s of
      Done  0
    | Skip s'  length_cons s'
    | Yield a s'  1 + length_cons s')"
by(cases "generator g s")(simp_all add: length_cons_def)

definition gen_length_cons :: "nat  's  nat"
where "gen_length_cons n s = n + length (unstream g s)"

lemma gen_length_cons_code [code]:
  "gen_length_cons n s = (case generator g s of
     Done  n | Skip s'  gen_length_cons n s' | Yield a s'  gen_length_cons (Suc n) s')"
by(simp add: gen_length_cons_def split: step.split)

lemma unstream_gen_length [stream_fusion]: "gen_length_cons 0 s = length (unstream g s)"
by(simp add: gen_length_cons_def)

lemma unstream_gen_length2 [stream_fusion]: "gen_length_cons n s = List.gen_length n (unstream g s)"
by(simp add: List.gen_length_def gen_length_cons_def)

end

subsubsection @{const foldr}

context 
  fixes g :: "('a, 's) generator"
  and f :: "'a  'b  'b"
  and z :: "'b"
begin

definition foldr_cons :: "'s  'b"
where [stream_fusion]: "foldr_cons s = foldr f (unstream g s) z"

lemma foldr_cons_code [code]:
  "foldr_cons s =
    (case generator g s of
      Done  z
    | Skip s'  foldr_cons s'
    | Yield a s'  f a (foldr_cons s'))"
by(cases "generator g s")(simp_all add: foldr_cons_def)

end

subsubsection @{const foldl}

context
  fixes g :: "('b, 's) generator"
  and f :: "'a  'b  'a"
begin

definition foldl_cons :: "'a  's  'a"
where [stream_fusion]: "foldl_cons z s = foldl f z (unstream g s)"

lemma foldl_cons_code [code]:
  "foldl_cons z s =
    (case generator g s of
      Done  z
    | Skip s'  foldl_cons z s'
    | Yield a s'  foldl_cons (f z a) s')"
by (cases "generator g s")(simp_all add: foldl_cons_def)

end

subsubsection @{const fold}

context
  fixes g :: "('a, 's) generator"
  and f :: "'a  'b  'b"
begin

definition fold_cons :: "'b  's  'b"
where [stream_fusion]: "fold_cons z s = fold f (unstream g s) z"

lemma fold_cons_code [code]:
  "fold_cons z s =
    (case generator g s of
      Done  z
    | Skip s'  fold_cons z s'
    | Yield a s'  fold_cons (f a z) s')"
by (cases "generator g s")(simp_all add: fold_cons_def)

end

subsubsection @{const List.null}

definition null_cons :: "('a, 's) generator  's  bool"
where [stream_fusion]: "null_cons g s = List.null (unstream g s)"

lemma null_cons_code [code]:
  "null_cons g s = (case generator g s of Done  True | Skip s'  null_cons g s' | Yield _ _  False)"
by(cases "generator g s")(simp_all add: null_cons_def null_def)

subsubsection @{const hd}

context fixes g :: "('a, 's) generator" begin

definition hd_cons :: "'s  'a"
where [stream_fusion]: "hd_cons s = hd (unstream g s)"

lemma hd_cons_code [code]:
  "hd_cons s =
    (case generator g s of
      Done  undefined
    | Skip s'  hd_cons s'
    | Yield a s'  a)"
by (cases "generator g s")(simp_all add: hd_cons_def hd_def)

end

subsubsection @{const last}

context fixes g :: "('a, 's) generator" begin

definition last_cons :: "'a option  's  'a"
where "last_cons x s = (if unstream g s = [] then the x else last (unstream g s))"

lemma last_cons_code [code]:
  "last_cons x s =
  (case generator g s of Done  the x
             | Skip s'  last_cons x s'
             | Yield a s'  last_cons (Some a) s')"
by (cases "generator g s")(simp_all add: last_cons_def)

lemma unstream_last_cons [stream_fusion]: "last_cons None s = last (unstream g s)"
by (simp add: last_cons_def last_def option.the_def)

end

subsubsection @{const sum_list}

context fixes g :: "('a :: monoid_add, 's) generator" begin

definition sum_list_cons :: "'s  'a"
where [stream_fusion]: "sum_list_cons s = sum_list (unstream g s)"

lemma sum_list_cons_code [code]:
  "sum_list_cons s =
    (case generator g s of
      Done  0
    | Skip s'  sum_list_cons s'
    | Yield a s'  a + sum_list_cons s')"
by (cases "generator g s")(simp_all add: sum_list_cons_def)

end

subsubsection @{const list_all2}

context
  fixes g :: "('a, 's1) generator"
  and h :: "('b, 's2) generator"
  and P :: "'a  'b  bool"
begin

definition list_all2_cons :: "'s1  's2  bool"
where [stream_fusion]: "list_all2_cons sg sh = list_all2 P (unstream g sg) (unstream h sh)"

definition list_all2_cons1 :: "'a  's1  's2  bool"
where "list_all2_cons1 x sg' sh = list_all2 P (x # unstream g sg') (unstream h sh)"

lemma list_all2_cons_code [code]:
  "list_all2_cons sg sh = 
  (case generator g sg of
     Done  null_cons h sh
   | Skip sg'  list_all2_cons sg' sh
   | Yield a sg'  list_all2_cons1 a sg' sh)"
by(simp split: step.split add: list_all2_cons_def null_cons_def List.null_def list_all2_cons1_def)

lemma list_all2_cons1_code [code]:
  "list_all2_cons1 x sg' sh = 
  (case generator h sh of
     Done  False
   | Skip sh'  list_all2_cons1 x sg' sh'
   | Yield y sh'  P x y  list_all2_cons sg' sh')"
by(simp split: step.split add: list_all2_cons_def null_cons_def List.null_def list_all2_cons1_def)

end

subsubsection @{const list_all}

context
  fixes g :: "('a, 's) generator"
  and P :: "'a  bool"
begin

definition list_all_cons :: "'s  bool"
where [stream_fusion]: "list_all_cons s = list_all P (unstream g s)"

lemma list_all_cons_code [code]:
  "list_all_cons s 
  (case generator g s of
    Done  True | Skip s'  list_all_cons s' | Yield x s'  P x  list_all_cons s')"
by(simp add: list_all_cons_def split: step.split)

end

subsubsection @{const ord.lexordp}

context ord begin

definition lexord_fusion :: "('a, 's1) generator  ('a, 's2) generator  's1  's2  bool"
where [code del]: "lexord_fusion g1 g2 s1 s2 = ord_class.lexordp (unstream g1 s1) (unstream g2 s2)"

definition lexord_eq_fusion :: "('a, 's1) generator  ('a, 's2) generator  's1  's2  bool"
where [code del]: "lexord_eq_fusion g1 g2 s1 s2 = lexordp_eq (unstream g1 s1) (unstream g2 s2)"

lemma lexord_fusion_code:
  "lexord_fusion g1 g2 s1 s2 
  (case generator g1 s1 of
     Done  ¬ null_cons g2 s2
   | Skip s1'  lexord_fusion g1 g2 s1' s2
   | Yield x s1'  
     (case force g2 s2 of
        None  False
      | Some (y, s2')  x < y  ¬ y < x  lexord_fusion g1 g2 s1' s2'))"
unfolding lexord_fusion_def
by(cases "generator g1 s1" "force g2 s2" rule: step.exhaust[case_product option.exhaust])(auto simp add: null_cons_def null_def)

lemma lexord_eq_fusion_code:
  "lexord_eq_fusion g1 g2 s1 s2 
  (case generator g1 s1 of
     Done  True
   | Skip s1'  lexord_eq_fusion g1 g2 s1' s2
   | Yield x s1' 
     (case force g2 s2 of
        None  False
      | Some (y, s2')  x < y  ¬ y < x  lexord_eq_fusion g1 g2 s1' s2'))"
unfolding lexord_eq_fusion_def
by(cases "generator g1 s1" "force g2 s2" rule: step.exhaust[case_product option.exhaust]) auto

end

lemmas [code] =
  lexord_fusion_code ord.lexord_fusion_code
  lexord_eq_fusion_code ord.lexord_eq_fusion_code

lemmas [stream_fusion] =
  lexord_fusion_def ord.lexord_fusion_def
  lexord_eq_fusion_def ord.lexord_eq_fusion_def

subsection ‹Transformers›

subsubsection @{const map}

definition map_raw :: "('a  'b)  ('a, 's) raw_generator  ('b, 's) raw_generator"
where
  "map_raw f g s = (case g s of
     Done  Done
   | Skip s'  Skip s'
   | Yield a s'  Yield (f a) s')"

lemma terminates_map_raw: 
  assumes "terminates g"
  shows "terminates (map_raw f g)"
proof (rule terminatesI)
  fix s
  from assms
  have "s  terminates_on g" by (simp add: terminates_def)
  then show "s  terminates_on (map_raw f g)"
    by (induction s)(auto intro: terminates_on.intros simp add: map_raw_def)
qed

lift_definition map_trans :: "('a  'b)  ('a, 's) generator  ('b, 's) generator" is "map_raw"
by (rule terminates_map_raw)

lemma unstream_map_trans [stream_fusion]: "unstream (map_trans f g) s = map f (unstream g s)"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  show ?case using "1.IH" by (cases "generator g s")(simp_all add: map_trans.rep_eq map_raw_def)
qed

subsubsection @{const drop}

fun drop_raw :: "('a, 's) raw_generator  ('a, (nat × 's)) raw_generator"
where
  "drop_raw g (n, s) = (case g s of
     Done  Done | Skip s'  Skip (n, s')
   | Yield a s'  (case n of 0  Yield a (0, s') | Suc n  Skip (n, s')))"

lemma terminates_drop_raw:
  assumes "terminates g"
  shows "terminates (drop_raw g)"
proof (rule terminatesI)
  fix st :: "nat × 'a"
  obtain n s where "st = (n, s)" by(cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  thus "st  terminates_on (drop_raw g)" unfolding st = (n, s)
    apply(induction arbitrary: n)
    apply(case_tac [!] n)
    apply(auto intro: terminates_on.intros)
    done
qed

lift_definition drop_trans :: "('a, 's) generator  ('a, nat × 's) generator" is "drop_raw"
by (rule terminates_drop_raw)

lemma unstream_drop_trans [stream_fusion]: "unstream (drop_trans g) (n, s) = drop n (unstream g s)"
proof (induction s arbitrary: n taking: g rule: unstream.induct)
  case (1 s)
  show ?case using "1.IH"(1)[of _ n] "1.IH"(2)[of _ _ n] "1.IH"(2)[of _ _ "n - 1"]
    by(cases "generator g s" "n" rule: step.exhaust[case_product nat.exhaust])
      (simp_all add: drop_trans.rep_eq)
qed

subsubsection @{const dropWhile}

fun dropWhile_raw :: "('a  bool)  ('a, 's) raw_generator  ('a, bool × 's) raw_generator"
  ― ‹Boolean flag indicates whether we are still in dropping phase›
where
  "dropWhile_raw P g (True, s) = (case g s of
     Done  Done | Skip s'  Skip (True, s')
   | Yield a s'  (if P a then Skip (True, s') else Yield a (False, s')))"
| "dropWhile_raw P g (False, s) = (case g s of
     Done  Done | Skip s'  Skip (False, s') | Yield a s'  Yield a (False, s'))"

lemma terminates_dropWhile_raw:
  assumes "terminates g"
  shows "terminates (dropWhile_raw P g)"
proof (rule terminatesI)
  fix st :: "bool × 'a"
  obtain b s where "st = (b, s)" by (cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  then show "st  terminates_on (dropWhile_raw P g)" unfolding st = (b, s)
  proof (induction s arbitrary: b)
    case (stop s b)
    then show ?case by (cases b)(simp_all add: terminates_on.stop)
  next
    case (pause s s' b)
    then show ?case by (cases b)(simp_all add: terminates_on.pause)
  next
    case (unfold s a s' b)
    then show ?case
      by(cases b)(cases "P a", auto intro: terminates_on.pause terminates_on.unfold)
   qed
qed

lift_definition dropWhile_trans :: "('a  bool)  ('a, 's) generator  ('a, bool × 's) generator"
is "dropWhile_raw" by (rule terminates_dropWhile_raw)

lemma unstream_dropWhile_trans_False:
  "unstream (dropWhile_trans P g) (False, s) = unstream g s"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case by (cases "generator g s")(simp_all add: dropWhile_trans.rep_eq)
qed

lemma unstream_dropWhile_trans [stream_fusion]:
  "unstream (dropWhile_trans P g) (True, s) = dropWhile P (unstream g s)"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case
  proof(cases "generator g s")
    case (Yield a s')
    then show ?thesis using "1.IH"(2) unstream_dropWhile_trans_False
      by (cases "P a")(simp_all add: dropWhile_trans.rep_eq)
  qed(simp_all add: dropWhile_trans.rep_eq)
qed

subsubsection @{const take}

fun take_raw :: "('a, 's) raw_generator  ('a, (nat × 's)) raw_generator"
where
  "take_raw g (0, s) = Done"
| "take_raw g (Suc n, s) = (case g s of 
     Done  Done | Skip s'  Skip (Suc n, s') | Yield a s'  Yield a (n, s'))"

lemma terminates_take_raw:
  assumes "terminates g"
  shows "terminates (take_raw g)"
proof (rule terminatesI)
  fix st :: "nat × 'a"
  obtain n s where "st = (n, s)" by(cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  thus "st  terminates_on (take_raw g)" unfolding st = (n, s)
    apply(induction s arbitrary: n)
    apply(case_tac [!] n)
    apply(auto intro: terminates_on.intros)
    done
qed

lift_definition take_trans :: "('a, 's) generator  ('a, nat × 's) generator" is "take_raw"
by (rule terminates_take_raw)

lemma unstream_take_trans [stream_fusion]: "unstream (take_trans g) (n, s) = take n (unstream g s)" 
proof (induction s arbitrary: n taking: g rule: unstream.induct)
  case (1 s)
  show ?case using "1.IH"(1)[of _ n] "1.IH"(2)
    by(cases "generator g s" n rule: step.exhaust[case_product nat.exhaust])
      (simp_all add: take_trans.rep_eq)
qed

subsubsection @{const takeWhile}

definition takeWhile_raw :: "('a  bool)  ('a, 's) raw_generator  ('a, 's) raw_generator"
where
  "takeWhile_raw P g s = (case g s of
     Done  Done | Skip s'  Skip s' | Yield a s'  if P a then Yield a s' else Done)"

lemma terminates_takeWhile_raw: 
  assumes "terminates g"
  shows "terminates (takeWhile_raw P g)"
proof (rule terminatesI)
  fix s
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  thus "s  terminates_on (takeWhile_raw P g)"
  proof (induction s rule: terminates_on.induct)
    case (unfold s a s')
    then show ?case by(cases "P a")(auto simp add: takeWhile_raw_def intro: terminates_on.intros)
  qed(auto intro: terminates_on.intros simp add: takeWhile_raw_def)
qed

lift_definition takeWhile_trans :: "('a  bool)  ('a, 's) generator  ('a, 's) generator"
is "takeWhile_raw" by (rule terminates_takeWhile_raw)

lemma unstream_takeWhile_trans [stream_fusion]:
  "unstream (takeWhile_trans P g) s = takeWhile P (unstream g s)"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case by(cases "generator g s")(simp_all add: takeWhile_trans.rep_eq takeWhile_raw_def)
qed

subsubsection@{const append}

fun append_raw :: "('a, 'sg) raw_generator  ('a, 'sh) raw_generator  'sh  ('a, 'sg + 'sh) raw_generator"
where
  "append_raw g h sh_start (Inl sg) = (case g sg of
     Done  Skip (Inr sh_start) | Skip sg'  Skip (Inl sg') | Yield a sg'  Yield a (Inl sg'))"
| "append_raw g h sh_start (Inr sh) = (case h sh of
     Done  Done | Skip sh'  Skip (Inr sh') | Yield a sh'  Yield a (Inr sh'))"

lemma terminates_on_append_raw_Inr: 
  assumes "terminates h"
  shows "Inr sh  terminates_on (append_raw g h sh_start)"
proof -
  from assms have "sh  terminates_on h" by (simp add: terminates_def)
  thus ?thesis by(induction sh)(auto intro: terminates_on.intros)
qed

lemma terminates_append_raw:
  assumes "terminates g" "terminates h"
  shows "terminates (append_raw g h sh_start)"
proof (rule terminatesI)
  fix s
  show "s  terminates_on (append_raw g h sh_start)"
  proof (cases s)
    case (Inl sg)
    from ‹terminates g have "sg  terminates_on g" by (simp add: terminates_def)
    thus "s  terminates_on (append_raw g h sh_start)" unfolding Inl
      by induction(auto intro: terminates_on.intros terminates_on_append_raw_Inr[OF ‹terminates h])
  qed(simp add: terminates_on_append_raw_Inr[OF ‹terminates h])
qed

lift_definition append_trans :: "('a, 'sg) generator  ('a, 'sh) generator  'sh  ('a, 'sg + 'sh) generator"
is "append_raw" by (rule terminates_append_raw)

lemma unstream_append_trans_Inr: "unstream (append_trans g h sh) (Inr sh') = unstream h sh'"
proof (induction sh' taking: h rule: unstream.induct)
  case (1 sh')
  then show ?case by (cases "generator h sh'")(simp_all add: append_trans.rep_eq)
qed

lemma unstream_append_trans [stream_fusion]:
  "unstream (append_trans g h sh) (Inl sg) = append (unstream g sg) (unstream h sh)"
proof(induction sg taking: g rule: unstream.induct)
  case (1 sg)
  then show ?case using unstream_append_trans_Inr 
    by (cases "generator g sg")(simp_all add: append_trans.rep_eq)
qed

subsubsection@{const filter}

definition filter_raw :: "('a  bool)  ('a, 's) raw_generator  ('a, 's) raw_generator"
where 
  "filter_raw P g s = (case g s of
     Done  Done | Skip s'  Skip s' | Yield a s'  if P a then Yield a s' else Skip s')"

lemma terminates_filter_raw:
  assumes "terminates g"
  shows "terminates (filter_raw P g)"
proof (rule terminatesI)
  fix s
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  thus "s  terminates_on (filter_raw P g)"
  proof(induction s)
    case (unfold s a s')
    thus ?case
      by(cases "P a")(auto intro: terminates_on.intros simp add: filter_raw_def)
  qed(auto intro: terminates_on.intros simp add: filter_raw_def)
qed

lift_definition filter_trans :: "('a  bool)  ('a,'s) generator  ('a,'s) generator"
is "filter_raw" by (rule terminates_filter_raw)

lemma unstream_filter_trans [stream_fusion]: "unstream (filter_trans P g) s = filter P (unstream g s)"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case by(cases "generator g s")(simp_all add: filter_trans.rep_eq filter_raw_def)
qed

subsubsection@{const zip}

fun zip_raw :: "('a, 'sg) raw_generator  ('b, 'sh) raw_generator  ('a × 'b, 'sg × 'sh × 'a option) raw_generator"
  ― ‹We search first the left list for the next element and cache it in the @{typ "'a option"}
        part of the state once we found one›
where
  "zip_raw g h (sg, sh, None) = (case g sg of
      Done  Done | Skip sg'  Skip (sg', sh, None) | Yield a sg'  Skip (sg', sh, Some a))"
| "zip_raw g h (sg, sh, Some a) = (case h sh of
      Done  Done | Skip sh'  Skip (sg, sh', Some a) | Yield b sh'  Yield (a, b) (sg, sh', None))"

lemma terminates_zip_raw: 
  assumes "terminates g" "terminates h"
  shows "terminates (zip_raw g h)"
proof (rule terminatesI)
  fix s :: "'a × 'c × 'b option"
  obtain sg sh m where "s = (sg, sh, m)" by(cases s)
  show "s  terminates_on (zip_raw g h)" 
  proof(cases m)
    case None
    from ‹terminates g have "sg  terminates_on g" by (simp add: terminates_def)
    then show ?thesis unfolding s = (sg, sh, m) None
    proof (induction sg arbitrary: sh)
      case (unfold sg a sg')
      from ‹terminates h have "sh  terminates_on h" by (simp add: terminates_def)
      hence "(sg', sh, Some a)  terminates_on (zip_raw g h)"
        by induction(auto intro: terminates_on.intros unfold.IH)
      thus ?case using unfold.hyps by(auto intro: terminates_on.pause)
    qed(simp_all add: terminates_on.stop terminates_on.pause)
  next
    case (Some a')
    from ‹terminates h have "sh  terminates_on h" by (simp add: terminates_def)
    thus ?thesis unfolding s = (sg, sh, m) Some
    proof (induction sh arbitrary: sg a')
      case (unfold sh b sh')
      from ‹terminates g have "sg  terminates_on g" by (simp add: terminates_def)
      hence "(sg, sh', None)  terminates_on (zip_raw g h)"
        by induction(auto intro: terminates_on.intros unfold.IH)
      thus ?case using unfold.hyps by(auto intro: terminates_on.unfold)
    qed(simp_all add: terminates_on.stop terminates_on.pause)
  qed
qed

lift_definition zip_trans :: "('a, 'sg) generator  ('b, 'sh) generator  ('a × 'b,'sg × 'sh × 'a option) generator"
is "zip_raw" by (rule terminates_zip_raw)

lemma unstream_zip_trans [stream_fusion]:
  "unstream (zip_trans g h) (sg, sh, None) = zip (unstream g sg) (unstream h sh)"        
proof (induction sg arbitrary: sh taking: g rule: unstream.induct)
  case (1 sg)
  then show ?case
  proof (cases "generator g sg")
    case (Yield a sg')
    note IH = "1.IH"(2)[OF Yield]
    have "unstream (zip_trans g h) (sg', sh, Some a) = zip (a # (unstream g sg')) (unstream h sh)"
    proof(induction sh taking: h rule: unstream.induct)
      case (1 sh)
      then show ?case using IH by(cases "generator h sh")(simp_all add: zip_trans.rep_eq)
    qed
    then show ?thesis using Yield by (simp add: zip_trans.rep_eq)
  qed(simp_all add: zip_trans.rep_eq)
qed

subsubsection @{const tl}

fun tl_raw :: "('a, 'sg) raw_generator  ('a, bool × 'sg) raw_generator"
  ― ‹The Boolean flag stores whether we have already skipped the first element›
where
  "tl_raw g (False, sg) = (case g sg of
      Done  Done | Skip sg'  Skip (False, sg') | Yield a sg'  Skip (True,sg'))"
| "tl_raw g (True, sg) = (case g sg of
      Done  Done | Skip sg'  Skip (True, sg') | Yield a sg'  Yield a (True, sg'))"

lemma terminates_tl_raw: 
  assumes "terminates g"
  shows "terminates (tl_raw g)"
proof (rule terminatesI)
  fix s :: "bool × 'a"
  obtain b sg where "s = (b, sg)" by(cases s)
  { fix sg
    from assms have "sg  terminates_on g" by(simp add: terminates_def)
    hence "(True, sg)  terminates_on (tl_raw g)"
      by(induction sg)(auto intro: terminates_on.intros) }
  moreover from assms have "sg  terminates_on g" by(simp add: terminates_def)
  hence "(False, sg)  terminates_on (tl_raw g)"
    by(induction sg)(auto intro: terminates_on.intros calculation)
  ultimately show "s  terminates_on (tl_raw g)" using s = (b, sg)
    by(cases b) simp_all
qed

lift_definition tl_trans :: "('a, 'sg) generator  ('a, bool × 'sg) generator"
is "tl_raw" by(rule terminates_tl_raw)

lemma unstream_tl_trans_True: "unstream (tl_trans g) (True, s) = unstream g s"
proof(induction s taking: g rule: unstream.induct)
  case (1 s)
  show ?case using "1.IH" by (cases "generator g s")(simp_all add: tl_trans.rep_eq)
qed

lemma unstream_tl_trans [stream_fusion]: "unstream (tl_trans g) (False, s) = tl (unstream g s)"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case using unstream_tl_trans_True
    by (cases "generator g s")(simp_all add: tl_trans.rep_eq)
qed

subsubsection @{const butlast}

fun butlast_raw :: "('a, 's) raw_generator  ('a, 'a option × 's) raw_generator"
  ― ‹The @{typ "'a option"} caches the previous element we have seen›
where
  "butlast_raw g (None,s) = (case g s of
     Done  Done | Skip s'  Skip (None, s') | Yield a s'  Skip (Some a, s'))"
| "butlast_raw g (Some b, s) = (case g s of
     Done  Done | Skip s'  Skip (Some b, s') | Yield a s'  Yield b (Some a, s'))"

lemma terminates_butlast_raw:
  assumes "terminates g"
  shows "terminates (butlast_raw g)"
proof (rule terminatesI)
  fix st :: "'b option × 'a"
  obtain ma s where "st = (ma,s)" by (cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  then show "st  terminates_on (butlast_raw g)" unfolding st = (ma, s)
    apply(induction s arbitrary: ma)
    apply(case_tac [!] ma)
    apply(auto intro: terminates_on.intros)
    done
qed

lift_definition butlast_trans :: "('a,'s) generator  ('a, 'a option × 's) generator"
is "butlast_raw" by (rule terminates_butlast_raw)

lemma unstream_butlast_trans_Some:
  "unstream (butlast_trans g) (Some b,s) = butlast (b # (unstream g s))"
proof (induction s arbitrary: b taking: g rule: unstream.induct)
  case (1 s)
  then show ?case by (cases "generator g s")(simp_all add: butlast_trans.rep_eq)
qed

lemma unstream_butlast_trans [stream_fusion]:
  "unstream (butlast_trans g) (None, s) = butlast (unstream g s)"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case using 1 unstream_butlast_trans_Some[of g]
    by (cases "generator g s")(simp_all add: butlast_trans.rep_eq)
qed

subsubsection @{const concat}

text ‹
  We only do the easy version here where
  the generator has type @{typ "('a list,'s) generator"}, not @{typ "(('a, 'si) generator, 's) generator"}

fun concat_raw :: "('a list, 's) raw_generator  ('a, 'a list × 's) raw_generator"
where
  "concat_raw g ([], s) = (case g s of
     Done  Done | Skip s'  Skip ([], s') | Yield xs s'  Skip (xs, s'))"
| "concat_raw g (x # xs, s) = Yield x (xs, s)"

lemma terminates_concat_raw: 
  assumes "terminates g"
  shows "terminates (concat_raw g)"
proof (rule terminatesI)
  fix st :: "'b list × 'a"
  obtain xs s where "st = (xs, s)" by (cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  then show "st  terminates_on (concat_raw g)" unfolding st = (xs, s)
  proof (induction s arbitrary: xs)
    case (stop s xs)
    then show ?case by (induction xs)(auto intro: terminates_on.stop terminates_on.unfold)
  next
    case (pause s s' xs)
    then show ?case by (induction xs)(auto intro: terminates_on.pause terminates_on.unfold)
  next
    case (unfold s a s' xs)
    then show ?case by (induction xs)(auto intro: terminates_on.pause terminates_on.unfold)
  qed
qed

lift_definition concat_trans :: "('a list, 's) generator  ('a, 'a list × 's) generator"
is "concat_raw" by (rule terminates_concat_raw)

lemma unstream_concat_trans_gen: "unstream (concat_trans g) (xs, s) = xs @ (concat (unstream g s))"
proof (induction s arbitrary: xs taking: g rule: unstream.induct)
  case (1 s)
  then show "unstream (concat_trans g) (xs, s) = xs @ (concat (unstream g s))"
  proof (cases "generator g s")
    case Done
    then show ?thesis by (induction xs)(simp_all add: concat_trans.rep_eq)
  next
    case (Skip s')
    then show ?thesis using "1.IH"(1)[of s' Nil]
      by (induction xs)(simp_all add: concat_trans.rep_eq)
  next
    case (Yield a s')
    then show ?thesis using "1.IH"(2)[of a s' a]
      by (induction xs)(simp_all add: concat_trans.rep_eq)
  qed
qed

lemma unstream_concat_trans [stream_fusion]:
  "unstream (concat_trans g) ([], s) = concat (unstream g s)"
by(simp only: unstream_concat_trans_gen append_Nil)

subsubsection @{const splice}

datatype ('a, 'b) splice_state = Left 'a 'b | Right 'a 'b | Left_only 'a | Right_only 'b

fun splice_raw :: "('a, 'sg) raw_generator  ('a, 'sh) raw_generator  ('a, ('sg, 'sh) splice_state) raw_generator"
where
  "splice_raw g h (Left_only sg) = (case g sg of
     Done  Done | Skip sg'  Skip (Left_only sg') | Yield a sg'  Yield a (Left_only sg'))"
| "splice_raw g h (Left sg sh) = (case g sg of
     Done  Skip (Right_only sh) | Skip sg'  Skip (Left sg' sh) | Yield a sg'  Yield a (Right sg' sh))"
| "splice_raw g h (Right_only sh) = (case h sh of
     Done  Done | Skip sh'  Skip (Right_only sh') | Yield a sh'  Yield a (Right_only sh'))"
| "splice_raw g h (Right sg sh) = (case h sh of
     Done  Skip (Left_only sg) | Skip sh'  Skip (Right sg sh') | Yield a sh'  Yield a (Left sg sh'))"

lemma terminates_splice_raw: 
  assumes g: "terminates g" and h: "terminates h"
  shows "terminates (splice_raw g h)"
proof (rule terminatesI)
  fix s
  { fix sg
    from g have "sg  terminates_on g" by (simp add: terminates_def)
    hence "Left_only sg  terminates_on (splice_raw g h)"
      by induction(auto intro: terminates_on.intros)
  } moreover {
    fix sh
    from h have "sh  terminates_on h" by (simp add: terminates_def)
    hence "Right_only sh  terminates_on (splice_raw g h)"
      by induction(auto intro: terminates_on.intros)
  } moreover {
    fix sg sh
    from g have "sg  terminates_on g" by (simp add: terminates_def)
    hence "Left sg sh  terminates_on (splice_raw g h)"
    proof (induction sg arbitrary: sh)
      case (unfold sg a sg')
      from h have "sh  terminates_on h" by (simp add: terminates_def)
      hence "Right sg' sh  terminates_on (splice_raw g h)"
        by induction(auto intro: terminates_on.intros unfold.IH calculation)
      thus ?case using unfold.hyps by (auto intro: terminates_on.unfold)
    qed(auto intro: terminates_on.intros calculation)
  } moreover {
    fix sg sh
    from h have "sh  terminates_on h" by (simp add: terminates_def)
    hence "Right sg sh  terminates_on (splice_raw g h)"
      by(induction sh arbitrary: sg)(auto intro: terminates_on.intros calculation) }
  ultimately show "s  terminates_on (splice_raw g h)" by(cases s)(simp_all)
qed

lift_definition splice_trans :: "('a, 'sg) generator  ('a, 'sh) generator  ('a, ('sg, 'sh) splice_state) generator"
is "splice_raw" by (rule terminates_splice_raw)

lemma unstream_splice_trans_Right_only: "unstream (splice_trans g h) (Right_only sh) = unstream h sh" 
proof (induction sh taking: h rule: unstream.induct)
  case (1 sh)
  then show ?case by (cases "generator h sh")(simp_all add: splice_trans.rep_eq)
qed

lemma unstream_splice_trans_Left_only: "unstream (splice_trans g h) (Left_only sg) = unstream g sg"
proof (induction sg taking: g rule: unstream.induct)
  case (1 sg)
  then show ?case by (cases "generator g sg")(simp_all add: splice_trans.rep_eq)
qed

lemma unstream_splice_trans [stream_fusion]:
  "unstream (splice_trans g h) (Left sg sh) = splice (unstream g sg) (unstream h sh)"
proof (induction sg arbitrary: sh taking: g rule: unstream.induct)
  case (1 sg sh)
  then show ?case
  proof (cases "generator g sg")
    case Done
    with unstream_splice_trans_Right_only[of g h]
    show ?thesis by (simp add: splice_trans.rep_eq)
  next
    case (Skip sg')
    then show ?thesis using "1.IH"(1) by (simp add: splice_trans.rep_eq)
  next
    case (Yield a sg')
    note IH = "1.IH"(2)[OF Yield]

    have "a # (unstream (splice_trans g h) (Right sg' sh)) = splice (unstream g sg) (unstream h sh)"
    proof (induction sh taking: h rule: unstream.induct)
      case (1 sh)
      show ?case
      proof (cases "generator h sh")
        case Done
        with unstream_splice_trans_Left_only[of g h sg']
        show ?thesis using Yield by (simp add: splice_trans.rep_eq)
      next
        case (Skip sh')
        then show ?thesis using Yield "1.IH"(1) "1.prems" by(simp add: splice_trans.rep_eq)
      next
        case (Yield b sh')
        then show ?thesis using IH ‹generator g sg = Yield a sg'
          by (simp add: splice_trans.rep_eq)
      qed
    qed
    then show ?thesis using Yield by (simp add: splice_trans.rep_eq)
  qed
qed


subsubsection @{const list_update}

fun list_update_raw :: "('a,'s) raw_generator  'a  ('a, nat × 's) raw_generator"
where
  "list_update_raw g b (n, s) = (case g s of
     Done  Done | Skip s'  Skip (n, s') 
   | Yield a s'  if n = 0 then Yield a (0,s')
                   else if n = 1 then Yield b (0, s')
                   else Yield a (n - 1, s'))"

lemma terminates_list_update_raw:
  assumes "terminates g"
  shows "terminates (list_update_raw g b)"
proof (rule terminatesI)
  fix st :: "nat × 'a"
  obtain n s where "st = (n, s)" by (cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  then show "st  terminates_on (list_update_raw g b)" unfolding st = (n, s)
  proof (induction s arbitrary: n)
    case (unfold s a s' n)
    then show "(n, s)  terminates_on (list_update_raw g b)"
      by(cases "n = 0  n = 1")(auto intro: terminates_on.unfold)
  qed(simp_all add: terminates_on.stop terminates_on.pause)
qed

lift_definition list_update_trans :: "('a,'s) generator  'a  ('a, nat × 's)  generator"
is "list_update_raw" by (rule terminates_list_update_raw)

lemma unstream_lift_update_trans_None: "unstream (list_update_trans g b) (0, s) = unstream g s"
proof(induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case by (cases "generator g s")(simp_all add: list_update_trans.rep_eq)
qed

lemma unstream_list_update_trans [stream_fusion]:
  "unstream (list_update_trans g b) (Suc n, s) = list_update (unstream g s) n b"
proof(induction s arbitrary: n taking: g rule: unstream.induct)
  case (1 s)
  then show ?case
  proof (cases "generator g s")
    case Done
    then show ?thesis by (simp add: list_update_trans.rep_eq)
  next
    case (Skip s')
    then show ?thesis using "1.IH"(1) by (simp add: list_update_trans.rep_eq)
  next
    case (Yield a s')
    then show ?thesis using unstream_lift_update_trans_None[of g b s'] "1.IH"(2) 
      by (cases n)(simp_all add: list_update_trans.rep_eq)
  qed
qed

subsubsection @{const removeAll}

definition removeAll_raw :: "'a  ('a, 's) raw_generator  ('a, 's) raw_generator"
where
 "removeAll_raw b g s = (case g s of
    Done  Done | Skip s'  Skip s' | Yield a s'  if a = b then Skip s' else Yield a s')"

lemma terminates_removeAll_raw:
  assumes "terminates g"
  shows "terminates (removeAll_raw b g)"
proof (rule terminatesI)
  fix s
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  then show "s  terminates_on (removeAll_raw b g)"
  proof(induction s)
    case (unfold s a s')
    then show ?case
      by(cases "a = b")(auto intro: terminates_on.intros simp add: removeAll_raw_def)
  qed(auto intro: terminates_on.intros simp add: removeAll_raw_def)
qed

lift_definition removeAll_trans :: "'a  ('a, 's) generator  ('a, 's) generator"
is "removeAll_raw" by (rule terminates_removeAll_raw)

lemma unstream_removeAll_trans [stream_fusion]:
  "unstream (removeAll_trans b g) s = removeAll b (unstream g s)"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case
  proof(cases "generator g s")
    case (Yield a s')
    then show ?thesis using "1.IH"(2)
      by(cases "a = b")(simp_all add: removeAll_trans.rep_eq removeAll_raw_def)
  qed(auto simp add: removeAll_trans.rep_eq removeAll_raw_def)
qed

subsubsection @{const remove1}

fun remove1_raw :: "'a  ('a, 's) raw_generator  ('a, bool × 's) raw_generator"
where
  "remove1_raw x g (b, s) = (case g s of
     Done  Done | Skip s'  Skip (b, s') 
   | Yield y s'  if b  x = y then Skip (False, s') else Yield y (b, s'))"

lemma terminates_remove1_raw: 
  assumes "terminates g"
  shows "terminates (remove1_raw b g)"
proof (rule terminatesI)
  fix st :: "bool × 'a"
  obtain c s where "st = (c, s)" by (cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  then show "st  terminates_on (remove1_raw b g)" unfolding st = (c, s)
  proof (induction s arbitrary: c)
    case (stop s)
    then show ?case by (cases c)(simp_all add: terminates_on.stop)
  next
    case (pause s s')
    then show ?case by (cases c)(simp_all add: terminates_on.pause)
  next
    case (unfold s a s')
    then show ?case
      by(cases c)(cases "a = b", auto intro: terminates_on.intros)
   qed
qed

lift_definition remove1_trans :: "'a  ('a, 's) generator  ('a, bool × 's) generator "
is "remove1_raw" by (rule terminates_remove1_raw)

lemma unstream_remove1_trans_False: "unstream (remove1_trans b g) (False, s) = unstream g s"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case by (cases "generator g s")(simp_all add: remove1_trans.rep_eq)
qed

lemma unstream_remove1_trans [stream_fusion]:
  "unstream (remove1_trans b g) (True, s) = remove1 b (unstream g s)"
proof(induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case
  proof (cases "generator g s")
    case (Yield a s')
    then show ?thesis
      using Yield "1.IH"(2) unstream_remove1_trans_False[of b g]
      by (cases "a = b")(simp_all add: remove1_trans.rep_eq)
  qed(simp_all add: remove1_trans.rep_eq)
qed

subsubsection @{term "(#)"}

fun Cons_raw :: "'a  ('a, 's) raw_generator  ('a, bool × 's) raw_generator"
where
  "Cons_raw x g (b, s) = (if b then Yield x (False, s) else case g s of
    Done  Done | Skip s'  Skip (False, s') | Yield y s'  Yield y (False, s'))"

lemma terminates_Cons_raw: 
  assumes "terminates g"
  shows "terminates (Cons_raw x g)"
proof (rule terminatesI)
  fix st :: "bool × 'a"
  obtain b s where "st = (b, s)" by (cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  hence "(False, s)  terminates_on (Cons_raw x g)"
    by(induction s arbitrary: b)(auto intro: terminates_on.intros)
  then show "st  terminates_on (Cons_raw x g)" unfolding st = (b, s)
    by(cases b)(auto intro: terminates_on.intros)
qed

lift_definition Cons_trans :: "'a  ('a, 's) generator  ('a, bool × 's) generator"
is Cons_raw by(rule terminates_Cons_raw)

lemma unstream_Cons_trans_False: "unstream (Cons_trans x g) (False, s) = unstream g s"
proof(induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case by(cases "generator g s")(auto simp add: Cons_trans.rep_eq)
qed

text ‹
  We do not declare @{const Cons_trans} as a transformer.
  Otherwise, literal lists would be transformed into streams which adds a significant overhead
  to the stream state.
›
lemma unstream_Cons_trans: "unstream (Cons_trans x g) (True, s) = x # unstream g s"
using unstream_Cons_trans_False[of x g s] by(simp add: Cons_trans.rep_eq)

subsubsection @{const List.maps}

text ‹Stream version based on Coutts \cite{Coutts2010PhD}.›

text ‹
  We restrict the function for generating the inner lists to terminating
  generators because the code generator does not directly supported nesting abstract
  datatypes in other types.
›

fun maps_raw
  :: "('a  ('b, 'sg) generator × 'sg)  ('a, 's) raw_generator
   ('b, 's × (('b, 'sg) generator × 'sg) option) raw_generator"
where
  "maps_raw f g (s, None) = (case g s of
    Done  Done | Skip s'  Skip (s', None) | Yield x s'  Skip (s', Some (f x)))"
| "maps_raw f g (s, Some (g'', s'')) = (case generator g'' s'' of
    Done  Skip (s, None) | Skip s'  Skip (s, Some (g'', s')) | Yield x s'  Yield x (s, Some (g'', s')))"

lemma terminates_on_maps_raw_Some: 
  assumes "(s, None)  terminates_on (maps_raw f g)"
  shows "(s, Some (g'', s''))  terminates_on (maps_raw f g)"
proof -
  from generator[of g''] have "s''  terminates_on (generator g'')" by (simp add: terminates_def)
  thus ?thesis by(induction)(auto intro: terminates_on.intros assms)
qed

lemma terminates_maps_raw: 
  assumes "terminates g"
  shows "terminates (maps_raw f g)"
proof
  fix st :: "'a × (('c, 'd) generator × 'd) option"
  obtain s mgs where "st = (s, mgs)" by(cases st) 
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  then show "st  terminates_on (maps_raw f g)" unfolding st = (s, mgs)
    apply(induction arbitrary: mgs)
    apply(case_tac [!] mgs)
    apply(auto intro: terminates_on.intros intro!: terminates_on_maps_raw_Some)
    done
qed

lift_definition maps_trans :: "('a  ('b, 'sg) generator × 'sg)  ('a, 's) generator
   ('b, 's × (('b, 'sg) generator × 'sg) option) generator"
is "maps_raw" by(rule terminates_maps_raw)

lemma unstream_maps_trans_Some:
  "unstream (maps_trans f g) (s, Some (g'', s'')) = unstream g'' s'' @ unstream (maps_trans f g) (s, None)"
proof(induction s'' taking: g'' rule: unstream.induct)
  case (1 s'')
  then show ?case by(cases "generator g'' s''")(simp_all add: maps_trans.rep_eq)
qed

lemma unstream_maps_trans:
  "unstream (maps_trans f g) (s, None) = List.maps (case_prod unstream  f) (unstream g s)"
proof(induction s taking: g rule: unstream.induct)
  case (1 s)
  thus ?case
  proof(cases "generator g s")
    case (Yield x s')
    with "1.IH"(2)[OF this] show ?thesis
      using unstream_maps_trans_Some[of f g _ "fst (f x)" "snd (f x)"]
      by(simp add: maps_trans.rep_eq maps_simps split_def)
  qed(simp_all add: maps_trans.rep_eq maps_simps)
qed

text ‹
  The rule @{thm [source] unstream_map_trans} is too complicated for fusion because of @{term split},
  which does not arise naturally from stream fusion rules. Moreover, according to Farmer et al.
  \cite{FarmerHoenerGill2014PEPM}, this fusion is too general for further optimisations because the
  generators of the inner list are generated by the outer generator and therefore compilers may
  think that is was not known statically. 

  Instead, they propose a weaker version using flatten› below.
  (More precisely, Coutts already mentions this approach in his PhD thesis \cite{Coutts2010PhD},
  but dismisses it because it requires a stronger rewriting engine than GHC has. But Isabelle's
  simplifier language is sufficiently powerful.
›

fun fix_step :: "'a  ('b, 's) step  ('b, 'a × 's) step"
where
  "fix_step a Done = Done"
| "fix_step a (Skip s) = Skip (a, s)"
| "fix_step a (Yield x s) = Yield x (a, s)"

fun fix_gen_raw :: "('a  ('b, 's) raw_generator)  ('b, 'a × 's) raw_generator"
where "fix_gen_raw g (a, s) = fix_step a (g a s)"

lemma terminates_fix_gen_raw:
  assumes "x. terminates (g x)"
  shows "terminates (fix_gen_raw g)"
proof
  fix st :: "'a × 'b"
  obtain a s where "st = (a, s)" by(cases st)
  from assms[of a] have "s  terminates_on (g a)" by (simp add: terminates_def)
  then show "st  terminates_on (fix_gen_raw g)" unfolding st = (a, s)
    by(induction)(auto intro: terminates_on.intros)
qed

lift_definition fix_gen :: "('a  ('b, 's) generator)  ('b, 'a × 's) generator"
is "fix_gen_raw" by(rule terminates_fix_gen_raw)

lemma unstream_fix_gen: "unstream (fix_gen g) (a, s) = unstream (g a) s"
proof(induction s taking: "g a" rule: unstream.induct)
  case (1 s)
  thus ?case by(cases "generator (g a) s")(simp_all add: fix_gen.rep_eq)
qed

context 
  fixes f :: "('a  's')"
  and g'' :: "('b, 's') raw_generator"
  and g :: "('a, 's) raw_generator"
begin

fun flatten_raw :: "('b, 's × 's' option) raw_generator"
where
  "flatten_raw (s, None) = (case g s of
     Done  Done | Skip s'  Skip (s', None) | Yield x s'  Skip (s', Some (f x)))"
| "flatten_raw (s, Some s'') = (case g'' s'' of
     Done  Skip (s, None) | Skip s'  Skip (s, Some s') | Yield x s'  Yield x (s, Some s'))"

lemma terminates_flatten_raw: 
  assumes "terminates g''" "terminates g"
  shows "terminates flatten_raw"
proof
  fix st :: "'s × 's' option"
  obtain s ms where "st = (s, ms)" by(cases st)
  { fix s s''
    assume s: "(s, None)  terminates_on flatten_raw"
    from ‹terminates g'' have "s''  terminates_on g''" by (simp add: terminates_def)
    hence "(s, Some s'')  terminates_on flatten_raw"
      by(induction)(auto intro: terminates_on.intros s) }
  note Some = this
  from ‹terminates g have "s  terminates_on g" by (simp add: terminates_def)
  then show "st  terminates_on flatten_raw" unfolding st = (s, ms)
    apply(induction arbitrary: ms)
    apply(case_tac [!] ms)
    apply(auto intro: terminates_on.intros intro!: Some)
    done
qed

end

lift_definition flatten :: "('a  's')  ('b, 's') generator  ('a, 's) generator  ('b, 's × 's' option) generator"
is "flatten_raw" by(fact terminates_flatten_raw)

lemma unstream_flatten_Some:
  "unstream (flatten f g'' g) (s, Some s') = unstream g'' s' @ unstream (flatten f g'' g) (s, None)"
proof(induction s' taking: g'' rule: unstream.induct)
  case (1 s')
  thus ?case by(cases "generator g'' s'")(simp_all add: flatten.rep_eq)
qed

text ‹HO rewrite equations can express the variable capture in the generator unlike GHC rules›

lemma unstream_flatten_fix_gen [stream_fusion]:
  "unstream (flatten (λs. (s, f s)) (fix_gen g'') g) (s, None) =
   List.maps (λs'. unstream (g'' s') (f s')) (unstream g s)"
proof(induction s taking: g rule: unstream.induct)
  case (1 s)
  thus ?case
  proof(cases "generator g s")
    case (Yield x s')
    with "1.IH"(2)[OF this] unstream_flatten_Some[of "λs. (s, f s)" "fix_gen g''" g]
    show ?thesis
      by(subst (1 3) unstream.simps)(simp add: flatten.rep_eq maps_simps unstream_fix_gen)
  qed(simp_all add: flatten.rep_eq maps_simps)
qed

text ‹
  Separate fusion rule when the inner generator does not depend on the elements of the outer stream.
›
lemma unstream_flatten [stream_fusion]:
  "unstream (flatten f g'' g) (s, None) = List.maps (λs'. unstream g'' (f s')) (unstream g s)"
proof(induction s taking: g rule: unstream.induct)
  case (1 s)
  thus ?case 
  proof(cases "generator g s")
    case (Yield x s')
    with "1.IH"(2)[OF this] show ?thesis
      using unstream_flatten_Some[of f g'' g s' "f x"]
      by(simp add: flatten.rep_eq maps_simps o_def)
  qed(simp_all add: maps_simps flatten.rep_eq)
qed

end

Theory Stream_Fusion_LList

(* Title: Stream_Fusion_LList
  Author: Andreas Lochbihler, ETH Zurich *)

section ‹Stream fusion for coinductive lists›

theory Stream_Fusion_LList imports
  Stream_Fusion_List
  Coinductive.Coinductive_List
begin

text ‹
  There are two choices of how many @{const Skip}s may occur consecutively.
  \begin{itemize}
  \item A generator for @{typ "'a llist"} may return only finitely many @{const Skip}s before
    it has to decide on a @{const Done} or @{const Yield}. Then, we can define stream versions
    for all functions that can be defined by corecursion up-to. This in particular excludes
    @{const lfilter}. Moreover, we have to prove that every generator satisfies this
    restriction.
  \item A generator for @{typ "'a llist"} may return infinitely many @{const Skip}s in a row.
    Then, the lunstream› function suffers from the same difficulties as @{const lfilter} with
    definitions, but we can define it using the least fixpoint approach described in
    \cite{LochbihlerHoelzl2014ITP}. Consequently, we can only fuse transformers that are monotone and
    continuous with respect to the ccpo ordering. This in particular excludes @{const lappend}.
  \end{itemize}
  Here, we take the both approaches where we consider the first preferable to the second.
  Consequently, we define producers such that they produce generators of the first kind, if possible.
  There will be multiple equations for transformers and consumers that deal with all the different
  combinations for their parameter generators. Transformers should yield generators of the first
  kind whenever possible. Consumers can be defined using lunstream› and refined with custom
  code equations, i.e., they can operate with infinitely many Skip›s in a row. We just
  have to lift the fusion equation to the first kind, too.
›

type_synonym ('a, 's) lgenerator = "'s  ('a, 's) step"

inductive_set productive_on :: "('a, 's) lgenerator  's set"
for g :: "('a, 's) lgenerator"
where
  Done: "g s = Done  s  productive_on g"
| Skip: " g s = Skip s'; s'  productive_on g   s  productive_on g"
| Yield: "g s = Yield x s'  s  productive_on g"

definition productive :: "('a, 's) lgenerator  bool"
where "productive g  productive_on g = UNIV"

lemma productiveI [intro?]:
  "(s. s  productive_on g)  productive g"
by(auto simp add: productive_def)

lemma productive_onI [dest?]: "productive g  s  productive_on g"
by(simp add: productive_def)

text ‹A type of generators that eventually will yield something else than a skip.›

typedef ('a, 's) lgenerator' = "{g :: ('a, 's) lgenerator. productive g}"
  morphisms lgenerator Abs_lgenerator'
proof
  show "(λ_. Done)  ?lgenerator'" by(auto intro: productive_on.intros productiveI)
qed

setup_lifting type_definition_lgenerator'

subsection ‹Conversions to @{typ "'a llist"}

subsubsection ‹Infinitely many consecutive @{term Skip}s›

context fixes g :: "('a, 's) lgenerator"
  notes [[function_internals]]
begin

partial_function (llist) lunstream :: "'s  'a llist"
where
  "lunstream s = (case g s of 
     Done  LNil | Skip s'  lunstream s' | Yield x s'  LCons x (lunstream s'))"

declare lunstream.simps[code]

lemma lunstream_simps:
  "g s = Done  lunstream s = LNil"
  "g s = Skip s'  lunstream s = lunstream s'"
  "g s = Yield x s'  lunstream s = LCons x (lunstream s')"
by(simp_all add: lunstream.simps)

lemma lunstream_sels:
  shows lnull_lunstream: "lnull (lunstream s)  
  (case g s of Done  True | Skip s'  lnull (lunstream s') | Yield _ _  False)"
  and lhd_lunstream: "lhd (lunstream s) =
  (case g s of Skip s'  lhd (lunstream s') | Yield x _  x)"
  and ltl_lunstream: "ltl (lunstream s) =
  (case g s of Done  LNil | Skip s'  ltl (lunstream s') | Yield _ s'  lunstream s')"
by(simp_all add: lhd_def lunstream_simps split: step.split)

end

subsubsection ‹Finitely many consecutive @{term Skip}s›

lift_definition lunstream' :: "('a, 's) lgenerator'  's  'a llist"
is lunstream .

lemma lunstream'_simps:
  "lgenerator g s = Done  lunstream' g s = LNil"
  "lgenerator g s = Skip s'  lunstream' g s = lunstream' g s'"
  "lgenerator g s = Yield x s'  lunstream' g s = LCons x (lunstream' g s')"
by(transfer, simp add: lunstream_simps)+

lemma lunstream'_sels:
  shows lnull_lunstream': "lnull (lunstream' g s)  
  (case lgenerator g s of Done  True | Skip s'  lnull (lunstream' g s') | Yield _ _  False)"
  and lhd_lunstream': "lhd (lunstream' g s) =
  (case lgenerator g s of Skip s'  lhd (lunstream' g s') | Yield x _  x)"
  and ltl_lunstream': "ltl (lunstream' g s) =
  (case lgenerator g s of Done  LNil | Skip s'  ltl (lunstream' g s') | Yield _ s'  lunstream' g s')"
by(transfer, simp add: lunstream_sels)+

setup ‹Context.theory_map (fold
  Stream_Fusion.add_unstream [@{const_name lunstream}, @{const_name lunstream'}])

subsection ‹Producers›

subsubsection ‹Conversion to streams›

fun lstream :: "('a, 'a llist) lgenerator"
where
  "lstream LNil = Done"
| "lstream (LCons x xs) = Yield x xs"

lemma case_lstream_conv_case_llist:
  "(case lstream xs of Done  done | Skip xs'  skip xs' | Yield x xs'  yield x xs') =
   (case xs of LNil  done | LCons x xs'  yield x xs')"
by(simp split: llist.split)

lemma mcont2mcont_lunstream[THEN llist.mcont2mcont, simp, cont_intro]:
  shows mcont_lunstream: "mcont lSup lprefix lSup lprefix (lunstream lstream)"
by(rule llist.fixp_preserves_mcont1[OF lunstream.mono lunstream_def])(simp add: case_lstream_conv_case_llist)

lemma lunstream_lstream: "lunstream lstream xs = xs"
by(induction xs)(simp_all add: lunstream_simps)

lift_definition lstream' :: "('a, 'a llist) lgenerator'"
is lstream
proof
  fix s :: "'a llist"
  show "s  productive_on lstream" by(cases s)(auto intro: productive_on.intros)
qed

lemma lunstream'_lstream: "lunstream' lstream' xs = xs"
by(transfer)(rule lunstream_lstream)

subsubsection @{const iterates}

definition iterates_raw :: "('a  'a)  ('a, 'a) lgenerator"
where "iterates_raw f s = Yield s (f s)"

lemma lunstream_iterates_raw: "lunstream (iterates_raw f) x = iterates f x"
by(coinduction arbitrary: x)(auto simp add: iterates_raw_def lunstream_sels)

lift_definition iterates_prod :: "('a  'a)  ('a, 'a) lgenerator'" is iterates_raw
by(auto 4 3 intro: productiveI productive_on.intros simp add: iterates_raw_def)

lemma lunstream'_iterates_prod [stream_fusion]: "lunstream' (iterates_prod f) x = iterates f x"
by transfer(rule lunstream_iterates_raw)

subsubsection @{const unfold_llist}

definition unfold_llist_raw :: "('a  bool)  ('a  'b)  ('a  'a)  ('b, 'a) lgenerator"
where
  "unfold_llist_raw stop head tail s = (if stop s then Done else Yield (head s) (tail s))"

lemma lunstream_unfold_llist_raw:
  "lunstream (unfold_llist_raw stop head tail) s = unfold_llist stop head tail s"
by(coinduction arbitrary: s)(auto simp add: lunstream_sels unfold_llist_raw_def)

lift_definition unfold_llist_prod :: "('a  bool)  ('a  'b)  ('a  'a)  ('b, 'a) lgenerator'"
is unfold_llist_raw
proof(rule productiveI)
  fix stop and head :: "'a  'b" and tail s
  show "s  productive_on (unfold_llist_raw stop head tail)"
    by(cases "stop s")(auto intro: productive_on.intros simp add: unfold_llist_raw_def)
qed

lemma lunstream'_unfold_llist_prod [stream_fusion]:
  "lunstream' (unfold_llist_prod stop head tail) s = unfold_llist stop head tail s"
by transfer(rule lunstream_unfold_llist_raw)

subsubsection @{const inf_llist}

definition inf_llist_raw :: "(nat  'a)  ('a, nat) lgenerator"
where "inf_llist_raw f n = Yield (f n) (Suc n)"

lemma lunstream_inf_llist_raw: "lunstream (inf_llist_raw f) n = ldropn n (inf_llist f)"
by(coinduction arbitrary: n)(auto simp add: lunstream_sels inf_llist_raw_def)

lift_definition inf_llist_prod :: "(nat  'a)  ('a, nat) lgenerator'" is inf_llist_raw
by(auto 4 3 intro: productiveI productive_on.intros simp add: inf_llist_raw_def)

lemma inf_llist_prod_fusion [stream_fusion]:
  "lunstream' (inf_llist_prod f) 0 = inf_llist f"
by transfer(simp add: lunstream_inf_llist_raw)

subsection ‹Consumers›

subsubsection @{const lhd}

context fixes g :: "('a, 's) lgenerator" begin

definition lhd_cons :: "'s  'a"
where [stream_fusion]: "lhd_cons s = lhd (lunstream g s)"

lemma lhd_cons_code[code]:
  "lhd_cons s = (case g s of Done  undefined | Skip s'  lhd_cons s' | Yield x _  x)"
by(simp add: lhd_cons_def lunstream_simps lhd_def split: step.split)

end

lemma lhd_cons_fusion2 [stream_fusion]:
  "lhd_cons (lgenerator g) s = lhd (lunstream' g s)"
by transfer(rule lhd_cons_def)

subsubsection @{const llength}

context fixes g :: "('a, 's) lgenerator" begin

definition gen_llength_cons :: "enat  's  enat"
where "gen_llength_cons n s = n + llength (lunstream g s)"

lemma gen_llength_cons_code [code]:
  "gen_llength_cons n s = (case g s of
    Done  n | Skip s'  gen_llength_cons n s' | Yield _ s'  gen_llength_cons (eSuc n) s')"
by(simp add: gen_llength_cons_def lunstream_simps iadd_Suc_right iadd_Suc split: step.split)

lemma gen_llength_cons_fusion [stream_fusion]:
  "gen_llength_cons 0 s = llength (lunstream g s)"
by(simp add: gen_llength_cons_def)

end

context fixes g :: "('a, 's) lgenerator'" begin

definition gen_llength_cons' :: "enat  's  enat"
where "gen_llength_cons' = gen_llength_cons (lgenerator g)"

lemma gen_llength_cons'_code [code]:
  "gen_llength_cons' n s = (case lgenerator g s of
    Done  n | Skip s'  gen_llength_cons' n s' | Yield _ s'  gen_llength_cons' (eSuc n) s')"
by(simp add: gen_llength_cons'_def cong: step.case_cong)(rule gen_llength_cons_code)

lemma gen_llength_cons'_fusion [stream_fusion]:
  "gen_llength_cons' 0 s = llength (lunstream' g s)"
by(simp add: gen_llength_cons'_def gen_llength_cons_fusion lunstream'.rep_eq)

end

subsubsection @{const lnull}

context fixes g :: "('a, 's) lgenerator" begin

definition lnull_cons :: "'s  bool"
where [stream_fusion]: "lnull_cons s  lnull (lunstream g s)"

lemma lnull_cons_code [code]:
  "lnull_cons s  (case g s of
    Done  True | Skip s'  lnull_cons s' | Yield _ _  False)"
by(simp add: lnull_cons_def lunstream_simps split: step.split)

end

context fixes g :: "('a, 's) lgenerator'" begin

definition lnull_cons' :: "'s  bool"
where "lnull_cons' = lnull_cons (lgenerator g)"

lemma lnull_cons'_code [code]:
  "lnull_cons' s  (case lgenerator g s of
    Done  True | Skip s'  lnull_cons' s' | Yield _ _  False)"
by(simp add: lnull_cons'_def cong: step.case_cong)(rule lnull_cons_code)

lemma lnull_cons'_fusion [stream_fusion]:
  "lnull_cons' s  lnull (lunstream' g s)"
by(simp add: lnull_cons'_def lnull_cons_def lunstream'.rep_eq)

end

subsubsection @{const llist_all2}

context
  fixes g :: "('a, 'sg) lgenerator"
  and h :: "('b, 'sh) lgenerator"
  and P :: "'a  'b  bool"
begin

definition llist_all2_cons :: "'sg  'sh  bool"
where [stream_fusion]: "llist_all2_cons sg sh  llist_all2 P (lunstream g sg) (lunstream h sh)"

definition llist_all2_cons1 :: "'a  'sg  'sh  bool"
where "llist_all2_cons1 x sg' sh = llist_all2 P (LCons x (lunstream g sg')) (lunstream h sh)"

lemma llist_all2_cons_code [code]:
  "llist_all2_cons sg sh = 
  (case g sg of
     Done  lnull_cons h sh
   | Skip sg'  llist_all2_cons sg' sh
   | Yield a sg'  llist_all2_cons1 a sg' sh)"
by(simp split: step.split add: llist_all2_cons_def lnull_cons_def llist_all2_cons1_def lunstream_simps lnull_def)

lemma llist_all2_cons1_code [code]:
  "llist_all2_cons1 x sg' sh = 
  (case h sh of
     Done  False
   | Skip sh'  llist_all2_cons1 x sg' sh'
   | Yield y sh'  P x y  llist_all2_cons sg' sh')"
by(simp split: step.split add: llist_all2_cons_def lnull_cons_def lnull_def llist_all2_cons1_def lunstream_simps)

end

lemma llist_all2_cons_fusion2 [stream_fusion]:
  "llist_all2_cons (lgenerator g) (lgenerator h) P sg sh  llist_all2 P (lunstream' g sg) (lunstream' h sh)"
by transfer(rule llist_all2_cons_def)

lemma llist_all2_cons_fusion3 [stream_fusion]:
  "llist_all2_cons g (lgenerator h) P sg sh  llist_all2 P (lunstream g sg) (lunstream' h sh)"
by transfer(rule llist_all2_cons_def)

lemma llist_all2_cons_fusion4 [stream_fusion]:
  "llist_all2_cons (lgenerator g) h P sg sh  llist_all2 P (lunstream' g sg) (lunstream h sh)"
by transfer(rule llist_all2_cons_def)

subsubsection @{const lnth}

context fixes g :: "('a, 's) lgenerator" begin

definition lnth_cons :: "nat  's  'a"
where [stream_fusion]: "lnth_cons n s = lnth (lunstream g s) n"

lemma lnth_cons_code [code]:
  "lnth_cons n s = (case g s of
    Done  undefined n
  | Skip s'  lnth_cons n s'
  | Yield x s'  (if n = 0 then x else lnth_cons (n - 1) s'))"
by(cases n)(simp_all add: lnth_cons_def lunstream_simps lnth_LNil split: step.split)

end

lemma lnth_cons_fusion2 [stream_fusion]:
  "lnth_cons (lgenerator g) n s = lnth (lunstream' g s) n"
by transfer(rule lnth_cons_def)

subsubsection @{const lprefix}

context
  fixes g :: "('a, 'sg) lgenerator"
  and h :: "('a, 'sh) lgenerator"
begin

definition lprefix_cons :: "'sg  'sh  bool"
where [stream_fusion]: "lprefix_cons sg sh  lprefix (lunstream g sg) (lunstream h sh)"

definition lprefix_cons1 :: "'a  'sg  'sh  bool"
where "lprefix_cons1 x sg' sh  lprefix (LCons x (lunstream g sg')) (lunstream h sh)"

lemma lprefix_cons_code [code]:
  "lprefix_cons sg sh  (case g sg of
     Done  True | Skip sg'  lprefix_cons sg' sh | Yield x sg'  lprefix_cons1 x sg' sh)"
by(simp add: lprefix_cons_def lprefix_cons1_def lunstream_simps split: step.split)

lemma lprefix_cons1_code [code]:
  "lprefix_cons1 x sg' sh  (case h sh of
     Done  False | Skip sh'  lprefix_cons1 x sg' sh'
   | Yield y sh'  x = y  lprefix_cons sg' sh')"
by(simp add: lprefix_cons_def lprefix_cons1_def lunstream_simps split: step.split)

end

lemma lprefix_cons_fusion2 [stream_fusion]:
  "lprefix_cons (lgenerator g) (lgenerator h) sg sh  lprefix (lunstream' g sg) (lunstream' h sh)"
by transfer(rule lprefix_cons_def)

lemma lprefix_cons_fusion3 [stream_fusion]:
  "lprefix_cons g (lgenerator h) sg sh  lprefix (lunstream g sg) (lunstream' h sh)"
by transfer(rule lprefix_cons_def)

lemma lprefix_cons_fusion4 [stream_fusion]:
  "lprefix_cons (lgenerator g) h sg sh  lprefix (lunstream' g sg) (lunstream h sh)"
by transfer(rule lprefix_cons_def)

subsection ‹Transformers›

subsubsection @{const lmap}

definition lmap_trans :: "('a  'b)  ('a, 's) lgenerator  ('b, 's) lgenerator"
where "lmap_trans = map_raw"

lemma lunstream_lmap_trans [stream_fusion]: fixes f g s
  defines [simp]: "g'  lmap_trans f g"
  shows "lunstream g' s = lmap f (lunstream g s)" (is "?lhs = ?rhs")
proof(rule lprefix_antisym)
  show "lprefix ?lhs ?rhs"
  proof(induction g' arbitrary: s rule: lunstream.fixp_induct) 
    case (3 lunstream_g')
    then show ?case
      by(cases "g s")(simp_all add: lmap_trans_def map_raw_def lunstream_simps)
  qed simp_all
next
  note [cont_intro] = ccpo.admissible_leI[OF llist_ccpo]
  show "lprefix ?rhs ?lhs"
  proof(induction g arbitrary: s rule: lunstream.fixp_induct) 
    case (3 lunstream_g)
    thus ?case by(cases "g s")(simp_all add: lmap_trans_def map_raw_def lunstream_simps)
  qed simp_all
qed

lift_definition lmap_trans' :: "('a  'b)  ('a, 's) lgenerator'  ('b, 's) lgenerator'"
is lmap_trans
proof
  fix f :: "'a  'b" and g :: "('a, 's) lgenerator" and s
  assume "productive g"
  hence "s  productive_on g" ..
  thus "s  productive_on (lmap_trans f g)"
    by induction(auto simp add: lmap_trans_def map_raw_def intro: productive_on.intros)
qed

lemma lunstream'_lmap_trans' [stream_fusion]:
  "lunstream' (lmap_trans' f g) s = lmap f (lunstream' g s)"
by transfer(rule lunstream_lmap_trans)

subsubsection @{const ltake}

fun ltake_trans :: "('a, 's) lgenerator  ('a, (enat × 's)) lgenerator"
where
  "ltake_trans g (n, s) =
  (if n = 0 then Done else case g s of 
    Done  Done | Skip s'  Skip (n, s') | Yield a s'  Yield a (epred n, s'))"

lemma ltake_trans_fusion [stream_fusion]:
  fixes g' g
  defines [simp]: "g'  ltake_trans g"
  shows "lunstream g' (n, s) = ltake n (lunstream g s)" (is "?lhs = ?rhs")
proof(rule lprefix_antisym)
  show "lprefix ?lhs ?rhs"
  proof(induction g' arbitrary: n s rule: lunstream.fixp_induct)
    case (3 lunstream_g')
    thus ?case
      by(cases "g s")(auto simp add: lunstream_simps neq_zero_conv_eSuc)
  qed simp_all
  show "lprefix ?rhs ?lhs"
  proof(induction g arbitrary: s n rule: lunstream.fixp_induct)
    case (3 lunstream_g)
    thus ?case by(cases "g s" n rule: step.exhaust[case_product enat_coexhaust])(auto simp add: lunstream_simps)
  qed simp_all
qed

lift_definition ltake_trans' :: "('a, 's) lgenerator'  ('a, (enat × 's)) lgenerator'"
is "ltake_trans"
proof
  fix g :: "('a, 's) lgenerator" and s :: "enat × 's"
  obtain n sg where s: "s = (n, sg)" by(cases s)
  assume "productive g"
  hence "sg  productive_on g" ..
  then show "s  productive_on (ltake_trans g)" unfolding s = (n, sg)
    apply(induction arbitrary: n)
    apply(case_tac [!] n rule: enat_coexhaust)
    apply(auto intro: productive_on.intros)
    done
qed

lemma ltake_trans'_fusion [stream_fusion]:
  "lunstream' (ltake_trans' g) (n, s) = ltake n (lunstream' g s)"
by transfer(rule ltake_trans_fusion)

subsubsection @{const ldropn}

abbreviation (input) ldropn_trans :: "('b, 'a) lgenerator  ('b, nat × 'a) lgenerator"
where "ldropn_trans  drop_raw"

lemma ldropn_trans_fusion [stream_fusion]:
  fixes g defines [simp]: "g'  ldropn_trans g"
  shows "lunstream g' (n, s) = ldropn n (lunstream g s)" (is "?lhs = ?rhs")
proof(rule lprefix_antisym)
  show "lprefix ?lhs ?rhs"
  proof(induction g' arbitrary: n s rule: lunstream.fixp_induct)
    case (3 lunstream_g')
    thus ?case
      by(cases "g s" n rule: step.exhaust[case_product nat.exhaust])
        (auto simp add: lunstream_simps elim: meta_allE[where x=0])
  qed simp_all
  note [cont_intro] = ccpo.admissible_leI[OF llist_ccpo]
  show "lprefix ?rhs ?lhs"
  proof(induction g arbitrary: n s rule: lunstream.fixp_induct)
    case (3 lunstream_g)
    thus ?case by(cases n)(auto split: step.split simp add: lunstream_simps elim: meta_allE[where x=0])
  qed simp_all
qed

lift_definition ldropn_trans' :: "('a, 's) lgenerator'  ('a, nat × 's) lgenerator'"
is ldropn_trans
proof
  fix g :: "('a, 's) lgenerator" and ns :: "nat × 's"
  obtain n s where ns: "ns = (n, s)" by(cases ns)
  assume g: "productive g"
  show "ns  productive_on (ldropn_trans g)" unfolding ns
  proof(induction n arbitrary: s)
    case 0
    from g have "s  productive_on g" ..
    thus ?case by induction(auto intro: productive_on.intros)
  next
    case (Suc n)
    from g have "s  productive_on g" ..
    thus ?case by induction(auto intro: productive_on.intros Suc.IH)
  qed
qed

lemma ldropn_trans'_fusion [stream_fusion]:
  "lunstream' (ldropn_trans' g) (n, s) = ldropn n (lunstream' g s)"
by transfer(rule ldropn_trans_fusion)

subsubsection @{const ldrop}

fun ldrop_trans :: "('a, 's) lgenerator  ('a, enat × 's) lgenerator"
where
  "ldrop_trans g (n, s) = (case g s of 
    Done  Done | Skip s'  Skip (n, s')
  | Yield x s'  (if n = 0 then Yield x (n, s') else Skip (epred n, s')))"

lemma ldrop_trans_fusion [stream_fusion]:
  fixes g g' defines [simp]: "g'  ldrop_trans g"
  shows "lunstream g' (n, s) = ldrop n (lunstream g s)" (is "?lhs = ?rhs")
proof(rule lprefix_antisym)
  show "lprefix ?lhs ?rhs"
    by(induction g' arbitrary: n s rule: lunstream.fixp_induct)
      (auto simp add: lunstream_simps neq_zero_conv_eSuc elim: meta_allE[where x=0] split: step.split)
  show "lprefix ?rhs ?lhs"
  proof(induction g arbitrary: n s rule: lunstream.fixp_induct)
    case (3 lunstream_g)
    thus ?case
      by(cases n rule: enat_coexhaust)(auto simp add: lunstream_simps split: step.split elim: meta_allE[where x=0])
  qed simp_all
qed

lemma ldrop_trans_fusion2 [stream_fusion]:
  "lunstream (ldrop_trans (lgenerator g)) (n, s) = ldrop n (lunstream' g s)"
by transfer (rule ldrop_trans_fusion)

subsubsection @{const ltakeWhile}

abbreviation (input) ltakeWhile_trans :: "('a  bool)  ('a, 's) lgenerator  ('a, 's) lgenerator"
where "ltakeWhile_trans  takeWhile_raw"

lemma ltakeWhile_trans_fusion [stream_fusion]:
  fixes P g g' defines [simp]: "g'  ltakeWhile_trans P g"
  shows "lunstream g' s = ltakeWhile P (lunstream g s)" (is "?lhs = ?rhs")
proof(rule lprefix_antisym)
  show "lprefix ?lhs ?rhs"
    by(induction g' arbitrary: s rule: lunstream.fixp_induct)(auto simp add: lunstream_simps takeWhile_raw_def split: step.split)
  show "lprefix ?rhs ?lhs"
    by(induction g arbitrary: s rule: lunstream.fixp_induct)(auto split: step.split simp add: lunstream_simps takeWhile_raw_def)
qed

lift_definition ltakeWhile_trans' :: "('a  bool)  ('a, 's) lgenerator'  ('a, 's) lgenerator'"
is ltakeWhile_trans
proof
  fix P and g :: "('a, 's) lgenerator" and s
  assume "productive g"
  hence "s  productive_on g" ..
  thus "s  productive_on (ltakeWhile_trans P g)"
    apply(induction)
    apply(case_tac [3] "P x")
    apply(auto intro: productive_on.intros simp add: takeWhile_raw_def)
    done
qed

lemma ltakeWhile_trans'_fusion [stream_fusion]:
  "lunstream' (ltakeWhile_trans' P g) s = ltakeWhile P (lunstream' g s)"
by transfer(rule ltakeWhile_trans_fusion)

subsubsection @{const ldropWhile}

abbreviation (input) ldropWhile_trans :: "('a  bool)  ('a, 'b) lgenerator  ('a, bool × 'b) lgenerator"
where "ldropWhile_trans  dropWhile_raw"

lemma ldropWhile_trans_fusion [stream_fusion]:
  fixes P g g' defines [simp]: "g'  ldropWhile_trans P g"
  shows "lunstream g' (True, s) = ldropWhile P (lunstream g s)" (is "?lhs = ?rhs")
proof -
  have "lprefix ?lhs ?rhs" "lprefix (lunstream g' (False, s)) (lunstream g s)"
    by(induction g' arbitrary: s rule: lunstream.fixp_induct)(simp_all add: lunstream_simps split: step.split)
  moreover have "lprefix ?rhs ?lhs" "lprefix (lunstream g s) (lunstream g' (False, s))"
    by(induction g arbitrary: s rule: lunstream.fixp_induct)(simp_all add: lunstream_simps split: step.split)
  ultimately show ?thesis by(blast intro: lprefix_antisym)
qed

lemma ldropWhile_trans_fusion2 [stream_fusion]:
  "lunstream (ldropWhile_trans P (lgenerator g)) (True, s) = ldropWhile P (lunstream' g s)"
by transfer(rule ldropWhile_trans_fusion)

subsubsection @{const lzip}

abbreviation (input) lzip_trans :: "('a, 's1) lgenerator  ('b, 's2) lgenerator  ('a × 'b, 's1 × 's2 × 'a option) lgenerator"
where "lzip_trans  zip_raw"

lemma lzip_trans_fusion [stream_fusion]:
  fixes g h gh defines [simp]: "gh  lzip_trans g h"
  shows "lunstream gh (sg, sh, None) = lzip (lunstream g sg) (lunstream h sh)"
  (is "?lhs = ?rhs")
proof -
  have "lprefix ?lhs ?rhs"
    and "x. lprefix (lunstream gh (sg, sh, Some x)) (lzip (LCons x (lunstream g sg)) (lunstream h sh))"
  proof(induction gh arbitrary: sg sh rule: lunstream.fixp_induct) 
    case (3 lunstream)
    { case 1 show ?case using 3
        by(cases "g sg")(simp_all add: lunstream_simps) }
    { case 2 show ?case using 3
        by(cases "h sh")(simp_all add: lunstream_simps) }
  qed simp_all
  moreover
  note [cont_intro] = ccpo.admissible_leI[OF llist_ccpo]
  have "lprefix ?rhs ?lhs" 
    and "x. lprefix (lzip (LCons x (lunstream g sg)) (lunstream h sh)) (lunstream gh (sg, sh, Some x))"
  proof(induction g arbitrary: sg sh rule: lunstream.fixp_induct)
    case (3 lunstream_g)
    note IH = "3.IH"
    { case 1 show ?case using 3
        by(cases "g sg")(simp_all add: lunstream_simps fun_ord_def) }
    { case 2 show ?case
      proof(induction h arbitrary: sh sg x rule: lunstream.fixp_induct)
        case (3 unstream_h)
        thus ?case
        proof(cases "h sh")
          case (Yield y sh')
          thus ?thesis using "3.prems" IH "3.hyps"
            by(cases "g sg")(auto 4 3 simp add: lunstream_simps fun_ord_def intro: monotone_lzip2[THEN monotoneD] lprefix_trans)
        qed(simp_all add: lunstream_simps)
      qed simp_all }
  next
    case 2 case 2
    show ?case
      by(induction h arbitrary: sh rule: lunstream.fixp_induct)(simp_all add: lunstream_simps split: step.split)
  qed simp_all
  ultimately show ?thesis by(blast intro: lprefix_antisym)
qed

lemma lzip_trans_fusion2 [stream_fusion]:
  "lunstream (lzip_trans (lgenerator g) h) (sg, sh, None) = lzip (lunstream' g sg) (lunstream h sh)"
by transfer(rule lzip_trans_fusion)

lemma lzip_trans_fusion3 [stream_fusion]:
  "lunstream (lzip_trans g (lgenerator h)) (sg, sh, None) = lzip (lunstream g sg) (lunstream' h sh)"
by transfer(rule lzip_trans_fusion)

lift_definition lzip_trans' :: "('a, 's1) lgenerator'  ('b, 's2) lgenerator'  ('a × 'b, 's1 × 's2 × 'a option) lgenerator'"
is "lzip_trans"
proof
  fix g :: "('a, 's1) lgenerator" and h :: "('b, 's2) lgenerator" and s :: "'s1 × 's2 × 'a option"
  assume "productive g" and "productive h"
  obtain sg sh mx where s: "s = (sg, sh, mx)" by(cases s)
  { fix x sg
    from ‹productive h have "sh  productive_on h" ..
    hence "(sg, sh, Some x)  productive_on (lzip_trans g h)"
      by(induction)(auto simp add: intro: productive_on.intros) }
  moreover
  from ‹productive g have "sg  productive_on g" ..
  then have "(sg, sh, None)  productive_on (lzip_trans g h)"
    by induction(auto intro: productive_on.intros calculation)
  ultimately show "s  productive_on (lzip_trans g h)" unfolding s
    by(cases mx) auto
qed

lemma lzip_trans'_fusion [stream_fusion]:
  "lunstream' (lzip_trans' g h) (sg, sh, None) = lzip (lunstream' g sg) (lunstream' h sh)"
by transfer(rule lzip_trans_fusion)

subsubsection @{const lappend}

lift_definition lappend_trans :: "('a, 'sg) lgenerator'  ('a, 'sh) lgenerator  'sh  ('a, 'sg + 'sh) lgenerator"
is append_raw .

lemma lunstream_append_raw:
  fixes g h sh gh defines [simp]: "gh  append_raw g h sh"
  assumes "productive g"
  shows "lunstream gh (Inl sg) = lappend (lunstream g sg) (lunstream h sh)"
proof(coinduction arbitrary: sg rule: llist.coinduct_strong)
  case (Eq_llist sg)
  { fix sh'
    have "lprefix (lunstream gh (Inr sh')) (lunstream h sh')"
      by(induction gh arbitrary: sh' rule: lunstream.fixp_induct)(simp_all add: lunstream_simps split: step.split)
    moreover have "lprefix (lunstream h sh') (lunstream gh (Inr sh'))"
      by(induction h arbitrary: sh' rule: lunstream.fixp_induct)(simp_all add: lunstream_simps split: step.split)
    ultimately have "lunstream gh (Inr sh') = lunstream h sh'"
      by(blast intro: lprefix_antisym) }
  note Inr = this[unfolded gh_def]
  from ‹productive g have sg: "sg  productive_on g" ..
  then show ?case by induction(auto simp add: lunstream_sels Inr)
qed

lemma lappend_trans_fusion [stream_fusion]:
  "lunstream (lappend_trans g h sh) (Inl sg) = lappend (lunstream' g sg) (lunstream h sh)"
by transfer(rule lunstream_append_raw)

lift_definition lappend_trans' :: "('a, 'sg) lgenerator'  ('a, 'sh) lgenerator'  'sh  ('a, 'sg + 'sh) lgenerator'"
is append_raw
proof
  fix g :: "('a, 'sg) lgenerator" and h :: "('a, 'sh) lgenerator" and sh s
  assume "productive g" "productive h"
  { fix sh'
    from ‹productive h have "sh'  productive_on h" ..
    then have "Inr sh'  productive_on (append_raw g h sh)"
      by induction (auto intro: productive_on.intros)
  } moreover {
    fix sg
    from ‹productive g have "sg  productive_on g" ..
    then have "Inl sg  productive_on (append_raw g h sh)"
      by induction(auto intro: productive_on.intros calculation) }
  ultimately show "s  productive_on (append_raw g h sh)" by(cases s) auto
qed

lemma lappend_trans'_fusion [stream_fusion]:
  "lunstream' (lappend_trans' g h sh) (Inl sg) = lappend (lunstream' g sg) (lunstream' h sh)"
by transfer(rule lunstream_append_raw)

subsubsection @{const lfilter}

definition lfilter_trans :: "('a  bool)  ('a, 's) lgenerator  ('a, 's) lgenerator"
where "lfilter_trans = filter_raw"

lemma lunstream_lfilter_trans [stream_fusion]:
  fixes P g g' defines [simp]: "g'  lfilter_trans P g"
  shows "lunstream g' s = lfilter P (lunstream g s)" (is "?lhs = ?rhs")
proof(rule lprefix_antisym)
  show "lprefix ?lhs ?rhs"
    by(induction g' arbitrary: s rule: lunstream.fixp_induct)
      (simp_all add: lfilter_trans_def filter_raw_def lunstream_simps split: step.split)
  show "lprefix ?rhs ?lhs"
  by(induction g arbitrary: s rule: lunstream.fixp_induct) 
    (simp_all add: lfilter_trans_def filter_raw_def lunstream_simps split: step.split)
qed

lemma lunstream_lfilter_trans2 [stream_fusion]:
  "lunstream (lfilter_trans P (lgenerator g)) s = lfilter P (lunstream' g s)"
by transfer(rule lunstream_lfilter_trans)

subsubsection @{const llist_of}

lift_definition llist_of_trans :: "('a, 's) generator  ('a, 's) lgenerator'"
is "λx. x"
proof
  fix g :: "('a, 's) raw_generator" and s
  assume "terminates g"
  hence "s  terminates_on g" by(simp add: terminates_def)
  then show "s  productive_on g"
    by(induction)(auto intro: productive_on.intros)
qed

lemma lunstream_llist_of_trans [stream_fusion]:
  "lunstream' (llist_of_trans g) s = llist_of (unstream g s)"
apply(induction s taking: g rule: unstream.induct)
apply(rule llist.expand)
apply(auto intro: llist.expand simp add: llist_of_trans.rep_eq lunstream_sels lunstream'.rep_eq split: step.split)
done

text ‹We cannot define a stream version of @{const list_of} because we would have to test
  for finiteness first and therefore traverse the list twice.›

end

Theory Stream_Fusion_Examples

(* Title: Stream_Fusion_Examples
  Author: Andreas Lochbihler, ETH Zurich *)

section ‹Examples and test cases for stream fusion›

theory Stream_Fusion_Examples imports Stream_Fusion_LList begin

lemma fixes rhs z
  defines "rhs  nth_cons (flatten (λs'. s') (upto_prod 17) (upto_prod z)) (2, None) 8"
  shows "nth (List.maps (λx. upto x 17) (upto 2 z)) 8 = rhs"
using [[simproc add: stream_fusion, stream_fusion_trace]]
apply(simp del: id_apply) ― ‹fuses›
by(unfold rhs_def) rule

lemma fixes rhs z
  defines "rhs  nth_cons (flatten (λs. (s, 1)) (fix_gen (λx. upto_prod (id x))) (upto_prod z)) (2, None) 8"
  shows "nth (List.maps (λx. upto 1 (id x)) (upto 2 z)) 8 = rhs"
using [[simproc add: stream_fusion, stream_fusion_trace]]
apply(simp del: id_apply) ― ‹fuses›
by(unfold rhs_def) rule

lemma fixes rhs n
  defines "rhs  List.maps (λx. [Suc 0..<sum_list_cons (replicate_prod x) x]) [2..<n]"
  shows "(concat (map (λx. [1..<sum_list (replicate x x)]) [2..<n])) = rhs"
using [[simproc add: stream_fusion, stream_fusion_trace]]
apply(simp add: concat_map_maps) ― ‹fuses partially›
by(unfold rhs_def) rule

subsection ‹Micro-benchmarks from Farmer et al. \cite{FarmerHoenerGill2014PEPM}›

definition test_enum :: "nat  nat" ― ‹@{const id} required to avoid eta contraction›
where "test_enum n = foldl (+) 0 (List.maps (λx. upt 1 (id x)) (upt 1 n))"

definition test_nested :: "nat  nat"
where "test_nested n = foldl (+) 0 (List.maps (λx. List.maps (λy. upt y x) (upt 1 x)) (upt 1 n))"

definition test_merge :: "integer  nat"
where "test_merge n = foldl (+) 0 (List.maps (λx. if 2 dvd x then upt 1 x else upt 2 x) (upt 1 (nat_of_integer n)))"

text ‹
  This rule performs the merge operation from \cite[\S 5.2]{FarmerHoenerGill2014PEPM} for if›.
  In general, we would also need it for all case operators.
›
lemma unstream_if [stream_fusion]:
  "unstream (if b then g else g') (if b then s else s') =
   (if b then unstream g s else unstream g' s')"
by simp

lemma if_same [code_unfold]: "(if b then x else x) = x"
by simp

code_thms test_enum
code_thms test_nested
code_thms test_merge

subsection ‹Test stream fusion in the code generator›

definition fuse_test :: integer
where "fuse_test = 
  integer_of_int (lhd (lfilter (λx. x < 1) (lappend (lmap (λx. x + 1) (llist_of (map (λx. if x = 0 then undefined else x) [-3..5]))) (repeat 3))))"

ML_val val ~2 = @{code fuse_test} ― ‹If this test fails with exception Fail, then the stream fusion simproc failed. This test exploits
  that stream fusion introduces laziness.›

end