Session Monad_Memo_DP

Theory State_Monad_Ext

subsection ‹State Monad›

theory State_Monad_Ext
  imports "HOL-Library.State_Monad"
begin

definition fun_app_lifted :: "('M,'a  ('M, 'b) state) state  ('M,'a) state  ('M,'b) state" where
  "fun_app_lifted fT xT  do { f  fT; x  xT; f x }"

bundle state_monad_syntax begin

notation fun_app_lifted (infixl "." 999)
type_synonym ('a,'M,'b) fun_lifted = "'a  ('M,'b) state" ("_ ==_ _" [3,1000,2] 2)
type_synonym ('a,'b) dpfun = "'a ==('a'b) 'b" (infixr "T" 2)
type_notation state ("[_| _]")

notation State_Monad.return ("_")
notation (ASCII) State_Monad.return ("(#_#)")
notation Transfer.Rel ("Rel")

end

context includes state_monad_syntax begin

qualified lemma return_app_return:
  "f . x = f x"
  unfolding fun_app_lifted_def bind_left_identity ..

qualified lemma return_app_return_meta:
  "f . x  f x"
  unfolding return_app_return .

qualified definition ifT :: "('M, bool) state  ('M, 'a) state  ('M, 'a) state  ('M, 'a) state" where
  "ifT bT xT yT  do {b  bT; if b then xT else yT}"
end

end

Theory Pure_Monad

section ‹Monadification›

subsection ‹Monads›

theory Pure_Monad
  imports Main
begin

definition Wrap :: "'a  'a" where
  "Wrap x  x"

definition App :: "('a  'b)  'a  'b" where
  "App f  f"

lemma Wrap_App_Wrap:
  "App (Wrap f) (Wrap x)  f x"
  unfolding App_def Wrap_def .


end

Theory DP_CRelVS

subsection ‹Parametricity of the State Monad›

theory DP_CRelVS
  imports "./State_Monad_Ext" "../Pure_Monad"
begin

definition lift_p :: "('s  bool)  ('s, 'a) state  bool" where
  "lift_p P f =
    ( heap. P heap  (case State_Monad.run_state f heap of (_, heap)  P heap))"

context
  fixes P f heap
  assumes lift: "lift_p P f" and P: "P heap"
begin

lemma run_state_cases:
  "case State_Monad.run_state f heap of (_, heap)  P heap"
  using lift P unfolding lift_p_def by auto

lemma lift_p_P:
  "P heap'" if "State_Monad.run_state f heap = (v, heap')"
  using that run_state_cases by auto

end

locale state_mem_defs =
  fixes lookup :: "'param  ('mem, 'result option) state"
    and update :: "'param  'result  ('mem, unit) state"
begin

definition checkmem :: "'param  ('mem, 'result) state  ('mem, 'result) state" where
  "checkmem param calc  do {
    x  lookup param;
    case x of
      Some x  State_Monad.return x
    | None  do {
        x  calc;
        update param x;
        State_Monad.return x
      }
  }"

abbreviation checkmem_eq ::
  "('param  ('mem, 'result) state)  'param  ('mem, 'result) state  bool"
  ("_$ _ =CHECKMEM= _" [1000,51] 51) where
  "(dpT$ param =CHECKMEM= calc)  (dpT param = checkmem param calc)"
term 0 (**)

definition map_of where
  "map_of heap k = fst (run_state (lookup k) heap)"

definition checkmem' :: "'param  (unit  ('mem, 'result) state)  ('mem, 'result) state" where
  "checkmem' param calc  do {
    x  lookup param;
    case x of
      Some x  State_Monad.return x
    | None  do {
        x  calc ();
        update param x;
        State_Monad.return x
      }
  }"

lemma checkmem_checkmem':
  "checkmem' param (λ_. calc) = checkmem param calc"
  unfolding checkmem'_def checkmem_def ..

lemma checkmem_eq_alt:
  "checkmem_eq dp param calc = (dp param = checkmem' param (λ _. calc))"
  unfolding checkmem_checkmem' ..

end (* Mem Defs *)


locale mem_correct = state_mem_defs +
  fixes P
  assumes lookup_inv: "lift_p P (lookup k)" and update_inv: "lift_p P (update k v)"
  assumes
    lookup_correct: "P m  map_of (snd (State_Monad.run_state (lookup k) m)) m (map_of m)"
      and
    update_correct: "P m  map_of (snd (State_Monad.run_state (update k v) m)) m (map_of m)(k  v)"
  (* assumes correct: "lookup (update m k v) ⊆m (lookup m)(k ↦ v)" *)

locale dp_consistency =
  mem_correct lookup update P
  for lookup :: "'param  ('mem, 'result option) state" and update and P +
  fixes dp :: "'param  'result"
begin

context
  includes lifting_syntax state_monad_syntax
begin

definition cmem :: "'mem  bool" where
  "cmem M  paramdom (map_of M). map_of M param = Some (dp param)"

definition crel_vs :: "('a  'b  bool)  'a  ('mem, 'b) state  bool" where
  "crel_vs R v s  M. cmem M  P M  (case State_Monad.run_state s M of (v', M')  R v v'  cmem M'  P M')"
  
abbreviation rel_fun_lifted :: "('a  'c  bool)  ('b  'd  bool)  ('a  'b)  ('c ==_ 'd)  bool" (infixr "===>T" 55) where
  "rel_fun_lifted R R'  R ===> crel_vs R'"
term 0 (**)

definition consistentDP :: "('param == 'mem  'result)  bool" where
  "consistentDP  ((=) ===> crel_vs (=)) dp"
term 0 (**)
  
  (* cmem *)
private lemma cmem_intro:
  assumes "param v M'. State_Monad.run_state (lookup param) M = (Some v, M')  v = dp param"
  shows "cmem M"
  unfolding cmem_def map_of_def
  apply safe
  subgoal for param y
    by (cases "State_Monad.run_state (lookup param) M") (auto intro: assms)
  done

lemma cmem_elim:
  assumes "cmem M" "State_Monad.run_state (lookup param) M = (Some v, M')"
  obtains "dp param = v"
  using assms unfolding cmem_def dom_def map_of_def by auto (metis fst_conv option.inject)
term 0 (**)
  
  (* crel_vs *)
lemma crel_vs_intro:
  assumes "M v' M'. cmem M; P M; State_Monad.run_state vT M = (v', M')  R v v'  cmem M'  P M'"
  shows "crel_vs R v vT"
  using assms unfolding crel_vs_def by blast
term 0 (**)
  
lemma crel_vs_elim:
  assumes "crel_vs R v vT" "cmem M" "P M"
  obtains v' M' where "State_Monad.run_state vT M = (v', M')" "R v v'" "cmem M'" "P M'"
  using assms unfolding crel_vs_def by blast
term 0 (**)
  
  (* consistentDP *)
lemma consistentDP_intro:
  assumes "param. Transfer.Rel (crel_vs (=)) (dp param) (dpT param)"
  shows "consistentDP dpT"
  using assms unfolding consistentDP_def Rel_def by blast
  
lemma crel_vs_return:
  "Transfer.Rel R x y  Transfer.Rel (crel_vs R) (Wrap x) (State_Monad.return y)"
  unfolding State_Monad.return_def Wrap_def Rel_def by (fastforce intro: crel_vs_intro)
term 0 (**)
  
lemma crel_vs_return_ext:
  "Transfer.Rel R x y  Transfer.Rel (crel_vs R) x (State_Monad.return y)"
  by (fact crel_vs_return[unfolded Wrap_def])
term 0 (**)

  (* Low level operators *)
private lemma cmem_upd:
  "cmem M'" if "cmem M" "P M" "State_Monad.run_state (update param (dp param)) M = (v, M')"
  using update_correct[of M param "dp param"] that unfolding cmem_def map_le_def by simp force

private lemma P_upd:
  "P M'" if "P M" "State_Monad.run_state (update param (dp param)) M = (v, M')"
  by (meson lift_p_P that update_inv)

private lemma crel_vs_get:
  "M. cmem M  crel_vs R v (sf M)  crel_vs R v (State_Monad.get  sf)"
  unfolding State_Monad.get_def State_Monad.bind_def by (fastforce intro: crel_vs_intro elim: crel_vs_elim split: prod.split)
term 0 (**)
  
private lemma crel_vs_set:
  "crel_vs R v sf; cmem M; P M  crel_vs R v (State_Monad.set M  sf)"
  unfolding State_Monad.set_def State_Monad.bind_def by (fastforce intro: crel_vs_intro elim: crel_vs_elim split: prod.split)
term 0 (**)
  
private lemma crel_vs_bind_eq:
  "crel_vs (=) v s; crel_vs R (f v) (sf v)  crel_vs R (f v) (s  sf)"
  unfolding State_Monad.bind_def rel_fun_def by (fastforce intro: crel_vs_intro elim: crel_vs_elim split: prod.split)
term 0 (**)

lemma bind_transfer[transfer_rule]:
  "(crel_vs R0 ===> (R0 ===>T R1) ===> crel_vs R1) (λv f. f v) (⤜)"
  unfolding State_Monad.bind_def rel_fun_def by (fastforce intro: crel_vs_intro elim: crel_vs_elim split: prod.split)

private lemma cmem_lookup:
  "cmem M'" if "cmem M" "P M" "State_Monad.run_state (lookup param) M = (v, M')"
  using lookup_correct[of M param] that unfolding cmem_def map_le_def by force

private lemma P_lookup:
  "P M'" if "P M" "State_Monad.run_state (lookup param) M = (v, M')"
  by (meson lift_p_P that lookup_inv)

lemma crel_vs_lookup:
  "crel_vs (λ v v'. case v' of None  True | Some v'  v = v'  v = dp param) (dp param) (lookup param)"
  by (auto elim: cmem_elim intro: cmem_lookup crel_vs_intro P_lookup split: option.split)

lemma crel_vs_update:
  "crel_vs (=) () (update param (dp param))"
  by (auto intro: cmem_upd crel_vs_intro P_upd)

private lemma crel_vs_checkmem:
  "is_equality R; Transfer.Rel (crel_vs R) (dp param) s
   Transfer.Rel (crel_vs R) (dp param) (checkmem param s)"
  unfolding checkmem_def Rel_def is_equality_def
  by (rule bind_transfer[unfolded rel_fun_def, rule_format, OF crel_vs_lookup])
     (auto 4 3 intro: crel_vs_lookup crel_vs_update crel_vs_return[unfolded Rel_def Wrap_def] crel_vs_bind_eq
               split: option.split_asm
     )

lemma crel_vs_checkmem_tupled:
  assumes "v = dp param"
  shows "is_equality R; Transfer.Rel (crel_vs R) v s
         Transfer.Rel (crel_vs R) v (checkmem param s)"
  unfolding assms by (fact crel_vs_checkmem)

  (** Transfer rules **)
  (* Basics *)
lemma return_transfer[transfer_rule]:
  "(R ===>T R) Wrap State_Monad.return"
  unfolding rel_fun_def by (metis crel_vs_return Rel_def)

lemma fun_app_lifted_transfer[transfer_rule]:
  "(crel_vs (R0 ===>T R1) ===> crel_vs R0 ===> crel_vs R1) App (.)"
  unfolding App_def fun_app_lifted_def by transfer_prover
    
lemma crel_vs_fun_app:
  "Transfer.Rel (crel_vs R0) x xT; Transfer.Rel (crel_vs (R0 ===>T R1)) f fT  Transfer.Rel (crel_vs R1) (App f x) (fT . xT)"
  unfolding Rel_def using fun_app_lifted_transfer[THEN rel_funD, THEN rel_funD] .

  (* HOL *)
lemma ifT_transfer[transfer_rule]:
  "(crel_vs (=) ===> crel_vs R ===> crel_vs R ===> crel_vs R) If State_Monad_Ext.ifT"
  unfolding State_Monad_Ext.ifT_def by transfer_prover
end (* Lifting Syntax *)

end (* Consistency *)
end (* Theory *)

Theory State_Heap_Misc

subsection ‹Miscellaneous Parametricity Theorems›

theory State_Heap_Misc
  imports Main
begin
context  includes lifting_syntax begin
lemma rel_fun_comp:
  assumes "(R1 ===> S1) f g" "(R2 ===> S2) g h"
  shows "(R1 OO R2 ===> S1 OO S2) f h"
  using assms by (auto intro!: rel_funI dest!: rel_funD)

lemma rel_fun_comp1:
  assumes "(R1 ===> S1) f g" "(R2 ===> S2) g h" "R' = R1 OO R2"
  shows "(R' ===> S1 OO S2) f h"
  using assms rel_fun_comp by metis

lemma rel_fun_comp2:
  assumes "(R1 ===> S1) f g" "(R2 ===> S2) g h" "S' = S1 OO S2"
  shows "(R1 OO R2 ===> S') f h"
  using assms rel_fun_comp by metis

lemma rel_fun_relcompp:
  "((R1 ===> S1) OO (R2 ===> S2)) a b  ((R1 OO R2) ===> (S1 OO S2)) a b"
  unfolding OO_def rel_fun_def by blast

lemma rel_fun_comp1':
  assumes "(R1 ===> S1) f g" "(R2 ===> S2) g h" " a b. R' a b  (R1 OO R2) a b"
  shows "(R' ===> S1 OO S2) f h"
  by (auto intro: assms rel_fun_mono[OF rel_fun_comp1])

lemma rel_fun_comp2':
  assumes "(R1 ===> S1) f g" "(R2 ===> S2) g h" " a b. (S1 OO S2) a b  S' a b"
  shows "(R1 OO R2 ===> S') f h"
  by (auto intro: assms rel_fun_mono[OF rel_fun_comp1])

end
end

Theory Heap_Monad_Ext

subsection ‹Heap Monad›

theory Heap_Monad_Ext
  imports "HOL-Imperative_HOL.Imperative_HOL"
begin

definition fun_app_lifted :: "('a  'b Heap) Heap  'a Heap  'b Heap" where
  "fun_app_lifted fT xT  do { f  fT; x  xT; f x }"

bundle heap_monad_syntax begin

notation fun_app_lifted (infixl "." 999)
type_synonym ('a, 'b) fun_lifted = "'a  'b Heap" ("_ ==H⟹ _" [3,2] 2)
type_notation Heap ("[_]")

notation Heap_Monad.return ("_")
notation (ASCII) Heap_Monad.return ("(#_#)")
notation Transfer.Rel ("Rel")

end

context includes heap_monad_syntax begin

qualified lemma return_app_return:
  "f . x = f x"
  unfolding fun_app_lifted_def return_bind ..

qualified lemma return_app_return_meta:
  "f . x  f x"
  unfolding return_app_return .

qualified definition ifT :: "bool Heap  'a Heap  'a Heap  'a Heap" where
  "ifT bT xT yT  do {b  bT; if b then xT else yT}"
end

end

Theory State_Heap

subsection ‹Relation Between the State and the Heap Monad›

theory State_Heap
  imports
    "../state_monad/DP_CRelVS"
    "HOL-Imperative_HOL.Imperative_HOL"
    State_Heap_Misc
    Heap_Monad_Ext
begin

definition lift_p :: "(heap  bool)  'a Heap  bool" where
  "lift_p P f =
    ( heap. P heap  (case execute f heap of None  False | Some (_, heap)  P heap))"

context
  fixes P f heap
  assumes lift: "lift_p P f" and P: "P heap"
begin

lemma execute_cases:
  "case execute f heap of None  False | Some (_, heap)  P heap"
  using lift P unfolding lift_p_def by auto

lemma execute_cases':
  "case execute f heap of Some (_, heap)  P heap"
  using execute_cases by (auto split: option.split)

lemma lift_p_None[simp, dest]:
  False if "execute f heap = None"
  using that execute_cases by auto

lemma lift_p_P:
  "case the (execute f heap) of (_, heap)  P heap"
  using execute_cases by (auto split: option.split_asm)

lemma lift_p_P':
  "P heap'" if "the (execute f heap) = (v, heap')"
  using that lift_p_P by auto

lemma lift_p_P'':
  "P heap'" if "execute f heap = Some (v, heap')"
  using that lift_p_P by auto

lemma lift_p_the_Some[simp]:
  "execute f heap = Some (v, heap')" if "the (execute f heap) = (v, heap')"
  using that execute_cases by (auto split: option.split_asm)

lemma lift_p_E:
  obtains v heap' where "execute f heap = Some (v, heap')" "P heap'"
  using execute_cases by (cases "execute f heap") auto

end

definition "state_of s  State (λ heap. the (execute s heap))"

locale heap_mem_defs =
  fixes P :: "heap  bool"
    and lookup :: "'k  'v option Heap"
    and update :: "'k  'v  unit Heap"
begin

definition rel_state :: "('a  'b  bool)  (heap, 'a) state  'b Heap  bool" where
  "rel_state R f g 
     heap. P heap 
      (case State_Monad.run_state f heap of (v1, heap1)  case execute g heap of
        Some (v2, heap2)  R v1 v2  heap1 = heap2  P heap2 | None  False)"

definition "lookup' k  State (λ heap. the (execute (lookup k) heap))"

definition "update' k v  State (λ heap. the (execute (update k v) heap))"

definition "heap_get = Heap_Monad.Heap (λ heap. Some (heap, heap))"

definition checkmem :: "'k  'v Heap  'v Heap" where
  "checkmem param calc 
    Heap_Monad.bind (lookup param) (λ x.
    case x of
      Some x  return x
    | None  Heap_Monad.bind calc (λ x.
        Heap_Monad.bind (update param x) (λ _.
        return x
      )
    )
  )
  "

definition checkmem' :: "'k  (unit  'v Heap)  'v Heap" where
  "checkmem' param calc 
    Heap_Monad.bind (lookup param) (λ x.
    case x of
      Some x  return x
    | None  Heap_Monad.bind (calc ()) (λ x.
        Heap_Monad.bind (update param x) (λ _.
        return x
      )
    )
  )
  "

lemma checkmem_checkmem':
  "checkmem' param (λ_. calc) = checkmem param calc"
  unfolding checkmem'_def checkmem_def ..

definition map_of_heap where
  "map_of_heap heap k = fst (the (execute (lookup k) heap))"

lemma rel_state_elim:
  assumes "rel_state R f g" "P heap"
  obtains heap' v v' where
    "State_Monad.run_state f heap = (v, heap')" "execute g heap = Some (v', heap')" "R v v'" "P heap'"
  apply atomize_elim
  using assms unfolding rel_state_def
  apply auto
  apply (cases "State_Monad.run_state f heap")
  apply auto
  apply (auto split: option.split_asm)
  done

lemma rel_state_intro:
  assumes
    " heap v heap'. P heap  State_Monad.run_state f heap = (v, heap')
        v'. R v v'  execute g heap = Some (v', heap')"
    " heap v heap'. P heap  State_Monad.run_state f heap = (v, heap')  P heap'"
  shows "rel_state R f g"
  unfolding rel_state_def
  apply auto
  apply (frule assms(1)[rotated])
   apply (auto intro: assms(2))
  done

context
  includes lifting_syntax state_monad_syntax
begin

lemma transfer_bind[transfer_rule]:
  "(rel_state R ===> (R ===> rel_state Q) ===> rel_state Q) State_Monad.bind Heap_Monad.bind"
  unfolding rel_fun_def State_Monad.bind_def Heap_Monad.bind_def
  by (force elim!: rel_state_elim intro!: rel_state_intro)

lemma transfer_return[transfer_rule]:
  "(R ===> rel_state R) State_Monad.return Heap_Monad.return"
  unfolding rel_fun_def State_Monad.return_def Heap_Monad.return_def
  by (fastforce intro: rel_state_intro elim: rel_state_elim simp: execute_heap)

lemma fun_app_lifted_transfer:
  "(rel_state (R ===> rel_state Q) ===> rel_state R ===> rel_state Q)
      State_Monad_Ext.fun_app_lifted Heap_Monad_Ext.fun_app_lifted"
  unfolding State_Monad_Ext.fun_app_lifted_def Heap_Monad_Ext.fun_app_lifted_def by transfer_prover

lemma transfer_get[transfer_rule]:
  "rel_state (=) State_Monad.get heap_get"
  unfolding State_Monad.get_def heap_get_def by (auto intro: rel_state_intro)

end (* Lifting Syntax *)

end (* Heap Mem Defs *)

locale heap_inv = heap_mem_defs _ lookup for lookup :: "'k  'v option Heap"  +
  assumes lookup_inv: "lift_p P (lookup k)"
  assumes update_inv: "lift_p P (update k v)"
begin

lemma rel_state_lookup:
  "rel_state (=) (lookup' k) (lookup k)"
  unfolding rel_state_def lookup'_def using lookup_inv[of k] by (auto intro: lift_p_P')

lemma rel_state_update:
  "rel_state (=) (update' k v) (update k v)"
  unfolding rel_state_def update'_def using update_inv[of k v] by (auto intro: lift_p_P')

context
  includes lifting_syntax
begin

lemma transfer_lookup:
  "((=) ===> rel_state (=)) lookup' lookup"
  unfolding rel_fun_def by (auto intro: rel_state_lookup)

lemma transfer_update:
  "((=) ===> (=) ===> rel_state (=)) update' update"
  unfolding rel_fun_def by (auto intro: rel_state_update)

lemma transfer_checkmem:
  "((=) ===> rel_state (=) ===> rel_state (=))
    (state_mem_defs.checkmem lookup' update') checkmem"
  supply [transfer_rule] = transfer_lookup transfer_update
  unfolding state_mem_defs.checkmem_def checkmem_def by transfer_prover

end (* Lifting Syntax *)

end (* Heap Invariant *)

locale heap_correct =
  heap_inv +
  assumes lookup_correct:
      "P m  map_of_heap (snd (the (execute (lookup k) m))) m (map_of_heap m)"
  and update_correct:
      "P m  map_of_heap (snd (the (execute (update k v) m))) m (map_of_heap m)(k  v)"
begin

lemma lookup'_correct:
  "state_mem_defs.map_of lookup' (snd (State_Monad.run_state (lookup' k) m)) m (state_mem_defs.map_of lookup' m)" if "P m"
  using P m unfolding state_mem_defs.map_of_def map_le_def lookup'_def
  by simp (metis (mono_tags, lifting) domIff lookup_correct map_le_def map_of_heap_def)

lemma update'_correct:
  "state_mem_defs.map_of lookup' (snd (State_Monad.run_state (update' k v) m)) m state_mem_defs.map_of lookup' m(k  v)"
  if "P m"
  unfolding state_mem_defs.map_of_def map_le_def lookup'_def update'_def
  using update_correct[of m k v] that by (auto split: if_split_asm simp: map_le_def map_of_heap_def)

lemma lookup'_inv:
  "DP_CRelVS.lift_p P (lookup' k)"
  unfolding DP_CRelVS.lift_p_def lookup'_def by (auto elim: lift_p_P'[OF lookup_inv])

lemma update'_inv:
  "DP_CRelVS.lift_p P (update' k v)"
  unfolding DP_CRelVS.lift_p_def update'_def by (auto elim: lift_p_P'[OF update_inv])

lemma mem_correct_heap: "mem_correct lookup' update' P"
  by (intro mem_correct.intro lookup'_correct update'_correct lookup'_inv update'_inv)

end (* Heap correct *)

context heap_mem_defs
begin

context
  includes lifting_syntax
begin

lemma mem_correct_heap_correct:
  assumes correct: "mem_correct lookups updates P"
    and lookup: "((=) ===> rel_state (=)) lookups lookup"
    and update: "((=) ===> (=) ===> rel_state (=)) updates update"
  shows "heap_correct P update lookup"
proof -
  interpret mem: mem_correct lookups updates P
    by (rule correct)
  have [simp]: "the (execute (lookup k) m) = run_state (lookups k) m" if "P m" for k m
    using lookup[THEN rel_funD, OF HOL.refl, of k] P m by (auto elim: rel_state_elim)
  have [simp]: "the (execute (update k v) m) = run_state (updates k v) m" if "P m" for k v m
    using update[THEN rel_funD, THEN rel_funD, OF HOL.refl HOL.refl, of k v] P m
    by (auto elim: rel_state_elim)
  have [simp]: "map_of_heap m = mem.map_of m" if "P m" for m
    unfolding map_of_heap_def mem.map_of_def using P m by simp
  show ?thesis
  apply standard
    subgoal for k
      using mem.lookup_inv[of k] lookup[THEN rel_funD, OF HOL.refl, of k]
      unfolding lift_p_def DP_CRelVS.lift_p_def
      by (auto split: option.splits elim: rel_state_elim)
    subgoal for k v
      using mem.update_inv[of k] update[THEN rel_funD, THEN rel_funD, OF HOL.refl HOL.refl, of k v]
      unfolding lift_p_def DP_CRelVS.lift_p_def
      by (auto split: option.splits elim: rel_state_elim)
    subgoal premises prems for m k
    proof -
      have "P (snd (run_state (lookups k) m))"
        by (meson DP_CRelVS.lift_p_P mem.lookup_inv prems prod.exhaust_sel)
      with mem.lookup_correct[OF P m, of k] P m show ?thesis
        by (simp add: prems)
    qed
    subgoal premises prems for m k v
    proof -
      have "P (snd (run_state (updates k v) m))"
        by (meson DP_CRelVS.lift_p_P mem.update_inv prems prod.exhaust_sel)
      with mem.update_correct[OF P m, of k] P m show ?thesis
        by (simp add: prems)
    qed
    done
qed

end

end

end (* Theory *)

Theory DP_CRelVH

subsection ‹Parametricity of the Heap Monad›

theory DP_CRelVH
  imports State_Heap
begin

locale dp_heap =
  state_dp_consistency: dp_consistency lookup_st update_st P dp + heap_mem_defs Q lookup update
  for P Q :: "heap  bool" and dp :: "'k  'v" and lookup :: "'k  'v option Heap"
  and lookup_st update update_st +
  assumes
    rel_state_lookup: "rel_fun (=) (rel_state (=)) lookup_st lookup"
      and
    rel_state_update: "rel_fun (=) (rel_fun (=) (rel_state (=))) update_st update"
begin

context
  includes lifting_syntax heap_monad_syntax
begin

definition "crel_vs R v f 
  heap. P heap  Q heap  state_dp_consistency.cmem heap 
    (case execute f heap of
      None  False |
      Some (v', heap')  P heap'  Q heap'  R v v'  state_dp_consistency.cmem heap'
    )
"

abbreviation rel_fun_lifted :: "('a  'c  bool)  ('b  'd  bool)  ('a  'b)  ('c ==H⟹ 'd)  bool" (infixr "===>T" 55) where
  "rel_fun_lifted R R'  R ===> crel_vs R'"


definition consistentDP :: "('k  'v Heap)  bool" where
  "consistentDP  ((=) ===> crel_vs (=)) dp"

lemma consistentDP_intro:
  assumes "param. Transfer.Rel (crel_vs (=)) (dp param) (dpT param)"
  shows "consistentDP dpT"
  using assms unfolding consistentDP_def Rel_def by blast

lemma crel_vs_execute_None:
  False if "crel_vs R a b" "execute b heap = None" "P heap" "Q heap" "state_dp_consistency.cmem heap"
  using that unfolding crel_vs_def by auto

lemma crel_vs_execute_Some:
  assumes "crel_vs R a b" "P heap" "Q heap" "state_dp_consistency.cmem heap"
  obtains x heap' where "execute b heap = Some (x, heap')" "P heap'" "Q heap'"
  using assms unfolding crel_vs_def by (cases "execute b heap") auto

lemma crel_vs_executeD:
  assumes "crel_vs R a b" "P heap" "Q heap" "state_dp_consistency.cmem heap"
  obtains x heap' where
    "execute b heap = Some (x, heap')" "P heap'" "Q heap'" "state_dp_consistency.cmem heap'" "R a x"
  using assms unfolding crel_vs_def by (cases "execute b heap") auto

lemma crel_vs_success:
  assumes "crel_vs R a b" "P heap" "Q heap" "state_dp_consistency.cmem heap"
  shows "success b heap"
  using assms unfolding success_def by (auto elim: crel_vs_executeD)

lemma crel_vsI: "crel_vs R a b" if "(state_dp_consistency.crel_vs R OO rel_state (=)) a b"
  using that by (auto 4 3 elim: state_dp_consistency.crel_vs_elim rel_state_elim simp: crel_vs_def)

lemma transfer'_return[transfer_rule]:
  "(R ===> crel_vs R) Wrap return"
proof -
  have "(R ===> (state_dp_consistency.crel_vs R OO rel_state (=))) Wrap return"
    by (rule rel_fun_comp1 state_dp_consistency.return_transfer transfer_return)+ auto
  then show ?thesis
    by (blast intro: rel_fun_mono crel_vsI)
qed

lemma crel_vs_return:
  "Transfer.Rel (crel_vs R) (Wrap x) (return y)" if "Transfer.Rel R x y"
  using that unfolding Rel_def by (rule transfer'_return[unfolded rel_fun_def, rule_format])

lemma crel_vs_return_ext:
  "Transfer.Rel R x y  Transfer.Rel (crel_vs R) x (Heap_Monad.return y)"
  by (fact crel_vs_return[unfolded Wrap_def])
term 0 (**)

lemma bind_transfer[transfer_rule]:
  "(crel_vs R0 ===> (R0 ===> crel_vs R1) ===> crel_vs R1) (λv f. f v) (⤜)"
  unfolding rel_fun_def bind_def
  by safe (subst crel_vs_def, auto 4 4 elim: crel_vs_execute_Some elim!: crel_vs_executeD)


lemma crel_vs_update:
  "crel_vs (=) () (update param (dp param))"
  by (rule
      crel_vsI relcomppI state_dp_consistency.crel_vs_update
      rel_state_update[unfolded rel_fun_def, rule_format] HOL.refl
     )+

lemma crel_vs_lookup:
  "crel_vs
    (λ v v'. case v' of None  True | Some v'  v = v'  v = dp param) (dp param) (lookup param)"
  by (rule
      crel_vsI relcomppI state_dp_consistency.crel_vs_lookup
      rel_state_lookup[unfolded rel_fun_def, rule_format] HOL.refl
     )+

lemma crel_vs_eq_eq_onp:
  "crel_vs (eq_onp (λ x. x = v)) v s" if "crel_vs (=) v s"
  using that unfolding crel_vs_def by (auto split: option.split simp: eq_onp_def)

lemma crel_vs_bind_eq:
  "crel_vs (=) v s; crel_vs R (f v) (sf v)  crel_vs R (f v) (s  sf)"
  by (erule bind_transfer[unfolded rel_fun_def, rule_format, OF crel_vs_eq_eq_onp])
     (auto simp: eq_onp_def)

lemma crel_vs_checkmem:
  "Transfer.Rel (crel_vs R) (dp param) (checkmem param s)" if "is_equality R" "Transfer.Rel (crel_vs R) (dp param) s"
  unfolding checkmem_def Rel_def that(1)[unfolded is_equality_def]
  by (rule bind_transfer[unfolded rel_fun_def, rule_format, OF crel_vs_lookup])
     (auto 4 3 split: option.split_asm intro: crel_vs_bind_eq crel_vs_update crel_vs_return[unfolded Wrap_def Rel_def] that(2)[unfolded Rel_def that(1)[unfolded is_equality_def]])

lemma crel_vs_checkmem_tupled:
  assumes "v = dp param"
  shows "is_equality R; Transfer.Rel (crel_vs R) v s
         Transfer.Rel (crel_vs R) v (checkmem param s)"
  unfolding assms by (fact crel_vs_checkmem)

lemma transfer_fun_app_lifted[transfer_rule]:
  "(crel_vs (R0 ===> crel_vs R1) ===> crel_vs R0 ===> crel_vs R1)
    App Heap_Monad_Ext.fun_app_lifted"
  unfolding Heap_Monad_Ext.fun_app_lifted_def App_def by transfer_prover

lemma crel_vs_fun_app:
  "Transfer.Rel (crel_vs R0) x xT; Transfer.Rel (crel_vs (R0 ===>T R1)) f fT  Transfer.Rel (crel_vs R1) (App f x) (fT . xT)"
  unfolding Rel_def using transfer_fun_app_lifted[THEN rel_funD, THEN rel_funD] .

end (* Lifting Syntax *)

end (* Dynamic Programming Problem *)

locale dp_consistency_heap = heap_correct +
  fixes dp :: "'a  'b"
begin

interpretation state_mem_correct: mem_correct lookup' update' P
  by (rule mem_correct_heap)

interpretation state_dp_consistency: dp_consistency lookup' update' P dp ..

lemma dp_heap: "dp_heap P P lookup lookup' update update'"
  by (standard; rule transfer_lookup transfer_update)

sublocale dp_heap P P dp lookup lookup' update update'
  by (rule dp_heap)

notation rel_fun_lifted (infixr "===>T" 55)
end

locale heap_correct_empty = heap_correct +
  fixes empty
  assumes empty_correct: "map_of_heap empty m Map.empty" and P_empty: "P empty"

locale dp_consistency_heap_empty =
  dp_consistency_heap + heap_correct_empty
begin

lemma cmem_empty:
  "state_dp_consistency.cmem empty"
  using empty_correct
  unfolding state_dp_consistency.cmem_def
  unfolding map_of_heap_def
  unfolding state_dp_consistency.map_of_def
  unfolding lookup'_def
  unfolding map_le_def
  by auto

corollary memoization_correct:
  "dp x = v" "state_dp_consistency.cmem m" if
  "consistentDP dpT" "Heap_Monad.execute (dpT x) empty = Some (v, m)"
  using that unfolding consistentDP_def
  by (auto dest!: rel_funD[where x = x] elim!: crel_vs_executeD intro: P_empty cmem_empty)

lemma memoized_success:
  "success (dpT x) empty" if "consistentDP dpT"
  using that cmem_empty P_empty
  by (auto dest!: rel_funD intro: crel_vs_success simp: consistentDP_def)

lemma memoized:
  "dp x = fst (the (Heap_Monad.execute (dpT x) empty))" if "consistentDP dpT"
  using surjective_pairing memoization_correct(1)[OF that]
    memoized_success[OF that, unfolded success_def]
  by (cases "execute (dpT x) empty"; auto)

lemma cmem_result:
  "state_dp_consistency.cmem (snd (the (Heap_Monad.execute (dpT x) empty)))" if "consistentDP dpT"
  using surjective_pairing memoization_correct(2)[OF that(1)]
    memoized_success[OF that, unfolded success_def]
  by (cases "execute (dpT x) empty"; auto)

end

end (* Theory *)

Theory Memory

section ‹Memoization›

subsection ‹Memory Implementations for the State Monad›

theory Memory
  imports "DP_CRelVS" "HOL-Library.Mapping"
begin

lemma lift_pI[intro?]:
  "lift_p P f" if " heap x heap'. P heap  run_state f heap = (x, heap')  P heap'"
  unfolding lift_p_def by (auto intro: that)

lemma mem_correct_default:
  "mem_correct
    (λ k. do {m  State_Monad.get; State_Monad.return (m k)})
    (λ k v. do {m  State_Monad.get; State_Monad.set (m(kv))})
    (λ _. True)"
  by standard
    (auto simp: map_le_def state_mem_defs.map_of_def State_Monad.bind_def State_Monad.get_def State_Monad.return_def State_Monad.set_def lift_p_def)


lemma mem_correct_rbt_mapping:
  "mem_correct
    (λ k. do {m  State_Monad.get; State_Monad.return (Mapping.lookup m k)})
    (λ k v. do {m  State_Monad.get; State_Monad.set (Mapping.update k v m)})
    (λ _. True)"
  by standard
     (auto simp:
        map_le_def state_mem_defs.map_of_def State_Monad.bind_def State_Monad.get_def State_Monad.return_def State_Monad.set_def lookup_update' lift_p_def
     )



locale mem_correct_empty = mem_correct +
  fixes empty
  assumes empty_correct: "map_of empty m Map.empty" and P_empty: "P empty"

lemma (in mem_correct_empty) dom_empty[simp]:
  "dom (map_of empty) = {}"
  using empty_correct by (auto dest: map_le_implies_dom_le)

lemma mem_correct_empty_default:
  "mem_correct_empty
    (λ k. do {m  State_Monad.get; State_Monad.return (m k)})
    (λ k v. do {m  State_Monad.get; State_Monad.set (m(kv))})
    (λ _. True)
    Map.empty"
  by (intro mem_correct_empty.intro[OF mem_correct_default] mem_correct_empty_axioms.intro)
     (auto simp: state_mem_defs.map_of_def map_le_def State_Monad.bind_def State_Monad.get_def State_Monad.return_def)

lemma mem_correct_rbt_empty_mapping:
  "mem_correct_empty
    (λ k. do {m  State_Monad.get; State_Monad.return (Mapping.lookup m k)})
    (λ k v. do {m  State_Monad.get; State_Monad.set (Mapping.update k v m)})
    (λ _. True)
    Mapping.empty"
  by (intro mem_correct_empty.intro[OF mem_correct_rbt_mapping] mem_correct_empty_axioms.intro)
     (auto simp: state_mem_defs.map_of_def map_le_def State_Monad.bind_def State_Monad.get_def State_Monad.return_def lookup_empty)

locale dp_consistency_empty =
  dp_consistency + mem_correct_empty
begin

lemma cmem_empty:
  "cmem empty"
  using empty_correct unfolding cmem_def by auto

corollary memoization_correct:
  "dp x = v" "cmem m" if "consistentDP dpT" "State_Monad.run_state (dpT x) empty = (v, m)"
  using that unfolding consistentDP_def
  by (auto dest!: rel_funD[where x = x] elim!: crel_vs_elim intro: P_empty cmem_empty)

lemma memoized:
  "dp x = fst (State_Monad.run_state (dpT x) empty)" if "consistentDP dpT"
  using surjective_pairing memoization_correct(1)[OF that] by blast

lemma cmem_result:
  "cmem (snd (State_Monad.run_state (dpT x) empty))" if "consistentDP dpT"
  using surjective_pairing memoization_correct(2)[OF that] by blast

end (* DP Consistency Empty *)

locale dp_consistency_default =
  fixes dp :: "'param  'result"
begin

sublocale dp_consistency_empty
  "λ k. do {(m::'param  'result)  State_Monad.get; State_Monad.return (m k)}"
  "λ k v. do {m  State_Monad.get; State_Monad.set (m(kv))}"
  "λ (_::'param  'result). True"
  dp
  Map.empty
  by (intro
      dp_consistency_empty.intro dp_consistency.intro mem_correct_default mem_correct_empty_default
     )

end (* DP Consistency Default *)

locale dp_consistency_mapping =
  fixes dp :: "'param  'result"
begin

sublocale dp_consistency_empty
  "(λ k. do {(m::('param,'result) mapping)  State_Monad.get; State_Monad.return (Mapping.lookup m k)})"
    "(λ k v. do {m  State_Monad.get; State_Monad.set (Mapping.update k v m)})"
    "(λ _::('param,'result) mapping. True)" dp Mapping.empty
  by (intro
      dp_consistency_empty.intro dp_consistency.intro mem_correct_rbt_mapping
      mem_correct_rbt_empty_mapping
     )

end (* DP Consistency RBT *)

subsubsection ‹Tracing Memory›
context state_mem_defs
begin

definition
  "lookup_trace k =
  State (λ (log, m). case State_Monad.run_state (lookup k) m of
    (None, m)  (None, ((''Missed'', k) # log, m)) |
    (Some v, m)  (Some v, ((''Found'', k) # log, m))
  )"

definition
  "update_trace k v =
  State (λ (log, m). case State_Monad.run_state (update k v) m of
    (_, m)  ((), ((''Stored'', k) # log, m))
  )"

end

context mem_correct
begin

lemma map_of_simp:
  "state_mem_defs.map_of lookup_trace = map_of o snd"
  unfolding state_mem_defs.map_of_def lookup_trace_def
  by (rule ext) (auto split: prod.split option.split)

lemma mem_correct_tracing: "mem_correct lookup_trace update_trace (P o snd)"
  by standard
    (auto
      intro!: lift_pI
      elim: lift_p_P[OF lookup_inv]
      simp: lookup_trace_def update_trace_def state_mem_defs.map_of_def map_of_simp
      split: prod.splits option.splits;
      metis snd_conv lookup_correct update_correct lift_p_P update_inv lookup_inv lift_p_P
   )+

end

context mem_correct_empty
begin

lemma mem_correct_tracing_empty:
  "mem_correct_empty lookup_trace update_trace (P o snd) ([], empty)"
  by (intro mem_correct_empty.intro mem_correct_tracing mem_correct_empty_axioms.intro)
     (simp add: map_of_simp empty_correct P_empty)+

end

locale dp_consistency_mapping_tracing =
  fixes dp :: "'param  'result"
begin

interpretation mapping: dp_consistency_mapping .

sublocale dp_consistency_empty
  mapping.lookup_trace mapping.update_trace "(λ _. True) o snd" dp "([], Mapping.empty)"
  by (rule
      dp_consistency_empty.intro dp_consistency.intro
      mapping.mem_correct_tracing_empty mem_correct_empty.axioms(1)
     )+

end (* DP Consistency RBT *)

end (* Theory *)

Theory Pair_Memory

subsection ‹Pair Memory›

theory Pair_Memory
  imports "../state_monad/Memory"
begin

(* XXX Move *)
lemma map_add_mono:
  "(m1 ++ m2) m (m1' ++ m2')" if "m1 m m1'" "m2 m m2'" "dom m1  dom m2' = {}"
  using that unfolding map_le_def map_add_def dom_def by (auto split: option.splits)

lemma map_add_upd2:
  "f(x  y) ++ g = (f ++ g)(x  y)" if "dom f  dom g = {}" "x  dom g"
  apply (subst map_add_comm)
   defer
   apply simp
   apply (subst map_add_comm)
  using that
  by auto

locale pair_mem_defs =
  fixes lookup1 lookup2 :: "'a  ('mem, 'v option) state"
    and update1 update2 :: "'a  'v  ('mem, unit) state"
    and move12 :: "'k1  ('mem, unit) state"
    and get_k1 get_k2 :: "('mem, 'k1) state"
    and P :: "'mem  bool"
  fixes key1 :: "'k  'k1" and key2 :: "'k  'a"
begin

text ‹We assume that look-ups happen on the older row, so it is biased towards the second entry.›
definition
  "lookup_pair k = do {
     let k' = key1 k;
     k2  get_k2;
     if k' = k2
     then lookup2 (key2 k)
     else do {
       k1  get_k1;
       if k' = k1
       then lookup1 (key2 k)
       else State_Monad.return None
     }
   }
   "

text ‹We assume that updates happen on the newer row, so it is biased towards the first entry.›
definition
  "update_pair k v = do {
    let k' = key1 k;
    k1  get_k1;
    if k' = k1
    then update1 (key2 k) v
    else do {
      k2  get_k2;
      if k' = k2
      then update2 (key2 k) v
      else (move12 k'  update1 (key2 k) v)
    }
  }
  "

sublocale pair: state_mem_defs lookup_pair update_pair .

sublocale mem1: state_mem_defs lookup1 update1 .

sublocale mem2: state_mem_defs lookup2 update2 .

definition
  "inv_pair heap 
    let
      k1 = fst (State_Monad.run_state get_k1 heap);
      k2 = fst (State_Monad.run_state get_k2 heap)
    in
    ( k  dom (mem1.map_of heap).  k'. key1 k' = k1  key2 k' = k) 
    ( k  dom (mem2.map_of heap).  k'. key1 k' = k2  key2 k' = k) 
    k1  k2  P heap
  "

definition
  "map_of1 m k = (if key1 k = fst (State_Monad.run_state get_k1 m) then mem1.map_of m (key2 k) else None)"

definition
  "map_of2 m k = (if key1 k = fst (State_Monad.run_state get_k2 m) then mem2.map_of m (key2 k) else None)"

end (* Pair Mem Defs *)

locale pair_mem = pair_mem_defs +
  assumes get_state:
    "State_Monad.run_state get_k1 m = (k, m')  m' = m"
    "State_Monad.run_state get_k2 m = (k, m')  m' = m"
  assumes move12_correct:
    "P m  State_Monad.run_state (move12 k1) m = (x, m')  mem1.map_of m' m Map.empty"
    "P m  State_Monad.run_state (move12 k1) m = (x, m')  mem2.map_of m' m mem1.map_of m"
  assumes move12_keys:
    "State_Monad.run_state (move12 k1) m = (x, m')  fst (State_Monad.run_state get_k1 m') = k1"
    "State_Monad.run_state (move12 k1) m = (x, m')  fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k1 m)"
  assumes move12_inv:
    "lift_p P (move12 k1)"
  assumes lookup_inv:
    "lift_p P (lookup1 k')" "lift_p P (lookup2 k')"
  assumes update_inv:
    "lift_p P (update1 k' v)" "lift_p P (update2 k' v)"
  assumes lookup_keys:
    "P m  State_Monad.run_state (lookup1 k') m = (v', m') 
     fst (State_Monad.run_state get_k1 m') = fst (State_Monad.run_state get_k1 m)"
    "P m  State_Monad.run_state (lookup1 k') m = (v', m') 
     fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k2 m)"
    "P m  State_Monad.run_state (lookup2 k') m = (v', m') 
     fst (State_Monad.run_state get_k1 m') = fst (State_Monad.run_state get_k1 m)"
    "P m  State_Monad.run_state (lookup2 k') m = (v', m') 
     fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k2 m)"
  assumes update_keys:
    "P m  State_Monad.run_state (update1 k' v) m = (x, m') 
     fst (State_Monad.run_state get_k1 m') = fst (State_Monad.run_state get_k1 m)"
    "P m  State_Monad.run_state (update1 k' v) m = (x, m') 
     fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k2 m)"
    "P m  State_Monad.run_state (update2 k' v) m = (x, m') 
     fst (State_Monad.run_state get_k1 m') = fst (State_Monad.run_state get_k1 m)"
    "P m  State_Monad.run_state (update2 k' v) m = (x, m') 
     fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k2 m)"
  assumes
    lookup_correct:
      "P m  mem1.map_of (snd (State_Monad.run_state (lookup1 k') m)) m (mem1.map_of m)"
      "P m  mem2.map_of (snd (State_Monad.run_state (lookup1 k') m)) m (mem2.map_of m)"
      "P m  mem1.map_of (snd (State_Monad.run_state (lookup2 k') m)) m (mem1.map_of m)"
      "P m  mem2.map_of (snd (State_Monad.run_state (lookup2 k') m)) m (mem2.map_of m)"
  assumes
    update_correct:
      "P m  mem1.map_of (snd (State_Monad.run_state (update1 k' v) m)) m (mem1.map_of m)(k'  v)"
      "P m  mem2.map_of (snd (State_Monad.run_state (update2 k' v) m)) m (mem2.map_of m)(k'  v)"
      "P m  mem2.map_of (snd (State_Monad.run_state (update1 k' v) m)) m mem2.map_of m"
      "P m  mem1.map_of (snd (State_Monad.run_state (update2 k' v) m)) m mem1.map_of m"
begin

lemma map_of_le_pair:
  "pair.map_of m m map_of1 m ++ map_of2 m"
  if "inv_pair m"
  using that
  unfolding pair.map_of_def map_of1_def map_of2_def
  unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
  unfolding State_Monad.bind_def
  by (auto 4 4
        simp: mem2.map_of_def mem1.map_of_def Let_def
        dest: get_state split: prod.split_asm if_split_asm
     )

lemma pair_le_map_of:
  "map_of1 m ++ map_of2 m m pair.map_of m"
  if "inv_pair m"
  using that
  unfolding pair.map_of_def map_of1_def map_of2_def
  unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
  unfolding State_Monad.bind_def
  by (auto
        simp: mem2.map_of_def mem1.map_of_def State_Monad.run_state_return Let_def
        dest: get_state split: prod.splits if_split_asm option.split
     )

lemma map_of_eq_pair:
  "map_of1 m ++ map_of2 m = pair.map_of m"
  if "inv_pair m"
  using that
  unfolding pair.map_of_def map_of1_def map_of2_def
  unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
  unfolding State_Monad.bind_def
  by (auto 4 4
        simp: mem2.map_of_def mem1.map_of_def State_Monad.run_state_return Let_def
        dest: get_state split: prod.splits option.split
     )

lemma inv_pair_neq[simp]:
  False if "inv_pair m" "fst (State_Monad.run_state get_k1 m) = fst (State_Monad.run_state get_k2 m)"
  using that unfolding inv_pair_def by auto

lemma inv_pair_P_D:
  "P m" if "inv_pair m"
  using that unfolding inv_pair_def by (auto simp: Let_def)

lemma inv_pair_domD[intro]:
  "dom (map_of1 m)  dom (map_of2 m) = {}" if "inv_pair m"
  using that unfolding inv_pair_def map_of1_def map_of2_def by (auto split: if_split_asm)

lemma move12_correct1:
  "map_of1 heap' m Map.empty" if "State_Monad.run_state (move12 k1) heap = (x, heap')" "P heap"
  using move12_correct[OF that(2,1)] unfolding map_of1_def by (auto simp: move12_keys map_le_def)

lemma move12_correct2:
  "map_of2 heap' m map_of1 heap" if "State_Monad.run_state (move12 k1) heap = (x, heap')" "P heap"
  using move12_correct(2)[OF that(2,1)] that unfolding map_of1_def map_of2_def
  by (auto simp: move12_keys map_le_def)

lemma dom_empty[simp]:
  "dom (map_of1 heap') = {}" if "State_Monad.run_state (move12 k1) heap = (x, heap')" "P heap"
  using move12_correct1[OF that] by (auto dest: map_le_implies_dom_le)

lemma inv_pair_lookup1:
  "inv_pair m'" if "State_Monad.run_state (lookup1 k) m = (v, m')" "inv_pair m"
  using that lookup_inv[of k] inv_pair_P_D[OF ‹inv_pair m] unfolding inv_pair_def
  by (auto 4 4
        simp: Let_def lookup_keys
        dest: lift_p_P lookup_correct[of _ k, THEN map_le_implies_dom_le]
     )

lemma inv_pair_lookup2:
  "inv_pair m'" if "State_Monad.run_state (lookup2 k) m = (v, m')" "inv_pair m"
  using that lookup_inv[of k] inv_pair_P_D[OF ‹inv_pair m] unfolding inv_pair_def
  by (auto 4 4
        simp: Let_def lookup_keys
        dest: lift_p_P lookup_correct[of _ k, THEN map_le_implies_dom_le]
     )

lemma inv_pair_update1:
  "inv_pair m'"
  if "State_Monad.run_state (update1 (key2 k) v) m = (v', m')" "inv_pair m" "fst (State_Monad.run_state get_k1 m) = key1 k"
  using that update_inv[of "key2 k" v] inv_pair_P_D[OF ‹inv_pair m] unfolding inv_pair_def
  apply (auto
        simp: Let_def update_keys
        dest: lift_p_P update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]
     )
   apply (frule update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]; auto 13 2; fail)
  apply (frule update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]; auto 13 2; fail)
  done

lemma inv_pair_update2:
  "inv_pair m'"
  if "State_Monad.run_state (update2 (key2 k) v) m = (v', m')" "inv_pair m" "fst (State_Monad.run_state get_k2 m) = key1 k"
  using that update_inv[of "key2 k" v] inv_pair_P_D[OF ‹inv_pair m] unfolding inv_pair_def
  apply (auto
        simp: Let_def update_keys
        dest: lift_p_P update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]
     )
   apply (frule update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]; auto 13 2; fail)
  apply (frule update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]; auto 13 2; fail)
  done

lemma inv_pair_move12:
  "inv_pair m'"
  if "State_Monad.run_state (move12 k) m = (v', m')" "inv_pair m" "fst (State_Monad.run_state get_k1 m)  k"
  using that move12_inv[of "k"] inv_pair_P_D[OF ‹inv_pair m] unfolding inv_pair_def
  apply (auto
        simp: Let_def move12_keys
        dest: lift_p_P move12_correct[of _ "k", THEN map_le_implies_dom_le]
     )
  apply (blast dest: move12_correct[of _ "k", THEN map_le_implies_dom_le])
  done

lemma mem_correct_pair:
  "mem_correct lookup_pair update_pair inv_pair"
  if injective: " k k'. key1 k = key1 k'  key2 k = key2 k'  k = k'"
proof (standard, goal_cases)
  case (1 k) ― ‹Lookup invariant›
  show ?case
    unfolding lookup_pair_def Let_def
    by (auto 4 4
        intro!: lift_pI
        dest: get_state inv_pair_lookup1 inv_pair_lookup2
        simp: State_Monad.bind_def State_Monad.run_state_return
        split: if_split_asm prod.split_asm
        )
next
  case (2 k v) ― ‹Update invariant›
  show ?case
    unfolding update_pair_def Let_def
    apply (auto 4 4
        intro!: lift_pI intro: inv_pair_update1 inv_pair_update2
        dest: get_state
        simp: State_Monad.bind_def get_state State_Monad.run_state_return
        split: if_split_asm prod.split_asm
        )+
    apply (elim inv_pair_update1 inv_pair_move12)
      apply (((subst get_state, assumption)+)?, auto intro: move12_keys dest: get_state; fail)+
    done
next
  case (3 m k)
  {
    let ?m = "snd (State_Monad.run_state (lookup2 (key2 k)) m)"
    have "map_of1 ?m m map_of1 m"
      by (smt 3 domIff inv_pair_P_D local.lookup_keys lookup_correct map_le_def map_of1_def surjective_pairing)
    moreover have "map_of2 ?m m map_of2 m"
      by (smt 3 domIff inv_pair_P_D local.lookup_keys lookup_correct map_le_def map_of2_def surjective_pairing)
    moreover have "dom (map_of1 ?m)  dom (map_of2 m) = {}"
      using 3 ‹map_of1 ?m m map_of1 m inv_pair_domD map_le_implies_dom_le by fastforce
    moreover have "inv_pair ?m"
      using 3 inv_pair_lookup2 surjective_pairing by metis
    ultimately have "pair.map_of ?m m pair.map_of m"
      apply (subst map_of_eq_pair[symmetric])
       defer
       apply (subst map_of_eq_pair[symmetric])
      by (auto intro: 3 map_add_mono)
  }
  moreover
  {
    let ?m = "snd (State_Monad.run_state (lookup1 (key2 k)) m)"
    have "map_of1 ?m m map_of1 m"
      by (smt 3 domIff inv_pair_P_D local.lookup_keys lookup_correct map_le_def map_of1_def surjective_pairing)
    moreover have "map_of2 ?m m map_of2 m"
      by (smt 3 domIff inv_pair_P_D local.lookup_keys lookup_correct map_le_def map_of2_def surjective_pairing)
    moreover have "dom (map_of1 ?m)  dom (map_of2 m) = {}"
      using 3 ‹map_of1 ?m m map_of1 m inv_pair_domD map_le_implies_dom_le by fastforce
    moreover have "inv_pair ?m"
      using 3 inv_pair_lookup1 surjective_pairing by metis
    ultimately have "pair.map_of ?m m pair.map_of m"
      apply (subst map_of_eq_pair[symmetric])
       defer
       apply (subst map_of_eq_pair[symmetric])
      by (auto intro: 3 map_add_mono)
  }
  ultimately show ?case
    by (auto
        split:if_split prod.split
        simp: Let_def lookup_pair_def State_Monad.bind_def State_Monad.run_state_return dest: get_state intro: map_le_refl
        )
next
  case prems: (4 m k v)
  let ?m1 = "snd (State_Monad.run_state (update1 (key2 k) v) m)"
  let ?m2 = "snd (State_Monad.run_state (update2 (key2 k) v) m)"
  from prems have disjoint: "dom (map_of1 m)  dom (map_of2 m) = {}"
    by (simp add: inv_pair_domD)
  show ?case
    apply (auto
        intro: map_le_refl dest: get_state
        split: prod.split
        simp: Let_def update_pair_def State_Monad.bind_def State_Monad.run_state_return
        )
  proof goal_cases
    case (1 m')
    then have "m' = m"
      by (rule get_state)
    from 1 prems have "map_of1 ?m1 m map_of1 m(k  v)"
      by (smt inv_pair_P_D map_le_def map_of1_def surjective_pairing domIff
          fst_conv fun_upd_apply injective update_correct update_keys
          )
    moreover from prems have "map_of2 ?m1 m map_of2 m"
      by (smt domIff inv_pair_P_D update_correct update_keys map_le_def map_of2_def surjective_pairing)
    moreover from prems have "dom (map_of1 ?m1)  dom (map_of2 m) = {}"
      by (smt inv_pair_P_D[OF ‹inv_pair m] domIff Int_emptyI eq_snd_iff inv_pair_neq 
          map_of1_def map_of2_def update_keys(1)
          )
    moreover from 1 prems have "k  dom (map_of2 m)"
      using inv_pair_neq map_of2_def by fastforce
    moreover from 1 prems have "inv_pair ?m1"
      using inv_pair_update1 fst_conv surjective_pairing by metis
    ultimately show "pair.map_of (snd (State_Monad.run_state (update1 (key2 k) v) m')) m pair.map_of m(k  v)"
      unfolding m' = m using disjoint
      apply (subst map_of_eq_pair[symmetric])
       defer
       apply (subst map_of_eq_pair[symmetric], rule prems)
       apply (subst map_add_upd2[symmetric])
      by (auto intro: map_add_mono)
  next
    case (2 k1 m' m'')
    then have "m' = m" "m'' = m"
      by (auto dest: get_state)
    from 2 prems have "map_of2 ?m2 m map_of2 m(k  v)"
      unfolding m' = m m'' = m
      by (smt inv_pair_P_D map_le_def map_of2_def surjective_pairing domIff
          fst_conv fun_upd_apply injective update_correct update_keys
          )
    moreover from prems have "map_of1 ?m2 m map_of1 m"
      by (smt domIff inv_pair_P_D update_correct update_keys map_le_def map_of1_def surjective_pairing)
    moreover from 2 have "dom (map_of1 ?m2)  dom (map_of2 m(k  v)) = {}"
      unfolding m' = m
      by (smt domIff ‹map_of1 ?m2 m map_of1 m disjoint_iff_not_equal fst_conv fun_upd_apply
          map_le_def map_of1_def map_of2_def
          )
    moreover from 2 prems have "inv_pair ?m2"
      unfolding m' = m
      using inv_pair_update2 fst_conv surjective_pairing by metis
    ultimately show "pair.map_of (snd (State_Monad.run_state (update2 (key2 k) v) m'')) m pair.map_of m(k  v)"
      unfolding m' = m m'' = m
      apply (subst map_of_eq_pair[symmetric])
       defer
       apply (subst map_of_eq_pair[symmetric], rule prems)
       apply (subst map_add_upd[symmetric])
      by (rule map_add_mono)
  next
    case (3 k1 m1 k2 m2 m3)
    then have "m1 = m" "m2 = m"
      by (auto dest: get_state)
    let ?m3 = "snd (State_Monad.run_state (update1 (key2 k) v) m3)"
    from 3 prems have "map_of1 ?m3 m map_of2 m(k  v)"
      unfolding m2 = m
      by (smt inv_pair_P_D map_le_def map_of1_def surjective_pairing domIff
          fst_conv fun_upd_apply injective
          inv_pair_move12 move12_correct move12_keys update_correct update_keys
          )
    moreover have "map_of2 ?m3 m map_of1 m"
    proof -
      from prems 3 have "P m" "P m3"
        unfolding m1 = m m2 = m
        using inv_pair_P_D[OF prems] by (auto elim: lift_p_P[OF move12_inv])
      from 3(3)[unfolded m2 = m] have "mem2.map_of ?m3 m mem1.map_of m"
        by - (erule map_le_trans[OF update_correct(3)[OF P m3] move12_correct(2)[OF P m]])
      with 3 prems show ?thesis
        unfolding m1 = m m2 = m map_le_def map_of2_def
        apply auto
        apply (frule move12_keys(2), simp)
        by (metis
            domI inv_pair_def map_of1_def surjective_pairing
            inv_pair_move12 move12_keys(2) update_keys(2)
            )
    qed
    moreover from prems 3 have "dom (map_of1 ?m3)  dom (map_of1 m) = {}"
      unfolding m1 = m m2 = m
      by (smt inv_pair_P_D disjoint_iff_not_equal map_of1_def surjective_pairing domIff
          fst_conv inv_pair_move12 move12_keys update_keys
          )
    moreover from 3 have "k  dom (map_of1 m)"
      by (simp add: domIff map_of1_def)
    moreover from 3 prems have "inv_pair ?m3"
      unfolding m2 = m
      by (metis inv_pair_move12 inv_pair_update1 move12_keys(1) fst_conv surjective_pairing)
    ultimately show ?case
      unfolding m1 = m m2 = m using disjoint
      apply (subst map_of_eq_pair[symmetric])
       defer
       apply (subst map_of_eq_pair[symmetric])
        apply (rule prems)
       apply (subst (2) map_add_comm)
        defer
        apply (subst map_add_upd2[symmetric])
          apply (auto intro: map_add_mono)
      done
  qed
qed

lemma emptyI:
  assumes "inv_pair m" "mem1.map_of m m Map.empty" "mem2.map_of m m Map.empty"
  shows "pair.map_of m m Map.empty"
  using assms by (auto simp: map_of1_def map_of2_def map_le_def map_of_eq_pair[symmetric])

end (* Pair Mem *)


datatype ('k, 'v) pair_storage = Pair_Storage 'k 'k 'v 'v

context mem_correct_empty
begin

context
  fixes key :: "'a  'k"
begin

text ‹We assume that look-ups happen on the older row, so it is biased towards the second entry.›
definition
  "lookup_pair k =
    State (λ mem.
    (
      case mem of Pair_Storage k1 k2 m1 m2  let k' = key k in
        if k' = k2 then case State_Monad.run_state (lookup k) m2 of (v, m)  (v, Pair_Storage k1 k2 m1 m)
        else if k' = k1 then case State_Monad.run_state (lookup k) m1 of (v, m)  (v, Pair_Storage k1 k2 m m2)
        else (None, mem)
    )
    )
  "

text ‹We assume that updates happen on the newer row, so it is biased towards the first entry.›
definition
  "update_pair k v =
    State (λ mem.
    (
      case mem of Pair_Storage k1 k2 m1 m2  let k' = key k in
        if k' = k1 then case State_Monad.run_state (update k v) m1 of (_, m)  ((), Pair_Storage k1 k2 m m2)
        else if k' = k2 then case State_Monad.run_state (update k v) m2 of (_, m)  ((),Pair_Storage k1 k2 m1 m)
        else case State_Monad.run_state (update k v) empty of (_, m)  ((), Pair_Storage k' k1 m m1)
    )
    )
  "

interpretation pair: state_mem_defs lookup_pair update_pair .

definition
  "inv_pair p = (case p of Pair_Storage k1 k2 m1 m2 
    key ` dom (map_of m1)  {k1}  key ` dom (map_of m2)  {k2}  k1  k2  P m1  P m2
  )"

lemma map_of_le_pair:
  "pair.map_of (Pair_Storage k1 k2 m1 m2) m (map_of m1 ++ map_of m2)"
  if "inv_pair (Pair_Storage k1 k2 m1 m2)"
  using that
  unfolding pair.map_of_def
  unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
  apply auto
  apply (auto 4 6 split: prod.split_asm if_split_asm option.split simp: Let_def)
  done

lemma pair_le_map_of:
  "map_of m1 ++ map_of m2 m pair.map_of (Pair_Storage k1 k2 m1 m2)"
  if "inv_pair (Pair_Storage k1 k2 m1 m2)"
  using that
  unfolding pair.map_of_def
  unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
  by (auto 4 5 split: prod.split_asm if_split_asm option.split simp: Let_def)

lemma map_of_eq_pair:
  "map_of m1 ++ map_of m2 = pair.map_of (Pair_Storage k1 k2 m1 m2)"
  if "inv_pair (Pair_Storage k1 k2 m1 m2)"
  using that
  unfolding pair.map_of_def
  unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
  by (auto 4 7 split: prod.split_asm if_split_asm option.split simp: Let_def)

lemma inv_pair_neq[simp, dest]:
  False if "inv_pair (Pair_Storage k k x y)"
  using that unfolding inv_pair_def by auto

lemma inv_pair_P_D1:
  "P m1" if "inv_pair (Pair_Storage k1 k2 m1 m2)"
  using that unfolding inv_pair_def by auto

lemma inv_pair_P_D2:
  "P m2" if "inv_pair (Pair_Storage k1 k2 m1 m2)"
  using that unfolding inv_pair_def by auto

lemma inv_pair_domD[intro]:
  "dom (map_of m1)  dom (map_of m2) = {}" if "inv_pair (Pair_Storage k1 k2 m1 m2)"
  using that unfolding inv_pair_def by fastforce

lemma mem_correct_pair:
  "mem_correct lookup_pair update_pair inv_pair"
proof (standard, goal_cases)
  case (1 k) ― ‹Lookup invariant›
  with lookup_inv[of k] show ?case
    unfolding lookup_pair_def Let_def
    by (auto intro!: lift_pI split: pair_storage.split_asm if_split_asm prod.split_asm)
       (auto dest: lift_p_P simp: inv_pair_def,
         (force dest!: lookup_correct[of _ k] map_le_implies_dom_le)+
       )
next
  case (2 k v) ― ‹Update invariant›
  with update_inv[of k v] update_correct[OF P_empty, of k v] P_empty show ?case
    unfolding update_pair_def Let_def
    by (auto intro!: lift_pI split: pair_storage.split_asm if_split_asm prod.split_asm)
       (auto dest: lift_p_P simp: inv_pair_def,
         (force dest: lift_p_P dest!: update_correct[of _ k v] map_le_implies_dom_le)+
       )
next
  case (3 m k)
  {
    fix m1 v1 m1' m2 v2 m2' k1 k2
    assume assms:
      "State_Monad.run_state (lookup k) m1 = (v1, m1')" "State_Monad.run_state (lookup k) m2 = (v2, m2')"
      "inv_pair (Pair_Storage k1 k2 m1 m2)"
    from assms have "P m1" "P m2"
      by (auto intro: inv_pair_P_D1 inv_pair_P_D2)
    have [intro]: "map_of m1' m map_of m1" "map_of m2' m map_of m2"
      using lookup_correct[OF P m1, of k] lookup_correct[OF P m2, of k] assms by auto
    from inv_pair_domD[OF assms(3)] have 1: "dom (map_of m1')  dom (map_of m2) = {}"
      by (metis (no_types) ‹map_of m1' m map_of m1 disjoint_iff_not_equal domIff map_le_def)
    have inv1: "inv_pair (Pair_Storage (key k) k2 m1' m2)" if "k2  key k" "k1 = key k"
      using that P m1 P m2 unfolding inv_pair_def
      apply clarsimp
      apply safe
      subgoal for x' y
        using assms(1,3) lookup_correct[OF P m1, of k, THEN map_le_implies_dom_le]
        unfolding inv_pair_def by auto
      subgoal for x' y
        using assms(3) unfolding inv_pair_def by fastforce
      using lookup_inv[of k] assms unfolding lift_p_def by force
    have inv2: "inv_pair (Pair_Storage k1 (key k) m1 m2')" if "k2 = key k" "k1  key k"
      using that P m1 P m2 unfolding inv_pair_def
      apply clarsimp
      apply safe
      subgoal for x' y
        using assms(3) unfolding inv_pair_def by fastforce
      subgoal for x x' y
        using assms(2,3) lookup_correct[OF P m2, of k, THEN map_le_implies_dom_le]
        unfolding inv_pair_def by fastforce
      using lookup_inv[of k] assms unfolding lift_p_def by force
    have A:
      "pair.map_of (Pair_Storage (key k) k2 m1' m2) m pair.map_of (Pair_Storage (key k) k2 m1 m2)"
      if "k2  key k" "k1 = key k"
      using inv1 assms(3) 1
      by (auto intro: map_add_mono map_le_refl simp: that map_of_eq_pair[symmetric])
    have B:
      "pair.map_of (Pair_Storage k1 (key k) m1 m2') m pair.map_of (Pair_Storage k1 (key k) m1 m2)"
      if "k2 = key k" "k1  key k"
      using inv2 assms(3) that
      by (auto intro: map_add_mono map_le_refl simp: map_of_eq_pair[symmetric] dest: inv_pair_domD)
    note A B
  }
  with ‹inv_pair m show ?case
    by (auto split: pair_storage.split if_split prod.split simp: Let_def lookup_pair_def)
next
  case (4 m k v)
  {
    fix m1 v1 m1' m2 v2 m2' m3 k1 k2
    assume assms:
      "State_Monad.run_state (update k v) m1 = ((), m1')" "State_Monad.run_state (update k v) m2 = ((), m2')"
      "State_Monad.run_state (update k v) empty = ((), m3)"
      "inv_pair (Pair_Storage k1 k2 m1 m2)"
    from assms have "P m1" "P m2"
      by (auto intro: inv_pair_P_D1 inv_pair_P_D2)
    from assms(3) P_empty update_inv[of k v] have "P m3"
      unfolding lift_p_def by auto
    have [intro]: "map_of m1' m map_of m1(k  v)" "map_of m2' m map_of m2(k  v)"
      using update_correct[OF P m1, of k v] update_correct[OF P m2, of k v] assms by auto
    have "map_of m3 m map_of empty(k  v)"
      using assms(3) update_correct[OF P_empty, of k v] by auto
    also have " m map_of m2(k  v)"
      using empty_correct by (auto elim: map_le_trans intro!: map_le_upd)
    finally have "map_of m3 m map_of m2(k  v)" .
    have 1: "dom (map_of m1)  dom (map_of m2(k  v)) = {}" if "k1  key k"
      using assms(4) that by (force simp: inv_pair_def)
    have 2: "dom (map_of m3)  dom (map_of m1) = {}" if "k1  key k"
      using ‹local.map_of m3 m local.map_of empty(k  v) assms(4) that
      by (fastforce dest!: map_le_implies_dom_le simp: inv_pair_def)
    have inv: "inv_pair (Pair_Storage (key k) k1 m3 m1)" if "k2  key k" "k1  key k"
      using that P m1 P m2 P m3 unfolding inv_pair_def
      apply clarsimp
      apply safe
      subgoal for x x' y
        using assms(3) update_correct[OF P_empty, of k v, THEN map_le_implies_dom_le]
          empty_correct
        by (auto dest: map_le_implies_dom_le)
      subgoal for x x' y
        using assms(4) unfolding inv_pair_def by fastforce
      done
    have A:
      "pair.map_of (Pair_Storage (key k) k1 m3 m1) m pair.map_of (Pair_Storage k1 k2 m1 m2)(k  v)"
      if "k2  key k" "k1  key k"
      using inv assms(4) ‹map_of m3 m map_of m2(k  v) 1
      apply (simp add: that map_of_eq_pair[symmetric])
      apply (subst map_add_upd[symmetric], subst Map.map_add_comm, rule 2, rule that)
      by (rule map_add_mono; auto)
    have inv1: "inv_pair (Pair_Storage (key k) k2 m1' m2)" if "k2  key k" "k1 = key k"
      using that P m1 P m2 unfolding inv_pair_def
      apply clarsimp
      apply safe
      subgoal for x' y
        using assms(1,4) update_correct[OF P m1, of k v, THEN map_le_implies_dom_le]
        unfolding inv_pair_def by auto
      subgoal for x' y
        using assms(4) unfolding inv_pair_def by fastforce
      using update_inv[of k v] assms unfolding lift_p_def by force
    have inv2: "inv_pair (Pair_Storage k1 (key k) m1 m2')" if "k2 = key k" "k1  key k"
      using that P m1 P m2 unfolding inv_pair_def
      apply clarsimp
      apply safe
      subgoal for x' y
        using assms(4) unfolding inv_pair_def by fastforce
      subgoal for x x' y
        using assms(2,4) update_correct[OF P m2, of k v, THEN map_le_implies_dom_le]
        unfolding inv_pair_def by fastforce
      using update_inv[of k v] assms unfolding lift_p_def by force
    have C:
      "pair.map_of (Pair_Storage (key k) k2 m1' m2) m
       pair.map_of (Pair_Storage (key k) k2 m1 m2)(k  v)"
      if "k2  key k" "k1 = key k"
      using inv1[OF that] assms(4) ‹inv_pair m
      by (simp add: that map_of_eq_pair[symmetric])
         (subst map_add_upd2[symmetric]; force simp: inv_pair_def intro: map_add_mono map_le_refl)
    have B:
      "pair.map_of (Pair_Storage k1 (key k) m1 m2') m
       pair.map_of (Pair_Storage k1 (key k) m1 m2)(k  v)"
      if "k2 = key k" "k1  key k"
      using inv2[OF that] assms(4)
      by (simp add: that map_of_eq_pair[symmetric])
         (subst map_add_upd[symmetric]; rule map_add_mono; force simp: inv_pair_def)
    note A B C
  }
  with ‹inv_pair m show ?case
    by (auto split: pair_storage.split if_split prod.split simp: Let_def update_pair_def)
qed

end (* Key function *)

end (* Lookup & Update w/ Empty *)

end (* Theory *)

Theory Index

subsection ‹Index›

theory Index
  imports Main
begin

definition injective :: "nat  ('k  nat)  bool" where
  "injective size to_index   a b.
      to_index a = to_index b
     to_index a < size
     to_index b < size
     a = b"
  for size to_index

lemma index_mono:
  fixes a b a0 b0 :: nat
  assumes a: "a < a0" and b: "b < b0"
  shows "a * b0 + b < a0 * b0"
proof -
  have "a * b0 + b < (Suc a) * b0"
    using b by auto
  also have "  a0 * b0"
    using a[THEN Suc_leI, THEN mult_le_mono1] .
  finally show ?thesis .
qed

lemma index_eq_iff:
  fixes a b c d b0 :: nat
  assumes "b < b0" "d < b0" "a * b0 + b = c * b0 + d"
  shows "a = c  b = d"
proof (rule conjI)
  { fix a b c d :: nat
    assume ac: "a < c" and b: "b < b0"
    have "a * b0 + b < (Suc a) * b0"
      using b by auto
    also have "  c * b0"
      using ac[THEN Suc_leI, THEN mult_le_mono1] .
    also have "  c * b0 + d"
      by auto
    finally have "a * b0 + b  c * b0 + d"
      by auto
  } note ac = this

  { assume "a  c"
    then consider (le) "a < c" | (ge) "a > c"
      by fastforce
    hence False proof cases
      case le show ?thesis using ac[OF le assms(1)] assms(3) ..
    next
      case ge show ?thesis using ac[OF ge assms(2)] assms(3)[symmetric] ..
    qed
  }
  
  then show "a = c"
    by auto

  with assms(3) show "b = d"
    by auto
qed


locale prod_order_def =
  order0: ord less_eq0 less0 +
  order1: ord less_eq1 less1
  for less_eq0 less0 less_eq1 less1
begin

fun less :: "'a × 'b  'a × 'b  bool" where
  "less (a,b) (c,d)  less0 a c  less1 b d"

fun less_eq :: "'a × 'b  'a × 'b  bool" where
  "less_eq ab cd  less ab cd  ab = cd"

end

locale prod_order =
  prod_order_def less_eq0 less0 less_eq1 less1 +
  order0: order less_eq0 less0 +
  order1: order less_eq1 less1
  for less_eq0 less0 less_eq1 less1
begin

sublocale order less_eq less
proof qed fastforce+

end

locale option_order =
  order0: order less_eq0 less0
  for less_eq0 less0
begin

fun less_eq_option :: "'a option  'a option  bool" where
  "less_eq_option None _  True"
| "less_eq_option (Some _) None  False"
| "less_eq_option (Some a) (Some b)  less_eq0 a b"

fun less_option :: "'a option  'a option  bool" where
  "less_option ao bo  less_eq_option ao bo  ao  bo"

sublocale order less_eq_option less_option
  apply standard
  subgoal for x y by (cases x; cases y) auto
  subgoal for x by (cases x) auto
  subgoal for x y z by (cases x; cases y; cases z) auto
  subgoal for x y by (cases x; cases y) auto
  done

end

datatype 'a bound = Bound (lower: 'a) (upper:'a)

definition in_bound :: "('a  'a  bool)  ('a  'a  bool)  'a bound  'a  bool" where
  "in_bound less_eq less bound x  case bound of Bound l r  less_eq l x  less x r" for less_eq less

locale index_locale_def = ord less_eq less for less_eq less :: "'a  'a  bool" +
  fixes idx :: "'a bound  'a  nat"
    and size :: "'a bound  nat"

locale index_locale = index_locale_def + idx_ord: order +
  assumes idx_valid: "in_bound less_eq less bound x  idx bound x < size bound"
    and idx_inj : "in_bound less_eq less bound x; in_bound less_eq less bound y; idx bound x = idx bound y  x = y"

locale prod_index_def =
  index0: index_locale_def less_eq0 less0 idx0 size0 +
  index1: index_locale_def less_eq1 less1 idx1 size1
  for less_eq0 less0 idx0 size0 less_eq1 less1 idx1 size1
begin

fun idx :: "('a × 'b) bound  'a × 'b  nat" where
  "idx (Bound (l0, r0) (l1, r1)) (a, b) = (idx0 (Bound l0 l1) a) * (size1 (Bound r0 r1)) + idx1 (Bound r0 r1) b"

fun size :: "('a × 'b) bound  nat" where
  "size (Bound (l0, r0) (l1, r1)) = size0 (Bound l0 l1) * size1 (Bound r0 r1)"

end

locale prod_index = prod_index_def less_eq0 less0 idx0 size0 less_eq1 less1 idx1 size1 +
  index0: index_locale less_eq0 less0 idx0 size0 +
  index1: index_locale less_eq1 less1 idx1 size1
  for less_eq0 less0 idx0 size0 less_eq1 less1 idx1 size1
begin

sublocale prod_order less_eq0 less0 less_eq1 less1 ..

sublocale index_locale less_eq less idx size proof
  { fix ab :: "'a × 'b" and bound :: "('a × 'b) bound"
    assume bound: "in_bound less_eq less bound ab"

    obtain a b l0 r0 l1 r1 where defs:"ab = (a, b)" "bound = Bound (l0, r0) (l1, r1)"
      by (cases ab; cases bound) auto

    with bound have a: "in_bound less_eq0 less0 (Bound l0 l1) a" and b: "in_bound less_eq1 less1 (Bound r0 r1) b"
      unfolding in_bound_def by auto

    have "idx (Bound (l0, r0) (l1, r1)) (a, b) < size (Bound (l0, r0) (l1, r1))"
      using index_mono[OF index0.idx_valid[OF a] index1.idx_valid[OF b]] by auto

    thus "idx bound ab < size bound"
      unfolding defs .
  }

  { fix ab cd :: "'a × 'b" and bound :: "('a × 'b) bound"
    assume bound: "in_bound less_eq less bound ab" "in_bound less_eq less bound cd"
      and idx_eq: "idx bound ab = idx bound cd"

    obtain a b c d l0 r0 l1 r1 where
      defs: "ab = (a, b)" "cd = (c, d)" "bound = Bound (l0, l1) (r0, r1)"
      by (cases ab; cases cd; cases bound) auto

    from defs bound have
          a: "in_bound less_eq0 less0 (Bound l0 r0) a"
      and b: "in_bound less_eq1 less1 (Bound l1 r1) b"
      and c: "in_bound less_eq0 less0 (Bound l0 r0) c"
      and d: "in_bound less_eq1 less1 (Bound l1 r1) d"
      unfolding in_bound_def by auto

    from index_eq_iff[OF index1.idx_valid[OF b] index1.idx_valid[OF d] idx_eq[unfolded defs, simplified]]
    have ac: "idx0 (Bound l0 r0) a = idx0 (Bound l0 r0) c" and bd: "idx1 (Bound l1 r1) b = idx1 (Bound l1 r1) d" by auto
    show "ab = cd"
      unfolding defs using index0.idx_inj[OF a c ac] index1.idx_inj[OF b d bd] by auto
  }
qed
end

locale option_index =
  index0: index_locale less_eq0 less0 idx0 size0
  for less_eq0 less0 idx0 size0
begin

fun idx :: "'a option bound  'a option  nat" where
  "idx (Bound (Some l) (Some r)) (Some a) = idx0 (Bound l r) a"
| "idx _ _ = undefined"
(* option is NOT an index *)

end

locale nat_index_def = ord "(≤) :: nat  nat  bool" "(<)"
begin

fun idx :: "nat bound  nat  nat" where
  "idx (Bound l _) i = i - l"

fun size :: "nat bound  nat" where
  "size (Bound l r) = r - l"

sublocale index_locale "(≤)" "(<)" idx size
proof qed (auto simp: in_bound_def split: bound.splits) 

end

locale nat_index = nat_index_def + order "(≤) :: nat  nat  bool" "(<)"

locale int_index_def = ord "(≤) :: int  int  bool" "(<)"
begin

fun idx :: "int bound  int  nat" where
  "idx (Bound l _) i = nat (i - l)"

fun size :: "int bound  nat" where
  "size (Bound l r) = nat (r - l)"

sublocale index_locale "(≤)" "(<)" idx size
proof qed (auto simp: in_bound_def split: bound.splits) 

end

locale int_index = int_index_def + order "(≤) :: int  int  bool" "(<)"

class index =
  fixes less_eq less :: "'a  'a  bool"
    and idx :: "'a bound  'a  nat"
    and size :: "'a bound  nat"
  assumes is_locale: "index_locale less_eq less idx size"

locale bounded_index =
  fixes bound :: "'k :: index bound"
begin

interpretation index_locale less_eq less idx size
  using is_locale .

definition "size  index_class.size bound" for size

definition "checked_idx x  if in_bound less_eq less bound x then idx bound x else size"

lemma checked_idx_injective:
  "injective size checked_idx"
  unfolding injective_def
  unfolding checked_idx_def
  using idx_inj by (fastforce split: if_splits)
end

instantiation nat :: index
begin

interpretation nat_index ..
thm index_locale_axioms

definition [simp]: "less_eq_nat  (≤) :: nat  nat  bool"
definition [simp]: "less_nat  (<) :: nat  nat  bool"
definition [simp]: "idx_nat  idx"
definition size_nat where [simp]: "size_nat  size"

instance by (standard, simp, fact index_locale_axioms)

end

instantiation int :: index
begin

interpretation int_index ..
thm index_locale_axioms

definition [simp]: "less_eq_int  (≤) :: int  int  bool"
definition [simp]: "less_int  (<) :: int  int  bool"
definition [simp]: "idx_int  idx"
definition [simp]: "size_int  size"

lemmas size_int = size.simps

instance by (standard, simp, fact index_locale_axioms)
end

instantiation prod :: (index, index) index
begin

interpretation prod_index
  "less_eq::'a  'a  bool" less idx size
  "less_eq::'b  'b  bool" less idx size
  by (rule prod_index.intro; fact is_locale)
thm index_locale_axioms

definition [simp]: "less_eq_prod  less_eq"
definition [simp]: "less_prod  less"
definition [simp]: "idx_prod  idx"
definition [simp]: "size_prod  size" for size_prod

lemmas size_prod = size.simps

instance by (standard, simp, fact index_locale_axioms)

end

lemma bound_int_simp[code]:
  "bounded_index.size (Bound (l1, l2) (u1, u2)) = nat (u1 - l1) * nat (u2 - l2)"
  by (simp add: bounded_index.size_def,unfold size_int_def[symmetric] size_prod,simp add: size_int)

lemmas [code] = bounded_index.size_def bounded_index.checked_idx_def

lemmas [code] =
  nat_index_def.size.simps
  nat_index_def.idx.simps

lemmas [code] =
  int_index_def.size.simps
  int_index_def.idx.simps

lemmas [code] =
  prod_index_def.size.simps
  prod_index_def.idx.simps

lemmas [code] =
  prod_order_def.less_eq.simps
  prod_order_def.less.simps

lemmas index_size_defs =
  prod_index_def.size.simps int_index_def.size.simps nat_index_def.size.simps bounded_index.size_def

end

Theory Memory_Heap

subsection ‹Heap Memory Implementations›

theory Memory_Heap
  imports State_Heap DP_CRelVH Pair_Memory "HOL-Eisbach.Eisbach" "../Index"
begin

text ‹Move›
abbreviation "result_of c h  fst (the (execute c h))"
abbreviation "heap_of   c h  snd (the (execute c h))"

lemma map_emptyI:
  "m m Map.empty" if " x. m x = None"
  using that unfolding map_le_def by auto

lemma result_of_return[simp]:
  "result_of (Heap_Monad.return x) h = x"
  by (simp add: execute_simps)

lemma get_result_of_lookup:
  "result_of (!r) heap = x" if "Ref.get heap r = x"
  using that by (auto simp: execute_simps)

context
  fixes size :: nat
    and to_index :: "('k2 :: heap)  nat"
begin

definition
  "mem_empty = (Array.new size (None :: ('v :: heap) option))"

lemma success_empty[intro]:
  "success mem_empty heap"
  unfolding mem_empty_def by (auto intro: success_intros)

lemma length_mem_empty:
  "Array.length
    (heap_of (mem_empty:: (('b :: heap) option array) Heap) h)
    (result_of (mem_empty :: ('b option array) Heap) h) = size"
  unfolding mem_empty_def by (auto simp: execute_simps Array.length_alloc)

lemma nth_mem_empty:
  "result_of
    (Array.nth (result_of (mem_empty :: ('b option array) Heap) h) i)
    (heap_of (mem_empty :: (('b :: heap) option array) Heap) h) = None" if "i < size"
  apply (subst execute_nth(1))
  apply (simp add: length_mem_empty that)
  apply (simp add: execute_simps mem_empty_def Array.get_alloc that)
  done

context
  fixes mem :: "('v :: heap) option array"
begin

definition
  "mem_lookup k = (let i = to_index k in
    if i < size then Array.nth mem i else return None
  )"

definition
  "mem_update k v = (let i = to_index k in
    if i < size then (Array.upd i (Some v) mem  (λ _. return ()))
    else return ()
  )
  "

context assumes injective: "injective size to_index"
begin

interpretation heap_correct "λheap. Array.length heap mem = size" mem_update mem_lookup
  apply standard
  subgoal lookup_inv
    unfolding State_Heap.lift_p_def mem_lookup_def by (simp add: Let_def execute_simps)
  subgoal update_inv
    unfolding State_Heap.lift_p_def mem_update_def by (simp add: Let_def execute_simps)
  subgoal for k heap
    unfolding heap_mem_defs.map_of_heap_def map_le_def mem_lookup_def
    by (auto simp: execute_simps Let_def split: if_split_asm)
  subgoal for heap k
    unfolding heap_mem_defs.map_of_heap_def map_le_def mem_lookup_def mem_update_def
    apply (auto simp: execute_simps Let_def length_def split: if_split_asm)
    apply (subst (asm) nth_list_update_neq)
    using injective[unfolded injective_def] apply auto
    done
  done

lemmas mem_heap_correct = heap_correct_axioms

context
  assumes [simp]: "mem = result_of mem_empty Heap.empty"
begin

interpretation heap_correct_empty
  "λheap. Array.length heap mem = size" mem_update mem_lookup
  "heap_of (mem_empty :: 'v option array Heap) Heap.empty"
  apply standard
  subgoal
    apply (rule map_emptyI)
    unfolding map_of_heap_def mem_lookup_def by (auto simp: Let_def nth_mem_empty)
  subgoal
    by (simp add: length_mem_empty)
  done

lemmas array_heap_emptyI = heap_correct_empty_axioms

context
  fixes dp :: "'k2  'v"
begin

interpretation dp_consistency_heap_empty
  "λheap. Array.length heap mem = size" mem_update mem_lookup dp
  "heap_of (mem_empty :: 'v option array Heap) Heap.empty"
  by standard

lemmas array_consistentI = dp_consistency_heap_empty_axioms

end

end (* Empty Memory *)

end (* Injectivity *)

end (* Fixed array *)

lemma execute_bind_success':
  assumes "success f h" "execute (f  g) h = Some (y, h'')"
  obtains x h' where "execute f h = Some (x, h')" "execute (g x) h' = Some (y, h'')"
  using assms by (auto simp: execute_simps elim: successE)

lemma success_bind_I:
  assumes "success f h"
    and " x h'. execute f h = Some (x, h')  success (g x) h'"
  shows "success (f  g) h"
  by (rule successE[OF assms(1)]) (auto elim: assms(2) intro: success_bind_executeI)

definition
  "alloc_pair a b  do {
    r1  ref a;
    r2  ref b;
    return (r1, r2)
  }"

lemma alloc_pair_alloc:
  "Ref.get heap' r1 = a" "Ref.get heap' r2 = b"
  if "execute (alloc_pair a b) heap = Some ((r1, r2), heap')"
  using that unfolding alloc_pair_def
  by (auto simp: execute_simps elim!: execute_bind_success'[OF success_refI])
     (metis Ref.get_alloc fst_conv get_alloc_neq next_present present_alloc_neq snd_conv)+

lemma alloc_pairD1:
  "r =!= r1  r =!= r2  Ref.present heap' r"
  if "execute (alloc_pair a b) heap = Some ((r1, r2), heap')" "Ref.present heap r"
  using that unfolding alloc_pair_def
  by (auto simp: execute_simps elim!: execute_bind_success'[OF success_refI])
     (metis next_fresh noteq_I Ref.present_alloc snd_conv)+

lemma alloc_pairD2:
  "r1 =!= r2  Ref.present heap' r2  Ref.present heap' r1"
  if "execute (alloc_pair a b) heap = Some ((r1, r2), heap')"
  using that unfolding alloc_pair_def
  by (auto simp: execute_simps elim!: execute_bind_success'[OF success_refI])
     (metis next_fresh next_present noteq_I Ref.present_alloc snd_conv)+

lemma alloc_pairD3:
  "Array.present heap' r"
  if "execute (alloc_pair a b) heap = Some ((r1, r2), heap')" "Array.present heap r"
  using that unfolding alloc_pair_def
  by (auto simp: execute_simps elim!: execute_bind_success'[OF success_refI])
     (metis array_present_alloc snd_conv)

lemma alloc_pairD4:
  "Ref.get heap' r = x"
  if "execute (alloc_pair a b) heap = Some ((r1, r2), heap')"
     "Ref.get heap r = x" "Ref.present heap r"
  using that unfolding alloc_pair_def
  by (auto simp: execute_simps elim!: execute_bind_success'[OF success_refI])
     (metis Ref.not_present_alloc Ref.present_alloc get_alloc_neq noteq_I snd_conv)

lemma alloc_pair_array_get:
  "Array.get heap' r = x"
  if "execute (alloc_pair a b) heap = Some ((r1, r2), heap')" "Array.get heap r = x"
  using that unfolding alloc_pair_def
  by (auto simp: execute_simps elim!: execute_bind_success'[OF success_refI])
 (metis array_get_alloc snd_conv)

lemma alloc_pair_array_length:
  "Array.length heap' r = Array.length heap r"
  if "execute (alloc_pair a b) heap = Some ((r1, r2), heap')"
  using that unfolding alloc_pair_def
  by (auto simp: execute_simps elim!: execute_bind_success'[OF success_refI])
     (metis Ref.length_alloc snd_conv)

lemma alloc_pair_nth:
  "result_of (Array.nth r i) heap' = result_of (Array.nth r i) heap"
  if "execute (alloc_pair a b) heap = Some ((r1, r2), heap')"
  using alloc_pair_array_get[OF that(1) HOL.refl, of r] alloc_pair_array_length[OF that(1), of r]
  by (cases "(λh. i < Array.length h r) heap"; simp add: execute_simps Array.nth_def)

lemma succes_alloc_pair[intro]:
  "success (alloc_pair a b) heap"
  unfolding alloc_pair_def by (auto intro: success_intros success_bind_I)

definition
  "init_state_inner k1 k2 m1 m2   do {
    (k_ref1, k_ref2)  alloc_pair k1 k2;
    (m_ref1, m_ref2)  alloc_pair m1 m2;
    return (k_ref1, k_ref2, m_ref1, m_ref2)
  }
  "

lemma init_state_inner_alloc:
  assumes
    "execute (init_state_inner k1 k2 m1 m2) heap = Some ((k_ref1, k_ref2, m_ref1, m_ref2), heap')"
  shows
    "Ref.get heap' k_ref1 = k1" "Ref.get heap' k_ref2 = k2"
    "Ref.get heap' m_ref1 = m1" "Ref.get heap' m_ref2 = m2"
  using assms unfolding init_state_inner_def
  by (auto simp: execute_simps elim!: execute_bind_success'[OF succes_alloc_pair])
     (auto intro: alloc_pair_alloc dest: alloc_pairD2 elim: alloc_pairD4)

lemma init_state_inner_distinct:
  assumes
    "execute (init_state_inner k1 k2 m1 m2) heap = Some ((k_ref1, k_ref2, m_ref1, m_ref2), heap')"
  shows
    "m_ref1 =!= m_ref2  m_ref1 =!= k_ref1  m_ref1 =!= k_ref2  m_ref2 =!= k_ref1
    m_ref2 =!= k_ref2  k_ref1 =!= k_ref2"
  using assms unfolding init_state_inner_def
  by (auto simp: execute_simps elim!: execute_bind_success'[OF succes_alloc_pair])
     (blast dest: alloc_pairD1 alloc_pairD2 intro: noteq_sym)+

lemma init_state_inner_present:
  assumes
    "execute (init_state_inner k1 k2 m1 m2) heap = Some ((k_ref1, k_ref2, m_ref1, m_ref2), heap')"
  shows
    "Ref.present heap' k_ref1" "Ref.present heap' k_ref2"
    "Ref.present heap' m_ref1" "Ref.present heap' m_ref2"
  using assms unfolding init_state_inner_def
  by (auto simp: execute_simps elim!: execute_bind_success'[OF succes_alloc_pair])
     (blast dest: alloc_pairD1 alloc_pairD2)+

lemma inite_state_inner_present':
  assumes
    "execute (init_state_inner k1 k2 m1 m2) heap = Some ((k_ref1, k_ref2, m_ref1, m_ref2), heap')"
    "Array.present heap a"
  shows
    "Array.present heap' a"
    using assms unfolding init_state_inner_def
    by (auto simp: execute_simps elim!: execute_bind_success'[OF succes_alloc_pair] alloc_pairD3)

lemma succes_init_state_inner[intro]:
  "success (init_state_inner k1 k2 m1 m2) heap"
  unfolding init_state_inner_def by (auto 4 3 intro: success_intros success_bind_I)

lemma init_state_inner_nth:
  "result_of (Array.nth r i) heap' = result_of (Array.nth r i) heap"
  if "execute (init_state_inner k1 k2 m1 m2) heap = Some ((r1, r2), heap')"
  using that unfolding init_state_inner_def
  by (auto simp: execute_simps alloc_pair_nth elim!: execute_bind_success'[OF succes_alloc_pair])

definition
  "init_state k1 k2  do {
    m1  mem_empty;
    m2  mem_empty;
    init_state_inner k1 k2 m1 m2
  }"

lemma succes_init_state[intro]:
  "success (init_state k1 k2) heap"
  unfolding init_state_def by (auto intro: success_intros success_bind_I)

definition
  "inv_distinct k_ref1 k_ref2 m_ref1 m_ref2 
     m_ref1 =!= m_ref2  m_ref1 =!= k_ref1  m_ref1 =!= k_ref2  m_ref2 =!= k_ref1
    m_ref2 =!= k_ref2  k_ref1 =!= k_ref2
  "

lemma init_state_distinct:
  assumes
    "execute (init_state k1 k2) heap = Some ((k_ref1, k_ref2, m_ref1, m_ref2), heap')"
  shows
    "inv_distinct k_ref1 k_ref2 m_ref1 m_ref2"
  using assms unfolding init_state_def inv_distinct_def
  by (elim execute_bind_success'[OF success_empty] init_state_inner_distinct)

lemma init_state_present:
  assumes
    "execute (init_state k1 k2) heap = Some ((k_ref1, k_ref2, m_ref1, m_ref2), heap')"
  shows
    "Ref.present heap' k_ref1" "Ref.present heap' k_ref2"
    "Ref.present heap' m_ref1" "Ref.present heap' m_ref2"
  using assms unfolding init_state_def
  by (auto
        simp: execute_simps elim!: execute_bind_success'[OF success_empty]
        dest: init_state_inner_present
     )

lemma empty_present:
  "Array.present h' x" if "execute mem_empty heap = Some (x, h')"
  using that unfolding mem_empty_def
  by (auto simp: execute_simps) (metis Array.present_alloc fst_conv snd_conv)

lemma empty_present':
  "Array.present h' a" if "execute mem_empty heap = Some (x, h')" "Array.present heap a"
  using that unfolding mem_empty_def
  by (auto simp: execute_simps Array.present_def Array.alloc_def Array.set_def Let_def)

lemma init_state_present2:
  assumes
    "execute (init_state k1 k2) heap = Some ((k_ref1, k_ref2, m_ref1, m_ref2), heap')"
  shows
    "Array.present heap' (Ref.get heap' m_ref1)" "Array.present heap' (Ref.get heap' m_ref2)"
  using assms unfolding init_state_def
  by (auto 4 3
        simp: execute_simps init_state_inner_alloc elim!: execute_bind_success'[OF success_empty]
        dest: inite_state_inner_present' empty_present empty_present'
     )

lemma init_state_neq:
  assumes
    "execute (init_state k1 k2) heap = Some ((k_ref1, k_ref2, m_ref1, m_ref2), heap')"
  shows
    "Ref.get heap' m_ref1 =!!= Ref.get heap' m_ref2"
  using assms unfolding init_state_def
  by (auto 4 3
        simp: execute_simps init_state_inner_alloc elim!: execute_bind_success'[OF success_empty]
        dest: inite_state_inner_present' empty_present empty_present'
     )
    (metis empty_present execute_new fst_conv mem_empty_def option.inject present_alloc_noteq)

lemma present_alloc_get:
  "Array.get heap' a = Array.get heap a"
  if "Array.alloc xs heap = (a', heap')" "Array.present heap a"
  using that by (auto simp: Array.alloc_def Array.present_def Array.get_def Let_def Array.set_def)

lemma init_state_length:
  assumes
    "execute (init_state k1 k2) heap = Some ((k_ref1, k_ref2, m_ref1, m_ref2), heap')"
  shows
    "Array.length heap' (Ref.get heap' m_ref1) = size"
    "Array.length heap' (Ref.get heap' m_ref2) = size"
  using assms unfolding init_state_def
  apply (auto
        simp: execute_simps init_state_inner_alloc elim!: execute_bind_success'[OF success_empty]
        dest: inite_state_inner_present' empty_present empty_present'
     )
   apply (auto
      simp: execute_simps init_state_inner_def alloc_pair_def mem_empty_def Array.length_def
      elim!: execute_bind_success'[OF success_refI]
     )
  apply (metis
      Array.alloc_def Array.get_set_eq Array.present_alloc array_get_alloc fst_conv length_replicate
      present_alloc_get snd_conv
     )+
  done

context
  fixes key1 :: "'k  ('k1 :: heap)" and key2 :: "'k  'k2"
    and m_ref1 m_ref2 :: "('v :: heap) option array ref"
    and k_ref1 k_ref2 :: "('k1 :: heap) ref"
begin

text ‹We assume that look-ups happen on the older row, so this is biased towards the second entry.›
definition
  "lookup_pair k = do {
    let k' = key1 k;
    k2  !k_ref2;
    if k' = k2 then
      do {
        m2  !m_ref2;
        mem_lookup m2 (key2 k)
      }
    else
      do {
      k1  !k_ref1;
      if k' = k1 then
        do {
          m1  !m_ref1;
          mem_lookup m1 (key2 k)
        }
      else
        return None
    }
  }
   "

text ‹We assume that updates happen on the newer row, so this is biased towards the first entry.›
definition
  "update_pair k v = do {
    let k' = key1 k;
      k1  !k_ref1;
      if k' = k1 then do {
        m  !m_ref1;
        mem_update m (key2 k) v
      }
      else do {
        k2  !k_ref2;
        if k' = k2 then do {
          m  !m_ref2;
          mem_update m (key2 k) v
        }
        else do {
          do {
            k1  !k_ref1;
            m  mem_empty;
            m1  !m_ref1;
            k_ref2 := k1;
            k_ref1 := k';
            m_ref2 := m1;
            m_ref1 := m
          }
        ;
        m  !m_ref1;
        mem_update m (key2 k) v
      }
    }
   }
   "

definition
  "inv_pair_weak heap = (
    let
      m1 = Ref.get heap m_ref1;
      m2 = Ref.get heap m_ref2
    in Array.length heap m1 = size  Array.length heap m2 = size
       Ref.present heap k_ref1  Ref.present heap k_ref2
       Ref.present heap m_ref1  Ref.present heap m_ref2
       Array.present heap m1  Array.present heap m2
       m1 =!!= m2
  )"

(* TODO: Remove? *)
definition
  "inv_pair heap  inv_pair_weak heap  inv_distinct k_ref1 k_ref2 m_ref1 m_ref2"

lemma init_state_inv:
  assumes
    "execute (init_state k1 k2) heap = Some ((k_ref1, k_ref2, m_ref1, m_ref2), heap')"
  shows "inv_pair_weak heap'"
  using assms unfolding inv_pair_weak_def Let_def
  by (auto intro:
      init_state_present init_state_present2 init_state_neq init_state_length
      init_state_distinct
     )

lemma inv_pair_lengthD1:
  "Array.length heap (Ref.get heap m_ref1) = size" if "inv_pair_weak heap"
  using that unfolding inv_pair_weak_def by (auto simp: Let_def)

lemma inv_pair_lengthD2:
  "Array.length heap (Ref.get heap m_ref2) = size" if "inv_pair_weak heap"
  using that unfolding inv_pair_weak_def by (auto simp: Let_def)

lemma inv_pair_presentD:
  "Array.present heap (Ref.get heap m_ref1)" "Array.present heap (Ref.get heap m_ref2)"
  if "inv_pair_weak heap"
  using that unfolding inv_pair_weak_def by (auto simp: Let_def)

lemma inv_pair_presentD2:
  "Ref.present heap m_ref1" "Ref.present heap m_ref2"
  "Ref.present heap k_ref1" "Ref.present heap k_ref2"
  if "inv_pair_weak heap"
  using that unfolding inv_pair_weak_def by (auto simp: Let_def)

lemma inv_pair_not_eqD:
  "Ref.get heap m_ref1 =!!= Ref.get heap m_ref2" if "inv_pair_weak heap"
  using that unfolding inv_pair_weak_def by (auto simp: Let_def)

definition "lookup1 k  state_of (do {m  !m_ref1; mem_lookup m k})"

definition "lookup2 k  state_of (do {m  !m_ref2; mem_lookup m k})"

definition "update1 k v  state_of (do {m  !m_ref1; mem_update m k v})"

definition "update2 k v  state_of (do {m  !m_ref2; mem_update m k v})"

definition "move12 k  state_of (do {
    k1  !k_ref1;
    m  mem_empty;
    m1  !m_ref1;
    k_ref2 := k1;
    k_ref1 := k;
    m_ref2 := m1;
    m_ref1 := m
  })
  "

definition "get_k1  state_of (!k_ref1)"

definition "get_k2  state_of (!k_ref2)"

lemma run_state_state_of[simp]:
  "State_Monad.run_state (state_of p) m = the (execute p m)"
  unfolding state_of_def by simp

context assumes injective: "injective size to_index"
begin

context
  assumes inv_distinct: "inv_distinct k_ref1 k_ref2 m_ref1 m_ref2"
begin

lemma disjoint[simp]:
  "m_ref1 =!= m_ref2" "m_ref1 =!= k_ref1" "m_ref1 =!= k_ref2"
  "m_ref2 =!= k_ref1" "m_ref2 =!= k_ref2"
  "k_ref1 =!= k_ref2"
  using inv_distinct unfolding inv_distinct_def by auto

lemmas [simp] = disjoint[THEN noteq_sym]

lemma [simp]:
  "Array.get (snd (Array.alloc xs heap)) a = Array.get heap a" if "Array.present heap a"
  using that unfolding Array.alloc_def Array.present_def
  apply (simp add: Let_def)
  apply (subst Array.get_set_neq)
  subgoal
    by (simp add: Array.noteq_def)
  subgoal
    unfolding Array.get_def by simp
  done

lemma [simp]:
  "Ref.get (snd (Array.alloc xs heap)) r = Ref.get heap r" if "Ref.present heap r"
  using that unfolding Array.alloc_def Ref.present_def
  by (simp add: Let_def Ref.get_def Array.set_def)

lemma alloc_present:
  "Array.present (snd (Array.alloc xs heap)) a" if "Array.present heap a"
  using that unfolding Array.present_def Array.alloc_def by (simp add: Let_def Array.set_def)

lemma alloc_present':
  "Ref.present (snd (Array.alloc xs heap)) r" if "Ref.present heap r"
  using that unfolding Ref.present_def Array.alloc_def by (simp add: Let_def Array.set_def)

lemma length_get_upd[simp]:
  "length (Array.get (Array.update a i x heap) r) = length (Array.get heap r)"
  unfolding Array.get_def Array.update_def Array.set_def by simp

method solve1 =
  (frule inv_pair_lengthD1, frule inv_pair_lengthD2, frule inv_pair_not_eqD)?,
  auto split: if_split_asm dest: Array.noteq_sym

interpretation pair: pair_mem lookup1 lookup2 update1 update2 move12 get_k1 get_k2 inv_pair_weak
  supply [simp] =
    mem_empty_def state_mem_defs.map_of_def map_le_def
    move12_def update1_def update2_def lookup1_def lookup2_def get_k1_def get_k2_def
    mem_update_def mem_lookup_def
    execute_bind_success[OF success_newI] execute_simps Let_def Array.get_alloc length_def
    inv_pair_presentD inv_pair_presentD2
    Memory_Heap.lookup1_def Memory_Heap.lookup2_def Memory_Heap.mem_lookup_def
  apply standard
                      apply (solve1; fail)+
  subgoal
    apply (rule lift_pI)
    unfolding inv_pair_weak_def
    apply (auto simp:
        intro: alloc_present alloc_present'
        elim: present_alloc_noteq[THEN Array.noteq_sym]
        )
    done
                     apply (rule lift_pI, unfold inv_pair_weak_def, auto split: if_split_asm; fail)+
                 apply (solve1; fail)+
  subgoal
    using injective[unfolded injective_def] by - (solve1, subst (asm) nth_list_update_neq, auto)
  subgoal
    using injective[unfolded injective_def] by - (solve1, subst (asm) nth_list_update_neq, auto)
   apply (solve1; fail)+
  done

lemmas mem_correct_pair = pair.mem_correct_pair

definition
  "mem_lookup1 k = do {m  !m_ref1; mem_lookup m k}"

definition
  "mem_lookup2 k = do {m  !m_ref2; mem_lookup m k}"

definition "get_k1'  !k_ref1"

definition "get_k2'  !k_ref2"

definition "update1' k v  do {m  !m_ref1; mem_update m k v}"

definition "update2' k v  do {m  !m_ref2; mem_update m k v}"

definition "move12' k  do {
    k1  !k_ref1;
    m  mem_empty;
    m1  !m_ref1;
    k_ref2 := k1;
    k_ref1 := k;
    m_ref2 := m1;
    m_ref1 := m
  }"

interpretation heap_mem_defs inv_pair_weak lookup_pair update_pair .

lemma rel_state_ofI:
  "rel_state (=) (state_of m) m" if
  " heap. inv_pair_weak heap  success m heap"
  "lift_p inv_pair_weak m"
  using that unfolding rel_state_def
  by (auto split: option.split intro: lift_p_P'' simp: success_def)

lemma inv_pair_iff:
  "inv_pair_weak = inv_pair"
  unfolding inv_pair_def using inv_distinct by simp

lemma lift_p_inv_pairI:
  "State_Heap.lift_p inv_pair m" if "State_Heap.lift_p inv_pair_weak m"
  using that unfolding inv_pair_iff by simp

lemma lift_p_success:
  "State_Heap.lift_p inv_pair_weak m"
  if "DP_CRelVS.lift_p inv_pair_weak (state_of m)" " heap. inv_pair_weak heap  success m heap"
  using that
  unfolding lift_p_def DP_CRelVS.lift_p_def
  by (auto simp: success_def split: option.split)

lemma rel_state_ofI2:
  "rel_state (=) (state_of m) m" if
  " heap. inv_pair_weak heap  success m heap"
  "DP_CRelVS.lift_p inv_pair_weak (state_of m)"
  using that by (blast intro: rel_state_ofI lift_p_success)

context
  includes lifting_syntax
begin

lemma [transfer_rule]:
  "((=) ===> rel_state (=)) move12 move12'"
  unfolding move12_def move12'_def
  apply (intro rel_funI)
  apply simp
  apply (rule rel_state_ofI2)
  subgoal
    by (auto
        simp: mem_empty_def inv_pair_lengthD1 execute_simps Let_def
        intro: success_intros intro!: success_bind_I
       )
  subgoal
    using pair.move12_inv unfolding move12_def .
  done

lemma [transfer_rule]:
  "((=) ===> rel_state (rel_option (=))) lookup1 mem_lookup1"
  unfolding lookup1_def mem_lookup1_def
  apply (intro rel_funI)
  apply (simp add: option.rel_eq)
  apply (rule rel_state_ofI2)
  subgoal
    by (auto 4 4
        simp: mem_lookup_def inv_pair_lengthD1 execute_simps Let_def
        intro: success_bind_executeI success_returnI Array.success_nthI
       )
  subgoal
    using pair.lookup_inv(1) unfolding lookup1_def .
  done

lemma [transfer_rule]:
  "((=) ===> rel_state (rel_option (=))) lookup2 mem_lookup2"
  unfolding lookup2_def mem_lookup2_def
  apply (intro rel_funI)
  apply (simp add: option.rel_eq)
  apply (rule rel_state_ofI2)
  subgoal
    by (auto 4 3
        simp: mem_lookup_def inv_pair_lengthD2 execute_simps Let_def
        intro: success_intros intro!: success_bind_I
       )
  subgoal
    using pair.lookup_inv(2) unfolding lookup2_def .
  done

lemma [transfer_rule]:
  "rel_state (=) get_k1 get_k1'"
  unfolding get_k1_def get_k1'_def
  apply (rule rel_state_ofI2)
  subgoal
    by (auto intro: success_lookupI)
  subgoal
    unfolding get_k1_def[symmetric] by (auto dest: pair.get_state(1) intro: lift_pI)
  done

lemma [transfer_rule]:
  "rel_state (=) get_k2 get_k2'"
  unfolding get_k2_def get_k2'_def
  apply (rule rel_state_ofI2)
  subgoal
    by (auto intro: success_lookupI)
  subgoal
    unfolding get_k2_def[symmetric] by (auto dest: pair.get_state(2) intro: lift_pI)
  done

lemma [transfer_rule]:
  "((=) ===> (=) ===> rel_state (=)) update1 update1'"
  unfolding update1_def update1'_def
  apply (intro rel_funI)
  apply simp
  apply (rule rel_state_ofI2)
  subgoal
    by (auto 4 3
        simp: mem_update_def inv_pair_lengthD1 execute_simps Let_def
        intro: success_intros intro!: success_bind_I
       )
  subgoal
    using pair.update_inv(1) unfolding update1_def .
  done

lemma [transfer_rule]:
  "((=) ===> (=) ===> rel_state (=)) update2 update2'"
  unfolding update2_def update2'_def
  apply (intro rel_funI)
  apply simp
  apply (rule rel_state_ofI2)
  subgoal
    by (auto 4 3
        simp: mem_update_def inv_pair_lengthD2 execute_simps Let_def
        intro: success_intros intro!: success_bind_I
       )
  subgoal
    using pair.update_inv(2) unfolding update2_def .
  done

lemma [transfer_rule]:
  "((=) ===> rel_state (rel_option (=))) lookup1 mem_lookup1"
  unfolding lookup1_def mem_lookup1_def
  apply (intro rel_funI)
  apply (simp add: option.rel_eq)
  apply (rule rel_state_ofI2)
  subgoal
    by (auto 4 3
        simp: mem_lookup_def inv_pair_lengthD1 execute_simps Let_def
        intro: success_intros intro!: success_bind_I
       )
  subgoal
    using pair.lookup_inv(1) unfolding lookup1_def .
  done

lemma rel_state_lookup:
  "((=) ===> rel_state (=)) pair.lookup_pair lookup_pair"
  unfolding pair.lookup_pair_def lookup_pair_def
  unfolding
    mem_lookup1_def[symmetric] mem_lookup2_def[symmetric]
    get_k2_def[symmetric] get_k2'_def[symmetric]
    get_k1_def[symmetric] get_k1'_def[symmetric]
  by transfer_prover

lemma rel_state_update:
  "((=) ===> (=) ===> rel_state (=)) pair.update_pair update_pair"
  unfolding pair.update_pair_def update_pair_def
  unfolding move12'_def[symmetric]
  unfolding
    update1'_def[symmetric] update2'_def[symmetric]
    get_k2_def[symmetric] get_k2'_def[symmetric]
    get_k1_def[symmetric] get_k1'_def[symmetric]
  by transfer_prover

interpretation mem: heap_mem_defs pair.inv_pair lookup_pair update_pair .

lemma inv_pairD:
  "inv_pair_weak heap" if "pair.inv_pair heap"
  using that unfolding pair.inv_pair_def by (auto simp: Let_def)

lemma mem_rel_state_ofI:
  "mem.rel_state (=) m' m" if
  "rel_state (=) m' m"
  " heap. pair.inv_pair heap 
    (case State_Monad.run_state m' heap of (_, heap)  inv_pair_weak heap  pair.inv_pair heap)"
proof -
  show ?thesis
    apply (rule mem.rel_state_intro)
    subgoal for heap v heap'
      by (auto elim: rel_state_elim[OF that(1)] dest!: inv_pairD)
    subgoal premises prems for heap v heap'
    proof -
      from prems that(1) have "inv_pair_weak heap'"
        by (fastforce elim: rel_state_elim dest: inv_pairD)
      with prems show ?thesis
        by (auto dest: that(2))
    qed
    done
qed

lemma mem_rel_state_ofI':
  "mem.rel_state (=) m' m" if
  "rel_state (=) m' m"
  "DP_CRelVS.lift_p pair.inv_pair m'"
  using that by (auto elim: DP_CRelVS.lift_p_P intro: mem_rel_state_ofI)

context
  assumes keys: "k k'. key1 k = key1 k'  key2 k = key2 k'  k = k'"
begin

interpretation mem_correct pair.lookup_pair pair.update_pair pair.inv_pair
  by (rule mem_correct_pair[OF keys])

lemma rel_state_lookup':
  "((=) ===> mem.rel_state (=)) pair.lookup_pair lookup_pair"
  apply (intro rel_funI)
  apply simp
  apply (rule mem_rel_state_ofI')
  using rel_state_lookup apply (rule rel_funD) apply (rule refl)
  apply (rule lookup_inv)
  done

lemma rel_state_update':
  "((=) ===> (=) ===> mem.rel_state (=)) pair.update_pair update_pair"
  apply (intro rel_funI)
  apply simp
  apply (rule mem_rel_state_ofI')
  subgoal for x y a b
    using rel_state_update by (blast dest: rel_funD)
  by (rule update_inv)

interpretation heap_correct pair.inv_pair update_pair lookup_pair
  by (rule mem.mem_correct_heap_correct[OF _ rel_state_lookup' rel_state_update']) standard

lemmas heap_correct_pairI = heap_correct_axioms 

(* TODO: Generalize *)
lemma mem_rel_state_resultD:
  "result_of m heap = fst (run_state m' heap)" if "mem.rel_state (=) m' m" "pair.inv_pair heap"
  by (metis (mono_tags, lifting) mem.rel_state_elim option.sel that)

lemma map_of_heap_eq:
  "mem.map_of_heap heap = pair.pair.map_of heap" if "pair.inv_pair heap"
  unfolding mem.map_of_heap_def pair.pair.map_of_def
  using that by (simp add: mem_rel_state_resultD[OF rel_state_lookup'[THEN rel_funD]])

context
  fixes k1 k2 heap heap'
  assumes init: "execute (init_state k1 k2) heap = Some ((k_ref1, k_ref2, m_ref1, m_ref2), heap')"
begin

lemma init_state_empty1:
  "pair.mem1.map_of heap' k = None"
  using init
  unfolding pair.mem1.map_of_def lookup1_def mem_lookup_def init_state_def
  by (auto
        simp: init_state_inner_nth init_state_inner_alloc(3) execute_simps Let_def
        elim!: execute_bind_success'[OF success_empty])
     (metis
        Array.present_alloc Memory_Heap.length_mem_empty execute_new execute_nth(1) fst_conv
        length_def mem_empty_def nth_mem_empty option.sel present_alloc_get snd_conv
     )

lemma init_state_empty2:
  "pair.mem2.map_of heap' k = None"
  using init
  unfolding pair.mem2.map_of_def lookup2_def mem_lookup_def init_state_def
  by (auto
        simp: execute_simps init_state_inner_nth init_state_inner_alloc(4) Let_def
        elim!: execute_bind_success'[OF success_empty]
     )
     (metis fst_conv nth_mem_empty option.sel snd_conv)

lemma
  shows init_state_k1: "result_of (!k_ref1) heap' = k1"
    and init_state_k2: "result_of (!k_ref2) heap' = k2"
  using init init_state_inner_alloc
  by (auto simp: execute_simps init_state_def elim!: execute_bind_success'[OF success_empty])

context
  assumes neq: "k1  k2"
begin

lemma init_state_inv':
  "pair.inv_pair heap'"
  unfolding pair.inv_pair_def
  apply (auto simp: Let_def)
  subgoal
    using init_state_empty1 by simp
  subgoal
    using init_state_empty2 by simp
  subgoal
    using neq init by (simp add: get_k1_def get_k2_def init_state_k1 init_state_k2)
  subgoal
    by (rule init_state_inv[OF init])
  done

lemma init_state_empty:
  "pair.pair.map_of heap' m Map.empty"
  using neq by (intro pair.emptyI init_state_inv' map_emptyI init_state_empty1 init_state_empty2)

interpretation heap_correct_empty pair.inv_pair update_pair lookup_pair heap'
  apply (rule heap_correct_empty.intro)
   apply (rule heap_correct_pairI)
  apply standard
  subgoal
    by (subst map_of_heap_eq; intro init_state_inv' init_state_empty)
  subgoal
    by (rule init_state_inv')
  done

lemmas heap_correct_empty_pairI = heap_correct_empty_axioms

context
  fixes dp :: "'k  'v"
begin

interpretation dp_consistency_heap_empty
  pair.inv_pair update_pair lookup_pair dp heap'
  by standard

lemmas consistent_empty_pairI = dp_consistency_heap_empty_axioms

end (* DP *)

end (* Unequal Keys *)

end (* Init State *)

end (* Keys injective *)

end (* Lifting Syntax *)

end (* Disjoint *)

end (* Injectivity *)

end (* Refs *)

end (* Key functions & Size *)

end (* Theory *)

Theory Transform_Cmd

subsection ‹Tool Setup›

theory Transform_Cmd
  imports
    "../Pure_Monad"
    "../state_monad/DP_CRelVS"
    "../heap_monad/DP_CRelVH"
  keywords
    "memoize_fun" :: thy_decl
    and "monadifies" :: thy_decl
    and "memoize_correct" :: thy_goal
    and "with_memory" :: quasi_command
    and "default_proof" :: quasi_command
begin

ML_file ‹../transform/Transform_Misc.ML›
ML_file ‹../transform/Transform_Const.ML›
ML_file ‹../transform/Transform_Data.ML›
ML_file ‹../transform/Transform_Tactic.ML›
ML_file ‹../transform/Transform_Term.ML›
ML_file ‹../transform/Transform.ML›
ML_file ‹../transform/Transform_Parser.ML›

ML val _ =
  Outer_Syntax.local_theory @{command_keyword memoize_fun} "whatever"
    (Transform_Parser.dp_fun_part1_parser >> Transform_DP.dp_fun_part1_cmd)

val _ =
  Outer_Syntax.local_theory @{command_keyword monadifies} "whatever"
    (Transform_Parser.dp_fun_part2_parser >> Transform_DP.dp_fun_part2_cmd)

ML val _ =
  Outer_Syntax.local_theory_to_proof @{command_keyword memoize_correct} "whatever"
    (Scan.succeed Transform_DP.dp_correct_cmd)

method_setup memoize_prover = ‹
Scan.succeed (fn ctxt => (SIMPLE_METHOD' (
  Transform_Data.get_last_cmd_info ctxt
  |> Transform_Tactic.solve_consistentDP_tac ctxt)))

method_setup memoize_prover_init = ‹
Scan.succeed (fn ctxt => (SIMPLE_METHOD' (
  Transform_Data.get_last_cmd_info ctxt
  |> Transform_Tactic.prepare_consistentDP_tac ctxt)))

method_setup memoize_prover_case_init = ‹
Scan.succeed (fn ctxt => (SIMPLE_METHOD' (
  Transform_Data.get_last_cmd_info ctxt
  |> Transform_Tactic.prepare_case_tac ctxt)))

method_setup memoize_prover_match_step = ‹
Scan.succeed (fn ctxt => (SIMPLE_METHOD' (
  Transform_Data.get_last_cmd_info ctxt
  |> Transform_Tactic.step_tac ctxt)))

method_setup memoize_unfold_defs  = ‹
Scan.option (Scan.lift (Args.parens Args.name) -- Args.term) >> (fn tm_opt => fn ctxt => (SIMPLE_METHOD' (
  Transform_Data.get_or_last_cmd_info ctxt tm_opt
  |> Transform_Tactic.dp_unfold_defs_tac ctxt)))

method_setup memoize_combinator_init  = ‹
Scan.option (Scan.lift (Args.parens Args.name) -- Args.term) >> (fn tm_opt => fn ctxt => (SIMPLE_METHOD' (
  Transform_Data.get_or_last_cmd_info ctxt tm_opt
  |> Transform_Tactic.prepare_combinator_tac ctxt)))
end (* theory *)

File ‹Transform_Misc.ML›

structure Transform_Misc = struct
  fun import_function_info term_opt ctxt =
    case term_opt of
      SOME tm => (case Function_Common.import_function_data tm ctxt of
        SOME info => info
      | NONE => raise TERM("not a function", [tm]))
    | NONE => (case Function_Common.import_last_function ctxt of
        SOME info => info
      | NONE => error "no function defined yet")

  fun get_const_pat ctxt tm_pat =
    let val (Const (name, _)) = Proof_Context.read_const {proper=false, strict=false} ctxt tm_pat
    in Const (name, dummyT) end

  fun head_of (func_info: Function.info) = #fs func_info |> the_single
  fun bind_of (func_info: Function.info) = #fnames func_info |> the_single

  fun totality_of (func_info: Function.info) =
    func_info |> #totality |> the;

  fun rel_of (func_info: Function.info) ctxt =
    Inductive.the_inductive ctxt (#R func_info) |> snd |> #eqs |> the_single;

  fun the_element l =
    if tl l |> find_first (not o equal (hd l)) |> is_none
      then hd l
      else (@{print} l; error "inconsistent n_args")

  fun add_function bind defs =
    let
      val fixes = [(bind, NONE, NoSyn)];
      val specs = map (fn def => (((Binding.empty, []), def), [], [])) defs
      val pat_completeness_auto = fn ctxt =>
        Pat_Completeness.pat_completeness_tac ctxt 1
        THEN auto_tac ctxt
    in
      Function.add_function fixes specs Function_Fun.fun_config pat_completeness_auto
    end

  fun behead head tm =
    let
      val head_nargs = strip_comb head |> snd |> length
      val (tm_head, tm_args) = strip_comb tm
      val (tm_args0, tm_args1) = chop head_nargs tm_args
      val tm_head' = list_comb (tm_head, tm_args0)
      val _ = if Term.aconv_untyped (head, tm_head')
        then () else raise TERM("head does not match", [head, tm_head'])
    in
      (tm_head', tm_args1)
    end

  fun term_name tm =
    if is_Free tm orelse is_Const tm
      then Term.term_name tm
      else raise TERM("not an atom, explicit name required", [tm])

  fun locale_term lthy locale_name term_name =
    Syntax.read_term lthy (Long_Name.qualify locale_name term_name)

  fun locale_thms lthy locale_name thms_name =
    Proof_Context.get_thms lthy (Long_Name.qualify locale_name thms_name)

  fun uncurry tm =
    let
      val arg_typs = fastype_of tm |> binder_types
      val arg_names = Name.invent_list [] "a" (length arg_typs)
      val args = map Free (arg_names ~~ arg_typs)
      val args_tuple = HOLogic.mk_tuple args
      val tm' = list_comb (tm, args) |> HOLogic.tupled_lambda args_tuple
    in
      tm'
    end
end

File ‹Transform_Const.ML›

structure Transform_Const = struct
  val pureappN = @{const_name Pure_Monad.App}
  fun pureapp tm = Const (pureappN, dummyT) $ tm

  type MONAD_CONSTS = {
    monad_name: string,
    mk_stateT: typ -> typ,
    return: term -> term,
    app: (term * term) -> term,
    if_termN: string,
    checkmemVN: string,
    rewrite_app_beta_conv: conv
  }
  
  val state_monad: MONAD_CONSTS =
    let
      val memT = TFree ("MemoryType", @{sort type})
      val memT = dummyT
    
      fun mk_stateT tp =
        Type (@{type_name State_Monad.state}, [memT, tp])
    
      val returnN = @{const_name State_Monad.return}
      fun return tm = Const (returnN, dummyT --> mk_stateT dummyT) $ tm
    
      val appN = @{const_name State_Monad_Ext.fun_app_lifted}
      fun app (tm0, tm1) = Const (appN, dummyT) $ tm0 $ tm1
    
      fun checkmem'C ctxt = Transform_Misc.get_const_pat ctxt "checkmem'"
      fun checkmem' ctxt param body = checkmem'C ctxt $ param $ body
    
      val checkmemVN = "checkmem"
      val checkmemC = @{const_name "state_mem_defs.checkmem"}

      fun rewrite_app_beta_conv ctm =
        case Thm.term_of ctm of
          Const (@{const_name State_Monad_Ext.fun_app_lifted}, _)
            $ (Const (@{const_name State_Monad.return}, _) $ Abs _)
            $ (Const (@{const_name State_Monad.return}, _) $ _)
            => Conv.rewr_conv @{thm State_Monad_Ext.return_app_return_meta} ctm
        | _ => Conv.no_conv ctm

    in {
      monad_name = "state",
      mk_stateT = mk_stateT,
      return = return,
      app = app,
      if_termN = @{const_name State_Monad_Ext.ifT},
      checkmemVN = checkmemVN,
      rewrite_app_beta_conv = rewrite_app_beta_conv
    } end

  val heap_monad: MONAD_CONSTS =
    let
      fun mk_stateT tp =
        Type (@{type_name Heap_Monad.Heap}, [tp])
    
      val returnN = @{const_name Heap_Monad.return}
      fun return tm = Const (returnN, dummyT --> mk_stateT dummyT) $ tm
    
      val appN = @{const_name Heap_Monad_Ext.fun_app_lifted}
      fun app (tm0, tm1) = Const (appN, dummyT) $ tm0 $ tm1
    
      fun checkmem'C ctxt = Transform_Misc.get_const_pat ctxt "checkmem'"
      fun checkmem' ctxt param body = checkmem'C ctxt $ param $ body
    
      val checkmemVN = "checkmem"
      val checkmemC = @{const_name "heap_mem_defs.checkmem"}

      fun rewrite_app_beta_conv ctm =
        case Thm.term_of ctm of
          Const (@{const_name Heap_Monad_Ext.fun_app_lifted}, _)
            $ (Const (@{const_name Heap_Monad.return}, _) $ Abs _)
            $ (Const (@{const_name Heap_Monad.return}, _) $ _)
            => Conv.rewr_conv @{thm Heap_Monad_Ext.return_app_return_meta} ctm
        | _ => Conv.no_conv ctm
    in {
      monad_name = "heap",
      mk_stateT = mk_stateT,
      return = return,
      app = app,
      if_termN = @{const_name Heap_Monad_Ext.ifT},
      checkmemVN = checkmemVN,
      rewrite_app_beta_conv = rewrite_app_beta_conv
    } end

  val monad_consts_dict = [
    ("state", state_monad),
    ("heap", heap_monad)
  ]

  fun get_monad_const name =
    case AList.lookup op= monad_consts_dict name of
      SOME consts => consts
    | NONE => error("unrecognized monad: " ^ name ^ " , choices: " ^ commas (map fst monad_consts_dict));

end

File ‹Transform_Data.ML›

structure Transform_Data = struct

type dp_info = {
  old_head: term,
  new_head': term,
  new_headT: term,

  old_defs: thm list,
  new_defT: thm,
  new_def': thm list
}

type cmd_info = {
  scope: binding,
  head: term,
  locale: string option,
  dp_info: dp_info option
}

fun map_cmd_info f0 f1 f2 f3 {scope, head, locale, dp_info} =
  {scope = f0 scope, head = f1 head, locale = f2 locale, dp_info = f3 dp_info}

fun map_cmd_dp_info f = map_cmd_info I I I f

structure Data = Generic_Data (
  type T = {
    monadified_terms: (string * cmd_info Item_Net.T) list,
    last_cmd_info: cmd_info option
  }

  val empty = {
    monadified_terms =
      ["state", "heap"]
      ~~ replicate 2 (Item_Net.init (op aconv o apply2 #head) (single o #head)),
    last_cmd_info = NONE
  }

  val extend = I

  fun merge (
    {monadified_terms = m0, ...},
    {monadified_terms = m1, ...}
  ) =
    let
      val keys0 = map fst m0
      val keys1 = map fst m1
      val _ = @{assert} (keys0 = keys1)
      val vals = map Item_Net.merge (map snd m0 ~~ map snd m1)
      val ms = keys0 ~~ vals
    in
      {monadified_terms = ms, last_cmd_info = NONE}
    end
)

fun transform_dp_info phi {old_head, new_head', new_headT, old_defs, new_defT, new_def'} =
  {
    old_head = Morphism.term phi old_head,
    new_head' = Morphism.term phi new_head',
    new_headT = Morphism.term phi new_headT,
  
    old_defs = Morphism.fact phi old_defs,
    new_def' = Morphism.fact phi new_def',
    new_defT = Morphism.thm phi new_defT
  }

fun get_monadified_terms_generic monad_name ctxt =
  Data.get ctxt
  |> #monadified_terms
  |> (fn l => AList.lookup op= l monad_name)
  |> the

fun get_monadified_terms monad_name lthy =
  get_monadified_terms_generic monad_name (Context.Proof lthy)

fun map_data f0 f1 = Data.map (fn {monadified_terms, last_cmd_info} =>
  {monadified_terms = f0 monadified_terms, last_cmd_info = f1 last_cmd_info})

fun map_monadified_terms f = map_data f I
fun map_last_cmd_info f    = map_data I f

fun put_monadified_terms_generic monad_name new_terms ctxt =
  ctxt |> map_monadified_terms (AList.update op= (monad_name, new_terms))

fun map_monadified_terms_generic monad_name f ctxt =
  ctxt |> map_monadified_terms (AList.map_entry op= monad_name f)

fun put_last_cmd_info cmd_info_opt ctxt =
  map_last_cmd_info (K cmd_info_opt) ctxt

fun get_cmd_info monad_name lthy tm =
  get_monadified_terms monad_name lthy
  |> (fn net => Item_Net.retrieve net tm)

fun get_dp_info monad_name lthy tm =
  get_cmd_info monad_name lthy tm
  |> (fn result => case result of
      {dp_info = SOME dp_info', ...} :: _ => SOME dp_info'
    | _ => NONE)

fun get_last_cmd_info_generic ctxt =
  Data.get ctxt
  |> #last_cmd_info
  |> the

fun get_last_cmd_info lthy =
  get_last_cmd_info_generic (Context.Proof lthy)

fun commit_dp_info monad_name dp_info =
  Local_Theory.declaration
    {pervasive = false, syntax = false}
    (fn phi => fn ctxt =>
      let
        val old_cmd_info = get_last_cmd_info_generic ctxt
        val new_dp_info = transform_dp_info phi dp_info
        val new_cmd_info = old_cmd_info |> map_cmd_dp_info (K (SOME new_dp_info))
      in
        ctxt
        |> map_monadified_terms_generic monad_name (Item_Net.update new_cmd_info)
         |> put_last_cmd_info (SOME new_cmd_info)
      end)

fun add_tmp_cmd_info (scope, head, locale_opt) =
  Local_Theory.declaration
    {pervasive = false, syntax = false}
    (fn phi => fn ctxt =>
      let
        val new_cmd_info = {
          scope = Morphism.binding phi scope,
          head = Morphism.term phi head,
          locale = locale_opt,
          dp_info = NONE
        }
      in
        ctxt |> put_last_cmd_info (SOME new_cmd_info)
      end )

fun get_or_last_cmd_info lthy monad_name_tm_opt =
  case monad_name_tm_opt of
    NONE => get_last_cmd_info lthy
  | SOME (monad_name, tm) => get_cmd_info monad_name lthy tm |> the_single

end

File ‹Transform_Tactic.ML›

structure Transform_Tactic = struct
  fun my_print_tac msg st = (tracing msg; all_tac st)
  
  fun totality_resolve_tac totality0 def0 def1 ctxt =
    let
      val totality0_unfolded = totality0 |> Local_Defs.unfold ctxt [def0]
      val totality1 = totality0_unfolded |> Local_Defs.fold ctxt [def1]
    in
      if Thm.full_prop_of totality0_unfolded aconv Thm.full_prop_of totality1
        then
          let
            val msg = Pretty.string_of (Pretty.block [
              Pretty.str "Failed to transform totality from", Pretty.brk 1,
              Pretty.quote (Syntax.pretty_term ctxt (Thm.full_prop_of def0)), Pretty.brk 1,
              Pretty.str "to", Pretty.brk 1,
              Pretty.quote (Syntax.pretty_term ctxt (Thm.full_prop_of def1)), Pretty.brk 1])
          in (*print_tac ctxt msg THEN*) no_tac end
        else HEADGOAL (resolve_tac ctxt [totality1])
    end
  
  fun totality_blast_tac totality0 def0 def1 ctxt =
    HEADGOAL (
      (resolve_tac ctxt [totality0 RS @{thm rev_iffD1}])
      THEN' (resolve_tac ctxt [@{thm arg_cong[where f=HOL.All]}])
      THEN' SELECT_GOAL (unfold_tac ctxt (map (Local_Defs.abs_def_rule ctxt) [def0, def1]))
      THEN' (resolve_tac ctxt [@{thm arg_cong[where f=Wellfounded.accp]}])
      THEN' (Blast.depth_tac ctxt 2)
    )
  
  fun totality_replay_tac old_info new_info ctxt =
    let
      val totality0 = Transform_Misc.totality_of old_info
      val def0 = Transform_Misc.rel_of old_info ctxt
      val def1 = Transform_Misc.rel_of new_info ctxt
      fun my_print_tac msg st = (tracing msg; all_tac st)
    in
      no_tac
      ORELSE (totality_resolve_tac totality0 def0 def1 ctxt
        THEN my_print_tac "termination by replaying")
      ORELSE (totality_blast_tac totality0 def0 def1 ctxt
        THEN my_print_tac "termination by blast")
    end

  fun dp_intro_tac ctxt (cmd_info: Transform_Data.cmd_info) =
    let
      val scope_name = Binding.name_of (#scope cmd_info)
      val consistentDP_rule = Transform_Misc.locale_thms ctxt scope_name "consistentDP_intro"
    in
      resolve_tac ctxt consistentDP_rule
    end

  fun expand_relator_tac ctxt =
    SELECT_GOAL (Local_Defs.fold_tac ctxt (Transfer.get_relator_eq ctxt))

  fun solve_relator_tac ctxt =
    SOLVED' (Transfer.eq_tac ctxt)

  fun split_params_tac ctxt =
    clarify_tac ctxt

  fun dp_induct_tac ctxt (cmd_info: Transform_Data.cmd_info) =
    let
      val dpT' = cmd_info |> #dp_info |> the |> #new_head'
      val dpT'_info = Function.get_info ctxt dpT'
      val induct_rule = dpT'_info |> #inducts |> the
    in
      resolve_tac ctxt induct_rule
    end

  fun dp_unfold_def_tac ctxt (cmd_info: Transform_Data.cmd_info) sel =
    cmd_info |> #dp_info |> the |> sel
    |> map (Local_Defs.meta_rewrite_rule ctxt)
    |> Conv.rewrs_conv 
    |> Conv.try_conv
    |> Conv.binop_conv
    |> HOLogic.Trueprop_conv 
    |> Conv.concl_conv ~1
    |> (fn cv => Conv.params_conv ~1 (K cv) ctxt)
    |> CONVERSION
    (* |> EqSubst.eqsubst_tac ctxt [0] : may rewrite locale parameters in certain situations *)

  fun dp_match_rule_tac ctxt (cmd_info: Transform_Data.cmd_info) =
    let
      val scope_name = Binding.name_of (#scope cmd_info)
      val dp_match_rules = Transform_Misc.locale_thms ctxt scope_name "dp_match_rule"
    in
      resolve_tac ctxt dp_match_rules
    end

  fun checkmem_tac ctxt (cmd_info: Transform_Data.cmd_info) =
    let
      val scope_name = Binding.name_of (#scope cmd_info)
      val dp_match_rules = Transform_Misc.locale_thms ctxt scope_name "crel_vs_checkmem_tupled"
    in
      resolve_tac ctxt dp_match_rules
      THEN' SOLVED' (clarify_tac ctxt)
      THEN' Transfer.eq_tac ctxt
    end

  fun solve_IH_tac ctxt =
    Method.assm_tac ctxt

  fun transfer_raw_tac ctxt =
    resolve_tac ctxt (Transfer.get_transfer_raw ctxt)

  fun step_tac ctxt (cmd_info: Transform_Data.cmd_info) =
    solve_IH_tac ctxt
    ORELSE' solve_relator_tac ctxt
    ORELSE' dp_match_rule_tac ctxt cmd_info
    ORELSE' transfer_raw_tac ctxt

  fun prepare_case_tac ctxt (cmd_info: Transform_Data.cmd_info) =
    dp_unfold_def_tac ctxt cmd_info #new_def'
    THEN' checkmem_tac ctxt cmd_info
    THEN' dp_unfold_def_tac ctxt cmd_info #old_defs

  fun solve_case_tac ctxt (cmd_info: Transform_Data.cmd_info) =
    prepare_case_tac ctxt cmd_info
    THEN' REPEAT_ALL_NEW (step_tac ctxt cmd_info)

  fun prepare_consistentDP_tac ctxt (cmd_info: Transform_Data.cmd_info) =
    dp_intro_tac ctxt cmd_info
    THEN' expand_relator_tac ctxt
    THEN' split_params_tac ctxt
    THEN' dp_induct_tac ctxt cmd_info

  fun solve_consistentDP_tac ctxt (cmd_info: Transform_Data.cmd_info) =
    prepare_consistentDP_tac ctxt cmd_info
    THEN_ALL_NEW SOLVED' (solve_case_tac ctxt cmd_info)

  fun prepare_combinator_tac ctxt (cmd_info: Transform_Data.cmd_info) =
     EqSubst.eqsubst_tac ctxt [0] @{thms Rel_def[symmetric]}
     THEN' dp_unfold_def_tac ctxt cmd_info (single o #new_defT)
     THEN' REPEAT_ALL_NEW (resolve_tac ctxt (@{thm Rel_abs} :: Transform_Misc.locale_thms ctxt "local" "crel_vs_return_ext"))
     THEN' (SELECT_GOAL (unfold_tac ctxt @{thms Rel_def}))

   fun dp_unfold_defs_tac ctxt (cmd_info: Transform_Data.cmd_info) =
     dp_unfold_def_tac ctxt cmd_info #new_def'
     THEN' dp_unfold_def_tac ctxt cmd_info #old_defs

end

File ‹Transform_Term.ML›

fun list_conv (head_conv, arg_convs) lthy =
  Library.foldl (uncurry Conv.combination_conv) (head_conv lthy, map (fn conv => conv lthy) arg_convs)

fun eta_conv1 ctxt =
  (Conv.abs_conv (K Conv.all_conv) ctxt)
  else_conv
  (Thm.eta_long_conversion then_conv Conv.abs_conv (K Thm.eta_conversion) ctxt)

fun eta_conv_n n =
  funpow n (fn conv => fn ctxt => eta_conv1 ctxt then_conv Conv.abs_conv (fn (_, ctxt) => conv ctxt) ctxt) (K Conv.all_conv)

fun conv_changed conv ctm =
  let val eq = conv ctm
  in if Thm.is_reflexive eq then Conv.no_conv ctm else eq end

fun repeat_sweep_conv conv =
  Conv.repeat_conv o conv_changed o Conv.top_sweep_conv conv

val app_mark_conv = Conv.rewr_conv @{thm App_def[symmetric]}
val app_unmark_conv = Conv.rewr_conv @{thm App_def}
val wrap_mark_conv = Conv.rewr_conv @{thm Wrap_def[symmetric]}

structure Transform_Term = struct

fun eta_expand tm =
  let
    val n_args = Integer.min 1 (length (binder_types (fastype_of tm)))
    val (args, body) = Term.strip_abs_eta n_args tm
  in
    Library.foldr (uncurry Term.absfree) (args, body)
  end

fun is_ctr_sugar ctxt tp_name =
  is_some (Ctr_Sugar.ctr_sugar_of ctxt tp_name)

fun type_nargs tp = tp |> strip_type |> fst |> length
fun term_nargs tm = type_nargs (fastype_of tm)

fun
  lift_type (monad_consts: Transform_Const.MONAD_CONSTS) ctxt tp = #mk_stateT monad_consts  (lift_type' monad_consts ctxt tp)
and
  lift_type' monad_consts ctxt (tp as Type (@{type_name fun}, _))
    = lift_type' monad_consts ctxt (domain_type tp) --> lift_type monad_consts ctxt (range_type tp)
| lift_type' monad_consts ctxt (tp as Type (tp_name, tp_args))
    = if is_ctr_sugar ctxt tp_name then Type (tp_name, map (lift_type' monad_consts ctxt) tp_args)
      else if null tp_args then tp (* int, nat, … *)
      else raise TYPE("not a ctr_sugar", [tp], [])
| lift_type' _ _ tp = tp

fun is_atom_type monad_consts ctxt tp =
  tp = lift_type' monad_consts ctxt tp

fun is_1st_type monad_consts ctxt tp =
  body_type tp :: binder_types tp
  |> forall (is_atom_type monad_consts ctxt)

fun orig_atom ctxt atom_name =
  Proof_Context.read_term_pattern ctxt atom_name

fun is_1st_term monad_consts ctxt tm =
  is_1st_type monad_consts ctxt (fastype_of tm)

fun is_1st_atom monad_consts ctxt atom_name =
  is_1st_term monad_consts ctxt (orig_atom ctxt atom_name)

fun wrap_1st_term monad_consts ctxt tm n_args_opt inner_wrap =
  let
    val n_args = the_default (term_nargs tm) n_args_opt
    val (vars_name_typ, body) = Term.strip_abs_eta n_args tm
    fun wrap (name_typ, (conv, tm)) = (
      eta_conv1 ctxt then_conv Conv.abs_conv (K conv) ctxt then_conv wrap_mark_conv,
      #return monad_consts (Term.absfree name_typ tm)
    )
    val (conv, result) = Library.foldr wrap (vars_name_typ, (
      if inner_wrap then (wrap_mark_conv, #return monad_consts body) else (Conv.all_conv, body) 
    ))
    
  in
    (K conv, result)
  end

fun lift_1st_atom monad_consts ctxt atom (name, tp) =
  let
    val (arg_typs, body_typ) = strip_type tp

    val n_args = term_nargs (orig_atom ctxt name)

    val (arg_typs, body_arg_typs) = chop n_args arg_typs
    val arg_typs' = map (lift_type' monad_consts ctxt) arg_typs
    val body_typ' = lift_type' monad_consts ctxt (body_arg_typs ---> body_typ)

    val tm' = atom (name, arg_typs' ---> body_typ') (* " *)
  in
    wrap_1st_term monad_consts ctxt tm' (SOME n_args) true
  end

fun fixed_args head_n_args tm =
  let
    val (tm_head, tm_args) = strip_comb tm
    val n_tm_args = length tm_args
  in
    head_n_args tm_head
    |> Option.mapPartial (fn n_args =>
      if n_tm_args > n_args then NONE
      else if n_tm_args < n_args then raise TERM("need " ^ string_of_int n_args ^ " args", [tm])
      else SOME (tm_head, tm_args))
  end

fun lift_abs' monad_consts ctxt (name, typ) cont lift_dict body =
  let
    val free = Free (name, typ)
    val typ' = lift_type' monad_consts ctxt typ
    val freeT' = Free (name, typ')
    val freeT = #return monad_consts (freeT')

    val lift_dict' = if is_atom_type monad_consts ctxt typ
      then lift_dict
      else (free, (K wrap_mark_conv, freeT))::lift_dict
    val (conv_free, body_free) = (cont (lift_dict') body)

    val body' = lambda freeT' body_free
    fun conv lthy =
      eta_conv1 ctxt then_conv Conv.abs_conv (fn (_, lthy') => conv_free lthy') lthy
  in
    (conv, body')
  end

fun lift_arg monad_consts ctxt lift_dict tm =
  (*
  let
    val (conv, tm') = lift_term ctxt lift_dict (eta_expand tm)
    fun conv' ctxt = Conv.try_conv (eta_conv1 ctxt) then_conv (conv ctxt)
  in
    (conv', tm')
  end

  eta_expand AFTER lifting
  *)
  lift_term monad_consts ctxt lift_dict tm
and lift_term monad_consts ctxt lift_dict tm = let
  val case_terms = Ctr_Sugar.ctr_sugars_of ctxt |> map #casex

  fun lookup_case_term tm =
    find_first (fn x => Term.aconv_untyped (x, tm)) case_terms

  val check_cont = lift_term monad_consts ctxt
  val check_cont_arg = lift_arg monad_consts ctxt

  fun check_const tm =
    case tm of
      Const (_, typ) => (
        case Transform_Data.get_dp_info (#monad_name monad_consts) ctxt tm of
          SOME {new_headT=Const (name, _), ...} => SOME (K Conv.all_conv, Const (name, lift_type monad_consts ctxt typ))
        | SOME {new_headT=tm', ...} => raise TERM("not a constant", [tm'])
        | NONE => NONE)
    | _ => NONE

  fun check_1st_atom tm =
    case tm of
      Const (name, typ) =>
        if is_1st_atom monad_consts ctxt name then SOME (lift_1st_atom monad_consts ctxt Const (name, typ)) else NONE
    | Free (name, typ) =>
        if is_1st_atom monad_consts ctxt name then SOME (lift_1st_atom monad_consts ctxt Free (name, typ)) else NONE
    | _ => (*
        if is_1st_term ctxt tm andalso exists_subterm (AList.defined (op aconv) lift_dict) tm
          then SOME (wrap_1st_term tm NONE)
          else *) NONE

(*
  fun check_dict tm =
    (* TODO: map -> mapT, dummyT *)
    AList.lookup Term.aconv_untyped lift_dict tm
    |> Option.map (fn tm' =>
      if is_Const tm
        then (@{assert} (is_Const tm'); map_types (K (lift_type ctxt (type_of tm))) tm')
        else tm')
*)

  fun check_dict tm =
    AList.lookup Term.aconv_untyped lift_dict tm

  fun check_if tm =
    fixed_args (fn head => if Term.aconv_untyped (head, @{term If}) then SOME 3 else NONE) tm
    |> Option.map (fn (_, args) =>
      let
        val (arg_convs, args') = map (check_cont lift_dict) args |> split_list
        val conv = list_conv (K Conv.all_conv, arg_convs)
        val tm' = list_comb (Const (#if_termN monad_consts, dummyT), args')
      in
        (conv, tm')
      end)

  fun check_abs tm =
    case tm of
      Abs (name, typ, body) =>
        let
          val (name', body') = Term.dest_abs (name, typ, body)
          val (conv, tm') = lift_abs' monad_consts ctxt (name', typ) check_cont lift_dict body'
          fun convT lthy = conv lthy then_conv wrap_mark_conv
          val tmT = #return monad_consts tm'
        in SOME (convT, tmT) end
    | _ => NONE

  fun check_case tm =
    fixed_args (lookup_case_term #> Option.map (fn cs => term_nargs cs - 1)) tm
    |> Option.map (fn (head, args) =>
      let
        val (case_name, case_type) = lookup_case_term head |> the |> dest_Const
        val ((clause_typs, _), _) =
          strip_type case_type |>> split_last
        
        val clase_nparams = clause_typs |> map type_nargs
        (* ('a⇒'b) ⇒ ('a⇒'b) |> type_nargs = 1*)
  
        fun lift_clause n_param clause =
          let
            val (vars_name_typ, body) = Term.strip_abs_eta n_param clause
            val abs_lift_wraps = map (lift_abs' monad_consts ctxt) vars_name_typ
            val lift_wrap = Library.foldr (op o) (abs_lift_wraps, I) check_cont
            val (conv, clauseT) = lift_wrap lift_dict body
          in
            (conv, clauseT)
          end
  
        val head' = Const (case_name, dummyT) (* clauses are sufficient for type inference *)
        val (convs, clauses') = map2 lift_clause clase_nparams args |> split_list

        fun conv lthy = list_conv (K Conv.all_conv, convs) lthy then_conv wrap_mark_conv
        val tm' = #return monad_consts (list_comb (head', clauses'))
      in
        (conv, tm')
      end)

  fun check_app tm =
    case tm of
      f $ x =>
        let
          val (f_conv, tmf) = check_cont lift_dict f
          val (x_conv, tmx) = check_cont_arg lift_dict x
          val tm' = #app monad_consts (tmf, tmx)
          fun conv lthy = Conv.combination_conv (f_conv lthy then_conv app_mark_conv) (x_conv lthy)
        in
          SOME (conv, tm')
        end
    | _ => NONE

  fun check_pure tm =
    if tm |> exists_subterm (AList.defined (op aconv) lift_dict)
      orelse not (is_1st_term monad_consts ctxt tm)
      then NONE
      else SOME (wrap_1st_term monad_consts ctxt tm NONE true)

  fun choke tm =
    raise TERM("cannot process term", [tm])

  val checks = [
    check_pure,
    check_const,
    check_case,
    check_if,
    check_abs,
    check_app,
    check_dict,
    check_1st_atom,
    choke
  ]
  in get_first (fn check => check tm) checks |> the end

fun rewrite_pureapp_beta_conv ctm =
  case Thm.term_of ctm of
    Const (@{const_name Pure_Monad.App}, _)
      $ (Const (@{const_name Pure_Monad.Wrap}, _) $ Abs _)
      $ (Const (@{const_name Pure_Monad.Wrap}, _) $ _)
      => Conv.rewr_conv @{thm Wrap_App_Wrap} ctm
  | _ => Conv.no_conv ctm

fun monadify monad_consts ctxt tm =
  let
    val (_, tm) = lift_term monad_consts ctxt [] tm
    (*val tm = rewrite_return_app_return tm*)
    val tm = Syntax.check_term ctxt tm
  in
    tm
  end

fun wrap_head (monad_consts: Transform_Const.MONAD_CONSTS) head n_args =
  Library.foldr
    (fn (typ, tm) => #return monad_consts (absdummy typ tm))
    (replicate n_args dummyT, list_comb (head, rev (map_range Bound n_args)))

fun lift_head monad_consts ctxt head n_args =
  let
    val dest_head = if is_Const head then dest_Const else dest_Free
    val (head_name, head_typ) = dest_head head
    val (arg_typs, body_typ) = strip_type head_typ
    val (arg_typs0, arg_typs1) = chop n_args arg_typs
    val arg_typs0' = map (lift_type' monad_consts ctxt) arg_typs0
    val arg_typs1T = lift_type monad_consts ctxt (arg_typs1 ---> body_typ)
    val head_typ' = arg_typs0' ---> arg_typs1T
    val head' = Free (head_name, head_typ')

    val (head_conv, headT) = wrap_1st_term monad_consts ctxt head' (SOME n_args) false
  in
    (head', (head_conv, headT))
  end

fun lift_equation monad_consts ctxt (lhs, rhs) memoizer_opt =
  let
    val (head, args) = strip_comb lhs
    val n_args = length args
    val (head', (head_conv, headT)) = lift_head monad_consts ctxt head n_args
    val args' = args |> map (map_aterms (fn tm => tm |> map_types
      (if is_Const tm then K dummyT else lift_type' monad_consts ctxt)))
    val lhs' = list_comb (head', args')

    val frees = fold Term.add_frees args []
     |> filter_out (is_atom_type monad_consts ctxt o snd)
    
    val lift_dict_args = frees |> map (fn (name, typ) => (
      Free (name, typ), 
      (K wrap_mark_conv,
       #return monad_consts (Free (name, lift_type' monad_consts ctxt typ)))
    ))
    val lift_dict = (head, (head_conv, headT)) :: lift_dict_args
    val (rhs_conv, rhsT) = lift_term monad_consts ctxt lift_dict rhs

    val rhsT_memoized = case memoizer_opt of
      SOME memoizer =>
        memoizer
        $ HOLogic.mk_tuple args
        $ rhsT
    | NONE => rhsT

    val eqs' = (lhs', rhsT_memoized) |> HOLogic.mk_eq |> HOLogic.mk_Trueprop
  in
    (rhs_conv, eqs', n_args)
  end


end

File ‹Transform.ML›

structure Transform_DP = struct

fun dp_interpretation standard_proof locale_name instance qualifier dp_term lthy =
  lthy
  |> Interpretation.isar_interpretation ([(locale_name, ((qualifier, true), (Expression.Named (("dp", dp_term) :: instance), [])))], [])
  |> (if standard_proof then Proof.global_default_proof else Proof.global_immediate_proof)

fun prep_params (((scope, tm_str), def_thms_opt), mem_locale_opt) lthy =
  let
    val tm = Syntax.read_term lthy tm_str
    val scope' = (Binding.is_empty scope? Binding.map_name (fn _ => Transform_Misc.term_name tm ^ "T")) scope
    val def_thms_opt' = Option.map (Attrib.eval_thms lthy) def_thms_opt
    val mem_locale_opt' = Option.map (Locale.check (Proof_Context.theory_of lthy)) mem_locale_opt
  in
    (scope', tm, def_thms_opt', mem_locale_opt')
  end
(*
fun dp_interpretation_cmd args lthy =
  let
    val (scope, tm, _, mem_locale_opt) = prep_params args lthy
    val scope_name = Binding.name_of scope
  in
    case mem_locale_opt of
      NONE => lthy
    | SOME x => dp_interpretation x scope_name (Transform_Misc.uncurry tm) lthy
  end
*)

fun do_monadify heap_name scope tm mem_locale_opt def_thms_opt lthy =
  let
    val monad_consts = Transform_Const.get_monad_const heap_name
    val scope_name = Binding.name_of scope

    val memoizer_opt = if is_none mem_locale_opt then NONE else
      SOME (Transform_Misc.locale_term lthy scope_name "checkmem")

    val old_info_opt = Function_Common.import_function_data tm lthy
    val old_defs_opt = [
      K def_thms_opt,
      K (Option.mapPartial #simps old_info_opt)
    ] |> Library.get_first (fn x => x ())

    val old_defs = case old_defs_opt of
      SOME defs => defs
    | NONE => raise TERM("no definition", [tm])

    val ((_, old_defs_imported), _) = Variable.import true old_defs lthy
(*
    val new_bind = Binding.suffix_name "T'" scope
    val new_bindT = Binding.suffix_name "T" scope
*)
    val new_bind = Binding.suffix_name "'" scope
    val new_bindT = scope

    fun dest_def (def, def_imported) =
      let
        val def_imported_meta = def_imported |> Local_Defs.meta_rewrite_rule lthy
        val eqs = def_imported_meta |> Thm.full_prop_of
        val (head, _) = Logic.dest_equals eqs |> fst |> Transform_Misc.behead tm

        (*val _ = if Term.aconv_untyped (head, tm) then () else raise THM("invalid definition", 0, [def])*)
        val Abs t = Term.lambda_name (Binding.name_of new_bind, head) eqs
        val (t_name, eqs') = Term.dest_abs t
        val _ = @{assert} (t_name = Binding.name_of new_bind)

        (*val eqs' = Term.subst_atomic [(head, Free (Binding.name_of new_bind, fastype_of head))] eqs*)
        val (rhs_conv, eqsT, n_args) = Transform_Term.lift_equation monad_consts lthy (Logic.dest_equals eqs') memoizer_opt
        val def_meta' = def |> Local_Defs.meta_rewrite_rule lthy |> Conv.fconv_rule (Conv.arg_conv (rhs_conv lthy))

        val def_meta_simped = def_meta'
          |> Conv.fconv_rule (
               repeat_sweep_conv (K Transform_Term.rewrite_pureapp_beta_conv) lthy
             )
(*
        val eqsT_simped = eqsT
          |> Syntax.check_term lthy
          |> Thm.cterm_of lthy
          |> repeat_sweep_conv (K Transform_Term.rewrite_app_beta_conv) lthy
          |> Thm.full_prop_of |> Logic.dest_equals |> snd
*)

      in ((def_meta_simped, eqsT), n_args) end

    val ((old_defs', new_defs_raw), n_args) =
      map dest_def (old_defs ~~ old_defs_imported)
      |> split_list |>> split_list
      ||> Transform_Misc.the_element

    val new_defs = Syntax.check_props lthy new_defs_raw |> map (fn eqsT => eqsT
      |> Thm.cterm_of lthy
      |> repeat_sweep_conv (K (#rewrite_app_beta_conv monad_consts)) lthy
      |> Thm.full_prop_of |> Logic.dest_equals |> snd)

    (*val _ = map (Pretty.writeln o Syntax.pretty_term @{context} o Thm.full_prop_of) old_defs'*)
    (*val (new_defs, lthy) = Variable.importT_terms new_defs lthy*)

    val (new_info, lthy) = Transform_Misc.add_function new_bind new_defs lthy
    val replay_tac = case old_info_opt of
      NONE => no_tac
    | SOME info => Transform_Tactic.totality_replay_tac info new_info lthy
    val totality_tac =
      replay_tac
      ORELSE (Function_Common.termination_prover_tac false lthy
        THEN Transform_Tactic.my_print_tac "termination by default prover")

    val (new_info, lthy) = Function.prove_termination NONE totality_tac lthy
    val new_def' = new_info |> #simps |> the

    val head' = new_info |> #fs |> the_single
    val headT = Transform_Term.wrap_head monad_consts head' n_args |> Syntax.check_term lthy
    val ((headTC, (_, new_defT)), lthy) = Local_Theory.define ((new_bindT, NoSyn), ((Thm.def_binding new_bindT,[]), headT)) lthy

    val lthy = Transform_Data.commit_dp_info (#monad_name monad_consts) ({
      old_head = tm,
      new_head' = head',
      new_headT = headTC,

      old_defs = old_defs',
      new_def' = new_def',
      new_defT = new_defT
    }) lthy

    val _ = Proof_Display.print_consts true (Position.thread_data ()) lthy (K false) [
      (Binding.name_of new_bind, Term.type_of head'),
      (Binding.name_of new_bindT, Term.type_of headTC)]
  in lthy end

fun gen_dp_monadify prep_term args lthy =
  let
    val (scope, tm, def_thms_opt, mem_locale_opt) = prep_params args lthy
(*
    val memoizer_opt = memoizer_scope_opt |> Option.map (fn memoizer_scope =>
      Syntax.read_term lthy (Long_Name.qualify memoizer_scope Transform_Const.checkmemVN))
    val _ = memoizer_opt |> Option.map (fn memoizer =>
      if Term.aconv_untyped (head_of memoizer, @{term mem_defs.checkmem})
        then () else raise TERM("invalid memoizer", [the memoizer_opt]))
*)
  in
    do_monadify "state" scope tm mem_locale_opt def_thms_opt lthy
  end

val dp_monadify_cmd = gen_dp_monadify Syntax.read_term

fun dp_fun_part1_cmd ((scope, tm_str), (mem_locale_instance_opt)) lthy =
  let
    val scope_name = Binding.name_of scope
    val tm = Syntax.read_term lthy tm_str
    val _ = if is_Free tm then warning ("Free term: " ^ (Syntax.pretty_term lthy tm |> Pretty.string_of)) else ()
    val mem_locale_opt' = Option.map (Locale.check (Proof_Context.theory_of lthy) o (snd o fst)) mem_locale_instance_opt

    val lthy_f = case mem_locale_instance_opt of
      NONE => I
    | SOME ((standard_proof, locale_name), instance) =>
      let
        val locale_name = Locale.check (Proof_Context.theory_of lthy) locale_name
        val instance = map (apsnd (Syntax.read_term lthy)) instance
      in
        dp_interpretation standard_proof locale_name instance scope_name (Transform_Misc.uncurry tm)
      end

    val lthy = lthy_f lthy
   val lthy = Transform_Data.add_tmp_cmd_info (Binding.reset_pos scope, tm, mem_locale_opt') lthy
    
  in
    lthy
  end

fun dp_fun_part2_cmd (heap_name, def_thms_str) lthy =
  let
    val {scope, head=tm, locale=locale_opt, dp_info=dp_info_opt} = Transform_Data.get_last_cmd_info lthy
    val _ = if is_none dp_info_opt then () else raise TERM("already monadified", [tm])
    val def_thms = Attrib.eval_thms lthy def_thms_str
    val heap_typ = Syntax.read_typ
    val lthy = do_monadify heap_name scope tm locale_opt (SOME def_thms) lthy
  in
    lthy
  end

fun dp_correct_cmd lthy =
  let
    val {scope, head=tm, locale=locale_opt, dp_info=dp_info_opt} = Transform_Data.get_last_cmd_info lthy
    val dp_info = case dp_info_opt of SOME x => x | NONE => raise TERM("not yet monadified", [tm])
    val _ = if is_some locale_opt then () else raise TERM("not interpreted yet", [tm])

    val scope_name = Binding.name_of scope
    val consistentDP = Transform_Misc.locale_term lthy scope_name "consistentDP"
    val dpT' = #new_head' dp_info
    val dpT'_curried = dpT' |> Transform_Misc.uncurry
    val goal_pat = consistentDP $ dpT'_curried
    val goal_prop = Syntax.check_term lthy (HOLogic.mk_Trueprop goal_pat)

    val tuple_pat = type_of dpT' |> strip_type |> fst |> length
      |> Name.invent_list [] "a"
      |> map (fn s => Var ((s, 0), TVar ((s, 0), @{sort type})))
      |> HOLogic.mk_tuple
      |> Thm.cterm_of lthy

    val memoized_thm_opt = Transform_Misc.locale_thms lthy scope_name "memoized" |> the_single |> SOME
      handle ERROR msg => (warning msg; NONE)
    val memoized_thm'_opt = memoized_thm_opt
      |> Option.map (Drule.infer_instantiate' lthy [NONE, SOME tuple_pat])

    fun display_thms thm_binds ctxt =
      Proof_Display.print_results true (Position.thread_data ()) ctxt((Thm.theoremK, ""), [thm_binds])

    val crel_thm_name = "crel"
    val memoized_thm_name = "memoized_correct"

    fun afterqed thmss ctxt =
      let
        val [[crel_thm]] = thmss

        val (crel_thm_binds, ctxt) = Local_Theory.note (
          (Binding.qualify_name true scope crel_thm_name, []),
          [crel_thm]
        ) ctxt

        val _ = display_thms crel_thm_binds ctxt

        val ctxt = case memoized_thm'_opt of NONE => ctxt | SOME memoized_thm' => let
          val (memoized_thm_binds, ctxt) = Local_Theory.note (
            (Binding.qualify_name true scope memoized_thm_name, []),
            [(crel_thm RS memoized_thm') |> Local_Defs.unfold lthy @{thms prod.case}]
          ) ctxt

          val _ = display_thms memoized_thm_binds ctxt
        in ctxt end
      in
        ctxt
      end

    val goal = Proof.theorem NONE afterqed [[(goal_prop, [])]] lthy
  in
    goal
  end


end

File ‹Transform_Parser.ML›

structure Transform_Parser = struct

val dp_fun_parser =
  Parse.binding (* name of instantiation and monadified term *)

(*
fun dp_fun binding =
  Transform_Data.update_last_binding binding
*)

val memoizes_parser =
  Parse.name_position (* name of locale, e.g. dp_consistency_rbt *)

val monadifies_parser =
  Parse.term (* term to be monadified *)
  -- Scan.option (
    @{keyword "("}
    |--  Parse.thms1 --| (* optional definitions, ".simps" as default *)
    @{keyword ")"})

val dp_monadify_cmd_parser =
  Scan.optional (Parse.binding --| Parse.$$$ ":") Binding.empty (* optional scope *)
  -- Parse.term (* term to be monadified *)
  -- Scan.option (@{keyword "("} |-- (* optional definitions, ".simps" as default *)
      Parse.thms1
    --| @{keyword ")"})
  -- Scan.option (@{keyword with_memory} |-- Parse.name_position) (* e.g. dp_consistency_rbt *)

val instance =
  (Parse.where_ |-- Parse.and_list1 (Parse.name -- (Parse.$$$ "=" |-- Parse.term))
  || Scan.succeed [])

val dp_fun_part1_parser =
  (Parse.binding --| Parse.$$$ ":") (* scope, e.g., bfT *)
  -- Parse.term (* term to be monadified, e.g., bf *)
  -- Scan.option (@{keyword with_memory}
    |-- Parse.opt_keyword "default_proof" -- Parse.name_position -- instance
    ) (* e.g. dp_consistency_rbt *)

val dp_fun_part2_parser =
  (* monadifies *)
  (@{keyword "("} |-- Parse.name --| @{keyword ")"}) -- Parse.thms1

end

Theory Bottom_Up_Computation

subsection ‹Bottom-Up Computation›

theory Bottom_Up_Computation
  imports "../state_monad/Memory" "../state_monad/DP_CRelVS"
begin

fun iterate_state where
  "iterate_state f [] = State_Monad.return ()" |
  "iterate_state f (x # xs) = do {f x; iterate_state f xs}"

locale iterator_defs =
  fixes cnt :: "'a  bool" and nxt :: "'a  'a"
begin

definition
  "iter_state f 
    wfrec
      {(nxt x, x) | x. cnt x}
      (λ rec x. if cnt x then do {f x; rec (nxt x)} else State_Monad.return ())"

definition
  "iterator_to_list 
    wfrec {(nxt x, x) | x. cnt x} (λ rec x. if cnt x then x # rec (nxt x) else [])
  "

end

locale iterator = iterator_defs +
  fixes sizef :: "'a  nat"
  assumes terminating:
    "finite {x. cnt x}" " x. cnt x  sizef x < sizef (nxt x)"
begin

lemma admissible:
  "adm_wf
      {(nxt x, x) | x. cnt x}
      (λ rec x. if cnt x then do {f x; rec (nxt x)} else State_Monad.return ())"
  unfolding adm_wf_def by auto

lemma wellfounded:
  "wf {(nxt x, x) | x. cnt x}" (is "wf ?S")
proof -
  from terminating have "acyclic ?S"
    by (auto intro: acyclicI_order[where f = sizef])
  moreover have "finite ?S"
    using [[simproc add: finite_Collect]] terminating(1) by auto
  ultimately show ?thesis
    by - (rule finite_acyclic_wf)
qed

lemma iter_state_unfold:
  "iter_state f x = (if cnt x then do {f x; iter_state f (nxt x)} else State_Monad.return ())"
  unfolding iter_state_def by (simp add: wfrec_fixpoint[OF wellfounded admissible])

lemma iterator_to_list_unfold:
  "iterator_to_list x = (if cnt x then x # iterator_to_list (nxt x) else [])"
  unfolding iterator_to_list_def by (simp add: adm_wf_def wfrec_fixpoint[OF wellfounded])

lemma iter_state_iterate_state:
  "iter_state f x = iterate_state f (iterator_to_list x)"
  apply (induction "iterator_to_list x" arbitrary: x)
   apply (simp add: iterator_to_list_unfold split: if_split_asm)
   apply (simp add: iter_state_unfold)
  apply (subst (asm) (3) iterator_to_list_unfold)
  apply (simp split: if_split_asm)
  apply (auto simp: iterator_to_list_unfold iter_state_unfold)
  done

end (* Termination *)

context dp_consistency
begin

context
  includes lifting_syntax
begin

lemma crel_vs_iterate_state:
  "crel_vs (=) () (iterate_state f xs)" if "((=) ===>T R) g f"
proof (induction xs)
  case Nil
  then show ?case
    by (simp; rule crel_vs_return_ext[unfolded Transfer.Rel_def]; simp; fail)
next
  case (Cons x xs)
  have unit_expand: "() = (λ a f. f a) () (λ _. ())" ..
  from Cons show ?case
    by simp
      (rule
        bind_transfer[unfolded rel_fun_def, rule_format, unfolded unit_expand]
        that[unfolded rel_fun_def, rule_format] HOL.refl
      )+
qed

lemma crel_vs_bind_ignore:
  "crel_vs R a (do {d; b})" if "crel_vs R a b" "crel_vs S c d"
proof -
  have unit_expand: "a = (λ a f. f a) () (λ _. a)" ..
  show ?thesis
    by (subst unit_expand)
       (rule bind_transfer[unfolded rel_fun_def, rule_format, unfolded unit_expand] that)+
qed

lemma crel_vs_iterate_and_compute:
  assumes "((=) ===>T R) g f"
  shows "crel_vs R (g x) (do {iterate_state f xs; f x})"
  by (rule
        crel_vs_bind_ignore crel_vs_iterate_state HOL.refl
        assms[unfolded rel_fun_def, rule_format] assms
     )+

end (* Lifting Syntax *)

end (* DP Consistency *)

locale dp_consistency_iterator =
  dp_consistency lookup update + iterator cnt nxt sizef
  for lookup :: "'a  ('b, 'c option) state" and update
    and cnt :: "'a  bool" and nxt and sizef
begin

lemma crel_vs_iter_and_compute:
  assumes "((=) ===>T R) g f"
  shows "crel_vs R (g x) (do {iter_state f y; f x})"
  unfolding iter_state_iterate_state using crel_vs_iterate_and_compute[OF assms] .

lemma consistentDP_iter_and_compute:
  assumes "consistentDP f"
  shows "crel_vs (=) (dp x) (do {iter_state f y; f x})"
  using assms unfolding consistentDP_def by (rule crel_vs_iter_and_compute)

end (* Consistency + Iterator *)

locale dp_consistency_iterator_empty =
  dp_consistency_iterator + dp_consistency_empty
begin

lemma memoized:
  "dp x = fst (run_state (do {iter_state f y; f x}) empty)" if "consistentDP f"
  using consistentDP_iter_and_compute[OF that, of x y]
  by (auto elim!: crel_vs_elim intro: P_empty cmem_empty)

lemma cmem_result:
  "cmem (snd (run_state (do {iter_state f y; f x}) empty))" if "consistentDP f"
  using consistentDP_iter_and_compute[OF that, of x y]
  by (auto elim!: crel_vs_elim intro: P_empty cmem_empty)

end (* Consistency + Iterator *)

lemma dp_consistency_iterator_emptyI:
  "dp_consistency_iterator_empty P lookup update cnt
    nxt sizef empty"
  if "dp_consistency_empty lookup update P empty"
     "iterator cnt nxt sizef"
   for empty
  by (meson
      dp_consistency_empty.axioms(1) dp_consistency_iterator_def
      dp_consistency_iterator_empty_def that
     )

context
  fixes m :: nat ― ‹Width of a row›
    and n :: nat ― ‹Number of rows›
begin

lemma table_iterator_up:
  "iterator
    (λ (x, y). x  n  y  m)
    (λ (x, y). if y < m then (x, y + 1) else (x + 1, 0))
    (λ (x, y). x * (m + 1) + y)"
  by standard auto

lemma table_iterator_down:
  "iterator
    (λ (x, y). x  n  y  m  x > 0)
    (λ (x, y). if y > 0 then (x, y - 1) else (x - 1, m))
    (λ (x, y). (n - x) * (m + 1) + (m - y))"
  using [[simproc add: finite_Collect]]  by standard (auto simp: Suc_diff_le)

end (* Table *)

end (* Theory *)

Theory Bottom_Up_Computation_Heap

theory Bottom_Up_Computation_Heap
  imports "../state_monad/Bottom_Up_Computation" "../heap_monad/DP_CRelVH"
begin

definition (in iterator_defs)
  "iter_heap f 
    wfrec
      {(nxt x, x) | x. cnt x}
      (λ rec x. if cnt x then do {f x; rec (nxt x)} else return ())"

lemma (in iterator) iter_heap_unfold:
  "iter_heap f x = (if cnt x then do {f x; iter_heap f (nxt x)} else return ())"
  unfolding iter_heap_def
  by (simp add: wfrec_fixpoint[OF iterator.wellfounded,OF iterator.intro,OF terminating] adm_wf_def)

locale dp_consistency_iterator_heap =
  dp_consistency_heap P update lookup dp + iterator cnt nxt sizef
  for lookup :: "'a  ('c option) Heap" and update and P dp
    and cnt :: "'a  bool" and nxt and sizef
begin

context
  includes lifting_syntax
begin

term iter_heap

term crel_vs

lemma crel_vs_iterate_state:
  "crel_vs (=) () (iter_heap f x)" if "((=) ===> crel_vs R) g f"
  using wellfounded
proof induction
  case (less x)
  have unit_expand: "() = (λ a f. f a) () (λ _. ())" ..
  from less show ?case
    by (subst iter_heap_unfold)
       (auto intro:
          bind_transfer[unfolded rel_fun_def, rule_format, unfolded unit_expand]
          crel_vs_return_ext[unfolded Transfer.Rel_def] that[unfolded rel_fun_def, rule_format]
       )
qed

lemma crel_vs_bind_ignore:
  "crel_vs R a (do {d; b})" if "crel_vs R a b" "crel_vs S c d"
proof -
  have unit_expand: "a = (λ a f. f a) () (λ _. a)" ..
  show ?thesis
    by (subst unit_expand)
       (rule bind_transfer[unfolded rel_fun_def, rule_format, unfolded unit_expand] that)+
qed

lemma crel_vs_iter_and_compute:
  assumes "((=) ===> crel_vs R) g f"
  shows "crel_vs R (g x) (do {iter_heap f y; f x})"
  by (rule
        crel_vs_bind_ignore crel_vs_iterate_state HOL.refl
        assms[unfolded rel_fun_def, rule_format] assms
     )+

lemma consistent_DP_iter_and_compute:
  assumes "consistentDP f"
  shows "consistentDP (λ x. do {iter_heap f y; f x})"
  apply (rule consistentDP_intro)
  using assms unfolding consistentDP_def Rel_def
  by (rule crel_vs_iter_and_compute)

end (* Lifting Syntax *)

end (* DP Consistency Iterator Heap *)

end (* Theory *)

Theory Solve_Cong

subsection ‹Setup for the Heap Monad›

theory Solve_Cong
  imports Main "HOL-Eisbach.Eisbach"
begin

text ‹Method for solving trivial equalities with congruence reasoning›

named_theorems cong_rules

method solve_cong methods solve =
  rule HOL.refl |
  rule cong_rules; solve_cong solve |
  solve; fail

end

Theory Heap_Main

theory Heap_Main
  imports
    "../heap_monad/Memory_Heap"
    "../transform/Transform_Cmd"
    Bottom_Up_Computation_Heap
    "../util/Solve_Cong"
begin

context includes heap_monad_syntax begin

thm if_cong
lemma ifT_cong:
  assumes "b = c" "c  x = u" "¬c  y = v"
  shows "Heap_Monad_Ext.ifT b x y = Heap_Monad_Ext.ifT c u v"
  unfolding Heap_Monad_Ext.ifT_def
  unfolding return_bind
  using if_cong[OF assms] .

lemma return_app_return_cong:
  assumes "f x = g y"
  shows "f . x = g . y"
  unfolding Heap_Monad_Ext.return_app_return_meta assms ..

lemmas [fundef_cong] =
  return_app_return_cong
  ifT_cong
end
memoize_fun compT: comp monadifies (heap) comp_def
thm compT'.simps
lemma (in dp_consistency_heap) shows compT_transfer[transfer_rule]:
  "crel_vs ((R1 ===>T R2) ===>T (R0 ===>T R1) ===>T (R0 ===>T R2)) comp compT"
  apply memoize_combinator_init
  subgoal premises IH [transfer_rule] by memoize_unfold_defs transfer_prover
  done

memoize_fun mapT: map monadifies (heap) list.map
lemma (in dp_consistency_heap) mapT_transfer[transfer_rule]:
  "crel_vs ((R0 ===>T R1) ===>T list_all2 R0 ===>T list_all2 R1) map mapT"
  apply memoize_combinator_init
  apply (erule list_all2_induct)
  subgoal premises [transfer_rule] by memoize_unfold_defs transfer_prover
  subgoal premises [transfer_rule] by memoize_unfold_defs transfer_prover
  done

memoize_fun foldT: fold monadifies (heap) fold.simps
lemma (in dp_consistency_heap) foldT_transfer[transfer_rule]:
  "crel_vs ((R0 ===>T R1 ===>T R1) ===>T list_all2 R0 ===>T R1 ===>T R1) fold foldT"
  apply memoize_combinator_init
  apply (erule list_all2_induct)
  subgoal premises [transfer_rule] by memoize_unfold_defs transfer_prover
  subgoal premises [transfer_rule] by memoize_unfold_defs transfer_prover
  done

context includes heap_monad_syntax begin

thm map_cong
lemma mapT_cong:
  assumes "xs = ys" "x. xset ys  f x = g x"
  shows "mapT . f . xs = mapT . g . ys"
  unfolding mapT_def 
  unfolding assms(1)
  using assms(2) by (induction ys) (auto simp: Heap_Monad_Ext.return_app_return_meta)

thm fold_cong
lemma foldT_cong:
  assumes "xs = ys" "x. xset ys  f x = g x"
  shows "foldT . f . xs = foldT . g . ys"
  unfolding foldT_def
  unfolding assms(1)
  using assms(2) by (induction ys) (auto simp: Heap_Monad_Ext.return_app_return_meta)

lemma abs_unit_cong:
  (* for lazy checkmem *)
  assumes "x = y"
  shows "(λ_::unit. x) = (λ_. y)"
  using assms ..


lemma arg_cong4:
  "f a b c d = f a' b' c' d'" if "a = a'" "b = b'" "c = c'" "d = d'"
  by (simp add: that)

lemmas [fundef_cong, cong_rules] =
  return_app_return_cong
  ifT_cong
  mapT_cong
  foldT_cong
  abs_unit_cong
lemmas [cong_rules] =
  arg_cong4[where f = heap_mem_defs.checkmem]
  arg_cong2[where f = fun_app_lifted]
end


context dp_consistency_heap begin
context includes lifting_syntax heap_monad_syntax begin

named_theorems dp_match_rule

thm if_cong
lemma ifT_cong2:
  assumes "Rel (=) b c" "c  Rel (crel_vs R) x xT" "¬c  Rel (crel_vs R) y yT"
  shows "Rel (crel_vs R) (if (Wrap b) then x else y) (Heap_Monad_Ext.ifT c xT yT)"
  using assms unfolding Heap_Monad_Ext.ifT_def bind_left_identity Rel_def Wrap_def
  by (auto split: if_split)

lemma mapT_cong2:
  assumes
    "is_equality R"
    "Rel R xs ys"
    "x. xset ys  Rel (crel_vs S) (f x) (fT' x)"
  shows "Rel (crel_vs (list_all2 S)) (App (App map (Wrap f)) (Wrap xs)) (mapT . fT' . ys)"
  unfolding mapT_def
  unfolding Heap_Monad_Ext.return_app_return_meta
  unfolding assms(2)[unfolded Rel_def assms(1)[unfolded is_equality_def]]
  using assms(3)
  unfolding Rel_def Wrap_def App_def
  apply (induction ys)
  subgoal premises by (memoize_unfold_defs (heap) map) transfer_prover
  subgoal premises prems for a ys
  apply (memoize_unfold_defs (heap) map)
    apply (unfold Heap_Monad_Ext.return_app_return_meta Wrap_App_Wrap)
    supply [transfer_rule] =
      prems(2)[OF list.set_intros(1)]
      prems(1)[OF prems(2)[OF list.set_intros(2)], simplified]
    by transfer_prover
  done

lemma foldT_cong2:
  assumes
    "is_equality R"
    "Rel R xs ys"
    "x. xset ys  Rel (crel_vs (S ===> crel_vs S)) (f x) (fT' x)"
  shows
    "Rel (crel_vs (S ===> crel_vs S)) (fold f xs) (foldT . fT' . ys)"
  unfolding foldT_def
  unfolding Heap_Monad_Ext.return_app_return_meta
  unfolding assms(2)[unfolded Rel_def assms(1)[unfolded is_equality_def]]
  using assms(3)
  unfolding Rel_def
  apply (induction ys)
  subgoal premises by (memoize_unfold_defs (heap) fold) transfer_prover
  subgoal premises prems for a ys
    apply (memoize_unfold_defs (heap) fold)
    apply (unfold Heap_Monad_Ext.return_app_return_meta Wrap_App_Wrap)
    supply [transfer_rule] =
      prems(2)[OF list.set_intros(1)]
      prems(1)[OF prems(2)[OF list.set_intros(2)], simplified]
    by transfer_prover
  done

lemma refl2:
  "is_equality R  Rel R x x"
  unfolding is_equality_def Rel_def by simp

lemma rel_fun2:
  assumes "is_equality R0" "x. Rel R1 (f x) (g x)"
  shows "Rel (rel_fun R0 R1) f g"
  using assms unfolding is_equality_def Rel_def by auto

lemma crel_vs_return_app_return:
  assumes "Rel R (f x) (g x)"
  shows "Rel R (App (Wrap f) (Wrap x)) (g . x)"
  using assms unfolding Heap_Monad_Ext.return_app_return_meta Wrap_App_Wrap .

thm option.case_cong[no_vars]
lemma option_case_cong':
"Rel (=) option' option 
(option = None  Rel R f1 g1) 
(x2. option = Some x2  Rel R (f2 x2) (g2 x2)) 
Rel R (case option' of None  f1 | Some x2  f2 x2)
(case option of None  g1 | Some x2  g2 x2)"
  unfolding Rel_def by (auto split: option.split)

thm prod.case_cong[no_vars]
lemma prod_case_cong': fixes prod prod' shows
"Rel (=) prod prod' 
(x1 x2. prod' = (x1, x2)  Rel R (f x1 x2) (g x1 x2)) 
Rel R (case prod of (x1, x2)  f x1 x2)
(case prod' of (x1, x2)  g x1 x2)"
  unfolding Rel_def by (auto split: prod.splits)

lemmas [dp_match_rule] = prod_case_cong' option_case_cong'


lemmas [dp_match_rule] =
  crel_vs_return_app_return

lemmas [dp_match_rule] =
  mapT_cong2
  foldT_cong2
  ifT_cong2

lemmas [dp_match_rule] =
  crel_vs_return
  crel_vs_fun_app
  refl2
  rel_fun2

(*
lemmas [dp_match_rule] =
  crel_vs_checkmem_tupled
*)

end (* context lifting_syntax *)
end (* context dp_consistency *)

subsubsection ‹More Heap›

lemma execute_heap_ofD:
  "heap_of c h = h'" if "execute c h = Some (v, h')"
  using that by auto

lemma execute_result_ofD:
  "result_of c h = v" if "execute c h = Some (v, h')"
  using that by auto

locale heap_correct_init_defs =
  fixes P :: "'m  heap  bool"
    and lookup :: "'m  'k  'v option Heap"
    and update :: "'m  'k  'v  unit Heap"
begin

definition map_of_heap' where
  "map_of_heap' m heap k = fst (the (execute (lookup m k) heap))"

end

locale heap_correct_init_inv = heap_correct_init_defs +
  assumes lookup_inv: " m. lift_p (P m) (lookup m k)"
  assumes update_inv: " m. lift_p (P m) (update m k v)"

locale heap_correct_init =
  heap_correct_init_inv +
  assumes lookup_correct:
      " a. P a m  map_of_heap' a (snd (the (execute (lookup a k) m))) m (map_of_heap' a m)"
  and update_correct:
      " a. P a m 
        map_of_heap' a (snd (the (execute (update a k v) m))) m (map_of_heap' a m)(k  v)"
begin

end

locale dp_consistency_heap_init = heap_correct_init _ lookup for lookup :: "'m  'k  'v option Heap"  +
  fixes dp :: "'k  'v"
  fixes init :: "'m Heap"
  assumes success: "success init Heap.empty"
  assumes empty_correct:
    " empty heap. execute init Heap.empty = Some (empty, heap)  map_of_heap' empty heap m Map.empty"
    and P_empty: " empty heap. execute init Heap.empty = Some (empty, heap)  P empty heap"
begin

definition "init_mem = result_of init Heap.empty"

sublocale dp_consistency_heap
  where P="P init_mem"
    and lookup="lookup init_mem"
    and update="update init_mem"
  apply standard
       apply (rule lookup_inv[of init_mem])
      apply (rule update_inv[of init_mem])
  subgoal
    unfolding heap_mem_defs.map_of_heap_def
    by (rule lookup_correct[of init_mem, unfolded map_of_heap'_def])
  subgoal
    unfolding heap_mem_defs.map_of_heap_def
    by (rule update_correct[of init_mem, unfolded map_of_heap'_def])
  done

interpretation consistent: dp_consistency_heap_empty
  where P="P init_mem"
    and lookup="lookup init_mem"
    and update="update init_mem"
    and empty= "heap_of init Heap.empty"
  apply standard
  subgoal
    apply (rule successE[OF success])
    apply (frule empty_correct)
    unfolding heap_mem_defs.map_of_heap_def init_mem_def map_of_heap'_def
    by simp
  subgoal
    apply (rule successE[OF success])
    apply (frule P_empty)
    unfolding init_mem_def
    by simp
  done

lemma memoized_empty:
  "dp x = result_of (init  (λmem. dpT mem x)) Heap.empty"
  if "consistentDP (dpT (result_of init Heap.empty))"
  by (simp add: execute_bind_success consistent.memoized[OF that(1)] success)

end

locale dp_consistency_heap_init' = heap_correct_init _ lookup for lookup :: "'m  'k  'v option Heap"  +
  fixes dp :: "'k  'v"
  fixes init :: "'m Heap"
  assumes success: "success init Heap.empty"
  assumes empty_correct:
    " empty heap. execute init Heap.empty = Some (empty, heap)  map_of_heap' empty heap m Map.empty"
    and P_empty: " empty heap. execute init Heap.empty = Some (empty, heap)  P empty heap"
begin

sublocale dp_consistency_heap
  where P="P init_mem"
    and lookup="lookup init_mem"
    and update="update init_mem"
  apply standard
       apply (rule lookup_inv[of init_mem])
      apply (rule update_inv[of init_mem])
  subgoal
    unfolding heap_mem_defs.map_of_heap_def
    by (rule lookup_correct[of init_mem, unfolded map_of_heap'_def])
  subgoal
    unfolding heap_mem_defs.map_of_heap_def
    by (rule update_correct[of init_mem, unfolded map_of_heap'_def])
  done

definition "init_mem = result_of init Heap.empty"

interpretation consistent: dp_consistency_heap_empty
  where P="P init_mem"
    and lookup="lookup init_mem"
    and update="update init_mem"
    and empty= "heap_of init Heap.empty"
  apply standard
  subgoal
    apply (rule successE[OF success])
    apply (frule empty_correct)
    unfolding heap_mem_defs.map_of_heap_def init_mem_def map_of_heap'_def
    by simp
  subgoal
    apply (rule successE[OF success])
    apply (frule P_empty)
    unfolding init_mem_def
    by simp
  done

lemma memoized_empty:
  "dp x = result_of (init  (λmem. dpT mem x)) Heap.empty"
  if "consistentDP init_mem (dpT (result_of init Heap.empty))"
  by (simp add: execute_bind_success consistent.memoized[OF that(1)] success)

end

locale dp_consistency_new =
  fixes dp :: "'k  'v"
  fixes P :: "'m  heap  bool"
    and lookup :: "'m  'k  'v option Heap"
    and update :: "'m  'k  'v  unit Heap"
    and init
  assumes
    success: "success init Heap.empty"
  assumes
    inv_init: " empty heap. execute init Heap.empty = Some (empty, heap)  P empty heap"
  assumes consistent:
    " empty heap. execute init Heap.empty = Some (empty, heap)
     dp_consistency_heap_empty (P empty) (update empty) (lookup empty) heap"
begin

sublocale dp_consistency_heap_empty
  where P="P (result_of init Heap.empty)"
    and lookup="lookup (result_of init Heap.empty)"
    and update="update (result_of init Heap.empty)"
    and empty= "heap_of init Heap.empty"
  using success by (auto 4 3 intro: consistent successE) (* Extract Theorem *)

lemma memoized_empty:
  "dp x = result_of (init  (λmem. dpT mem x)) Heap.empty"
  if "consistentDP (dpT (result_of init Heap.empty))"
  by (simp add: execute_bind_success memoized[OF that(1)] success)

end

locale dp_consistency_new' =
  fixes dp :: "'k  'v"
  fixes P :: "'m  heap  bool"
    and lookup :: "'m  'k  'v option Heap"
    and update :: "'m  'k  'v  unit Heap"
    and init
    and mem :: 'm
  assumes mem_is_init: "mem = result_of init Heap.empty"
  assumes
    success: "success init Heap.empty"
  assumes
    inv_init: " empty heap. execute init Heap.empty = Some (empty, heap)  P empty heap"
  assumes consistent:
    " empty heap. execute init Heap.empty = Some (empty, heap)
     dp_consistency_heap_empty (P empty) (update empty) (lookup empty) heap"
begin

sublocale dp_consistency_heap_empty
  where P="P mem"
    and lookup="lookup mem"
    and update="update mem"
    and empty= "heap_of init Heap.empty"
  unfolding mem_is_init
  using success by (auto 4 3 intro: consistent successE) (* Extract Theorem *)

lemma memoized_empty:
  "dp x = result_of (init  (λmem. dpT mem x)) Heap.empty"
  if "consistentDP (dpT (result_of init Heap.empty))"
  by (simp add: execute_bind_success memoized[OF that(1)] success)

end

locale dp_consistency_heap_array_new' =
  fixes size :: nat
    and to_index :: "('k :: heap)  nat"
    and mem :: "('v::heap) option array"
    and dp :: "'k  'v::heap"
  assumes mem_is_init: "mem = result_of (mem_empty size) Heap.empty"
  assumes injective: "injective size to_index"
begin

sublocale dp_consistency_new'
  where P      = "λ mem heap. Array.length heap mem = size"
    and lookup = "λ mem. mem_lookup size to_index mem"
    and update = "λ mem. mem_update size to_index mem"
    and init   = "mem_empty size"
    and mem    = mem
  apply (rule dp_consistency_new'.intro)
  subgoal
    by (rule mem_is_init)
  subgoal
    by (rule success_empty)
  subgoal for empty heap
    using length_mem_empty by (metis fst_conv option.sel snd_conv)
  subgoal
    apply (frule execute_heap_ofD[symmetric])
    apply (frule execute_result_ofD[symmetric])
    apply simp
    apply (rule array_consistentI[OF injective HOL.refl])
    done
  done

thm memoized_empty

end

locale dp_consistency_heap_array_new =
  fixes size :: nat
    and to_index :: "('k :: heap)  nat"
    and dp :: "'k  'v::heap"
  assumes injective: "injective size to_index"
begin

sublocale dp_consistency_new
  where P      = "λ mem heap. Array.length heap mem = size"
    and lookup = "λ mem. mem_lookup size to_index mem"
    and update = "λ mem. mem_update size to_index mem"
    and init   = "mem_empty size"
  apply (rule dp_consistency_new.intro)
  subgoal
    by (rule success_empty)
  subgoal for empty heap
    using length_mem_empty by (metis fst_conv option.sel snd_conv)
  subgoal
    apply (frule execute_heap_ofD[symmetric])
    apply (frule execute_result_ofD[symmetric])
    apply simp
    apply (rule array_consistentI[OF injective HOL.refl])
    done
  done

thm memoized_empty

end

locale dp_consistency_heap_array =
  fixes size :: nat
    and to_index :: "('k :: heap)  nat"
    and dp :: "'k  'v::heap"
  assumes injective: "injective size to_index"
begin

sublocale dp_consistency_heap_init
  where P="λmem heap. Array.length heap mem = size"
    and lookup="λ mem. mem_lookup size to_index mem"
    and update="λ mem. mem_update size to_index mem"
    and init="mem_empty size"
  apply standard
  subgoal lookup_inv
    unfolding lift_p_def mem_lookup_def by (simp add: Let_def execute_simps)
  subgoal update_inv
    unfolding State_Heap.lift_p_def mem_update_def by (simp add: Let_def execute_simps)
  subgoal for k heap
    unfolding heap_correct_init_defs.map_of_heap'_def map_le_def mem_lookup_def
    by (auto simp: execute_simps Let_def split: if_split_asm)
  subgoal for heap k
    unfolding heap_correct_init_defs.map_of_heap'_def map_le_def mem_lookup_def mem_update_def
    apply (auto simp: execute_simps Let_def length_def split: if_split_asm)
    apply (subst (asm) nth_list_update_neq)
    using injective[unfolded injective_def] apply auto
    done
  subgoal
    by (rule success_empty)
  subgoal for empty' heap
    unfolding heap_correct_init_defs.map_of_heap'_def mem_lookup_def
    by (auto intro!: map_emptyI simp: Let_def ) (metis fst_conv option.sel snd_conv nth_mem_empty)
  subgoal for empty' heap
    unfolding heap_correct_init_defs.map_of_heap'_def mem_lookup_def map_le_def
    using length_mem_empty by (metis fst_conv option.sel snd_conv)
  done

end


locale dp_consistency_heap_array_pair' =
  fixes size :: nat
  fixes key1 :: "'k  ('k1 :: heap)" and key2 :: "'k  'k2 :: heap"
    and to_index :: "'k2  nat"
    and dp :: "'k  'v::heap"
    and k1 k2 :: "'k1"
    and mem :: "('k1 ref ×
             'k1 ref ×
             'v option array ref ×
             'v option array ref)"
  assumes mem_is_init: "mem = result_of (init_state size k1 k2) Heap.empty"
  assumes injective: "injective size to_index"
      and keys_injective: "k k'. key1 k = key1 k'  key2 k = key2 k'  k = k'"
      and keys_neq: "k1  k2"
begin

definition
  "inv_pair' = (λ (k_ref1, k_ref2, m_ref1, m_ref2).
      pair_mem_defs.inv_pair (lookup1 size to_index m_ref1)
        (lookup2 size to_index m_ref2) (get_k1 k_ref1)
        (get_k2 k_ref2)
        (inv_pair_weak size m_ref1 m_ref2 k_ref1 k_ref2) key1 key2)"

sublocale dp_consistency_new'
  where P=inv_pair'
    and lookup="λ (k_ref1, k_ref2, m_ref1, m_ref2).
      lookup_pair size to_index key1 key2 m_ref1 m_ref2 k_ref1 k_ref2"
    and update="λ (k_ref1, k_ref2, m_ref1, m_ref2).
      update_pair size to_index key1 key2 m_ref1 m_ref2 k_ref1 k_ref2"
    and init="init_state size k1 k2"
  apply (rule dp_consistency_new'.intro)
  subgoal
    by (rule mem_is_init)
  subgoal
    by (rule succes_init_state)
  subgoal for empty heap
    unfolding inv_pair'_def
    apply safe
    apply (rule init_state_inv')
        apply (rule injective)
       apply (erule init_state_distinct)
      apply (rule keys_injective)
     apply assumption
    apply (rule keys_neq)
    done
  apply safe
  unfolding inv_pair'_def
  apply simp
  apply (rule consistent_empty_pairI)
      apply (rule injective)
     apply (erule init_state_distinct)
    apply (rule keys_injective)
   apply assumption
  apply (rule keys_neq)
  done

end

locale dp_consistency_heap_array_pair_iterator =
  dp_consistency_heap_array_pair' where dp = dp + iterator where cnt = cnt
  for dp :: "'k  'v::heap" and cnt :: "'k  bool"
begin

sublocale dp_consistency_iterator_heap
  where P = "inv_pair' mem"
  and update = "(case mem of
  (k_ref1, k_ref2, m_ref1, m_ref2) 
    update_pair size to_index key1 key2 m_ref1 m_ref2 k_ref1 k_ref2)"
  and lookup = "(case mem of
  (k_ref1, k_ref2, m_ref1, m_ref2) 
    lookup_pair size to_index key1 key2 m_ref1 m_ref2 k_ref1 k_ref2)"
  ..

end


locale dp_consistency_heap_array_pair =
  fixes size :: nat
  fixes key1 :: "'k  ('k1 :: heap)" and key2 :: "'k  'k2 :: heap"
    and to_index :: "'k2  nat"
    and dp :: "'k  'v::heap"
    and k1 k2 :: "'k1"
  assumes injective: "injective size to_index"
      and keys_injective: "k k'. key1 k = key1 k'  key2 k = key2 k'  k = k'"
      and keys_neq: "k1  k2"
begin

definition
  "inv_pair' = (λ (k_ref1, k_ref2, m_ref1, m_ref2).
      pair_mem_defs.inv_pair (lookup1 size to_index m_ref1)
        (lookup2 size to_index m_ref2) (get_k1 k_ref1)
        (get_k2 k_ref2)
        (inv_pair_weak size m_ref1 m_ref2 k_ref1 k_ref2) key1 key2)"

sublocale dp_consistency_new
  where P=inv_pair'
    and lookup="λ (k_ref1, k_ref2, m_ref1, m_ref2).
      lookup_pair size to_index key1 key2 m_ref1 m_ref2 k_ref1 k_ref2"
    and update="λ (k_ref1, k_ref2, m_ref1, m_ref2).
      update_pair size to_index key1 key2 m_ref1 m_ref2 k_ref1 k_ref2"
    and init="init_state size k1 k2"
  apply (rule dp_consistency_new.intro)
  subgoal
    by (rule succes_init_state)
  subgoal for empty heap
    unfolding inv_pair'_def
    apply safe
    apply (rule init_state_inv')
        apply (rule injective)
       apply (erule init_state_distinct)
      apply (rule keys_injective)
     apply assumption
    apply (rule keys_neq)
    done
  apply safe
  unfolding inv_pair'_def
  apply simp
  apply (rule consistent_empty_pairI)
      apply (rule injective)
     apply (erule init_state_distinct)
    apply (rule keys_injective)
   apply assumption
  apply (rule keys_neq)
  done

end

subsubsection ‹Code Setup›
lemmas [code_unfold] = heap_mem_defs.checkmem_checkmem'[symmetric]
lemmas [code] =
  heap_mem_defs.checkmem'_def
  Heap_Main.mapT_def

end (* theory *)

Theory State_Main

subsection ‹Setup for the State Monad›

theory State_Main
  imports
    "../transform/Transform_Cmd"
    Memory
begin

context includes state_monad_syntax begin

thm if_cong
lemma ifT_cong:
  assumes "b = c" "c  x = u" "¬c  y = v"
  shows "State_Monad_Ext.ifT b x y = State_Monad_Ext.ifT c u v"
  unfolding State_Monad_Ext.ifT_def
  unfolding bind_left_identity
  using if_cong[OF assms] .

lemma return_app_return_cong:
  assumes "f x = g y"
  shows "f . x = g . y"
  unfolding State_Monad_Ext.return_app_return_meta assms ..

lemmas [fundef_cong] =
  return_app_return_cong
  ifT_cong
end

memoize_fun compT: comp monadifies (state) comp_def
lemma (in dp_consistency) compT_transfer[transfer_rule]:
  "crel_vs ((R1 ===>T R2) ===>T (R0 ===>T R1) ===>T (R0 ===>T R2)) comp compT"
  apply memoize_combinator_init
  subgoal premises IH [transfer_rule] by memoize_unfold_defs transfer_prover
  done

memoize_fun map