Session Refine_Imperative_HOL

Theory Concl_Pres_Clarification

theory Concl_Pres_Clarification
imports Main
begin
  text ‹Clarification and clarsimp that preserve the structure of 
    the subgoal's conclusion, i.e., neither solve it, nor swap it 
    with premises, as, eg, @{thm [source] notE} does.
    ›

  ML local 
      open Classical

      fun is_cp_brl (is_elim,thm) = let
        val prems = Thm.prems_of thm
        val nprems = length prems
        val concl = Thm.concl_of thm
      in
        (if is_elim then nprems=2 else nprems=1) andalso let
          val lprem_concl = hd (rev prems)
            |> Logic.strip_assums_concl
        in
          concl aconv lprem_concl
        end
      end

      val not_elim = @{thm notE}
      val hyp_subst_tacs = [Hypsubst.hyp_subst_tac]

      fun eq_contr_tac ctxt i = ematch_tac ctxt [not_elim] i THEN eq_assume_tac i;
      fun eq_assume_contr_tac ctxt = eq_assume_tac ORELSE' eq_contr_tac ctxt;

      fun cp_bimatch_from_nets_tac ctxt =
        biresolution_from_nets_tac ctxt (order_list o filter (is_cp_brl o snd)) true;


    in
      fun cp_clarify_step_tac ctxt =
        let val {safep_netpair, ...} = (rep_cs o claset_of) ctxt in
          appSWrappers ctxt
           (FIRST'
             [eq_assume_contr_tac ctxt,
              FIRST' (map (fn tac => tac ctxt) hyp_subst_tacs),
              cp_bimatch_from_nets_tac ctxt safep_netpair
              ])
        end;
      
        fun cp_clarify_tac ctxt = SELECT_GOAL (REPEAT_DETERM (cp_clarify_step_tac ctxt 1));

        fun cp_clarsimp_tac ctxt =
          Simplifier.safe_asm_full_simp_tac ctxt THEN_ALL_NEW
          cp_clarify_tac (addSss ctxt);


    end

  method_setup cp_clarify = (Classical.cla_method' (CHANGED_PROP oo cp_clarify_tac))

  method_setup cp_clarsimp = let
    fun clasimp_method' tac =
      Method.sections clasimp_modifiers >> K (SIMPLE_METHOD' o tac);
  in
    clasimp_method' (CHANGED_PROP oo cp_clarsimp_tac)
  end



end

Theory Named_Theorems_Rev

theory Named_Theorems_Rev 
imports Main
keywords "named_theorems_rev" :: thy_decl
begin

ML signature NAMED_THEOREMS_REV =
sig
  val member: Proof.context -> string -> thm -> bool
  val get: Proof.context -> string -> thm list
  val add_thm: string -> thm -> Context.generic -> Context.generic
  val del_thm: string -> thm -> Context.generic -> Context.generic
  val add: string -> attribute
  val del: string -> attribute
  val check: Proof.context -> string * Position.T -> string
  val declare: binding -> string -> local_theory -> string * local_theory
end;

structure Named_Theorems_Rev: NAMED_THEOREMS_REV =
struct

(* context data *)

structure Data = Generic_Data
(
  type T = thm Item_Net.T Symtab.table;
  val empty: T = Symtab.empty;
  val extend = I;
  val merge : T * T -> T = Symtab.join (K Item_Net.merge);
);

fun new_entry name =
  Data.map (fn data =>
    if Symtab.defined data name
    then error ("Duplicate declaration of named theorems: " ^ quote name)
    else Symtab.update (name, Thm.full_rules) data);

fun undeclared name = "Undeclared named theorems " ^ quote name;

fun the_entry context name =
  (case Symtab.lookup (Data.get context) name of
    NONE => error (undeclared name)
  | SOME entry => entry);

fun map_entry name f context =
  (the_entry context name; Data.map (Symtab.map_entry name f) context);


(* maintain content *)

fun member ctxt = Item_Net.member o the_entry (Context.Proof ctxt);

fun content context = Item_Net.content o the_entry context;
val get = content o Context.Proof;

fun add_thm name = map_entry name o Item_Net.update;
fun del_thm name = map_entry name o Item_Net.remove;

val add = Thm.declaration_attribute o add_thm;
val del = Thm.declaration_attribute o del_thm;


(* check *)

fun check ctxt (xname, pos) =
  let
    val context = Context.Proof ctxt;
    val fact_ref = Facts.Named ((xname, Position.none), NONE);
    fun err () = error (undeclared xname ^ Position.here pos);
  in
    (case try (Proof_Context.get_fact_generic context) fact_ref of
      SOME (SOME name, _) => if can (the_entry context) name then name else err ()
    | _ => err ())
  end;


(* declaration *)

fun declare binding descr lthy =
  let
    val name = Local_Theory.full_name lthy binding;
    val description =
      "declaration of " ^ (if descr = "" then Binding.name_of binding ^ " rules" else descr);
    val lthy' = lthy
      |> Local_Theory.background_theory (Context.theory_map (new_entry name))
      |> Local_Theory.map_contexts (K (Context.proof_map (new_entry name)))
      |> Local_Theory.add_thms_dynamic (binding, fn context => content context name)
      |> Attrib.local_setup binding (Attrib.add_del (add name) (del name)) description
  in (name, lthy') end;

val _ =
  Outer_Syntax.local_theory @{command_keyword named_theorems_rev}
    "declare named collection of theorems"
    (Parse.and_list1 (Parse.binding -- Scan.optional Parse.text "") >>
      fold (fn (b, descr) => snd o declare b descr));


(* ML antiquotation *)

val _ = Theory.setup
  (ML_Antiquotation.inline @{binding named_theorems_rev}
    (Args.context -- Scan.lift Args.name_position >>
      (fn (ctxt, name) => ML_Syntax.print_string (check ctxt name))));

end;

end

Theory Pf_Add

theory Pf_Add
imports Automatic_Refinement.Misc "HOL-Library.Monad_Syntax"
begin

lemma fun_ordI:
  assumes "x. ord (f x) (g x)"
  shows "fun_ord ord f g"
  using assms unfolding fun_ord_def by auto

lemma fun_ordD:
  assumes "fun_ord ord f g"
  shows "ord (f x) (g x)"
  using assms unfolding fun_ord_def by auto

lemma mono_fun_fun_cnv:
  assumes "d. monotone (fun_ord ordA) ordB (λx. F x d)"
  shows "monotone (fun_ord ordA) (fun_ord ordB) F"
  apply rule
  apply (rule fun_ordI)
  using assms
  by (blast dest: monotoneD)

lemma fun_lub_Sup[simp]: "fun_lub Sup = Sup"
  unfolding fun_lub_def[abs_def]
  by (clarsimp intro!: ext; metis image_def)

lemma fun_ord_le[simp]: "fun_ord (≤) = (≤)"
  unfolding fun_ord_def[abs_def]
  by (auto intro!: ext simp: le_fun_def)

end

Theory Pf_Mono_Prover

section ‹Interfacing Partial-Function's Monotonicity Prover›
theory Pf_Mono_Prover
imports Separation_Logic_Imperative_HOL.Sep_Main
begin
  (* TODO: Adjust mono-prover accordingly  *)
  (* Wraps mono-prover of partial-function to erase premises. 
    This is a workaround for mono_tac, which does not accept premises if the case-split rule is applied. *)

ML structure Pf_Mono_Prover = struct
    fun mono_tac ctxt = (REPEAT o eresolve_tac ctxt @{thms thin_rl})
      THEN' Partial_Function.mono_tac ctxt
  end

method_setup pf_mono = ‹Scan.succeed (fn ctxt => SIMPLE_METHOD' (Pf_Mono_Prover.mono_tac ctxt)) ‹Monotonicity prover of the partial function package›

end

Theory PO_Normalizer

theory PO_Normalizer
imports Automatic_Refinement.Refine_Lib
begin
  ML_file ‹PO_Normalizer.ML›
end

File ‹PO_Normalizer.ML›

signature PO_NORMALIZER = sig 
  type norm_set = {
    trans_rules : thm list, (* Transitivity rules, of form "R x y ⟹ R y z ⟹ R x z" *)
    cong_rules : thm list, (* Congruence rules, of form: "⟦ R1 a1 b1; ... ⟧ ⟹ R (f a1 ...) (f b1 ...)" *)
    norm_rules : thm list, (* Normalization rules, of form: "R f g" *)
    refl_rules : thm list (* Reflexivity rules, of form: "R x x"*)
  }

  val gen_norm_tac : norm_set -> Proof.context -> tactic'
  val gen_norm_rule : thm list -> norm_set -> Proof.context -> thm -> thm
end

structure PO_Normalizer : PO_NORMALIZER = struct
  type norm_set = {
    trans_rules : thm list, (* Transitivity rules, of form "R x y ⟹ R y z ⟹ R x z" *)
    cong_rules : thm list, (* Congruence rules, of form: "⟦ R1 a1 b1; ... ⟧ ⟹ R (f a1 ...) (f b1 ...)" *)
    norm_rules : thm list, (* Normalization rules, of form: "R f g" *)
    refl_rules : thm list (* Reflexivity rules, of form: "R x x"*)
  }

  val cfg_trace = 
    Attrib.setup_config_bool @{binding "norm_rel_trace"} (K false)

  val cfg_depth_limit = 
    Attrib.setup_config_int @{binding "norm_rel_depth_limit"} (K ~1)


  fun gen_norm_tac {trans_rules, cong_rules, norm_rules, refl_rules} ctxt = let
    val do_trace = Config.get ctxt cfg_trace

    fun trace_tac str _ st = if do_trace then 
      (tracing str; Seq.single st)
    else Seq.single st
    val print_tac = if do_trace then print_tac else (K (K all_tac))

    val depth_limit = Config.get ctxt cfg_depth_limit

    fun norm_tac d ctxt i st = let
      val transr_tac = resolve_tac ctxt trans_rules
      val congr_tac = resolve_tac ctxt cong_rules
      val rewrr_tac = resolve_tac ctxt norm_rules
      val solver_tac = resolve_tac ctxt refl_rules

      val cong_tac = (transr_tac THEN' (
        (congr_tac THEN' trace_tac "cong") THEN_ALL_NEW_FWD norm_tac (d+1) ctxt))
      val rewr_tac = (transr_tac THEN' (SOLVED' rewrr_tac) 
        THEN' trace_tac "rewr" THEN' transr_tac THEN' norm_tac (d+1) ctxt)
      val solve_tac = SOLVED' solver_tac THEN' (K (print_tac ctxt "solved"))
    in 
      if depth_limit>=0 andalso d>depth_limit then
        (K (print_tac ctxt "Norm-Depth limit reached"))
        THEN' solve_tac
      else
        (K (print_tac ctxt ("Normalizing ("^ string_of_int d  ^")"))) THEN'
        (TRY o cong_tac)
        THEN' (TRY o rewr_tac)
        THEN' solve_tac
    end i st
  in norm_tac 1 ctxt end

  fun gen_norm_rule init_thms norm_set ctxt thm = let
    val orig_ctxt = ctxt
    val ((_,[thm]),ctxt) = Variable.import false [thm] ctxt

    fun tac ctxt = 
      eresolve_tac ctxt init_thms
      THEN' gen_norm_tac norm_set ctxt

    val concl = Thm.concl_of thm
    val x = Var (("x",0),@{typ prop})
    val t = @{mk_term "PROP ?concl  PROP ?x"}

    val thm2 = Goal.prove ctxt [] [] t 
      (fn {context = ctxt, ...} => tac ctxt 1)
    
    val thm = thm RS thm2 
    val [thm] = Variable.export ctxt orig_ctxt [thm]
  in
    thm
  end
  
end

Theory Sepref_Misc

theory Sepref_Misc
imports 
  Refine_Monadic.Refine_Monadic
  PO_Normalizer
  "List-Index.List_Index"
  Separation_Logic_Imperative_HOL.Sep_Main
  Named_Theorems_Rev
  "HOL-Eisbach.Eisbach"
  Separation_Logic_Imperative_HOL.Array_Blit
begin

  hide_const (open) CONSTRAINT

  (* Additions for List_Index *)  
  lemma index_of_last_distinct[simp]: 
    "distinct l  index l (last l) = length l - 1"  
    apply (cases l rule: rev_cases)
    apply (auto simp: index_append)
    done

  lemma index_eqlen_conv[simp]: "index l x = length l  xset l"
    by (auto simp: index_size_conv)


  subsection ‹Iterated Curry and Uncurry›    


  text ‹Uncurry0›  
  definition "uncurry0 c  λ_::unit. c"
  definition curry0 :: "(unit  'a)  'a" where "curry0 f = f ()"
  lemma uncurry0_apply[simp]: "uncurry0 c x = c" by (simp add: uncurry0_def)

  lemma curry_uncurry0_id[simp]: "curry0 (uncurry0 f) = f" by (simp add: curry0_def)
  lemma uncurry_curry0_id[simp]: "uncurry0 (curry0 g) = g" by (auto simp: curry0_def)
  lemma param_uncurry0[param]: "(uncurry0,uncurry0)  A  (unit_relA)" by auto
    
  text ‹Abbreviations for higher-order uncurries›    
  abbreviation "uncurry2 f  uncurry (uncurry f)"
  abbreviation "curry2 f  curry (curry f)"
  abbreviation "uncurry3 f  uncurry (uncurry2 f)"
  abbreviation "curry3 f  curry (curry2 f)"
  abbreviation "uncurry4 f  uncurry (uncurry3 f)"
  abbreviation "curry4 f  curry (curry3 f)"
  abbreviation "uncurry5 f  uncurry (uncurry4 f)"
  abbreviation "curry5 f  curry (curry4 f)"
  abbreviation "uncurry6 f  uncurry (uncurry5 f)"
  abbreviation "curry6 f  curry (curry5 f)"
  abbreviation "uncurry7 f  uncurry (uncurry6 f)"
  abbreviation "curry7 f  curry (curry6 f)"
  abbreviation "uncurry8 f  uncurry (uncurry7 f)"
  abbreviation "curry8 f  curry (curry7 f)"
  abbreviation "uncurry9 f  uncurry (uncurry8 f)"
  abbreviation "curry9 f  curry (curry8 f)"

    
    
  lemma fold_partial_uncurry: "uncurry (λ(ps, cf). f ps cf) = uncurry2 f" by auto

  lemma curry_shl: 
    "g f. (g  curry f)  (uncurry g  f)"
    "g f. (g  curry0 f)  (uncurry0 g  f)"
    by (atomize (full); auto)+
  
  lemma curry_shr: 
    "f g. (curry f  g)  (f  uncurry g)"
    "f g. (curry0 f  g)  (f  uncurry0 g)"
    by (atomize (full); auto)+
  
  lemmas uncurry_shl = curry_shr[symmetric]  
  lemmas uncurry_shr = curry_shl[symmetric]  
  
end

Theory Structured_Apply

section ‹Subgoal Structure for Apply Scripts›
theory Structured_Apply
imports Main
keywords 
  "focus" "solved" "applyS" "apply1" "applyF" "applyT" :: prf_script
begin

text ‹This theory provides some variants of the apply command 
  that make the proof structure explicit. See below for examples.

  Compared to the @{command subgoal}-command, these set of commands is more lightweight,
  and fully supports schematic variables.
›

(*
  focus, focus <method text>, applyF <method text>
    Focus on current subgoal, and then (optionally) apply method. applyF m is a synonym for focus m.

  solved
    Assert that subgoal is solved and release focus.

  applyT <method text>
    Apply method to current subgoal only. Same as apply m [].

  applyS <method text>
    Apply method to current subgoal, and assert that subgoal is solved.
    "applyS m" is roughly equal to "focus m solved"

  apply1 <method text>
    Apply method to current subgoal, and assert that there is exactly one resulting subgoal.

*)

ML signature STRUCTURED_APPLY = sig
  val focus: Proof.state -> Proof.state
  val solved: Proof.state -> Proof.state
  val unfocus: Proof.state -> Proof.state

  val apply1: Method.text_range -> Proof.state -> Proof.state Seq.result Seq.seq
  val applyT: Method.text * Position.range -> Proof.state -> Proof.state Seq.result Seq.seq
  val apply_focus: Method.text_range -> Proof.state -> Proof.state Seq.result Seq.seq
  val apply_solve: Method.text_range -> Proof.state -> Proof.state Seq.result Seq.seq
end

structure Structured_Apply: STRUCTURED_APPLY = struct
  val focus = Proof.refine_primitive (K (Goal.restrict 1 1))
  val unfocus = Proof.refine_primitive (K (Goal.unrestrict 1))
  val solved = Proof.refine_primitive (fn _ => fn thm => let
      val _ = if Thm.nprems_of thm > 0 then error "Subgoal not solved" else ()
    in
      Goal.unrestrict 1 thm
    end
  )

  fun apply_focus m = focus #> Proof.apply m

  fun assert_num_solved d msg m s = let
    val n_subgoals = Proof.raw_goal #> #goal #> Thm.nprems_of
    val n1 = n_subgoals s

    fun do_assert s = if n1 - n_subgoals s <> d then error msg else s
  in
    s 
    |> Proof.apply m
    |> Seq.map_result do_assert
  end

  fun apply_solve m = 
      focus 
    #> assert_num_solved 1 "Subgoal not solved" m
    #> Seq.map_result unfocus

  fun apply1 m = 
      focus 
    #> assert_num_solved 0 "Method must not produce or solve subgoals" m 
    #> Seq.map_result unfocus

  fun applyT (m,pos) = let
    open Method
    val m = Combinator (no_combinator_info, Select_Goals 1, [m])
  in
    Proof.apply (m,pos)
  end  


end

val _ =
  Outer_Syntax.command @{command_keyword solved} "Primitive unfocus after subgoal is solved"
    (Scan.succeed ( Toplevel.proof (Structured_Apply.solved) ));

val _ =
  Outer_Syntax.command @{command_keyword focus} "Primitive focus then optionally apply method"
    (Scan.option Method.parse >> (fn 
        NONE => Toplevel.proof (Structured_Apply.focus)
      | SOME m => (Method.report m; Toplevel.proofs (Structured_Apply.apply_focus m))
    ));

val _ =
  Outer_Syntax.command @{command_keyword applyF} "Primitive focus then apply method"
    (Method.parse >> (fn m => (Method.report m; 
      Toplevel.proofs (Structured_Apply.apply_focus m)
    )));

val _ =
  Outer_Syntax.command @{command_keyword applyS} "Apply method that solves exactly one subgoal"
    (Method.parse >> (fn m => (Method.report m; 
      Toplevel.proofs (Structured_Apply.apply_solve m) 
    )));

val _ =
  Outer_Syntax.command @{command_keyword apply1} "Apply method that does not change number of subgoals"
    (Method.parse >> (fn m => (Method.report m; 
      Toplevel.proofs (Structured_Apply.apply1 m) 
    )));

val _ =
  Outer_Syntax.command @{command_keyword applyT} "Apply method on first subgoal"
    (Method.parse >> (fn m => (Method.report m; 
      Toplevel.proofs (Structured_Apply.applyT m) 
    )));


end

Theory Term_Synth

section ‹Rule-Based Synthesis of Terms›
theory Term_Synth
imports Sepref_Misc
begin
  definition SYNTH_TERM :: "'a::{}  'b::{}  bool"
    ― ‹Indicate synthesis of @{term y} from @{term x}.›
    where [simp]: "SYNTH_TERM x y  True"
  consts SDUMMY :: "'a :: {}"
    ― ‹After synthesis has been completed, these are replaced by fresh schematic variable›

  named_theorems_rev synth_rules ‹Term synthesis rules›

  text ‹Term synthesis works by proving @{term "SYNTH_TERM t v"}, by repeatedly applying the 
    first matching intro-rule from synth_rules›.  ›


ML signature TERM_SYNTH = sig
    (* Synthesize something from term t. The initial list of theorems is
      added to beginning of synth_rules, and can be used to install intro-rules
      for SYNTH_TERM.*)
    val synth_term: thm list -> Proof.context -> term -> term
  end


  structure Term_Synth : TERM_SYNTH = struct

    (* Assumption: Term does not contain dummy variables *)
    fun replace_sdummies t = let
      fun r (t1$t2) n = let
              val (t1,n) = r t1 n
              val (t2,n) = r t2 n
            in (t1$t2,n) end
        | r (Abs (x,T,t)) n = let
              val (t,n) = r t n
            in (Abs (x,T,t),n) end
        | r @{mpat (typs) "SDUMMY::?'v_T"} n = (Var (("_dummy",n),T),n+1)
        | r (t' as (Var ((name,_),_))) n = if String.isPrefix "_" name then raise TERM ("replace_sdummies: Term already contains dummy patterns",[t',t]) else (t',n)
        | r t n = (t,n)
    in
      fst (r t 0)
    end    

    (* Use synthesis rules to transform the given term *)
    fun synth_term thms ctxt t = let
      val orig_ctxt = ctxt
      val (t,ctxt) = yield_singleton (Variable.import_terms true) t ctxt
      val v = Var (("result",0),TVar (("T",0),[]))
      val goal = @{mk_term "Trueprop (SYNTH_TERM ?t ?v)"} |> Thm.cterm_of ctxt
  
      val rules = thms @ Named_Theorems_Rev.get ctxt @{named_theorems_rev synth_rules}
        |> Tactic.build_net
      fun tac ctxt = ALLGOALS (TRY_SOLVED' (
        REPEAT_DETERM' (CHANGED o resolve_from_net_tac ctxt rules)))
      
      val thm = Goal.prove_internal ctxt [] goal (fn _ => tac ctxt)

      val res = case Thm.concl_of thm of
          @{mpat "Trueprop (SYNTH_TERM _ ?res)"} => res 
        | _ => raise THM("Synth_Term: Proved a different theorem?",~1,[thm])

      val res = singleton (Variable.export_terms ctxt orig_ctxt) res
        |> replace_sdummies
  
    in
      res
    end
  end



end

Theory User_Smashing

theory User_Smashing
  imports Pure
begin
(* Alternative flex-flex smasher by Simon Wimmer *)
ML fun enumerate xs = fold (fn x => fn (i, xs) => (i +1, (x, i) :: xs)) xs (0, []) |> snd
›

ML fun dummy_abs _ [] t = t
    | dummy_abs n (T :: Ts) t = Abs ("x" ^ Int.toString n, T, dummy_abs (n + 1) Ts t)

ML fun common_prefix Ts (t1 as Abs (_, T, t)) (u1 as Abs (_, U, u)) =
    if U = T then common_prefix (T :: Ts) t u else ([], t1, u1)
    | common_prefix Ts t u = (Ts, t, u);

  fun dest_app acc (t $ u) = dest_app (u :: acc) t
    | dest_app acc t = (t, acc);

  fun add_bound (Bound i, n) bs = (i, n) :: bs
    | add_bound _ bs = bs;

ML fun smash_pair ctxt thm (t, u) =
    let
      val idx = Thm.maxidx_of thm + 1;
      val ctxt' = ctxt;
      val (Ts, t1, _) = common_prefix [] t u;
      val (tas, t2) = Term.strip_abs t;
      val (uas, u2) = Term.strip_abs u;
      val (tx as Var (_, T1), ts) = Term.strip_comb t2;
      val (ux as Var (_, U1), us) = Term.strip_comb u2;
      val Ts1 = Term.binder_types T1;
      val Us1 = Term.binder_types U1;
      val T = Term.fastype_of1 (Ts, t1);
      val tshift = length tas - length Ts;
      val ushift = length uas - length Ts;
      val tbs = fold add_bound (enumerate (rev ts)) [] |> map (apfst (fn i => i - tshift));
      val ubs = fold add_bound (enumerate (rev us)) [] |> map (apfst (fn i => i - ushift));
      val bounds = inter (op =) (map fst tbs) (map fst ubs) |> distinct (=);
      val T' = map (nth Ts) bounds ---> T;
      val v = Var (("simon", idx), T');
      val tbs' = map (fn i => find_first (fn (j, _) => i = j) tbs |> the |> snd |> Bound) bounds;
      val t' = list_comb (v, tbs') |> dummy_abs 0 Ts1;
      (* Need to add bounds for superfluous abstractions here *)
      val ubs' = map (fn i => find_first (fn (j, _) => i = j) ubs |> the |> snd |> Bound) bounds;
      val u' = list_comb (v, ubs') |> dummy_abs 0 Us1;
      val subst = [(Term.dest_Var tx, Thm.cterm_of ctxt' t'), (Term.dest_Var ux, Thm.cterm_of ctxt' u')];
    in
      instantiate_normalize ([], subst) thm
    end;
    fun smash ctxt thm =
      case (Thm.tpairs_of thm) of
        [] => thm
      | (p :: _) => smash_pair ctxt thm p;
    fun smashed_attrib ctxt thm =
      (NONE, SOME (smash ctxt thm));

ML val smash_new_rule = Seq.single oo smash;

end

Theory Sepref_Chapter_Tool

chapter ‹The Sepref Tool›
text ‹This chapter contains the Sepref tool and related tools.›
(*<*)
theory Sepref_Chapter_Tool
imports Main
begin
end
(*>*)

Theory Sepref_Id_Op

section ‹Operation Identification Phase›
theory Sepref_Id_Op
imports 
  Main 
  Automatic_Refinement.Refine_Lib
  Automatic_Refinement.Autoref_Tagging
  "Lib/Named_Theorems_Rev"
begin

text ‹
  The operation identification phase is adapted from the Autoref tool.
  The basic idea is to have a type system, which works on so called 
  interface types (also called conceptual types). Each conceptual type
  denotes an abstract data type, e.g., set, map, priority queue.
  
  Each abstract operation, which must be a constant applied to its arguments,
  is assigned a conceptual type. Additionally, there is a set of 
  {\emph pattern rewrite rules},
  which are applied to subterms before type inference takes place, and 
  which may be backtracked over. 
  This way, encodings of abstract operations in Isabelle/HOL, like 
  @{term [source] "λ_. None"} for the empty map, 
  or @{term [source] "fun_upd m k (Some v)"} for map update, can be rewritten
  to abstract operations, and get properly typed.
›

subsection "Proper Protection of Term"
text ‹ The following constants are meant to encode abstraction and 
  application as proper HOL-constants, and thus avoid strange effects with
  HOL's higher-order unification heuristics and automatic 
  beta and eta-contraction.

  The first step of operation identification is to protect the term
  by replacing all function applications and abstractions be 
  the constants defined below.
›

definition [simp]: "PROTECT2 x (y::prop)  x"
consts DUMMY :: "prop"

abbreviation PROTECT2_syn ("'(#_#')") where "PROTECT2_syn t  PROTECT2 t DUMMY"

abbreviation (input)ABS2 :: "('a'b)'a'b" (binder "λ2" 10)
  where "ABS2 f  (λx. PROTECT2 (f x) DUMMY)"

lemma beta: "(λ2x. f x)$x  f x" by simp

text ‹
  Another version of @{const "APP"}. Treated like @{const APP} by our tool.
  Required to avoid infinite pattern rewriting in some cases, e.g., map-lookup.
›

definition APP' (infixl "$''" 900) where [simp, autoref_tag_defs]: "f$'a  f a"

text ‹
  Sometimes, whole terms should be protected from being processed by our tool.
  For example, our tool should not look into numerals. For this reason,
  the PR_CONST› tag indicates terms that our tool shall handle as
  atomic constants, an never look into them.

  The special form UNPROTECT› can be used inside pattern rewrite rules.
  It has the effect to revert the protection from its argument, and then wrap
  it into a PR_CONST›.
›
definition [simp, autoref_tag_defs]: "PR_CONST x  x" ― ‹Tag to protect constant›
definition [simp, autoref_tag_defs]: "UNPROTECT x  x" ― ‹Gets 
  converted to @{term PR_CONST}, after unprotecting its content›


subsection ‹Operation Identification›

text ‹ Indicator predicate for conceptual typing of a constant ›
definition intf_type :: "'a  'b itself  bool" (infix "::i" 10) where
  [simp]: "c::iI  True"

lemma itypeI: "c::iI" by simp
lemma itypeI': "intf_type c TYPE('T)" by (rule itypeI)

lemma itype_self: "(c::'a) ::i TYPE('a)" by simp

definition CTYPE_ANNOT :: "'b  'a itself  'b" (infix ":::i" 10) where
  [simp]: "c:::iI  c"

text ‹ Wrapper predicate for an conceptual type inference ›
definition ID :: "'a  'a  'c itself  bool" 
  where [simp]: "ID t t' T  t=t'"

subsubsection ‹Conceptual Typing Rules›

lemma ID_unfold_vars: "ID x y T  xy" by simp
lemma ID_PR_CONST_trigger: "ID (PR_CONST x) y T  ID (PR_CONST x) y T" .

lemma pat_rule:
  " pp'; ID p' t' T   ID p t' T" by simp

lemma app_rule:
  " ID f f' TYPE('a'b); ID x x' TYPE('a)  ID (f$x) (f'$x') TYPE('b)"
  by simp

lemma app'_rule:
  " ID f f' TYPE('a'b); ID x x' TYPE('a)  ID (f$'x) (f'$x') TYPE('b)"
  by simp

lemma abs_rule:
  " x x'. ID x x' TYPE('a)  ID (t x) (t' x x') TYPE('b)  
    ID (λ2x. t x) (λ2x'. t' x' x') TYPE('a'b)"
  by simp

lemma id_rule: "c::iI  ID c c I" by simp

lemma annot_rule: "ID t t' I  ID (t:::iI) t' I"
  by simp

lemma fallback_rule:
  "ID (c::'a) c TYPE('c)"
  by simp

lemma unprotect_rl1: "ID (PR_CONST x) t T  ID (UNPROTECT x) t T"
  by simp

subsection ‹ ML-Level code ›
ML infix 0 THEN_ELSE_COMB'

signature ID_OP_TACTICAL = sig
  val SOLVE_FWD: tactic' -> tactic'
  val DF_SOLVE_FWD: bool -> tactic' -> tactic'
end

structure Id_Op_Tactical :ID_OP_TACTICAL = struct

  fun SOLVE_FWD tac i st = SOLVED' (
    tac 
    THEN_ALL_NEW_FWD (SOLVE_FWD tac)) i st


  (* Search for solution with DFS-strategy. If dbg-flag is given,
    return sequence of stuck states if no solution is found.
  *)
  fun DF_SOLVE_FWD dbg tac = let
    val stuck_list_ref = Unsynchronized.ref []

    fun stuck_tac _ st = if dbg then (
      stuck_list_ref := st :: !stuck_list_ref;
      Seq.empty
    ) else Seq.empty

    fun rec_tac i st = (
        (tac THEN_ALL_NEW_FWD (SOLVED' rec_tac))
        ORELSE' stuck_tac
      ) i st

    fun fail_tac _ _ = if dbg then
      Seq.of_list (rev (!stuck_list_ref))
    else Seq.empty
  in
    rec_tac ORELSE' fail_tac    
  end

end


named_theorems_rev id_rules "Operation identification rules"
named_theorems_rev pat_rules "Operation pattern rules"
named_theorems_rev def_pat_rules "Definite operation pattern rules (not backtracked over)"



ML structure Id_Op = struct

    fun id_a_conv cnv ct = case Thm.term_of ct of
      @{mpat "ID _ _ _"} => Conv.fun_conv (Conv.fun_conv (Conv.arg_conv cnv)) ct
    | _ => raise CTERM("id_a_conv",[ct])

    fun 
      protect env (@{mpat "?t:::i?I"}) = let
        val t = protect env t
      in 
        @{mk_term env: "?t:::i?I"}
      end
    | protect _ (t as @{mpat "PR_CONST _"}) = t
    | protect env (t1$t2) = let
        val t1 = protect env t1
        val t2 = protect env t2
      in
        @{mk_term env: "?t1.0 $ ?t2.0"}
      end
    | protect env (Abs (x,T,t)) = let
        val t = protect (T::env) t
      in
        @{mk_term env: "λv_x::?'v_T. PROTECT2 ?t DUMMY"}
      end
    | protect _ t = t

    fun protect_conv ctxt = Refine_Util.f_tac_conv ctxt
      (protect []) 
      (simp_tac 
        (put_simpset HOL_basic_ss ctxt addsimps @{thms PROTECT2_def APP_def}) 1)

    fun unprotect_conv ctxt
      = Simplifier.rewrite (put_simpset HOL_basic_ss ctxt 
        addsimps @{thms PROTECT2_def APP_def})

    fun do_unprotect_tac ctxt =
      resolve_tac ctxt @{thms unprotect_rl1} THEN'
      CONVERSION (Refine_Util.HOL_concl_conv (fn ctxt => id_a_conv (unprotect_conv ctxt)) ctxt)

    val cfg_id_debug = 
      Attrib.setup_config_bool @{binding id_debug} (K false)

    val cfg_id_trace_fallback = 
      Attrib.setup_config_bool @{binding id_trace_fallback} (K false)

    fun dest_id_rl thm = case Thm.concl_of thm of
      @{mpat (typs) "Trueprop (?c::iTYPE(?'v_T))"} => (c,T)
    | _ => raise THM("dest_id_rl",~1,[thm])

    
    val add_id_rule = snd oo Thm.proof_attributes [Named_Theorems_Rev.add @{named_theorems_rev id_rules}]

    datatype id_tac_mode = Init | Step | Normal | Solve

    fun id_tac ss ctxt = let
      open Id_Op_Tactical
      val certT = Thm.ctyp_of ctxt
      val cert = Thm.cterm_of ctxt

      val thy = Proof_Context.theory_of ctxt

      val id_rules = Named_Theorems_Rev.get ctxt @{named_theorems_rev id_rules}
      val pat_rules = Named_Theorems_Rev.get ctxt @{named_theorems_rev pat_rules}
      val def_pat_rules = Named_Theorems_Rev.get ctxt @{named_theorems_rev def_pat_rules}

      val rl_net = Tactic.build_net (
        (pat_rules |> map (fn thm => thm RS @{thm pat_rule})) 
        @ @{thms annot_rule app_rule app'_rule abs_rule} 
        @ (id_rules |> map (fn thm => thm RS @{thm id_rule}))
      )

      val def_rl_net = Tactic.build_net (
        (def_pat_rules |> map (fn thm => thm RS @{thm pat_rule}))
      )  

      val id_pr_const_rename_tac = 
          resolve_tac ctxt @{thms ID_PR_CONST_trigger} THEN'
          Subgoal.FOCUS (fn { context=ctxt, prems, ... } => 
            let
              fun is_ID @{mpat "Trueprop (ID _ _ _)"} = true | is_ID _ = false
              val prems = filter (Thm.prop_of #> is_ID) prems
              val eqs = map (fn thm => thm RS @{thm ID_unfold_vars}) prems
              val conv = Conv.rewrs_conv eqs
              val conv = fn ctxt => (Conv.top_sweep_conv (K conv) ctxt)
              val conv = fn ctxt => Conv.fun2_conv (Conv.arg_conv (conv ctxt))
              val conv = Refine_Util.HOL_concl_conv conv ctxt
            in CONVERSION conv 1 end 
          ) ctxt THEN'
          resolve_tac ctxt @{thms id_rule} THEN'
          resolve_tac ctxt id_rules 

      val ityping = id_rules 
        |> map dest_id_rl
        |> filter (is_Const o #1)
        |> map (apfst (#1 o dest_Const))
        |> Symtab.make_list

      val has_type = Symtab.defined ityping

      fun mk_fallback name cT =
        case try (Sign.the_const_constraint thy) name of
          SOME T => try (Thm.instantiate' 
                          [SOME (certT cT), SOME (certT T)] [SOME (cert (Const (name,cT)))])
                        @{thm fallback_rule} 
        | NONE => NONE

      fun trace_fallback thm = 
        Config.get ctxt cfg_id_trace_fallback       
        andalso let 
          open Pretty
          val p = block [str "ID_OP: Applying fallback rule: ", Thm.pretty_thm ctxt thm]
        in 
          string_of p |> tracing; 
          false
        end  

      val fallback_tac = CONVERSION Thm.eta_conversion THEN' IF_EXGOAL (fn i => fn st =>
        case Logic.concl_of_goal (Thm.prop_of st) i of
          @{mpat "Trueprop (ID (mpaq_STRUCT (mpaq_Const ?name ?cT)) _ _)"} => (
            if not (has_type name) then 
              case mk_fallback name cT of
                SOME thm => (trace_fallback thm; resolve_tac ctxt [thm] i st)
              | NONE => Seq.empty  
            else Seq.empty
          )
        | _ => Seq.empty)

      val init_tac = CONVERSION (
        Refine_Util.HOL_concl_conv (fn ctxt => (id_a_conv (protect_conv ctxt))) 
          ctxt
      )

      val step_tac = (FIRST' [
        assume_tac ctxt, 
        eresolve_tac ctxt @{thms id_rule},
        resolve_from_net_tac ctxt def_rl_net, 
        resolve_from_net_tac ctxt rl_net, 
        id_pr_const_rename_tac,
        do_unprotect_tac ctxt, 
        fallback_tac])

      val solve_tac = DF_SOLVE_FWD (Config.get ctxt cfg_id_debug) step_tac  

    in
      case ss of
        Init => init_tac 
      | Step => step_tac 
      | Normal => init_tac THEN' solve_tac
      | Solve => solve_tac

    end

  end

subsection ‹Default Setup›

subsubsection ‹Numerals› 
lemma pat_numeral[def_pat_rules]: "numeral$x  UNPROTECT (numeral$x)" by simp

lemma id_nat_const[id_rules]: "(PR_CONST (a::nat)) ::i TYPE(nat)" by simp
lemma id_int_const[id_rules]: "(PR_CONST (a::int)) ::i TYPE(int)" by simp

(*subsection ‹Example›
schematic_lemma 
  "ID (λa b. (b(1::int↦2::nat) |`(-{3})) a, Map.empty, λa. case a of None ⇒ Some a | Some _ ⇒ None) (?c) (?T::?'d itself)"
  (*"TERM (?c,?T)"*)
  using [[id_debug]]
  apply (tactic {* Id_Op.id_tac Id_Op.Normal @{context} 1  *})  
  done
*)

end

Theory Sepref_Basic

section ‹Basic Definitions›
theory Sepref_Basic
imports 
  "HOL-Eisbach.Eisbach"
  Separation_Logic_Imperative_HOL.Sep_Main
  Refine_Monadic.Refine_Monadic
  "Lib/Sepref_Misc"
  "Lib/Structured_Apply"
  Sepref_Id_Op
begin
no_notation i_ANNOT (infixr ":::i" 10)
no_notation CONST_INTF (infixr "::i" 10)

text ‹
  In this theory, we define the basic concept of refinement 
  from a nondeterministic program specified in the 
  Isabelle Refinement Framework to an imperative deterministic one 
  specified in Imperative/HOL.
›

subsection ‹Values on Heap›
text ‹We tag every refinement assertion with the tag hn_ctxt›, to
  avoid higher-order unification problems when the refinement assertion 
  is schematic.›
definition hn_ctxt :: "('a'cassn)  'a  'c  assn" 
  ― ‹Tag for refinement assertion›
  where
  "hn_ctxt P a c  P a c"

definition pure :: "('b × 'a) set  'a  'b  assn"
  ― ‹Pure binding, not involving the heap›
  where "pure R  (λa c. ((c,a)R))"

lemma pure_app_eq: "pure R a c = ((c,a)R)" by (auto simp: pure_def)

lemma pure_eq_conv[simp]: "pure R = pure R'  R=R'"
  unfolding pure_def 
  apply (rule iffI)
  apply safe
  apply (meson pure_assn_eq_conv)
  apply (meson pure_assn_eq_conv)
  done

lemma pure_rel_eq_false_iff: "pure R x y = false  (y,x)R"
  by (auto simp: pure_def)
    
    
definition "is_pure P  P'. x x'. P x x'=(P' x x')"
lemma is_pureI[intro?]: 
  assumes "x x'. P x x' = (P' x x')"
  shows "is_pure P"
  using assms unfolding is_pure_def by blast

lemma is_pureE:
  assumes "is_pure P"
  obtains P' where "x x'. P x x' = (P' x x')"
  using assms unfolding is_pure_def by blast

lemma pure_pure[simp]: "is_pure (pure P)"
  unfolding pure_def by rule blast
lemma pure_hn_ctxt[intro!]: "is_pure P  is_pure (hn_ctxt P)"
  unfolding hn_ctxt_def[abs_def] .


definition "the_pure P  THE P'. x x'. P x x'=((x',x)P')"

lemma the_pure_pure[simp]: "the_pure (pure R) = R"
  unfolding pure_def the_pure_def
  by (rule theI2[where a=R]) auto

lemma is_pure_alt_def: "is_pure R  (Ri. x y. R x y = ((y,x)Ri))"
  unfolding is_pure_def
  apply auto
  apply (rename_tac P')
  apply (rule_tac x="{(x,y). P' y x}" in exI)
  apply auto
  done

lemma pure_the_pure[simp]: "is_pure R  pure (the_pure R) = R"
  unfolding is_pure_alt_def pure_def the_pure_def
  apply (intro ext)
  apply clarsimp
  apply (rename_tac a c Ri)
  apply (rule_tac a=Ri in theI2)
  apply auto
  done
  
lemma is_pure_conv: "is_pure R  (R'. R = pure R')"
  unfolding pure_def is_pure_alt_def by force

lemma is_pure_the_pure_id_eq[simp]: "is_pure R  the_pure R = Id  R=pure Id"  
  by (auto simp: is_pure_conv)

lemma is_pure_iff_pure_assn: "is_pure P = (x x'. is_pure_assn (P x x'))"
  unfolding is_pure_def is_pure_assn_def by metis



abbreviation "hn_val R  hn_ctxt (pure R)"

lemma hn_val_unfold: "hn_val R a b = ((b,a)R)"
  by (simp add: hn_ctxt_def pure_def)


definition "invalid_assn R x y  (h. hR x y) * true"

abbreviation "hn_invalid R  hn_ctxt (invalid_assn R)"

lemma invalidate_clone: "R x y A invalid_assn R x y * R x y"
  apply (rule entailsI)
  unfolding invalid_assn_def
  apply (auto simp: models_in_range mod_star_trueI)
  done

lemma invalidate_clone': "R x y A invalid_assn R x y * R x y * true"
  apply (rule entailsI)
  unfolding invalid_assn_def
  apply (auto simp: models_in_range mod_star_trueI)
  done

lemma invalidate: "R x y A invalid_assn R x y"
  apply (rule entailsI)
  unfolding invalid_assn_def
  apply (auto simp: models_in_range mod_star_trueI)
  done

lemma invalid_pure_recover: "invalid_assn (pure R) x y = pure R x y * true"
  apply (rule ent_iffI) 
  subgoal
    apply (rule entailsI)
    unfolding invalid_assn_def
    by (auto simp: pure_def)
  subgoal
    unfolding invalid_assn_def
    by (auto simp: pure_def)
  done    

lemma hn_invalidI: "hhn_ctxt P x y  hn_invalid P x y = true"
  apply (cases h)
  apply (rule ent_iffI)
  apply (auto simp: invalid_assn_def hn_ctxt_def)
  done

lemma invalid_assn_cong[cong]:
  assumes "xx'"
  assumes "yy'"
  assumes "R x' y'  R' x' y'"
  shows "invalid_assn R x y = invalid_assn R' x' y'"
  using assms unfolding invalid_assn_def
  by simp

subsection ‹Constraints in Refinement Relations›

lemma mod_pure_conv[simp]: "(h,as)pure R a b  (as={}  (b,a)R)"
  by (auto simp: pure_def)

definition rdomp :: "('a  'c  assn)  'a  bool" where
  "rdomp R a  h c. h  R a c"

abbreviation "rdom R  Collect (rdomp R)"

lemma rdomp_ctxt[simp]: "rdomp (hn_ctxt R) = rdomp R"
  by (simp add: hn_ctxt_def[abs_def])  

lemma rdomp_pure[simp]: "rdomp (pure R) a  aRange R"
  unfolding rdomp_def pure_def by auto

lemma rdom_pure[simp]: "rdom (pure R) = Range R"
  unfolding rdomp_def[abs_def] pure_def by auto

lemma Range_of_constraint_conv[simp]: "Range (AUNIV×C) = Range A  C"
  by auto


subsection ‹Heap-Nres Refinement Calculus›

text ‹Predicate that expresses refinement. Given a heap
  Γ›, program c› produces a heap Γ'› and
  a concrete result that is related with predicate R› to some
  abstract result from m›
definition "hn_refine Γ c Γ' R m  nofail m 
  <Γ> c <λr. Γ' * (Ax. R x r * (RETURN x  m)) >t"

(* TODO: Can we change the patterns of assn_simproc to add this pattern? *)
simproc_setup assn_simproc_hnr ("hn_refine Γ c Γ'")
  = ‹K Seplogic_Auto.assn_simproc_fun

lemma hn_refineI[intro?]:
  assumes "nofail m 
     <Γ> c <λr. Γ' * (Ax. R x r * (RETURN x  m)) >t"
  shows "hn_refine Γ c Γ' R m"
  using assms unfolding hn_refine_def by blast

lemma hn_refineD:
  assumes "hn_refine Γ c Γ' R m"
  assumes "nofail m"
  shows "<Γ> c <λr. Γ' * (Ax. R x r * (RETURN x  m)) >t"
  using assms unfolding hn_refine_def by blast

lemma hn_refine_preI: 
  assumes "h. hΓ  hn_refine Γ c Γ' R a"
  shows "hn_refine Γ c Γ' R a"
  using assms unfolding hn_refine_def
  by (auto intro: hoare_triple_preI)

lemma hn_refine_nofailI: 
  assumes "nofail a  hn_refine Γ c Γ' R a"  
  shows "hn_refine Γ c Γ' R a"
  using assms by (auto simp: hn_refine_def)

lemma hn_refine_false[simp]: "hn_refine false c Γ' R m"
  by rule auto

lemma hn_refine_fail[simp]: "hn_refine Γ c Γ' R FAIL"
  by rule auto

lemma hn_refine_frame:
  assumes "hn_refine P' c Q' R m"
  assumes "P t F * P'"
  shows "hn_refine P c (F * Q') R m"
  using assms
  unfolding hn_refine_def entailst_def
  apply clarsimp
  apply (erule cons_pre_rule)
  apply (rule cons_post_rule)
  apply (erule fi_rule, frame_inference)
  apply (simp only: star_aci)
  apply simp
  done

lemma hn_refine_cons:
  assumes I: "PtP'"
  assumes R: "hn_refine P' c Q R m"
  assumes I': "QtQ'"
  assumes R': "x y. R x y t R' x y"
  shows "hn_refine P c Q' R' m"
  using R unfolding hn_refine_def
  apply clarify
  apply (rule cons_pre_rulet[OF I])
  apply (rule cons_post_rulet)
  apply assumption
  apply (sep_auto simp: entailst_def)
  apply (rule enttD)
  apply (intro entt_star_mono I' R')
  done

(*lemma hn_refine_cons:
  assumes I: "P⟹AP'"
  assumes R: "hn_refine P' c Q R m"
  assumes I': "Q⟹AQ'"
  assumes R': "⋀x y. R x y ⟹A R' x y"
  shows "hn_refine P c Q' R' m"
  using R unfolding hn_refine_def
  apply clarsimp
  apply (rule cons_pre_rule[OF I])
  apply (erule cons_post_rule)
  apply (rule ent_star_mono ent_refl I' R' ent_ex_preI ent_ex_postI)+
  done
*)
lemma hn_refine_cons_pre:
  assumes I: "PtP'"
  assumes R: "hn_refine P' c Q R m"
  shows "hn_refine P c Q R m"
  by (rule hn_refine_cons[OF I R]) sep_auto+

lemma hn_refine_cons_post:
  assumes R: "hn_refine P c Q R m"
  assumes I: "QtQ'"
  shows "hn_refine P c Q' R m"
  using assms
  by (rule hn_refine_cons[OF entt_refl _ _ entt_refl])

lemma hn_refine_cons_res: 
  " hn_refine Γ f Γ' R g; a c. R a c t R' a c   hn_refine Γ f Γ' R' g"
  by (erule hn_refine_cons[OF entt_refl]) sep_auto+

lemma hn_refine_ref:
  assumes LE: "mm'"
  assumes R: "hn_refine P c Q R m"
  shows "hn_refine P c Q R m'"
  apply rule
  apply (rule cons_post_rule)
  apply (rule hn_refineD[OF R])
  using LE apply (simp add: pw_le_iff)
  apply (sep_auto intro: order_trans[OF _ LE])
  done

lemma hn_refine_cons_complete:
  assumes I: "PtP'"
  assumes R: "hn_refine P' c Q R m"
  assumes I': "QtQ'"
  assumes R': "x y. R x y t R' x y"
  assumes LE: "mm'"
  shows "hn_refine P c Q' R' m'"
  apply (rule hn_refine_ref[OF LE])
  apply (rule hn_refine_cons[OF I R I' R'])
  done
 
lemma hn_refine_augment_res:
  assumes A: "hn_refine Γ f Γ' R g"
  assumes B: "g n SPEC Φ"
  shows "hn_refine Γ f Γ' (λa c. R a c * (Φ a)) g"
  apply (rule hn_refineI)
  apply (rule cons_post_rule)
  apply (erule A[THEN hn_refineD])
  using B
  apply (sep_auto simp: pw_le_iff pw_leof_iff)
  done


subsection ‹Product Types›
text ‹Some notion for product types is already defined here, as it is used 
  for currying and uncurrying, which is fundamental for the sepref tool›
definition prod_assn :: "('a1'c1assn)  ('a2'c2assn) 
   'a1*'a2  'c1*'c2  assn" where
  "prod_assn P1 P2 a c  case (a,c) of ((a1,a2),(c1,c2)) 
  P1 a1 c1 * P2 a2 c2"

notation prod_assn (infixr "×a" 70)
  
lemma prod_assn_pure_conv[simp]: "prod_assn (pure R1) (pure R2) = pure (R1 ×r R2)"
  by (auto simp: pure_def prod_assn_def intro!: ext)

lemma prod_assn_pair_conv[simp]: 
  "prod_assn A B (a1,b1) (a2,b2) = A a1 a2 * B b1 b2"
  unfolding prod_assn_def by auto

lemma prod_assn_true[simp]: "prod_assn (λ_ _. true) (λ_ _. true) = (λ_ _. true)"
  by (auto intro!: ext simp: hn_ctxt_def prod_assn_def)

subsection "Convenience Lemmas"

lemma hn_refine_guessI:
  assumes "hn_refine P f P' R f'"
  assumes "f=f_conc"
  shows "hn_refine P f_conc P' R f'"
  ― ‹To prove a refinement, first synthesize one, and then prove equality›
  using assms by simp


lemma imp_correctI:
  assumes R: "hn_refine Γ c Γ' R a"
  assumes C: "a  SPEC Φ"
  shows "<Γ> c <λr'. Ar. Γ' * R r r' * (Φ r)>t"
  apply (rule cons_post_rule)
  apply (rule hn_refineD[OF R])
  apply (rule le_RES_nofailI[OF C])
  apply (sep_auto dest: order_trans[OF _ C])
  done

lemma hnr_pre_ex_conv: 
  shows "hn_refine (Ax. Γ x) c Γ' R a  (x. hn_refine (Γ x) c Γ' R a)"
  unfolding hn_refine_def
  apply safe
  apply (erule cons_pre_rule[rotated])
  apply (rule ent_ex_postI)
  apply (rule ent_refl)
  apply sep_auto
  done

lemma hnr_pre_pure_conv:  
  shows "hn_refine (Γ * P) c Γ' R a  (P  hn_refine Γ c Γ' R a)"
  unfolding hn_refine_def
  by auto

lemma hn_refine_split_post:
  assumes "hn_refine Γ c Γ' R a"
  shows "hn_refine Γ c (Γ' A Γ'') R a"
  apply (rule hn_refine_cons_post[OF assms])
  by (rule entt_disjI1_direct)

lemma hn_refine_post_other: 
  assumes "hn_refine Γ c Γ'' R a"
  shows "hn_refine Γ c (Γ' A Γ'') R a"
  apply (rule hn_refine_cons_post[OF assms])
  by (rule entt_disjI2_direct)


subsubsection ‹Return›

lemma hnr_RETURN_pass:
  "hn_refine (hn_ctxt R x p) (return p) (hn_invalid R x p) R (RETURN x)"
  ― ‹Pass on a value from the heap as return value›
  apply rule 
  apply (sep_auto simp: hn_ctxt_def eintros: invalidate_clone')
  done

lemma hnr_RETURN_pure:
  assumes "(c,a)R"
  shows "hn_refine emp (return c) emp (pure R) (RETURN a)"
  ― ‹Return pure value›
  unfolding hn_refine_def using assms
  by (sep_auto simp: pure_def)
  
subsubsection ‹Assertion›
lemma hnr_FAIL[simp, intro!]: "hn_refine Γ c Γ' R FAIL"
  unfolding hn_refine_def
  by simp

lemma hnr_ASSERT:
  assumes "Φ  hn_refine Γ c Γ' R c'"
  shows "hn_refine Γ c Γ' R (do { ASSERT Φ; c'})"
  using assms
  apply (cases Φ)
  by auto

subsubsection ‹Bind›
lemma bind_det_aux: " RETURN x  m; RETURN y  f x   RETURN y  m  f"
  apply (rule order_trans[rotated])
  apply (rule Refine_Basic.bind_mono)
  apply assumption
  apply (rule order_refl)
  apply simp
  done

lemma hnr_bind:
  assumes D1: "hn_refine Γ m' Γ1 Rh m"
  assumes D2: 
    "x x'. RETURN x  m  hn_refine (Γ1 * hn_ctxt Rh x x') (f' x') (Γ2 x x') R (f x)"
  assumes IMP: "x x'. Γ2 x x' t Γ' * hn_ctxt Rx x x'"
  shows "hn_refine Γ (m'f') Γ' R (mf)"
  using assms
  unfolding hn_refine_def
  apply (clarsimp simp add: pw_bind_nofail)
  apply (rule Hoare_Triple.bind_rule)
  apply assumption
  apply (clarsimp intro!: normalize_rules simp: hn_ctxt_def)
proof -
  fix x' x
  assume 1: "RETURN x  m" 
    and "nofail m" "x. inres m x  nofail (f x)"
  hence "nofail (f x)" by (auto simp: pw_le_iff)
  moreover assume "x x'. RETURN x  m 
           nofail (f x)  <Γ1 * Rh x x'> f' x'
           <λr'. Ar. Γ2 x x' * R r r' * true *  (RETURN r  f x)>"
  ultimately have "x'. <Γ1 * Rh x x'> f' x'
           <λr'. Ar. Γ2 x x' * R r r' * true *  (RETURN r  f x)>"
    using 1 by simp
  also have "r'. Ar. Γ2 x x' * R r r' * true *  (RETURN r  f x) A
    Ar. Γ' * R r r' * true *  (RETURN r  f x)"
    apply (sep_auto)
    apply (rule ent_frame_fwd[OF IMP[THEN enttD]])
    apply frame_inference
    apply (solve_entails)
    done
  finally (cons_post_rule) have 
    R: "<Γ1 * Rh x x'> f' x' 
        <λr'. Ar. Γ' * R r r' * true * (RETURN r  f x)>"
    .
  show "<Γ1 * Rh x x' * true> f' x'
          <λr'. Ar. Γ' * R r r' * true *  (RETURN r  m  f)>"
    by (sep_auto heap: R intro: bind_det_aux[OF 1])
qed

subsubsection ‹Recursion›

definition "hn_rel P m  λr. Ax. P x r * (RETURN x  m)"

lemma hn_refine_alt: "hn_refine Fpre c Fpost P m  nofail m 
  <Fpre> c <λr. hn_rel P m r * Fpost>t"
  apply (rule eq_reflection)
  unfolding hn_refine_def hn_rel_def
  apply (simp add: hn_ctxt_def)
  apply (simp only: star_aci)
  done

lemma wit_swap_forall:
  assumes W: "<P> c <λ_. true>"
  assumes T: "(x. A x  <P> c <Q x>)"
  shows "<P> c <λr. ¬A (Ax. (A x) * ¬A Q x r)>"
  unfolding hoare_triple_def Let_def
  apply (intro conjI impI allI)
  subgoal by (elim conjE) (rule hoare_tripleD[OF W], assumption+) []

  subgoal
    apply (clarsimp, intro conjI allI)
    apply1 (rule models_in_range)
    applyS (rule hoare_tripleD[OF W]; assumption; fail)
    apply1 (simp only: disj_not2, intro impI)
    apply1 (drule spec[OF T, THEN mp])
    apply1 (drule (2) hoare_tripleD(2))
    by assumption

  subgoal by (elim conjE) (rule hoare_tripleD[OF W], assumption+)

  subgoal by (elim conjE) (rule hoare_tripleD[OF W], assumption+) 
  done

lemma hn_admissible:
  assumes PREC: "precise Ry"
  assumes E: "fA. nofail (f x)  <P> c <λr. hn_rel Ry (f x) r * F>"
  assumes NF: "nofail (INF fA. f x)"
  shows "<P> c <λr. hn_rel Ry (INF fA. f x) r * F>"
proof -
  from NF obtain f where "fA" and "nofail (f x)"
    by (simp only: refine_pw_simps) blast

  with E have "<P> c <λr. hn_rel Ry (f x) r * F>" by blast
  hence W: "<P> c <λ_. true>" by (rule cons_post_rule, simp)

  from E have 
    E': "f. fA  nofail (f x)  <P> c <λr. hn_rel Ry (f x) r * F>"
    by blast
  from wit_swap_forall[OF W E'] have 
    E'': "<P> c
     <λr. ¬A (Axa.  (xa  A  nofail (xa x)) *
                ¬A (hn_rel Ry (xa x) r * F))>" .
  
  thus ?thesis
    apply (rule cons_post_rule)
    unfolding entails_def hn_rel_def
    apply clarsimp
  proof -
    fix h as p
    assume A: "f. fA  (a.
      ((h, as)  Ry a p * F  RETURN a  f x))  ¬ nofail (f x)"
    with fA and ‹nofail (f x) obtain a where 
      1: "(h, as)  Ry a p * F" and "RETURN a  f x"
      by blast
    have
      "fA. nofail (f x)  (h, as)  Ry a p * F  RETURN a  f x"
    proof clarsimp
      fix f'
      assume "f'A" and "nofail (f' x)"
      with A obtain a' where 
        2: "(h, as)  Ry a' p * F" and "RETURN a'  f' x"
        by blast

      moreover note preciseD'[OF PREC 1 2] 
      ultimately show "(h, as)  Ry a p * F  RETURN a  f' x" by simp
    qed
    hence "RETURN a  (INF fA. f x)"
      by (metis (mono_tags) le_INF_iff le_nofailI)
    with 1 show "a. (h, as)  Ry a p * F  RETURN a  (INF fA. f x)"
      by blast
  qed
qed

lemma hn_admissible':
  assumes PREC: "precise Ry"
  assumes E: "fA. nofail (f x)  <P> c <λr. hn_rel Ry (f x) r * F>t"
  assumes NF: "nofail (INF fA. f x)"
  shows "<P> c <λr. hn_rel Ry (INF fA. f x) r * F>t"
  apply (rule hn_admissible[OF PREC, where F="F*true", simplified])
  apply simp
  by fact+

lemma hnr_RECT_old:
  assumes S: "cf af ax px. 
    ax px. hn_refine (hn_ctxt Rx ax px * F) (cf px) (F' ax px) Ry (af ax) 
     hn_refine (hn_ctxt Rx ax px * F) (cB cf px) (F' ax px) Ry (aB af ax)"
  assumes M: "(x. mono_Heap (λf. cB f x))"
  assumes PREC: "precise Ry"
  shows "hn_refine 
    (hn_ctxt Rx ax px * F) (heap.fixp_fun cB px) (F' ax px) Ry (RECT aB ax)"
  unfolding RECT_gfp_def
proof (simp, intro conjI impI)
  assume "trimono aB"
  hence "mono aB" by (simp add: trimonoD)
  have "ax px. 
    hn_refine (hn_ctxt Rx ax px * F) (heap.fixp_fun cB px) (F' ax px) Ry 
      (gfp aB ax)"
    apply (rule gfp_cadm_induct[OF _ _ ‹mono aB])

    apply rule
    apply (auto simp: hn_refine_alt intro: hn_admissible'[OF PREC]) []

    apply (auto simp: hn_refine_alt) []

    apply clarsimp
    apply (subst heap.mono_body_fixp[of cB, OF M])
    apply (rule S)
    apply blast
    done
  thus "hn_refine (hn_ctxt Rx ax px * F)
     (ccpo.fixp (fun_lub Heap_lub) (fun_ord Heap_ord) cB px) (F' ax px) Ry
     (gfp aB ax)" by simp
qed

lemma hnr_RECT:
  assumes S: "cf af ax px. 
    ax px. hn_refine (hn_ctxt Rx ax px * F) (cf px) (F' ax px) Ry (af ax) 
     hn_refine (hn_ctxt Rx ax px * F) (cB cf px) (F' ax px) Ry (aB af ax)"
  assumes M: "(x. mono_Heap (λf. cB f x))"
  shows "hn_refine 
    (hn_ctxt Rx ax px * F) (heap.fixp_fun cB px) (F' ax px) Ry (RECT aB ax)"
  unfolding RECT_def
proof (simp, intro conjI impI)
  assume "trimono aB"
  hence "flatf_mono_ge aB" by (simp add: trimonoD)
  have "ax px. 
    hn_refine (hn_ctxt Rx ax px * F) (heap.fixp_fun cB px) (F' ax px) Ry 
      (flatf_gfp aB ax)"
      
    apply (rule flatf_ord.fixp_induct[OF _ ‹flatf_mono_ge aB])  

    apply (rule flatf_admissible_pointwise)
    apply simp

    apply (auto simp: hn_refine_alt) []

    apply clarsimp
    apply (subst heap.mono_body_fixp[of cB, OF M])
    apply (rule S)
    apply blast
    done
  thus "hn_refine (hn_ctxt Rx ax px * F)
     (ccpo.fixp (fun_lub Heap_lub) (fun_ord Heap_ord) cB px) (F' ax px) Ry
     (flatf_gfp aB ax)" by simp
qed

lemma hnr_If:
  assumes P: "Γ t Γ1 * hn_val bool_rel a a'"
  assumes RT: "a  hn_refine (Γ1 * hn_val bool_rel a a') b' Γ2b R b"
  assumes RE: "¬a  hn_refine (Γ1 * hn_val bool_rel a a') c' Γ2c R c"
  assumes IMP: "Γ2b A Γ2c t Γ'"
  shows "hn_refine Γ (if a' then b' else c') Γ' R (if a then b else c)"
  apply (rule hn_refine_cons[OF P])
  apply1 (rule hn_refine_preI)
  applyF (cases a; simp add: hn_ctxt_def pure_def)
    focus
      apply1 (rule hn_refine_split_post)
      applyF (rule hn_refine_cons_pre[OF _ RT])
        applyS (simp add: hn_ctxt_def pure_def)
        applyS simp
      solved
    solved
    apply1 (rule hn_refine_post_other)
    applyF (rule hn_refine_cons_pre[OF _ RE])
      applyS (simp add: hn_ctxt_def pure_def)
      applyS simp
    solved
  solved
  applyS (rule IMP)
  applyS (rule entt_refl)
  done


subsection ‹ML-Level Utilities›
ML signature SEPREF_BASIC = sig
    (* Destroy lambda term, return function to reconstruct. Bound var is replaced by free. *)
    val dest_lambda_rc: Proof.context -> term -> ((term * (term -> term)) * Proof.context)
    (* Apply function under lambda. Bound var is replaced by free. *)
    val apply_under_lambda: (Proof.context -> term -> term) -> Proof.context -> term -> term

    (* 'a nres type *)
    val is_nresT: typ -> bool
    val mk_nresT: typ -> typ
    val dest_nresT: typ -> typ

    (* Make certified == *)
    val mk_cequals: cterm * cterm -> cterm
    (* Make ⟹A *)
    val mk_entails: term * term -> term


    (* Operations on pre-terms *)
    val constrain_type_pre: typ -> term -> term (* t::T *)

    val mk_pair_in_pre: term -> term -> term -> term (* (c,a) ∈ R *)

    val mk_compN_pre: int -> term -> term -> term  (* f o...o g*)

    val mk_curry0_pre: term -> term                (* curry0 f *) 
    val mk_curry_pre: term -> term                 (* curry f *) 
    val mk_curryN_pre: int -> term -> term         (* curry (...(curry f)...) *) 

    val mk_uncurry0_pre: term -> term              (* uncurry0 f *)       
    val mk_uncurry_pre: term -> term               (* uncurry f *)
    val mk_uncurryN_pre: int -> term -> term       (* uncurry (...(uncurry f)...) *)



    (* Conversion for hn_refine - term*)
    val hn_refine_conv: conv -> conv -> conv -> conv -> conv -> conv

    (* Conversion on abstract value (last argument) of hn_refine - term *)
    val hn_refine_conv_a: conv -> conv

    (* Conversion on abstract value of hn_refine term in conclusion of theorem *)
    val hn_refine_concl_conv_a: (Proof.context -> conv) -> Proof.context -> conv

    (* Destruct hn_refine term *)
    val dest_hn_refine: term -> term * term * term * term * term 
    (* Make hn_refine term *)
    val mk_hn_refine: term * term * term * term * term -> term
    (* Check if given term is Trueprop (hn_refine ...). Use with CONCL_COND'. *)
    val is_hn_refine_concl: term -> bool

    (* Destruct abs-fun, returns RETURN-flag, (f, args) *)
    val dest_hnr_absfun: term -> bool * (term * term list)
    (* Make abs-fun. *)
    val mk_hnr_absfun: bool * (term * term list) -> term
    (* Make abs-fun. Guess RETURN-flag from type. *)
    val mk_hnr_absfun': (term * term list) -> term
    
    (* Prove permutation of *. To be used with f_tac_conv. *)
    val star_permute_tac: Proof.context -> tactic

    (* Make separation conjunction *)
    val mk_star: term * term -> term
    (* Make separation conjunction from list. "[]" yields "emp". *)
    val list_star: term list -> term
    (* Decompose separation conjunction. "emp" yields "[]". *)
    val strip_star: term -> term list

    (* Check if true-assertion *)
    val is_true: term -> bool

    (* Check if term is hn_ctxt-assertion *)
    val is_hn_ctxt: term -> bool 
    (* Decompose hn_ctxt-assertion *)
    val dest_hn_ctxt: term -> term * term * term
    (* Decompose hn_ctxt-assertion, NONE if term has wrong format *)
    val dest_hn_ctxt_opt: term -> (term * term * term) option
      

    type phases_ctrl = {
      trace: bool,            (* Trace phases *)
      int_res: bool,          (* Stop with intermediate result *)
      start: string option,   (* Start with this phase. NONE: First phase *)
      stop: string option     (* Stop after this phase. NONE: Last phase *)
    }

    (* No tracing or intermediate result, all phases *)
    val dflt_phases_ctrl: phases_ctrl 
    (* Tracing, intermediate result, all phases *)
    val dbg_phases_ctrl: phases_ctrl
    val flag_phases_ctrl: bool -> phases_ctrl

    (* Name, tactic, expected number of created goals (may be negative for solved goals) *)
    type phase = string * (Proof.context -> tactic') * int

    (* Perform sequence of tactics (tac,n), each expected to create n new goals, 
       or solve goals if n is negative. 
       Debug-flag: Stop with intermediate state after tactic 
       fails or produces less/more goals as expected. *)   
    val PHASES': phase list -> phases_ctrl -> Proof.context -> tactic'

  end

  structure Sepref_Basic: SEPREF_BASIC = struct

    fun is_nresT (Type (@{type_name nres},[_])) = true | is_nresT _ = false
    fun mk_nresT T = Type(@{type_name nres},[T])
    fun dest_nresT (Type (@{type_name nres},[T])) = T | dest_nresT T = raise TYPE("dest_nresT",[T],[])


    fun dest_lambda_rc ctxt (Abs (x,T,t)) = let
        val (u,ctxt) = yield_singleton Variable.variant_fixes x ctxt
        val u = Free (u,T)
        val t = subst_bound (u,t)
        val reconstruct = Term.lambda_name (x,u)
      in
        ((t,reconstruct),ctxt)
      end
    | dest_lambda_rc _ t = raise TERM("dest_lambda_rc",[t])

    fun apply_under_lambda f ctxt t = let
      val ((t,rc),ctxt) = dest_lambda_rc ctxt t
      val t = f ctxt t
    in
      rc t
    end


    (* Functions on pre-terms *)
    fun mk_pair_in_pre x y r = Const (@{const_name Set.member}, dummyT) $
      (Const (@{const_name Product_Type.Pair}, dummyT) $ x $ y) $ r


    fun mk_uncurry_pre t = Const(@{const_name uncurry}, dummyT)$t
    fun mk_uncurry0_pre t = Const(@{const_name uncurry0}, dummyT)$t
    fun mk_uncurryN_pre 0 = mk_uncurry0_pre
      | mk_uncurryN_pre 1 = I
      | mk_uncurryN_pre n = mk_uncurry_pre o mk_uncurryN_pre (n-1)

    fun mk_curry_pre t = Const(@{const_name curry}, dummyT)$t
    fun mk_curry0_pre t = Const(@{const_name curry0}, dummyT)$t
    fun mk_curryN_pre 0 = mk_curry0_pre
      | mk_curryN_pre 1 = I
      | mk_curryN_pre n = mk_curry_pre o mk_curryN_pre (n-1)


    fun mk_compN_pre 0 f g = f $ g
      | mk_compN_pre n f g = let
          val g = fold (fn i => fn t => t$Bound i) (n-2 downto 0) g
          val t = Const(@{const_name "Fun.comp"},dummyT) $ f $ g

          val t = fold (fn i => fn t => Abs ("x"^string_of_int i,dummyT,t)) (n-1 downto 1) t
        in
          t
        end

    fun constrain_type_pre T t = Const(@{syntax_const "_type_constraint_"},T-->T) $ t




    local open Conv in
      fun hn_refine_conv c1 c2 c3 c4 c5 ct = case Thm.term_of ct of
        @{mpat "hn_refine _ _ _ _ _"} => let
          val cc = combination_conv
        in
          cc (cc (cc (cc (cc all_conv c1) c2) c3) c4) c5 ct
        end
      | _ => raise CTERM ("hn_refine_conv",[ct])
  
      val hn_refine_conv_a = hn_refine_conv all_conv all_conv all_conv all_conv
  
      fun hn_refine_concl_conv_a conv ctxt = Refine_Util.HOL_concl_conv 
        (fn ctxt => hn_refine_conv_a (conv ctxt)) ctxt
  
    end

    (* FIXME: Strange dependency! *)
    val mk_cequals = uncurry SMT_Util.mk_cequals
  
    val mk_entails = HOLogic.mk_binrel @{const_name "entails"}
  
    val mk_star = HOLogic.mk_binop @{const_name "Groups.times_class.times"}

    fun list_star [] = @{term "emp::assn"}
      | list_star [a] = a
      | list_star (a::l) = mk_star (list_star l,a)

    fun strip_star @{mpat "?a*?b"} = strip_star a @ strip_star b
      | strip_star @{mpat "emp"} = []
      | strip_star t = [t]

    fun is_true @{mpat "true"} = true | is_true _ = false
  
    fun is_hn_ctxt @{mpat "hn_ctxt _ _ _"} = true | is_hn_ctxt _ = false
    fun dest_hn_ctxt @{mpat "hn_ctxt ?R ?a ?p"} = (R,a,p) 
      | dest_hn_ctxt t = raise TERM("dest_hn_ctxt",[t])
  
    fun dest_hn_ctxt_opt @{mpat "hn_ctxt ?R ?a ?p"} = SOME (R,a,p) 
      | dest_hn_ctxt_opt _ = NONE
  
    fun strip_abs_args (t as @{mpat "PR_CONST _"}) = (t,[])
      | strip_abs_args @{mpat "?f$?a"} = (case strip_abs_args f of (f,args) => (f,args@[a]))
      | strip_abs_args t = (t,[])
  
    fun dest_hnr_absfun @{mpat "RETURN$?a"} = (true, strip_abs_args a)
      | dest_hnr_absfun f = (false, strip_abs_args f)
  
    fun mk_hnr_absfun (true,fa) = Autoref_Tagging.list_APP fa |> (fn a => @{mk_term "RETURN$?a"})
      | mk_hnr_absfun (false,fa) = Autoref_Tagging.list_APP fa
  
    fun mk_hnr_absfun' fa = let
      val t = Autoref_Tagging.list_APP fa
      val T = fastype_of t
    in
      case T of
        Type (@{type_name nres},_) => t
      | _ => @{mk_term "RETURN$?t"}
  
    end  
  
    fun dest_hn_refine @{mpat "hn_refine ?P ?c ?Q ?R ?a"} = (P,c,Q,R,a)
      | dest_hn_refine t = raise TERM("dest_hn_refine",[t])
  
    fun mk_hn_refine (P,c,Q,R,a) = @{mk_term "hn_refine ?P ?c ?Q ?R ?a"}
  
    val is_hn_refine_concl = can (HOLogic.dest_Trueprop #> dest_hn_refine)
  
    fun star_permute_tac ctxt = ALLGOALS (simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms star_aci}))
      

    type phases_ctrl = {
      trace: bool,            
      int_res: bool,          
      start: string option,   
      stop: string option     
    }

    val dflt_phases_ctrl = {trace=false,int_res=false,start=NONE,stop=NONE} 
    val dbg_phases_ctrl = {trace=true,int_res=true,start=NONE,stop=NONE}
    fun flag_phases_ctrl dbg = if dbg then dbg_phases_ctrl else dflt_phases_ctrl

    type phase = string * (Proof.context -> tactic') * int

    local
      fun ph_range phases start stop = let
        fun find_phase name = let
          val i = find_index (fn (n,_,_) => n=name) phases
          val _ = if i<0 then error ("No such phase: " ^ name) else ()
        in
          i
        end

        val i = case start of NONE => 0 | SOME n => find_phase n
        val j = case stop of NONE => length phases - 1 | SOME n => find_phase n

        val phases = take (j+1) phases |> drop i

        val _ = case phases of [] => error "No phases selected, range is empty" | _ => ()
      in
        phases
      end
    in  
  
      fun PHASES' phases ctrl ctxt = let
        val phases = ph_range phases (#start ctrl) (#stop ctrl)
        val phases = map (fn (n,tac,d) => (n,tac ctxt,d)) phases
  
        fun r [] _ st = Seq.single st
          | r ((name,tac,d)::tacs) i st = let
              val n = Thm.nprems_of st
              val bailout_tac = if #int_res ctrl then all_tac else no_tac
              fun trace_tac msg st = (if #trace ctrl then tracing msg else (); Seq.single st)
              val trace_start_tac = trace_tac ("Phase " ^ name)
            in
              K trace_start_tac THEN' IF_EXGOAL (tac)
              THEN_ELSE' (
                fn i => fn st => 
                  (* Bail out if a phase does not solve/create exactly the expected subgoals *)
                  if Thm.nprems_of st = n+d then
                    ((trace_tac "  Done" THEN r tacs i) st)
                  else
                    (trace_tac "*** Wrong number of produced goals" THEN bailout_tac) st
              , 
                K (trace_tac "*** Phase tactic failed" THEN bailout_tac))
            end i st
  
      in
        r phases
      end


    end

(*    (* Perform sequence of tactics (tac,n), each expected to create n new goals, 
       or solve goals if n is negative. 
       Debug-flag: Stop with intermediate state after tactic 
       fails or produces less/more goals as expected. *)   
    val PHASES': phase list -> phases_ctrl -> Proof.context -> tactic'
*)



(*

    fun xPHASES' dbg tacs ctxt = let
      val tacs = map (fn (tac,d) => (tac ctxt,d)) tacs

      fun r [] _ st = Seq.single st
        | r ((tac,d)::tacs) i st = let
            val n = Thm.nprems_of st
            val bailout_tac = if dbg then all_tac else no_tac
          in
            IF_EXGOAL (tac)
            THEN_ELSE' (
              fn i => fn st => 
                (* Bail out if a phase does not solve/create exactly the expected subgoals *)
                if Thm.nprems_of st = n+d then
                  (r tacs i st)
                else
                  bailout_tac st
            , 
              K bailout_tac)
          end i st

    in
      r tacs
    end
*)
  end


  signature SEPREF_DEBUGGING = sig
    (*************************)
    (* Debugging *)
    (* Centralized debugging mode flag *)
    val cfg_debug_all: bool Config.T

    val is_debug: bool Config.T -> Proof.context -> bool
    val is_debug': Proof.context -> bool

    (* Conversion, trace errors if custom or central debugging flag is activated *)
    val DBG_CONVERSION: bool Config.T -> Proof.context -> conv -> tactic'

    (* Conversion, trace errors if central debugging flag is activated *)
    val DBG_CONVERSION': Proof.context -> conv -> tactic'

    (* Tracing message and current subgoal *)
    val tracing_tac': string -> Proof.context -> tactic'
    (* Warning message and current subgoal *)
    val warning_tac': string -> Proof.context -> tactic'
    (* Error message and current subgoal *)
    val error_tac': string -> Proof.context -> tactic'

    (* Trace debug message *)
    val dbg_trace_msg: bool Config.T -> Proof.context -> string -> unit
    val dbg_trace_msg': Proof.context -> string -> unit

    val dbg_msg_tac: bool Config.T -> (Proof.context -> int -> thm -> string) -> Proof.context -> tactic'
    val dbg_msg_tac': (Proof.context -> int -> thm -> string) -> Proof.context -> tactic'

    val msg_text: string -> Proof.context -> int -> thm -> string
    val msg_subgoal: string -> Proof.context -> int -> thm -> string
    val msg_from_subgoal: string -> (term -> Proof.context -> string) -> Proof.context -> int -> thm -> string
    val msg_allgoals: string -> Proof.context -> int -> thm -> string

  end

  structure Sepref_Debugging: SEPREF_DEBUGGING = struct

    val cfg_debug_all = 
      Attrib.setup_config_bool @{binding sepref_debug_all} (K false)

    fun is_debug cfg ctxt = Config.get ctxt cfg orelse Config.get ctxt cfg_debug_all
    fun is_debug' ctxt = Config.get ctxt cfg_debug_all

    fun dbg_trace cfg ctxt obj = 
      if is_debug cfg ctxt then  
        tracing (@{make_string} obj)
      else ()

    fun dbg_trace' ctxt obj = 
      if is_debug' ctxt then  
        tracing (@{make_string} obj)
      else ()

    fun dbg_trace_msg cfg ctxt msg =   
      if is_debug cfg ctxt then  
        tracing msg
      else ()
    fun dbg_trace_msg' ctxt msg = 
      if is_debug' ctxt then  
        tracing msg
      else ()

    fun DBG_CONVERSION cfg ctxt cv i st = 
      Seq.single (Conv.gconv_rule cv i st)
      handle e as THM _ => (dbg_trace cfg ctxt e; Seq.empty)
           | e as CTERM _ => (dbg_trace cfg ctxt e; Seq.empty)
           | e as TERM _ => (dbg_trace cfg ctxt e; Seq.empty)
           | e as TYPE _ => (dbg_trace cfg ctxt e; Seq.empty);

    fun DBG_CONVERSION' ctxt cv i st = 
      Seq.single (Conv.gconv_rule cv i st)
      handle e as THM _ => (dbg_trace' ctxt e; Seq.empty)
           | e as CTERM _ => (dbg_trace' ctxt e; Seq.empty)
           | e as TERM _ => (dbg_trace' ctxt e; Seq.empty)
           | e as TYPE _ => (dbg_trace' ctxt e; Seq.empty);


    local 
      fun gen_subgoal_msg_tac do_msg msg ctxt = IF_EXGOAL (fn i => fn st => let
        val t = nth (Thm.prems_of st) (i-1)
        val _ = Pretty.block [Pretty.str msg, Pretty.fbrk, Syntax.pretty_term ctxt t]
          |> Pretty.string_of |> do_msg

      in
        Seq.single st
      end)
    in       
      val tracing_tac' = gen_subgoal_msg_tac tracing
      val warning_tac' = gen_subgoal_msg_tac warning
      val error_tac' = gen_subgoal_msg_tac error
    end


    fun dbg_msg_tac cfg msg ctxt =
      if is_debug cfg ctxt then (fn i => fn st => (tracing (msg ctxt i st); Seq.single st))
      else K all_tac
    fun dbg_msg_tac' msg ctxt =
      if is_debug' ctxt then (fn i => fn st => (tracing (msg ctxt i st); Seq.single st))
      else K all_tac

    fun msg_text msg _ _ _ = msg

    fun msg_from_subgoal msg sgmsg ctxt i st = 
      case try (nth (Thm.prems_of st)) (i-1) of
        NONE => msg ^ "\n" ^ "Subgoal out of range"
      | SOME t => msg ^ "\n" ^ sgmsg t ctxt

    fun msg_subgoal msg = msg_from_subgoal msg (fn t => fn ctxt =>
      Syntax.pretty_term ctxt t |> Pretty.string_of
    )

    fun msg_allgoals msg ctxt _ st = 
      msg ^ "\n" ^ Pretty.string_of (Pretty.chunks (Goal_Display.pretty_goals ctxt st))

  end


ML (* Tactics for produced subgoals *)
  infix 1 THEN_NEXT THEN_ALL_NEW_LIST THEN_ALL_NEW_LIST'
  signature STACTICAL = sig
    (* Apply first tactic on this subgoal, and then second tactic on next subgoal *)
    val THEN_NEXT: tactic' * tactic' -> tactic'
    (* Apply tactics to the current and following subgoals *)
    val APPLY_LIST: tactic' list -> tactic'
    (* Apply list of tactics on subgoals emerging from tactic. 
      Requires exactly one tactic per emerging subgoal.*)
    val THEN_ALL_NEW_LIST: tactic' * tactic' list -> tactic'
    (* Apply list of tactics to subgoals emerging from tactic, use fallback for additional subgoals. *)
    val THEN_ALL_NEW_LIST': tactic' * (tactic' list * tactic') -> tactic'

  end

  structure STactical : STACTICAL = struct
    infix 1 THEN_WITH_GOALDIFF
    fun (tac1 THEN_WITH_GOALDIFF tac2) st = let
      val n1 = Thm.nprems_of st
    in
      st |> (tac1 THEN (fn st => tac2 (Thm.nprems_of st - n1) st ))
    end

    fun (tac1 THEN_NEXT tac2) i = 
      tac1 i THEN_WITH_GOALDIFF (fn d => (
        if d < ~1 then 
          (error "THEN_NEXT: Tactic solved more than one goal"; no_tac) 
        else 
          tac2 (i+1+d)
      ))

    fun APPLY_LIST [] = K all_tac
      | APPLY_LIST (tac::tacs) = tac THEN_NEXT APPLY_LIST tacs
            
    fun (tac1 THEN_ALL_NEW_LIST tacs) i = 
      tac1 i 
      THEN_WITH_GOALDIFF (fn d =>
        if d+1 <> length tacs then (
          error "THEN_ALL_NEW_LIST: Tactic produced wrong number of goals"; no_tac
        ) else APPLY_LIST tacs i
      )

    fun (tac1 THEN_ALL_NEW_LIST' (tacs,rtac)) i =  
      tac1 i 
      THEN_WITH_GOALDIFF (fn d => let
        val _ = if d+1 < length tacs then error "THEN_ALL_NEW_LIST': Tactic produced too few goals" else ();
        val tacs' = tacs @ replicate (d + 1 - length tacs) rtac
      in    
        APPLY_LIST tacs' i
      end)


  end


  open STactical

end

Theory Sepref_Monadify

section ‹Monadify›
theory Sepref_Monadify
imports Sepref_Basic Sepref_Id_Op
begin


text ‹
  In this phase, a monadic program is converted to complete monadic form,
  that is, computation of compound expressions are made visible as top-level 
  operations in the monad.

  The monadify process is separated into 2 steps.
  \begin{enumerate}
    \item In a first step, eta-expansion is used to add missing operands 
      to operations and combinators. This way, operators and combinators
      always occur with the same arity, which simplifies further processing.

    \item In a second step, computation of compound operands is flattened,
      introducing new bindings for the intermediate values. 
  \end{enumerate}
›


definition SP ― ‹Tag to protect content from further application of arity 
  and combinator equations›
  where [simp]: "SP x  x"
lemma SP_cong[cong]: "SP x  SP x" by simp
lemma PR_CONST_cong[cong]: "PR_CONST x  PR_CONST x" by simp

definition RCALL ― ‹Tag that marks recursive call›
  where [simp]: "RCALL D  D"
definition EVAL ― ‹Tag that marks evaluation of plain expression for monadify phase›
  where [simp]: "EVAL x  RETURN x"

text ‹
  Internally, the package first applies rewriting rules from 
  sepref_monadify_arity›, which use eta-expansion to ensure that
  every combinator has enough actual parameters. Moreover, this phase will
  mark recursive calls by the tag @{const RCALL}.

  Next, rewriting rules from sepref_monadify_comb› are used to
  add @{const EVAL}-tags to plain expressions that should be evaluated
  in the monad. The @{const EVAL} tags are flattened using a default simproc 
  that generates left-to-right argument order.
›

lemma monadify_simps: 
  "Refine_Basic.bind$(RETURN$x)$(λ2x. f x) = f x" 
  "EVAL$x  RETURN$x"
  by simp_all

definition [simp]: "PASS  RETURN"
  ― ‹Pass on value, invalidating old one›

lemma remove_pass_simps:
  "Refine_Basic.bind$(PASS$x)$(λ2x. f x)  f x" 
  "Refine_Basic.bind$m$(λ2x. PASS$x)  m"
  by simp_all


definition COPY :: "'a  'a" 
  ― ‹Marks required copying of parameter›
  where [simp]: "COPY x  x"
lemma RET_COPY_PASS_eq: "RETURN$(COPY$p) = PASS$p" by simp


named_theorems_rev sepref_monadify_arity "Sepref.Monadify: Arity alignment equations"
named_theorems_rev sepref_monadify_comb "Sepref.Monadify: Combinator equations"

ML structure Sepref_Monadify = struct
    local
      fun cr_var (i,T) = ("v"^string_of_int i, Free ("__v"^string_of_int i,T))

      fun lambda2_name n t = let
        val t = @{mk_term "PROTECT2 ?t DUMMY"}
      in
        Term.lambda_name n t
      end


      fun 
        bind_args exp0 [] = exp0
      | bind_args exp0 ((x,m)::xms) = let
          val lr = bind_args exp0 xms 
            |> incr_boundvars 1 
            |> lambda2_name x
        in @{mk_term "Refine_Basic.bind$?m$?lr"} end

      fun monadify t = let
        val (f,args) = Autoref_Tagging.strip_app t
        val _ = not (is_Abs f) orelse 
          raise TERM ("monadify: higher-order",[t])

        val argTs = map fastype_of args
        (*val args = map monadify args*)
        val args = map (fn a => @{mk_term "EVAL$?a"}) args

        (*val fT = fastype_of f
        val argTs = binder_types fT*)
  
        val argVs = tag_list 0 argTs
          |> map cr_var

        val res0 = let
          val x = Autoref_Tagging.list_APP (f,map #2 argVs)
        in 
          @{mk_term "SP (RETURN$?x)"}
        end

        val res = bind_args res0 (argVs ~~ args)
      in
        res
      end

      fun monadify_conv_aux ctxt ct = case Thm.term_of ct of
        @{mpat "EVAL$_"} => let
          val ss = put_simpset HOL_basic_ss ctxt
          val ss = (ss addsimps @{thms monadify_simps SP_def})
          val tac = (simp_tac ss 1)
        in (*Refine_Util.monitor_conv "monadify"*) (
          Refine_Util.f_tac_conv ctxt (dest_comb #> #2 #> monadify) tac) ct
        end
      | t => raise TERM ("monadify_conv",[t])

      (*fun extract_comb_conv ctxt = Conv.rewrs_conv 
        (Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_evalcomb})
      *)  
    in
      (*
      val monadify_conv = Conv.top_conv 
        (fn ctxt => 
          Conv.try_conv (
            extract_comb_conv ctxt else_conv monadify_conv_aux ctxt
          )
        )
      *)  

      val monadify_simproc = 
        Simplifier.make_simproc @{context} "monadify_simproc"
         {lhss =
          [Logic.varify_global @{term "EVAL$a"}],
          proc = K (try o monadify_conv_aux)};

    end

    local
      open Sepref_Basic
      fun mark_params t = let
        val (P,c,Q,R,a) = dest_hn_refine t
        val pps = strip_star P |> map_filter (dest_hn_ctxt_opt #> map_option #2)

        fun tr env (t as @{mpat "RETURN$?x"}) = 
              if is_Bound x orelse member (aconv) pps x then
                @{mk_term env: "PASS$?x"}
              else t
          | tr env (t1$t2) = tr env t1 $ tr env t2
          | tr env (Abs (x,T,t)) = Abs (x,T,tr (T::env) t)
          | tr _ t = t

        val a = tr [] a
      in
        mk_hn_refine (P,c,Q,R,a)
      end

    in  
    fun mark_params_conv ctxt = Refine_Util.f_tac_conv ctxt 
      (mark_params) 
      (simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms PASS_def}) 1)

    end  

    local

      open Sepref_Basic

      fun dp ctxt (@{mpat "Refine_Basic.bind$(PASS$?p)$(?t' ASp (λ_. PROTECT2 _ DUMMY))"}) = 
          let
            val (t',ps) = let
                val ((t',rc),ctxt) = dest_lambda_rc ctxt t'
                val f = case t' of @{mpat "PROTECT2 ?f _"} => f | _ => raise Match 
                val (f,ps) = dp ctxt f
                val t' = @{mk_term "PROTECT2 ?f DUMMY"}
                val t' = rc t'
              in
                (t',ps)
              end
  
            val dup = member (aconv) ps p
            val t = if dup then
              @{mk_term "Refine_Basic.bind$(RETURN$(COPY$?p))$?t'"}
            else
              @{mk_term "Refine_Basic.bind$(PASS$?p)$?t'"}
          in
            (t,p::ps)
          end
        | dp ctxt (t1$t2) = (#1 (dp ctxt t1) $ #1 (dp ctxt t2),[])
        | dp ctxt (t as (Abs _)) = (apply_under_lambda (#1 oo dp) ctxt t,[])
        | dp _ t = (t,[])

      fun dp_conv ctxt = Refine_Util.f_tac_conv ctxt 
        (#1 o dp ctxt) 
        (ALLGOALS (simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms RET_COPY_PASS_eq}))) 


    in
      fun dup_tac ctxt = CONVERSION (Sepref_Basic.hn_refine_concl_conv_a dp_conv ctxt)
    end


    fun arity_tac ctxt = let
      val arity1_ss = put_simpset HOL_basic_ss ctxt 
        addsimps ((Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_arity}))
        |> Simplifier.add_cong @{thm SP_cong}
        |> Simplifier.add_cong @{thm PR_CONST_cong}

      val arity2_ss = put_simpset HOL_basic_ss ctxt 
        addsimps @{thms beta SP_def}
    in
      simp_tac arity1_ss THEN' simp_tac arity2_ss
    end

    fun comb_tac ctxt = let
      val comb1_ss = put_simpset HOL_basic_ss ctxt 
        addsimps (Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_comb})
        (*addsimps (Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_evalcomb})*)
        addsimprocs [monadify_simproc]
        |> Simplifier.add_cong @{thm SP_cong}
        |> Simplifier.add_cong @{thm PR_CONST_cong}

      val comb2_ss = put_simpset HOL_basic_ss ctxt 
        addsimps @{thms SP_def}
    in
      simp_tac comb1_ss THEN' simp_tac comb2_ss
    end

    (*fun ops_tac ctxt = CONVERSION (
      Sepref_Basic.hn_refine_concl_conv_a monadify_conv ctxt)*)

    fun mark_params_tac ctxt = CONVERSION (
      Refine_Util.HOL_concl_conv (K (mark_params_conv ctxt)) ctxt)

    fun contains_eval @{mpat "Trueprop (hn_refine _ _ _ _ ?a)"} =   
      Term.exists_subterm (fn @{mpat EVAL} => true | _ => false) a
    | contains_eval t = raise TERM("contains_eval",[t]);  

    fun remove_pass_tac ctxt = 
      simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms remove_pass_simps})

    fun monadify_tac dbg ctxt = let
      open Sepref_Basic
    in
      PHASES' [
        ("arity", arity_tac, 0),
        ("comb", comb_tac, 0),
        (*("ops", ops_tac, 0),*)
        ("check_EVAL", K (CONCL_COND' (not o contains_eval)), 0),
        ("mark_params", mark_params_tac, 0),
        ("dup", dup_tac, 0),
        ("remove_pass", remove_pass_tac, 0)
      ] (flag_phases_ctrl dbg) ctxt
    end

  end

lemma dflt_arity[sepref_monadify_arity]:
  "RETURN  λ2x. SP RETURN$x" 
  "RECT  λ2B x. SP RECT$(λ2D x. B$(λ2x. RCALL$D$x)$x)$x" 
  "case_list  λ2fn fc l. SP case_list$fn$(λ2x xs. fc$x$xs)$l" 
  "case_prod  λ2fp p. SP case_prod$(λ2a b. fp$a$b)$p" 
  "case_option  λ2fn fs ov. SP case_option$fn$(λ2x. fs$x)$ov" 
  "If  λ2b t e. SP If$b$t$e" 
  "Let  λ2x f. SP Let$x$(λ2x. f$x)"
  by (simp_all only: SP_def APP_def PROTECT2_def RCALL_def)


lemma dflt_comb[sepref_monadify_comb]:
  "B x. RECT$B$x  Refine_Basic.bind$(EVAL$x)$(λ2x. SP (RECT$B$x))"
  "D x. RCALL$D$x  Refine_Basic.bind$(EVAL$x)$(λ2x. SP (RCALL$D$x))"
  "fn fc l. case_list$fn$fc$l  Refine_Basic.bind$(EVAL$l)$(λ2l. (SP case_list$fn$fc$l))"
  "fp p. case_prod$fp$p  Refine_Basic.bind$(EVAL$p)$(λ2p. (SP case_prod$fp$p))"
  "fn fs ov. case_option$fn$fs$ov 
     Refine_Basic.bind$(EVAL$ov)$(λ2ov. (SP case_option$fn$fs$ov))"
  "b t e. If$b$t$e  Refine_Basic.bind$(EVAL$b)$(λ2b. (SP If$b$t$e))"
  "x. RETURN$x  Refine_Basic.bind$(EVAL$x)$(λ2x. SP (RETURN$x))"
  "x f. Let$x$f  Refine_Basic.bind$(EVAL$x)$(λ2x. (SP Let$x$f))"
  by (simp_all)


lemma dflt_plain_comb[sepref_monadify_comb]:
  "EVAL$(If$b$t$e)  Refine_Basic.bind$(EVAL$b)$(λ2b. If$b$(EVAL$t)$(EVAL$e))"
  "EVAL$(case_list$fn$(λ2x xs. fc x xs)$l)  
    Refine_Basic.bind$(EVAL$l)$(λ2l. case_list$(EVAL$fn)$(λ2x xs. EVAL$(fc x xs))$l)"
  "EVAL$(case_prod$(λ2a b. fp a b)$p)  
    Refine_Basic.bind$(EVAL$p)$(λ2p. case_prod$(λ2a b. EVAL$(fp a b))$p)"
  "EVAL$(case_option$fn$(λ2x. fs x)$ov)  
    Refine_Basic.bind$(EVAL$ov)$(λ2ov. case_option$(EVAL$fn)$(λ2x. EVAL$(fs x))$ov)"
  "EVAL $ (Let $ v $ (λ2x. f x))  (⤜) $ (EVAL $ v) $ (λ2x. EVAL $ (f x))"
  apply (rule eq_reflection, simp split: list.split prod.split option.split)+
  done

lemma evalcomb_PR_CONST[sepref_monadify_comb]:
  "EVAL$(PR_CONST x)  SP (RETURN$(PR_CONST x))"
  by simp


end

Theory Sepref_Constraints

theory Sepref_Constraints
imports Main Automatic_Refinement.Refine_Lib Sepref_Basic
begin

definition "CONSTRAINT_SLOT (x::prop)  x"

(* TODO: Find something better than True to put in empty slot! Perhaps "A⟹A" *)
lemma insert_slot_rl1:
  assumes "PROP P  PROP (CONSTRAINT_SLOT (Trueprop True))  PROP Q"
  shows "PROP (CONSTRAINT_SLOT (PROP P))  PROP Q"
  using assms unfolding CONSTRAINT_SLOT_def by simp

lemma insert_slot_rl2:
  assumes "PROP P  PROP (CONSTRAINT_SLOT S)  PROP Q"
  shows "PROP (CONSTRAINT_SLOT (PROP S &&& PROP P))  PROP Q"
  using assms unfolding CONSTRAINT_SLOT_def conjunction_def .

lemma remove_slot: "PROP (CONSTRAINT_SLOT (Trueprop True))"
  unfolding CONSTRAINT_SLOT_def by (rule TrueI)

definition CONSTRAINT where [simp]: "CONSTRAINT P x  P x"

lemma CONSTRAINT_D:
  assumes "CONSTRAINT (P::'a => bool) x"
  shows "P x"
  using assms unfolding CONSTRAINT_def by simp

lemma CONSTRAINT_I:
  assumes "P x"
  shows "CONSTRAINT (P::'a => bool) x"
  using assms unfolding CONSTRAINT_def by simp

text ‹Special predicate to indicate unsolvable constraint.
  The constraint solver refuses to put those into slot.
  Thus, adding safe rules introducing this can be used to indicate 
  unsolvable constraints early.
›
definition CN_FALSE :: "('abool)  'a  bool" where [simp]: "CN_FALSE P x  False"  
lemma CN_FALSEI: "CN_FALSE P x  P x" by simp


named_theorems constraint_simps ‹Simplification of constraints›

named_theorems constraint_abbrevs ‹Constraint Solver: Abbreviations›
lemmas split_constraint_rls 
    = atomize_conj[symmetric] imp_conjunction all_conjunction conjunction_imp

ML signature SEPREF_CONSTRAINTS = sig
    (******** Constraint Slot *)
    (* Tactic with slot subgoal *)
    val WITH_SLOT: tactic' -> tactic
    (* Process all goals in slot *)
    val ON_SLOT: tactic -> tactic
    (* Create slot as last subgoal. Fail if slot already present. *)
    val create_slot_tac: tactic
    (* Create slot if there isn't one already *)
    val ensure_slot_tac: tactic
    (* Remove empty slot *)
    val remove_slot_tac: tactic
    (* Move slot to first subgoal *)
    val prefer_slot_tac: tactic
    (* Destruct slot *)
    val dest_slot_tac: tactic'
    (* Check if goal state has slot *)
    val has_slot: thm -> bool
    (* Defer subgoal to slot *)
    val to_slot_tac: tactic'
    (* Print slot constraints *)
    val print_slot_tac: Proof.context -> tactic

    (* Focus on goals in slot *)
    val focus: tactic
    (* Unfocus goals in slot *)
    val unfocus: tactic
    (* Unfocus goals, and insert them as first subgoals *)
    val unfocus_ins:tactic

    (* Focus on some goals in slot *)
    val cond_focus: (term -> bool) -> tactic
    (* Move some goals to slot *)
    val some_to_slot_tac: (term -> bool) -> tactic


    (******** Constraints *)
    (* Check if subgoal is a constraint. To be used with COND' *)
    val is_constraint_goal: term -> bool
    (* Identity on constraint subgoal, no_tac otherwise *)
    val is_constraint_tac: tactic'
    (* Defer constraint to slot *)
    val slot_constraint_tac: int -> tactic

    (******** Constraint solving *)

    val add_constraint_rule: thm -> Context.generic -> Context.generic
    val del_constraint_rule: thm -> Context.generic -> Context.generic
    val get_constraint_rules: Proof.context -> thm list

    val add_safe_constraint_rule: thm -> Context.generic -> Context.generic
    val del_safe_constraint_rule: thm -> Context.generic -> Context.generic
    val get_safe_constraint_rules: Proof.context -> thm list

    (* Solve constraint subgoal *)
    val solve_constraint_tac: Proof.context -> tactic'
    (* Solve constraint subgoal if solvable, fail if definitely unsolvable, 
      apply simplification and unique rules otherwise. *)
    val safe_constraint_tac: Proof.context -> tactic'

    (* CONSTRAINT tag on goal is optional *)
    val solve_constraint'_tac: Proof.context -> tactic'
    (* CONSTRAINT tag on goal is optional *)
    val safe_constraint'_tac: Proof.context -> tactic'
    
    (* Solve, or apply safe-rules and defer to constraint slot *)
    val constraint_tac: Proof.context -> tactic'

    (* Apply safe rules to all constraint goals in slot *)
    val process_constraint_slot: Proof.context -> tactic

    (* Solve all constraint goals in slot, insert unsolved ones as first subgoals *)
    val solve_constraint_slot: Proof.context -> tactic


    val setup: theory -> theory

  end


  structure Sepref_Constraints: SEPREF_CONSTRAINTS  = struct
    fun is_slot_goal @{mpat "CONSTRAINT_SLOT _"} = true | is_slot_goal _ = false

    fun slot_goal_num st = let
      val i = find_index is_slot_goal (Thm.prems_of st) + 1
    in
      i
    end

    fun has_slot st = slot_goal_num st > 0

    fun WITH_SLOT tac st = let
      val si = slot_goal_num st
    in
      if si>0 then tac si st else (warning "Constraints: No slot"; Seq.empty)
    end

    val to_slot_tac = IF_EXGOAL (fn i => WITH_SLOT (fn si => 
      if i<si then
        prefer_tac si THEN prefer_tac (i+1)
        THEN (
          PRIMITIVE (fn st => Drule.comp_no_flatten (st, 0) 1 @{thm insert_slot_rl1}) 
          ORELSE PRIMITIVE (fn st => Drule.comp_no_flatten (st, 0) 1 @{thm insert_slot_rl2})
        )
        THEN defer_tac 1
      else no_tac))

    val create_slot_tac = 
      COND (has_slot) no_tac
        (PRIMITIVE (Thm.implies_intr @{cterm "CONSTRAINT_SLOT (Trueprop True)"}) 
        THEN defer_tac 1)
        
    val ensure_slot_tac = TRY create_slot_tac
          
      
    val prefer_slot_tac = WITH_SLOT prefer_tac

    val dest_slot_tac = SELECT_GOAL (
      ALLGOALS (
        CONVERSION (Conv.rewr_conv @{thm CONSTRAINT_SLOT_def}) 
        THEN' Goal.conjunction_tac
        THEN' TRY o resolve0_tac @{thms TrueI})
      THEN distinct_subgoals_tac
    )

    val remove_slot_tac = WITH_SLOT (resolve0_tac @{thms remove_slot})

    val focus = WITH_SLOT (fn i => 
      PRIMITIVE (Goal.restrict i 1) 
      THEN ALLGOALS dest_slot_tac
      THEN create_slot_tac)

    val unfocus_ins = 
      PRIMITIVE (Goal.unrestrict 1)
      THEN WITH_SLOT defer_tac

    fun some_to_slot_tac cond = (ALLGOALS (COND' (fn t => is_slot_goal t orelse not (cond t)) ORELSE' to_slot_tac))

    val unfocus = 
      some_to_slot_tac (K true)
      THEN unfocus_ins

    fun cond_focus cond =
      focus 
      THEN some_to_slot_tac (not o cond)


    fun ON_SLOT tac = focus THEN tac THEN unfocus

    fun print_slot_tac ctxt = ON_SLOT (print_tac ctxt "SLOT:")

    local
      (*fun prepare_constraint_conv ctxt = let
        open Conv 
        fun CONSTRAINT_conv ct = case Thm.term_of ct of
          @{mpat "Trueprop (_ _)"} => 
            HOLogic.Trueprop_conv 
              (rewr_conv @{thm CONSTRAINT_def[symmetric]}) ct
          | _ => raise CTERM ("CONSTRAINT_conv", [ct])

        fun rec_conv ctxt ct = (
          CONSTRAINT_conv
          else_conv 
          implies_conv (rec_conv ctxt) (rec_conv ctxt)
          else_conv
          forall_conv (rec_conv o #2) ctxt
        ) ct
      in
        rec_conv ctxt
      end*)

      fun unfold_abbrevs ctxt = 
        Local_Defs.unfold0 ctxt (
          @{thms split_constraint_rls CONSTRAINT_def} 
          @ Named_Theorems.get ctxt @{named_theorems constraint_abbrevs}
          @ Named_Theorems.get ctxt @{named_theorems constraint_simps})
        #> Conjunction.elim_conjunctions
  
      fun check_constraint_rl thm = let
        fun ck (t as @{mpat "Trueprop (?C _)"}) = 
              if is_Var (Term.head_of C) then
                raise TERM ("Schematic head in constraint rule",[t,Thm.prop_of thm])
              else ()
          | ck @{mpat "_. PROP ?t"} = ck t
          | ck @{mpat "PROP ?s  PROP ?t"} = (ck s; ck t)
          | ck t = raise TERM ("Invalid part of constraint rule",[t,Thm.prop_of thm])
  
      in
        ck (Thm.prop_of thm); thm
      end

      fun check_unsafe_constraint_rl thm = let
        val _ = Thm.nprems_of thm = 0 
          andalso raise TERM("Unconditional constraint rule must be safe (register this as safe rule)",[Thm.prop_of thm])
      in
        thm
      end

    in
      structure constraint_rules = Named_Sorted_Thms (
        val name = @{binding constraint_rules}
        val description = "Constraint rules"
        val sort = K I
        fun transform context = let
          open Conv
          val ctxt = Context.proof_of context
        in
          unfold_abbrevs ctxt #> map (check_constraint_rl o check_unsafe_constraint_rl)
        end
      )

      structure safe_constraint_rules = Named_Sorted_Thms (
        val name = @{binding safe_constraint_rules}
        val description = "Safe Constraint rules"
        val sort = K I
        fun transform context = let
          open Conv
          val ctxt = Context.proof_of context
        in
          unfold_abbrevs ctxt #> map check_constraint_rl
        end
      )

    end  

    val add_constraint_rule = constraint_rules.add_thm
    val del_constraint_rule = constraint_rules.del_thm
    val get_constraint_rules = constraint_rules.get

    val add_safe_constraint_rule = safe_constraint_rules.add_thm
    val del_safe_constraint_rule = safe_constraint_rules.del_thm
    val get_safe_constraint_rules = safe_constraint_rules.get

    fun is_constraint_goal t = case Logic.strip_assums_concl t of
      @{mpat "Trueprop (CONSTRAINT _ _)"} => true
    | _ => false

    val is_constraint_tac = COND' is_constraint_goal

    fun is_slottable_constraint_goal t = case Logic.strip_assums_concl t of
      @{mpat "Trueprop (CONSTRAINT (CN_FALSE _) _)"} => false
    | @{mpat "Trueprop (CONSTRAINT _ _)"} => true
    | _ => false

    val slot_constraint_tac = COND' is_slottable_constraint_goal THEN' to_slot_tac

    datatype 'a seq_cases = SC_NONE | SC_SINGLE of 'a Seq.seq | SC_MULTIPLE of 'a Seq.seq

    fun seq_cases seq = 
      case Seq.pull seq of
        NONE => SC_NONE
      | SOME (st1,seq) => case Seq.pull seq of
          NONE => SC_SINGLE (Seq.single st1)
        | SOME (st2,seq) => SC_MULTIPLE (Seq.cons st1 (Seq.cons st2 seq))  

    fun SEQ_CASES tac (single_tac, multiple_tac) st = let
      val res = tac st
    in
      case seq_cases res of
        SC_NONE => Seq.empty
      | SC_SINGLE res => Seq.maps single_tac res
      | SC_MULTIPLE res => Seq.maps multiple_tac res
    end

    fun SAFE tac = SEQ_CASES tac (all_tac, no_tac)
    fun SAFE' tac = SAFE o tac

    local
      fun simp_constraints_tac ctxt = let
        val ctxt = put_simpset HOL_basic_ss ctxt 
          addsimps (Named_Theorems.get ctxt @{named_theorems constraint_simps})
      in
        simp_tac ctxt
      end

      fun unfold_abbrevs_tac ctxt =  let
        val ctxt = put_simpset HOL_basic_ss ctxt 
          addsimps (Named_Theorems.get ctxt @{named_theorems constraint_abbrevs})
        val ethms = @{thms conjE}  
        val ithms = @{thms conjI}  
      in
        full_simp_tac ctxt 
        THEN_ALL_NEW TRY o REPEAT_ALL_NEW (ematch_tac ctxt ethms)
        THEN_ALL_NEW TRY o REPEAT_ALL_NEW (match_tac ctxt ithms)
      end
  
      fun WITH_RULE_NETS tac ctxt = let
        val scn_net = safe_constraint_rules.get ctxt |> Tactic.build_net
        val cn_net = constraint_rules.get ctxt |> Tactic.build_net
      in
        tac (scn_net,cn_net) ctxt
      end

      fun wrap_tac step_tac ctxt = REPEAT_ALL_NEW (
        simp_constraints_tac ctxt 
        THEN_ALL_NEW unfold_abbrevs_tac ctxt
        THEN_ALL_NEW step_tac ctxt
      )

      fun solve_step_tac (scn_net,cn_net) ctxt = REPEAT_ALL_NEW (
        DETERM o resolve_from_net_tac ctxt scn_net
        ORELSE' resolve_from_net_tac ctxt cn_net
      )

      fun safe_step_tac (scn_net,cn_net) ctxt = REPEAT_ALL_NEW (
        DETERM o resolve_from_net_tac ctxt scn_net
        ORELSE' SAFE' (resolve_from_net_tac ctxt cn_net)
      )

      fun solve_tac cn_nets ctxt = SOLVED' (wrap_tac (solve_step_tac cn_nets) ctxt)
      fun safe_tac cn_nets ctxt =  
        simp_constraints_tac ctxt
        THEN_ALL_NEW unfold_abbrevs_tac ctxt
        THEN_ALL_NEW (solve_tac cn_nets ctxt ORELSE' TRY o wrap_tac (safe_step_tac cn_nets) ctxt)

    in
      val solve_constraint_tac = TRADE (fn ctxt =>
        is_constraint_tac
        THEN' resolve_tac ctxt @{thms CONSTRAINT_I}
        THEN' WITH_RULE_NETS solve_tac ctxt)

      val safe_constraint_tac = TRADE (fn ctxt =>
        is_constraint_tac
        THEN' resolve_tac ctxt @{thms CONSTRAINT_I}
        THEN' WITH_RULE_NETS safe_tac ctxt
        THEN_ALL_NEW fo_resolve_tac @{thms CONSTRAINT_D} ctxt) (* TODO/FIXME: fo_resolve_tac has non-canonical parameter order *)

      val solve_constraint'_tac = TRADE (fn ctxt =>
        TRY o resolve_tac ctxt @{thms CONSTRAINT_I}
        THEN' WITH_RULE_NETS solve_tac ctxt)

      val safe_constraint'_tac = TRADE (fn ctxt =>
        TRY o resolve_tac ctxt @{thms CONSTRAINT_I}
        THEN' WITH_RULE_NETS safe_tac ctxt)


    end  

    fun constraint_tac ctxt = 
      safe_constraint_tac ctxt THEN_ALL_NEW slot_constraint_tac

    fun process_constraint_slot ctxt = ON_SLOT (ALLGOALS (TRY o safe_constraint_tac ctxt))

    fun solve_constraint_slot ctxt = 
      cond_focus is_constraint_goal 
        THEN ALLGOALS (
          COND' is_slot_goal
          ORELSE' (
            solve_constraint_tac ctxt
            ORELSE' TRY o safe_constraint_tac ctxt
          )
        )
      THEN unfocus_ins


    val setup = I
      #> constraint_rules.setup
      #> safe_constraint_rules.setup

  end

setup Sepref_Constraints.setup

method_setup print_slot = ‹Scan.succeed (fn ctxt => SIMPLE_METHOD (Sepref_Constraints.print_slot_tac ctxt))

method_setup solve_constraint = ‹Scan.succeed (fn ctxt => SIMPLE_METHOD' (Sepref_Constraints.solve_constraint'_tac ctxt))
method_setup safe_constraint = ‹Scan.succeed (fn ctxt => SIMPLE_METHOD' (Sepref_Constraints.safe_constraint'_tac ctxt))


end

Theory Sepref_Frame

section ‹Frame Inference›
theory Sepref_Frame
imports Sepref_Basic Sepref_Constraints
begin
  text ‹ In this theory, we provide a specific frame inference tactic
    for Sepref.

    The first tactic, frame_tac›, is a standard frame inference tactic, 
    based on the assumption that only @{const hn_ctxt}-assertions need to be
    matched.

    The second tactic, merge_tac›, resolves entailments of the form
      F1 ∨A F2 ⟹t ?F›
    that occur during translation of if and case statements.
    It synthesizes a new frame ?F, where refinements of variables 
    with equal refinements in F1› and F2› are preserved,
    and the others are set to @{const hn_invalid}.
    ›

definition mismatch_assn :: "('a  'c  assn)  ('a  'c  assn)  'a  'c  assn"
  where "mismatch_assn R1 R2 x y  R1 x y A R2 x y"

abbreviation "hn_mismatch R1 R2  hn_ctxt (mismatch_assn R1 R2)"

lemma recover_pure_aux: "CONSTRAINT is_pure R  hn_invalid R x y t hn_ctxt R x y"
  by (auto simp: is_pure_conv invalid_pure_recover hn_ctxt_def)



lemma frame_thms:
  "P t P"
  "PtP'  FtF'  F*P t F'*P'"
  "hn_ctxt R x y t hn_invalid R x y"
  "hn_ctxt R x y t hn_ctxt (λ_ _. true) x y"
  "CONSTRAINT is_pure R  hn_invalid R x y t hn_ctxt R x y"
  apply -
  applyS simp
  applyS (rule entt_star_mono; assumption)
  subgoal
    apply (simp add: hn_ctxt_def)
    apply (rule enttI)
    apply (rule ent_trans[OF invalidate[of R]])
    by solve_entails
  applyS (sep_auto simp: hn_ctxt_def)  
  applyS (erule recover_pure_aux)
  done

named_theorems_rev sepref_frame_match_rules ‹Sepref: Additional frame rules›

text ‹Rules to discharge unmatched stuff›
(*lemma frame_rem_thms:
  "P ⟹t P"
  "P ⟹t emp"
  by sep_auto+
*)
lemma frame_rem1: "PtP" by simp

lemma frame_rem2: "F t F'  F * hn_ctxt A x y t F' * hn_ctxt A x y"
  apply (rule entt_star_mono) by auto

lemma frame_rem3: "F t F'  F * hn_ctxt A x y t F'"
  using frame_thms(2) by fastforce
  
lemma frame_rem4: "P t emp" by simp

lemmas frame_rem_thms = frame_rem1 frame_rem2 frame_rem3 frame_rem4

named_theorems_rev sepref_frame_rem_rules
  ‹Sepref: Additional rules to resolve remainder of frame-pairing›

lemma ent_disj_star_mono:
  " A A C A E; B A D A F   A*B A C*D A E*F"
  by (metis ent_disjI1 ent_disjI2 ent_disjE ent_star_mono)  

lemma entt_disj_star_mono:
  " A A C t E; B A D t F   A*B A C*D t E*F"
proof -
  assume a1: "A A C t E"
  assume "B A D t F"
  then have "A * B A C * D A true * E * (true * F)"
    using a1 by (simp add: ent_disj_star_mono enttD)
  then show ?thesis
    by (metis (no_types) assn_times_comm enttI merge_true_star_ctx star_aci(3))
qed
    


lemma hn_merge1:
  (*"emp ∨A emp ⟹A emp"*)
  "F A F t F"
  " hn_ctxt R1 x x' A hn_ctxt R2 x x' t hn_ctxt R x x'; Fl A Fr t F  
     Fl * hn_ctxt R1 x x' A Fr * hn_ctxt R2 x x' t F * hn_ctxt R x x'"
  apply simp
  by (rule entt_disj_star_mono; simp)

lemma hn_merge2:
  "hn_invalid R x x' A hn_ctxt R x x' t hn_invalid R x x'"
  "hn_ctxt R x x' A hn_invalid R x x' t hn_invalid R x x'"
  by (sep_auto eintros: invalidate ent_disjE intro!: ent_imp_entt simp: hn_ctxt_def)+

lemma invalid_assn_mono: "hn_ctxt A x y t hn_ctxt B x y 
   hn_invalid A x y t hn_invalid B x y"
  by (clarsimp simp: invalid_assn_def entailst_def entails_def hn_ctxt_def)
      (force simp: mod_star_conv)

lemma hn_merge3: (* Not used *)
  "NO_MATCH (hn_invalid XX) R2; hn_ctxt R1 x x' A hn_ctxt R2 x x' t hn_ctxt Rm x x'  hn_invalid R1 x x' A hn_ctxt R2 x x' t hn_invalid Rm x x'"
  "NO_MATCH (hn_invalid XX) R1; hn_ctxt R1 x x' A hn_ctxt R2 x x' t hn_ctxt Rm x x'  hn_ctxt R1 x x' A hn_invalid R2 x x' t hn_invalid Rm x x'"
  apply (meson entt_disjD1 entt_disjD2 entt_disjE entt_trans frame_thms(3) invalid_assn_mono)  
  apply (meson entt_disjD1 entt_disjD2 entt_disjE entt_trans frame_thms(3) invalid_assn_mono)  
  done

lemmas merge_thms = hn_merge1 hn_merge2 

named_theorems sepref_frame_merge_rules ‹Sepref: Additional merge rules›


lemma hn_merge_mismatch: "hn_ctxt R1 x x' A hn_ctxt R2 x x' t hn_mismatch R1 R2 x x'"
  by (sep_auto simp: hn_ctxt_def mismatch_assn_def)

lemma is_merge: "P1AP2tP  P1AP2tP" .

lemma merge_mono: "AtA'; BtB'; A'AB' t C  AAB t C"
  by (meson entt_disjE entt_disjI1_direct entt_disjI2_direct entt_trans)
  
text ‹Apply forward rule on left or right side of merge›
lemma gen_merge_cons1: "AtA'; A'AB t C  AAB t C"
  by (meson merge_mono entt_refl)

lemma gen_merge_cons2: "BtB'; AAB' t C  AAB t C"
  by (meson merge_mono entt_refl)
  
lemmas gen_merge_cons = gen_merge_cons1 gen_merge_cons2


text ‹These rules are applied to recover pure values that have been destroyed by rule application›

definition "RECOVER_PURE P Q  P t Q"

lemma recover_pure:
  "RECOVER_PURE emp emp"
  "RECOVER_PURE P2 Q2; RECOVER_PURE P1 Q1  RECOVER_PURE (P1*P2) (Q1*Q2)"
  "CONSTRAINT is_pure R  RECOVER_PURE (hn_invalid R x y) (hn_ctxt R x y)"
  "RECOVER_PURE (hn_ctxt R x y) (hn_ctxt R x y)"
  unfolding RECOVER_PURE_def
  subgoal by sep_auto
  subgoal by (drule (1) entt_star_mono)
  subgoal by (rule recover_pure_aux)
  subgoal by sep_auto
  done
  
lemma recover_pure_triv: 
  "RECOVER_PURE P P"
  unfolding RECOVER_PURE_def by sep_auto


text ‹Weakening the postcondition by converting @{const invalid_assn} to @{term "λ_ _. true"}
definition "WEAKEN_HNR_POST Γ Γ' Γ''  (h. hΓ)  (Γ'' t Γ')"

lemma weaken_hnr_postI:
  assumes "WEAKEN_HNR_POST Γ Γ'' Γ'"
  assumes "hn_refine Γ c Γ' R a"
  shows "hn_refine Γ c Γ'' R a"
  apply (rule hn_refine_preI)
  apply (rule hn_refine_cons_post)
  apply (rule assms)
  using assms(1) unfolding WEAKEN_HNR_POST_def by blast

lemma weaken_hnr_post_triv: "WEAKEN_HNR_POST Γ P P"
  unfolding WEAKEN_HNR_POST_def
  by sep_auto

lemma weaken_hnr_post:
  "WEAKEN_HNR_POST Γ P P'; WEAKEN_HNR_POST Γ' Q Q'  WEAKEN_HNR_POST (Γ*Γ') (P*Q) (P'*Q')"
  "WEAKEN_HNR_POST (hn_ctxt R x y) (hn_ctxt R x y) (hn_ctxt R x y)"
  "WEAKEN_HNR_POST (hn_ctxt R x y) (hn_invalid R x y) (hn_ctxt (λ_ _. true) x y)"
proof (goal_cases)
  case 1 thus ?case
    unfolding WEAKEN_HNR_POST_def
    apply clarsimp
    apply (rule entt_star_mono) 
    by (auto simp: mod_star_conv)
next
  case 2 thus ?case by (rule weaken_hnr_post_triv)
next
  case 3 thus ?case 
    unfolding WEAKEN_HNR_POST_def 
    by (sep_auto simp: invalid_assn_def hn_ctxt_def)
qed



lemma reorder_enttI:
  assumes "A*true = C*true"
  assumes "B*true = D*true"
  shows "(AtB)  (CtD)"
  apply (intro eq_reflection)
  unfolding entt_def_true
  by (simp add: assms)
  
  

lemma merge_sat1: "(AAA' t Am)  (AAAm t Am)"
  using entt_disjD1 entt_disjE by blast
lemma merge_sat2: "(AAA' t Am)  (AmAA' t Am)"
  using entt_disjD2 entt_disjE by blast





ML signature SEPREF_FRAME = sig


  (* Check if subgoal is a frame obligation *)
  (*val is_frame : term -> bool *)
  (* Check if subgoal is a merge obligation *)
  val is_merge: term -> bool
  (* Perform frame inference *)
  val frame_tac: (Proof.context -> tactic') -> Proof.context -> tactic'
  (* Perform merging *)
  val merge_tac: (Proof.context -> tactic') -> Proof.context -> tactic'

  val frame_step_tac: (Proof.context -> tactic') -> bool -> Proof.context -> tactic'

  (* Reorder frame *)
  val prepare_frame_tac : Proof.context -> tactic'
  (* Solve a RECOVER_PURE goal, inserting constraints as necessary *)
  val recover_pure_tac: Proof.context -> tactic'

  (* Split precondition of hnr-goal into frame and arguments *)
  val align_goal_tac: Proof.context -> tactic'
  (* Normalize goal's precondition *)
  val norm_goal_pre_tac: Proof.context -> tactic'
  (* Rearrange precondition of hnr-term according to parameter order, normalize all relations *)
  val align_rl_conv: Proof.context -> conv

  (* Convert hn_invalid to λ_ _. true in postcondition of hnr-goal. Makes proving the goal easier.*)
  val weaken_post_tac: Proof.context -> tactic'

  val add_normrel_eq : thm -> Context.generic -> Context.generic
  val del_normrel_eq : thm -> Context.generic -> Context.generic
  val get_normrel_eqs : Proof.context -> thm list

  val cfg_debug: bool Config.T

  val setup: theory -> theory
end


structure Sepref_Frame : SEPREF_FRAME = struct

  val cfg_debug = 
    Attrib.setup_config_bool @{binding sepref_debug_frame} (K false)

  val DCONVERSION = Sepref_Debugging.DBG_CONVERSION cfg_debug
  val dbg_msg_tac = Sepref_Debugging.dbg_msg_tac cfg_debug


  structure normrel_eqs = Named_Thms (
    val name = @{binding sepref_frame_normrel_eqs}
    val description = "Equations to normalize relations for frame matching"
  )

  val add_normrel_eq = normrel_eqs.add_thm
  val del_normrel_eq = normrel_eqs.del_thm
  val get_normrel_eqs = normrel_eqs.get

  val mk_entailst = HOLogic.mk_binrel @{const_name "entailst"}


  local
    open Sepref_Basic Refine_Util Conv
  
    fun assn_ord p = case apply2 dest_hn_ctxt_opt p of
        (NONE,NONE) => EQUAL
      | (SOME _, NONE) => LESS
      | (NONE, SOME _) => GREATER
      | (SOME (_,a,_), SOME (_,a',_)) => Term_Ord.fast_term_ord (a,a')

  in
    fun reorder_ctxt_conv ctxt ct = let
      val cert = Thm.cterm_of ctxt

      val new_ct = Thm.term_of ct 
        |> strip_star
        |> sort assn_ord
        |> list_star
        |> cert

      val thm = Goal.prove_internal ctxt [] (mk_cequals (ct,new_ct)) 
        (fn _ => simp_tac 
          (put_simpset HOL_basic_ss ctxt addsimps @{thms star_aci}) 1)

    in
      thm
    end
  
    fun prepare_fi_conv ctxt ct = case Thm.term_of ct of
      @{mpat "?P t ?Q"} => let
        val cert = Thm.cterm_of ctxt
  
        (* Build table from abs-vars to ctxt *)
        val (Qm, Qum) = strip_star Q |> filter_out is_true |> List.partition is_hn_ctxt

        val Qtab = (
          Qm |> map (fn x => (#2 (dest_hn_ctxt x),(NONE,x))) 
          |> Termtab.make
        ) handle
            e as (Termtab.DUP _) => (
              tracing ("Dup heap: " ^ @{make_string} ct); raise e)
        
        (* Go over entries in P and try to find a partner *)
        val (Qtab,Pum) = fold (fn a => fn (Qtab,Pum) => 
          case dest_hn_ctxt_opt a of
            NONE => (Qtab,a::Pum)
          | SOME (_,p,_) => ( case Termtab.lookup Qtab p of
              SOME (NONE,tg) => (Termtab.update (p,(SOME a,tg)) Qtab, Pum)
            | _ => (Qtab,a::Pum)
            )
        ) (strip_star P) (Qtab,[])

        val Pum = filter_out is_true Pum

        (* Read out information from Qtab *)
        val (pairs,Qum2) = Termtab.dest Qtab |> map #2 
          |> List.partition (is_some o #1)
          |> apfst (map (apfst the))
          |> apsnd (map #2)
  
        (* Build reordered terms: P' = fst pairs * Pum, Q' = snd pairs * (Qum2*Qum) *)
        val P' = mk_star (list_star (map fst pairs), list_star Pum)
        val Q' = mk_star (list_star (map snd pairs), list_star (Qum2@Qum))
        
        val new_ct = mk_entailst (P', Q') |> cert
  
        val msg_tac = dbg_msg_tac (Sepref_Debugging.msg_allgoals "Solving frame permutation") ctxt 1
        val tac = msg_tac THEN ALLGOALS (resolve_tac ctxt @{thms reorder_enttI}) THEN star_permute_tac ctxt

        val thm = Goal.prove_internal ctxt [] (mk_cequals (ct,new_ct)) (fn _ => tac)
  
      in 
        thm
      end
    | _ => no_conv ct
  
  end

  fun is_merge @{mpat "Trueprop (_ A _ t _)"} = true | is_merge _ = false
  fun is_gen_frame @{mpat "Trueprop (_ t _)"} = true | is_gen_frame _ = false


  fun prepare_frame_tac ctxt = let
    open Refine_Util Conv
    val frame_ss = put_simpset HOL_basic_ss ctxt addsimps 
      @{thms mult_1_right[where 'a=assn] mult_1_left[where 'a=assn]}
  in
    CONVERSION Thm.eta_conversion THEN'
    (*CONCL_COND' is_frame THEN'*)
    simp_tac frame_ss THEN'
    CONVERSION (HOL_concl_conv (fn _ => prepare_fi_conv ctxt) ctxt)
  end    


  local
    fun wrap_side_tac side_tac dbg tac = tac THEN_ALL_NEW_FWD (
      CONCL_COND' is_gen_frame 
      ORELSE' (if dbg then TRY_SOLVED' else SOLVED') side_tac
    )
  in  
    fun frame_step_tac side_tac dbg ctxt = let
      open Refine_Util Conv

      (* Constraint solving is built-in *)
      val side_tac = Sepref_Constraints.constraint_tac ctxt ORELSE' side_tac ctxt

      val frame_thms = @{thms frame_thms} @
        Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_frame_match_rules} 
      val merge_thms = @{thms merge_thms} @
        Named_Theorems.get ctxt @{named_theorems sepref_frame_merge_rules}
      val ss = put_simpset HOL_basic_ss ctxt addsimps normrel_eqs.get ctxt
      fun frame_thm_tac dbg = wrap_side_tac side_tac dbg (resolve_tac ctxt frame_thms)
      fun merge_thm_tac dbg = wrap_side_tac side_tac dbg (resolve_tac ctxt merge_thms)
  
      fun thm_tac dbg = CONCL_COND' is_merge THEN_ELSE' (merge_thm_tac dbg, frame_thm_tac dbg)
    in
      full_simp_tac ss THEN' thm_tac dbg
    end
  end  

  fun frame_loop_tac side_tac ctxt = let

  in
    TRY o (
      REPEAT_ALL_NEW (DETERM o frame_step_tac side_tac false ctxt)
    )
  end


  fun frame_tac side_tac ctxt = let
    open Refine_Util Conv
    val frame_rem_thms = @{thms frame_rem_thms}
      @ Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_frame_rem_rules}
    val solve_remainder_tac = TRY o REPEAT_ALL_NEW (DETERM o resolve_tac ctxt frame_rem_thms)
  in
    (prepare_frame_tac ctxt
      THEN' resolve_tac ctxt @{thms ent_star_mono entt_star_mono})
    THEN_ALL_NEW_LIST [
      frame_loop_tac side_tac ctxt,
      solve_remainder_tac
    ]  
  end

  fun merge_tac side_tac ctxt = let
    open Refine_Util Conv
    val merge_conv = arg1_conv (binop_conv (reorder_ctxt_conv ctxt))
  in
    CONVERSION Thm.eta_conversion THEN'
    CONCL_COND' is_merge THEN'
    simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms star_aci}) THEN'
    CONVERSION (HOL_concl_conv (fn _ => merge_conv) ctxt) THEN'
    frame_loop_tac side_tac ctxt
  end

  val setup = normrel_eqs.setup

  local
    open Sepref_Basic
    fun is_invalid @{mpat "hn_invalid _ _ _ :: assn"} = true | is_invalid _ = false
    fun contains_invalid @{mpat "Trueprop (RECOVER_PURE ?Q _)"} = exists is_invalid (strip_star Q)
      | contains_invalid _ = false

  in
    fun recover_pure_tac ctxt = 
      CONCL_COND' contains_invalid THEN_ELSE' (
        REPEAT_ALL_NEW (DETERM o (resolve_tac ctxt @{thms recover_pure} ORELSE' Sepref_Constraints.constraint_tac ctxt)),
        resolve_tac ctxt @{thms recover_pure_triv}
      )
  end

  local
    open Sepref_Basic Refine_Util
    datatype cte = Other of term | Hn of term * term * term
    fun dest_ctxt_elem @{mpat "hn_ctxt ?R ?a ?c"} = Hn (R,a,c)
      | dest_ctxt_elem t = Other t

    fun mk_ctxt_elem (Other t) = t 
      | mk_ctxt_elem (Hn (R,a,c)) = @{mk_term "hn_ctxt ?R ?a ?c"}

    fun match x (Hn (_,y,_)) = x aconv y
      | match _ _ = false

    fun dest_with_frame (*ctxt*) _ t = let
      val (P,c,Q,R,a) = dest_hn_refine t
  
      val (_,(_,args)) = dest_hnr_absfun a
      val pre_ctes = strip_star P |> map dest_ctxt_elem
  
      val (pre_args,frame) = 
        (case split_matching match args pre_ctes of
            NONE => raise TERM("align_conv: Could not match all arguments",[P,a])
          | SOME x => x)

    in
      ((frame,pre_args),c,Q,R,a)
    end
  
    fun align_goal_conv_aux ctxt t = let
      val ((frame,pre_args),c,Q,R,a) = dest_with_frame ctxt t
      val P' = apply2 (list_star o map mk_ctxt_elem) (frame,pre_args) |> mk_star
      val t' = mk_hn_refine (P',c,Q,R,a)
    in t' end  

    fun align_rl_conv_aux ctxt t = let
      val ((frame,pre_args),c,Q,R,a) = dest_with_frame ctxt t

      val _ = frame = [] orelse raise TERM ("align_rl_conv: Extra preconditions in rule",[t,list_star (map mk_ctxt_elem frame)])

      val P' = list_star (map mk_ctxt_elem pre_args)
      val t' = mk_hn_refine (P',c,Q,R,a)
    in t' end  


    fun normrel_conv ctxt = let
      val ss = put_simpset HOL_basic_ss ctxt addsimps normrel_eqs.get ctxt
    in
      Simplifier.rewrite ss
    end

  in
    fun align_goal_conv ctxt = f_tac_conv ctxt (align_goal_conv_aux ctxt) (star_permute_tac ctxt)

    fun norm_goal_pre_conv ctxt = let
      open Conv
      val nr_conv = normrel_conv ctxt
    in
      HOL_concl_conv (fn _ => hn_refine_conv nr_conv all_conv all_conv all_conv all_conv) ctxt
    end  

    fun norm_goal_pre_tac ctxt = CONVERSION (norm_goal_pre_conv ctxt)

    fun align_rl_conv ctxt = let
      open Conv
      val nr_conv = normrel_conv ctxt
    in
      HOL_concl_conv (fn ctxt => f_tac_conv ctxt (align_rl_conv_aux ctxt) (star_permute_tac ctxt)) ctxt
      then_conv HOL_concl_conv (K (hn_refine_conv nr_conv all_conv nr_conv nr_conv all_conv)) ctxt
    end

    fun align_goal_tac ctxt = 
      CONCL_COND' is_hn_refine_concl 
      THEN' DCONVERSION ctxt (HOL_concl_conv align_goal_conv ctxt)
  end


  fun weaken_post_tac ctxt = TRADE (fn ctxt =>
    resolve_tac ctxt @{thms weaken_hnr_postI} 
    THEN' SOLVED' (REPEAT_ALL_NEW (DETERM o resolve_tac ctxt @{thms weaken_hnr_post weaken_hnr_post_triv}))
  ) ctxt

end

setup Sepref_Frame.setup

method_setup weaken_hnr_post = ‹Scan.succeed (fn ctxt => SIMPLE_METHOD' (Sepref_Frame.weaken_post_tac ctxt))
  ‹Convert "hn_invalid" to "hn_ctxt (λ_ _. true)" in postcondition of hn_refine goal›

(* TODO: Improper, modifies all h⊨_ premises that happen to be there. Use tagging to protect! *)
method extract_hnr_invalids = (
  rule hn_refine_preI,
  ((drule mod_starD hn_invalidI | elim conjE exE)+)?
) ― ‹Extract hn_invalid _ _ _ = true› preconditions from hn_refine› goal.›
  


lemmas [sepref_frame_normrel_eqs] = the_pure_pure pure_the_pure

end

Theory Sepref_Rules

section ‹Refinement Rule Management›
theory Sepref_Rules
imports Sepref_Basic Sepref_Constraints
begin
  text ‹This theory contains tools for managing the refinement rules used by Sepref›

  text ‹The theories are based on uncurried functions, i.e.,
    every function has type @{typ "'a'b"}, where @{typ 'a} is the 
    tuple of parameters, or unit if there are none.
    ›


  subsection ‹Assertion Interface Binding›
  text ‹Binding of interface types to refinement assertions›
  definition intf_of_assn :: "('a  _  assn)  'b itself  bool" where
    [simp]: "intf_of_assn a b = True"

  lemma intf_of_assnI: "intf_of_assn R TYPE('a)" by simp
  
  named_theorems_rev intf_of_assn ‹Links between refinement assertions and interface types›  

  lemma intf_of_assn_fallback: "intf_of_assn (R :: 'a  _  assn) TYPE('a)" by simp

  subsection ‹Function Refinement with Precondition›
  definition fref :: "('c  bool)  ('a × 'c) set  ('b × 'd) set
            (('a  'b) × ('c  'd)) set"
    ("[_]f _  _" [0,60,60] 60)         
  where "[P]f R  S  {(f,g). x y. P y  (x,y)R  (f x, g y)S}"
  
  abbreviation freft ("_ f _" [60,60] 60) where "R f S  ([λ_. True]f R  S)"
  
  lemma rel2p_fref[rel2p]: "rel2p (fref P R S) 
    = (λf g. (x y. P y  rel2p R x y  rel2p S (f x) (g y)))"  
    by (auto simp: fref_def rel2p_def[abs_def])

  lemma fref_cons:  
    assumes "(f,g)  [P]f R  S"
    assumes "c a. (c,a)R'  Q a  P a"
    assumes "R'  R"
    assumes "S  S'"
    shows "(f,g)  [Q]f R'  S'"
    using assms
    unfolding fref_def
    by fastforce

  lemmas fref_cons' = fref_cons[OF _ _ order_refl order_refl]  

  lemma frefI[intro?]: 
    assumes "x y. P y; (x,y)R  (f x, g y)S"
    shows "(f,g)fref P R S"
    using assms
    unfolding fref_def
    by auto

  lemma fref_ncI: "(f,g)RS  (f,g)RfS"  
    apply (rule frefI)
    apply parametricity
    done

  lemma frefD: 
    assumes "(f,g)fref P R S"
    shows "P y; (x,y)R  (f x, g y)S"
    using assms
    unfolding fref_def
    by auto

  lemma fref_ncD: "(f,g)RfS  (f,g)RS"  
    apply (rule fun_relI)
    apply (drule frefD)
    apply simp
    apply assumption+
    done


  lemma fref_compI: 
    "fref P R1 R2 O fref Q S1 S2 
      fref (λx. Q x  (y. (y,x)S1  P y)) (R1 O S1) (R2 O S2)"
    unfolding fref_def
    apply (auto)
    apply blast
    done

  lemma fref_compI':
    " (f,g)fref P R1 R2; (g,h)fref Q S1 S2  
       (f,h)  fref (λx. Q x  (y. (y,x)S1  P y)) (R1 O S1) (R2 O S2)"
    using fref_compI[of P R1 R2 Q S1 S2]   
    by auto

  lemma fref_unit_conv:
    "(λ_. c, λ_. a)  fref P unit_rel S  (P ()  (c,a)S)"   
    by (auto simp: fref_def)

  lemma fref_uncurry_conv:
    "(uncurry c, uncurry a)  fref P (R1×rR2) S 
     (x1 y1 x2 y2. P (y1,y2)  (x1,y1)R1  (x2,y2)R2  (c x1 x2, a y1 y2)  S)"
    by (auto simp: fref_def)

  lemma fref_mono: " x. P' x  P x; R'  R; S  S'  
     fref P R S  fref P' R' S'"  
    unfolding fref_def
    by auto blast

  lemma fref_composeI:
    assumes FR1: "(f,g)fref P R1 R2"
    assumes FR2: "(g,h)fref Q S1 S2"
    assumes C1: "x. P' x  Q x"
    assumes C2: "x y. P' x; (y,x)S1  P y"
    assumes R1: "R'  R1 O S1"
    assumes R2: "R2 O S2  S'"
    assumes FH: "f'=f" "h'=h"
    shows "(f',h')  fref P' R' S'"
    unfolding FH
    apply (rule subsetD[OF fref_mono fref_compI'[OF FR1 FR2]])
    using C1 C2 apply blast
    using R1 apply blast
    using R2 apply blast
    done

  lemma fref_triv: "AId  (f,f)[P]f A  Id"
    by (auto simp: fref_def)


  subsection ‹Heap-Function Refinement›
  text ‹
    The following relates a heap-function with a pure function.
    It contains a precondition, a refinement assertion for the arguments
    before and after execution, and a refinement relation for the result.
    ›
  (* TODO: We only use this with keep/destroy information, so we could model
    the parameter relations as such (('a⇒'ai ⇒ assn) × bool) *)
  definition hfref 
    :: "
      ('a  bool) 
    (('a  'ai  assn) × ('a  'ai  assn)) 
    ('b  'bi  assn) 
    (('ai  'bi Heap) × ('a'b nres)) set"
   ("[_]a _  _" [0,60,60] 60)
   where
    "[P]a RS  T  { (f,g) . c a.  P a  hn_refine (fst RS a c) (f c) (snd RS a c) T (g a)}"

  abbreviation hfreft ("_ a _" [60,60] 60) where "RS a T  ([λ_. True]a RS  T)"

  lemma hfrefI[intro?]: 
    assumes "c a. P a  hn_refine (fst RS a c) (f c) (snd RS a c) T (g a)"
    shows "(f,g)hfref P RS T"
    using assms unfolding hfref_def by blast

  lemma hfrefD: 
    assumes "(f,g)hfref P RS T"
    shows "c a. P a  hn_refine (fst RS a c) (f c) (snd RS a c) T (g a)"
    using assms unfolding hfref_def by blast

  lemma hfref_to_ASSERT_conv: 
    "NO_MATCH (λ_. True) P  (a,b)[P]a R  S  (a,λx. ASSERT (P x)  b x)  R a S"  
    unfolding hfref_def
    apply (clarsimp; safe; clarsimp?)
    apply (rule hn_refine_nofailI)
    apply (simp add: refine_pw_simps)
    subgoal for xc xa
      apply (drule spec[of _ xc])
      apply (drule spec[of _ xa])
      by simp
    done

  text ‹
    A pair of argument refinement assertions can be created by the 
    input assertion and the information whether the parameter is kept or destroyed
    by the function.
    ›  
  primrec hf_pres 
    :: "('a  'b  assn)  bool  ('a  'b  assn)×('a  'b  assn)"
    where 
      "hf_pres R True = (R,R)" | "hf_pres R False = (R,invalid_assn R)"

  abbreviation hfkeep 
    :: "('a  'b  assn)  ('a  'b  assn)×('a  'b  assn)" 
    ("(_k)" [1000] 999)
    where "Rk  hf_pres R True"
  abbreviation hfdrop 
    :: "('a  'b  assn)  ('a  'b  assn)×('a  'b  assn)" 
    ("(_d)" [1000] 999)
    where "Rd  hf_pres R False"

  abbreviation "hn_kede R kd  hn_ctxt (snd (hf_pres R kd))"
  abbreviation "hn_keep R  hn_kede R True"
  abbreviation "hn_dest R  hn_kede R False"

  lemma keep_drop_sels[simp]:  
    "fst (Rk) = R"
    "snd (Rk) = R"
    "fst (Rd) = R"
    "snd (Rd) = invalid_assn R"
    by auto

  lemma hf_pres_fst[simp]: "fst (hf_pres R k) = R" by (cases k) auto

  text ‹
    The following operator combines multiple argument assertion-pairs to
    argument assertion-pairs for the product. It is required to state
    argument assertion-pairs for uncurried functions.
    ›  
  definition hfprod :: "
    (('a  'b  assn)×('a  'b  assn)) 
     (('c  'd  assn)×('c  'd  assn))
     ((('a×'c)  ('b × 'd)  assn) × (('a×'c)  ('b × 'd)  assn))"
    (infixl "*a" 65)
    where "RR *a SS  (prod_assn (fst RR) (fst SS), prod_assn (snd RR) (snd SS))"

  lemma hfprod_fst_snd[simp]:
    "fst (A *a B) = prod_assn (fst A) (fst B)" 
    "snd (A *a B) = prod_assn (snd A) (snd B)" 
    unfolding hfprod_def by auto



  subsubsection ‹Conversion from fref to hfref›  
  (* TODO: Variant of import-param! Automate this! *)
  lemma fref_to_pure_hfref':
    assumes "(f,g)  [P]f RSnres_rel"
    assumes "x. xDomain R  R¯``Collect P  f x = RETURN (f' x)"
    shows "(return o f', g)  [P]a (pure R)kpure S"
    apply (rule hfrefI) apply (rule hn_refineI)
    using assms
    apply ((sep_auto simp: fref_def pure_def pw_le_iff pw_nres_rel_iff
      refine_pw_simps eintros del: exI))
    apply force
    done


  subsubsection ‹Conversion from hfref to hnr›  
  text ‹This section contains the lemmas. The ML code is further down. ›
  lemma hf2hnr:
    assumes "(f,g)  [P]a R  S"
    shows "x xi. P x  hn_refine (emp * hn_ctxt (fst R) x xi) (f$xi) (emp * hn_ctxt (snd R) x xi) S (g$x)"
    using assms
    unfolding hfref_def 
    by (auto simp: hn_ctxt_def)

  (*lemma hf2hnr_new:
    assumes "(f,g) ∈ [P]a R → S"
    shows "∀x xi. (∀h. h⊨fst R x xi ⟶ P x) ⟶ hn_refine (emp * hn_ctxt (fst R) x xi) (f xi) (emp * hn_ctxt (snd R) x xi) S (g$x)"
    using assms
    unfolding hfref_def 
    by (auto simp: hn_ctxt_def intro: hn_refine_preI)
  *)


  (* Products that stem from currying are tagged by a special refinement relation *)  
  definition [simp]: "to_hnr_prod  prod_assn"

  lemma to_hnr_prod_fst_snd:
    "fst (A *a B) = to_hnr_prod (fst A) (fst B)" 
    "snd (A *a B) = to_hnr_prod (snd A) (snd B)" 
    unfolding hfprod_def by auto

  (* Warning: This lemma is carefully set up to be applicable as an unfold rule,
    for more than one level of uncurrying*)
  lemma hnr_uncurry_unfold: "
    (x xi. P x  
      hn_refine 
        (Γ * hn_ctxt (to_hnr_prod A B) x xi) 
        (fi xi) 
        (Γ' * hn_ctxt (to_hnr_prod A' B') x xi) 
        R 
        (f x))
 (b bi a ai. P (a,b) 
      hn_refine 
        (Γ * hn_ctxt B b bi * hn_ctxt A a ai) 
        (fi (ai,bi)) 
        (Γ' * hn_ctxt B' b bi * hn_ctxt A' a ai)
        R
        (f (a,b))
    )"
    by (auto simp: hn_ctxt_def prod_assn_def star_aci)
    
  lemma hnr_intro_dummy:
    "x xi. P x  hn_refine (Γ x xi) (c xi) (Γ' x xi) R (a x)  x xi. P x  hn_refine (emp*Γ x xi) (c xi) (emp*Γ' x xi) R (a x)" 
    by simp

  lemma hn_ctxt_ctxt_fix_conv: "hn_ctxt (hn_ctxt R) = hn_ctxt R"
    by (simp add: hn_ctxt_def[abs_def])

  lemma uncurry_APP: "uncurry f$(a,b) = f$a$b" by auto

  (* TODO: Replace by more general rule. *)  
  lemma norm_RETURN_o: 
    "f. (RETURN o f)$x = (RETURN$(f$x))"
    "f. (RETURN oo f)$x$y = (RETURN$(f$x$y))"
    "f. (RETURN ooo f)$x$y$z = (RETURN$(f$x$y$z))"
    "f. (λx. RETURN ooo f x)$x$y$z$a = (RETURN$(f$x$y$z$a))"
    "f. (λx y. RETURN ooo f x y)$x$y$z$a$b = (RETURN$(f$x$y$z$a$b))"
    by auto

  lemma norm_return_o: 
    "f. (return o f)$x = (return$(f$x))"
    "f. (return oo f)$x$y = (return$(f$x$y))"
    "f. (return ooo f)$x$y$z = (return$(f$x$y$z))"
    "f. (λx. return ooo f x)$x$y$z$a = (return$(f$x$y$z$a))"
    "f. (λx y. return ooo f x y)$x$y$z$a$b = (return$(f$x$y$z$a$b))"
    by auto

  
  lemma hn_val_unit_conv_emp[simp]: "hn_val unit_rel x y = emp"
    by (auto simp: hn_ctxt_def pure_def)

  subsubsection ‹Conversion from hnr to hfref›  
  text ‹This section contains the lemmas. The ML code is further down. ›

  abbreviation "id_assn  pure Id"
  abbreviation "unit_assn  id_assn :: unit  _"

  lemma pure_unit_rel_eq_empty: "unit_assn x y = emp"  
    by (auto simp: pure_def)

  lemma uc_hfprod_sel:
    "fst (A *a B) a c = (case (a,c) of ((a1,a2),(c1,c2))  fst A a1 c1 * fst B a2 c2)" 
    "snd (A *a B) a c = (case (a,c) of ((a1,a2),(c1,c2))  snd A a1 c1 * snd B a2 c2)" 
    unfolding hfprod_def prod_assn_def[abs_def] by auto


  subsubsection ‹Conversion from relation to fref›  
  text ‹This section contains the lemmas. The ML code is further down. ›

  definition "CURRY R  { (f,g). (uncurry f, uncurry g)  R }"

  lemma fref_param1: "RS = fref (λ_. True) R S"  
    by (auto simp: fref_def fun_relD)

  lemma fref_nest: "fref P1 R1 (fref P2 R2 S) 
     CURRY (fref (λ(a,b). P1 a  P2 b) (R1×rR2) S)"
    apply (rule eq_reflection)
    by (auto simp: fref_def CURRY_def)

  lemma in_CURRY_conv: "(f,g)  CURRY R  (uncurry f, uncurry g)  R"  
    unfolding CURRY_def by auto

  lemma uncurry0_APP[simp]: "uncurry0 c $ x = c" by auto

  lemma fref_param0I: "(c,a)R  (uncurry0 c, uncurry0 a)  fref (λ_. True) unit_rel R"
    by (auto simp: fref_def)

  subsubsection ‹Composition›
  definition hr_comp :: "('b  'c  assn)  ('b × 'a) set  'a  'c  assn"
    ― ‹Compose refinement assertion with refinement relation›
    where "hr_comp R1 R2 a c  Ab. R1 b c * ((b,a)R2)"

  definition hrp_comp 
    :: "('d  'b  assn) × ('d  'c  assn)
         ('d × 'a) set  ('a  'b  assn) × ('a  'c  assn)"
    ― ‹Compose argument assertion-pair with refinement relation›    
    where "hrp_comp RR' S  (hr_comp (fst RR') S, hr_comp (snd RR') S) "

  lemma hr_compI: "(b,a)R2  R1 b c A hr_comp R1 R2 a c"  
    unfolding hr_comp_def
    by sep_auto

  lemma hr_comp_Id1[simp]: "hr_comp (pure Id) R = pure R"  
    unfolding hr_comp_def[abs_def] pure_def
    apply (intro ext ent_iffI)
    by sep_auto+

  lemma hr_comp_Id2[simp]: "hr_comp R Id = R"  
    unfolding hr_comp_def[abs_def]
    apply (intro ext ent_iffI)
    by sep_auto+
    
  (*lemma hr_comp_invalid[simp]: "hr_comp (λa c. true) R a c = true * ↑(∃b. (b,a)∈R)"
    unfolding hr_comp_def[abs_def]
    apply (intro ext ent_iffI)
    apply sep_auto+
    done*)
    
  lemma hr_comp_emp[simp]: "hr_comp (λa c. emp) R a c = (b. (b,a)R)"
    unfolding hr_comp_def[abs_def]
    apply (intro ext ent_iffI)
    apply sep_auto+
    done

  lemma hr_comp_prod_conv[simp]:
    "hr_comp (prod_assn Ra Rb) (Ra' ×r Rb') 
    = prod_assn (hr_comp Ra Ra') (hr_comp Rb Rb')"  
    unfolding hr_comp_def[abs_def] prod_assn_def[abs_def]
    apply (intro ext ent_iffI)
    apply solve_entails apply clarsimp apply sep_auto
    apply clarsimp apply (intro ent_ex_preI)
    apply (rule ent_ex_postI) apply (sep_auto split: prod.splits)
    done

  lemma hr_comp_pure: "hr_comp (pure R) S = pure (R O S)"  
    apply (intro ext)
    apply (rule ent_iffI)
    unfolding hr_comp_def[abs_def] 
    apply (sep_auto simp: pure_def)+
    done

  lemma hr_comp_is_pure[safe_constraint_rules]: "is_pure A  is_pure (hr_comp A B)"
    by (auto simp: hr_comp_pure is_pure_conv)

  lemma hr_comp_the_pure: "is_pure A  the_pure (hr_comp A B) = the_pure A O B"
    unfolding is_pure_conv
    by (clarsimp simp: hr_comp_pure)

  lemma rdomp_hrcomp_conv: "rdomp (hr_comp A R) x  (y. rdomp A y  (y,x)R)"
    by (auto simp: rdomp_def hr_comp_def)

  lemma hn_rel_compI: 
    "nofail a; (b,a)R2nres_rel  hn_rel R1 b c A hn_rel (hr_comp R1 R2) a c"
    unfolding hr_comp_def hn_rel_def nres_rel_def
    apply (clarsimp intro!: ent_ex_preI)
    apply (drule (1) order_trans)
    apply (simp add: ret_le_down_conv)
    by sep_auto

  lemma hr_comp_precise[constraint_rules]:
    assumes [safe_constraint_rules]: "precise R"
    assumes SV: "single_valued S"
    shows "precise (hr_comp R S)"
    apply (rule preciseI)
    unfolding hr_comp_def
    apply clarsimp
    by (metis SV assms(1) preciseD single_valuedD)

  lemma hr_comp_assoc: "hr_comp (hr_comp R S) T = hr_comp R (S O T)"
    apply (intro ext)
    unfolding hr_comp_def
    apply (rule ent_iffI; clarsimp)
    apply sep_auto
    apply (rule ent_ex_preI; clarsimp) (* TODO: 
      sep_auto/solve_entails is too eager splitting the subgoal here! *)
    apply sep_auto
    done


  lemma hnr_comp:
    assumes R: "b1 c1. P b1  hn_refine (R1 b1 c1 * Γ) (c c1) (R1p b1 c1 * Γ') R (b b1)"
    assumes S: "a1 b1. Q a1; (b1,a1)R1'  (b b1,a a1)R'nres_rel"
    assumes PQ: "a1 b1. Q a1; (b1,a1)R1'  P b1"
    assumes Q: "Q a1"
    shows "hn_refine 
      (hr_comp R1 R1' a1 c1 * Γ) 
      (c c1)
      (hr_comp R1p R1' a1 c1 * Γ') 
      (hr_comp R R') 
      (a a1)"
    unfolding hn_refine_alt
  proof clarsimp
    assume NF: "nofail (a a1)"
    show "
      <hr_comp R1 R1' a1 c1 * Γ> 
        c c1 
      <λr. hn_rel (hr_comp R R') (a a1) r * (hr_comp R1p R1' a1 c1 * Γ')>t"
      apply (subst hr_comp_def)
      apply (clarsimp intro!: norm_pre_ex_rule)
    proof -
      fix b1
      assume R1: "(b1, a1)  R1'"

      from S R1 Q have R': "(b b1, a a1)  R'nres_rel" by blast
      with NF have NFB: "nofail (b b1)" 
        by (simp add: nres_rel_def pw_le_iff refine_pw_simps)
      
      from PQ R1 Q have P: "P b1" by blast
      with NFB R have "<R1 b1 c1 * Γ> c c1 <λr. hn_rel R (b b1) r * (R1p b1 c1 * Γ')>t"
        unfolding hn_refine_alt by auto
      thus "<R1 b1 c1 * Γ> 
        c c1 
        <λr. hn_rel (hr_comp R R') (a a1) r * (hr_comp R1p R1' a1 c1 * Γ')>t"
        apply (rule cons_post_rule)
        apply (solve_entails)
        by (intro ent_star_mono hn_rel_compI[OF NF R'] hr_compI[OF R1] ent_refl)
    qed
  qed    

  lemma hnr_comp1_aux:
    assumes R: "b1 c1. P b1  hn_refine (hn_ctxt R1 b1 c1) (c c1) (hn_ctxt R1p b1 c1) R (b$b1)"
    assumes S: "a1 b1. Q a1; (b1,a1)R1'  (b$b1,a$a1)R'nres_rel"
    assumes PQ: "a1 b1. Q a1; (b1,a1)R1'  P b1"
    assumes Q: "Q a1"
    shows "hn_refine 
      (hr_comp R1 R1' a1 c1) 
      (c c1)
      (hr_comp R1p R1' a1 c1) 
      (hr_comp R R') 
      (a a1)"
    using assms hnr_comp[where Γ=emp and Γ'=emp and a=a and b=b and c=c and P=P and Q=Q]  
    unfolding hn_ctxt_def
    by auto

  lemma hfcomp:
    assumes A: "(f,g)  [P]a RR'  S"
    assumes B: "(g,h)  [Q]f T  Unres_rel"
    shows "(f,h)  [λa. Q a  (a'. (a',a)T  P a')]a 
      hrp_comp RR' T  hr_comp S U"
    using assms  
    unfolding fref_def hfref_def hrp_comp_def
    apply clarsimp
    apply (rule hnr_comp1_aux[of 
        P "fst RR'" f "snd RR'" S g "λa. Q a  (a'. (a',a)T  P a')" T h U])
    apply (auto simp: hn_ctxt_def)
    done

  lemma hfref_weaken_pre_nofail: 
    assumes "(f,g)  [P]a R  S"  
    shows "(f,g)  [λx. nofail (g x)  P x]a R  S"
    using assms
    unfolding hfref_def hn_refine_def
    by auto

  lemma hfref_cons:
    assumes "(f,g)  [P]a R  S"
    assumes "x. P' x  P x"
    assumes "x y. fst R' x y t fst R x y"
    assumes "x y. snd R x y t snd R' x y"
    assumes "x y. S x y t S' x y"
    shows "(f,g)  [P']a R'  S'"
    unfolding hfref_def
    apply clarsimp
    apply (rule hn_refine_cons)
    apply (rule assms(3))
    defer
    apply (rule entt_trans[OF assms(4)]; sep_auto)
    apply (rule assms(5))
    apply (frule assms(2))
    using assms(1)
    unfolding hfref_def
    apply auto
    done

  subsubsection ‹Composition Automation›  
  text ‹This section contains the lemmas. The ML code is further down. ›

  lemma prod_hrp_comp: 
    "hrp_comp (A *a B) (C ×r D) = hrp_comp A C *a hrp_comp B D"
    unfolding hrp_comp_def hfprod_def by simp
  
  lemma hrp_comp_keep: "hrp_comp (Ak) B = (hr_comp A B)k"
    by (auto simp: hrp_comp_def)

  lemma hr_comp_invalid: "hr_comp (invalid_assn R1) R2 = invalid_assn (hr_comp R1 R2)"
    apply (intro ent_iffI entailsI ext)
    unfolding invalid_assn_def hr_comp_def
    by auto

  lemma hrp_comp_dest: "hrp_comp (Ad) B = (hr_comp A B)d"
    by (auto simp: hrp_comp_def hr_comp_invalid)



  definition "hrp_imp RR RR'  
    a b. (fst RR' a b t fst RR a b)  (snd RR a b t snd RR' a b)"

  lemma hfref_imp: "hrp_imp RR RR'  [P]a RR  S  [P]a RR'  S"  
    apply clarsimp
    apply (erule hfref_cons)
    apply (simp_all add: hrp_imp_def)
    done
    
  lemma hrp_imp_refl: "hrp_imp RR RR"
    unfolding hrp_imp_def by auto

  lemma hrp_imp_reflI: "RR = RR'  hrp_imp RR RR'"
    unfolding hrp_imp_def by auto


  lemma hrp_comp_cong: "hrp_imp A A'  B=B'  hrp_imp (hrp_comp A B) (hrp_comp A' B')"
    by (sep_auto simp: hrp_imp_def hrp_comp_def hr_comp_def entailst_def)
    
  lemma hrp_prod_cong: "hrp_imp A A'  hrp_imp B B'  hrp_imp (A*aB) (A'*aB')"
    by (sep_auto simp: hrp_imp_def prod_assn_def intro: entt_star_mono)

  lemma hrp_imp_trans: "hrp_imp A B  hrp_imp B C  hrp_imp A C"  
    unfolding hrp_imp_def
    by (fastforce intro: entt_trans)

  lemma fcomp_norm_dflt_init: "x[P]a R  T  hrp_imp R S  x[P]a S  T"
    apply (erule rev_subsetD)
    by (rule hfref_imp)

  definition "comp_PRE R P Q S  λx. S x  (P x  (y. (y,x)R  Q x y))"

  lemma comp_PRE_cong[cong]: 
    assumes "RR'"
    assumes "x. P x  P' x"
    assumes "x. S x  S' x"
    assumes "x y. P x; (y,x)R; yDomain R; S' x   Q x y  Q' x y"
    shows "comp_PRE R P Q S  comp_PRE R' P' Q' S'"
    using assms
    by (fastforce simp: comp_PRE_def intro!: eq_reflection ext)

  lemma fref_compI_PRE:
    " (f,g)fref P R1 R2; (g,h)fref Q S1 S2  
       (f,h)  fref (comp_PRE S1 Q (λ_. P) (λ_. True)) (R1 O S1) (R2 O S2)"
    using fref_compI[of P R1 R2 Q S1 S2]   
    unfolding comp_PRE_def
    by auto

  lemma PRE_D1: "(Q x  P x)  comp_PRE S1 Q (λx _. P x) S x"
    by (auto simp: comp_PRE_def)

  lemma PRE_D2: "(Q x  (y. (y,x)S1  S x  P x y))  comp_PRE S1 Q P S x"
    by (auto simp: comp_PRE_def)

  lemma fref_weaken_pre: 
    assumes "x. P x  P' x"  
    assumes "(f,h)  fref P' R S"
    shows "(f,h)  fref P R S"
    apply (rule rev_subsetD[OF assms(2) fref_mono])
    using assms(1) by auto
    
  lemma fref_PRE_D1:
    assumes "(f,h)  fref (comp_PRE S1 Q (λx _. P x) X) R S"  
    shows "(f,h)  fref (λx. Q x  P x) R S"
    by (rule fref_weaken_pre[OF PRE_D1 assms])

  lemma fref_PRE_D2:
    assumes "(f,h)  fref (comp_PRE S1 Q P X) R S"  
    shows "(f,h)  fref (λx. Q x  (y. (y,x)S1  X x  P x y)) R S"
    by (rule fref_weaken_pre[OF PRE_D2 assms])

  lemmas fref_PRE_D = fref_PRE_D1 fref_PRE_D2

  lemma hfref_weaken_pre: 
    assumes "x. P x  P' x"  
    assumes "(f,h)  hfref P' R S"
    shows "(f,h)  hfref P R S"
    using assms
    by (auto simp: hfref_def)

  lemma hfref_weaken_pre': 
    assumes "x. P x; rdomp (fst R) x  P' x"  
    assumes "(f,h)  hfref P' R S"
    shows "(f,h)  hfref P R S"
    apply (rule hfrefI)
    apply (rule hn_refine_preI)
    using assms
    by (auto simp: hfref_def rdomp_def)

  lemma hfref_weaken_pre_nofail': 
    assumes "(f,g)  [P]a R  S"  
    assumes "x. nofail (g x); Q x  P x"
    shows "(f,g)  [Q]a R  S"
    apply (rule hfref_weaken_pre[OF _ assms(1)[THEN hfref_weaken_pre_nofail]])
    using assms(2) 
    by blast

  lemma hfref_compI_PRE_aux:
    assumes A: "(f,g)  [P]a RR'  S"
    assumes B: "(g,h)  [Q]f T  Unres_rel"
    shows "(f,h)  [comp_PRE T Q (λ_. P) (λ_. True)]a 
      hrp_comp RR' T  hr_comp S U"
    apply (rule hfref_weaken_pre[OF _ hfcomp[OF A B]])
    by (auto simp: comp_PRE_def)


  lemma hfref_compI_PRE:
    assumes A: "(f,g)  [P]a RR'  S"
    assumes B: "(g,h)  [Q]f T  Unres_rel"
    shows "(f,h)  [comp_PRE T Q (λx y. P y) (λx. nofail (h x))]a 
      hrp_comp RR' T  hr_comp S U"
    using hfref_compI_PRE_aux[OF A B, THEN hfref_weaken_pre_nofail]  
    apply (rule hfref_weaken_pre[rotated])
    apply (auto simp: comp_PRE_def)
    done

  lemma hfref_PRE_D1:
    assumes "(f,h)  hfref (comp_PRE S1 Q (λx _. P x) X) R S"  
    shows "(f,h)  hfref (λx. Q x  P x) R S"
    by (rule hfref_weaken_pre[OF PRE_D1 assms])

  lemma hfref_PRE_D2:
    assumes "(f,h)  hfref (comp_PRE S1 Q P X) R S"  
    shows "(f,h)  hfref (λx. Q x  (y. (y,x)S1  X x  P x y)) R S"
    by (rule hfref_weaken_pre[OF PRE_D2 assms])

  lemma hfref_PRE_D3:
    assumes "(f,h)  hfref (comp_PRE S1 Q P X) R S"  
    shows "(f,h)  hfref (comp_PRE S1 Q P X) R S"
    using assms .

  lemmas hfref_PRE_D = hfref_PRE_D1 hfref_PRE_D3

  subsection ‹Automation›  
  text ‹Purity configuration for constraint solver›
  lemmas [safe_constraint_rules] = pure_pure

  text ‹Configuration for hfref to hnr conversion›
  named_theorems to_hnr_post ‹to_hnr converter: Postprocessing unfold rules›

  lemma uncurry0_add_app_tag: "uncurry0 (RETURN c) = uncurry0 (RETURN$c)" by simp

  lemmas [to_hnr_post] = norm_RETURN_o norm_return_o
    uncurry0_add_app_tag uncurry0_apply uncurry0_APP hn_val_unit_conv_emp
    mult_1[of "x::assn" for x] mult_1_right[of "x::assn" for x]

  named_theorems to_hfref_post ‹to_hfref converter: Postprocessing unfold rules› 
  lemma prod_casesK[to_hfref_post]: "case_prod (λ_ _. k) = (λ_. k)" by auto
  lemma uncurry0_hfref_post[to_hfref_post]: "hfref (uncurry0 True) R S = hfref (λ_. True) R S" 
    apply (fo_rule arg_cong fun_cong)+ by auto


  (* Currently not used, we keep it in here anyway. *)  
  text ‹Configuration for relation normalization after composition›
  named_theorems fcomp_norm_unfold ‹fcomp-normalizer: Unfold theorems›
  named_theorems fcomp_norm_simps ‹fcomp-normalizer: Simplification theorems›
  named_theorems fcomp_norm_init "fcomp-normalizer: Initialization rules"  
  named_theorems fcomp_norm_trans "fcomp-normalizer: Transitivity rules"  
  named_theorems fcomp_norm_cong "fcomp-normalizer: Congruence rules"  
  named_theorems fcomp_norm_norm "fcomp-normalizer: Normalization rules"  
  named_theorems fcomp_norm_refl "fcomp-normalizer: Reflexivity rules"  

  text ‹Default Setup›
  lemmas [fcomp_norm_unfold] = prod_rel_comp nres_rel_comp Id_O_R R_O_Id
  lemmas [fcomp_norm_unfold] = hr_comp_Id1 hr_comp_Id2
  lemmas [fcomp_norm_unfold] = hr_comp_prod_conv
  lemmas [fcomp_norm_unfold] = prod_hrp_comp hrp_comp_keep hrp_comp_dest hr_comp_pure
  (*lemmas [fcomp_norm_unfold] = prod_casesK uncurry0_hfref_post*)

  lemma [fcomp_norm_simps]: "CONSTRAINT is_pure P  pure (the_pure P) = P" by simp
  lemmas [fcomp_norm_simps] = True_implies_equals 

  lemmas [fcomp_norm_init] = fcomp_norm_dflt_init
  lemmas [fcomp_norm_trans] = hrp_imp_trans
  lemmas [fcomp_norm_cong] = hrp_comp_cong hrp_prod_cong
  (*lemmas [fcomp_norm_norm] = hrp_comp_dest*)
  lemmas [fcomp_norm_refl] = refl hrp_imp_refl

  lemma ensure_fref_nresI: "(f,g)[P]f RS  (RETURN o f, RETURN o g)[P]f RSnres_rel" 
    by (auto intro: nres_relI simp: fref_def)

  lemma ensure_fref_nres_unfold:
    "f. RETURN o (uncurry0 f) = uncurry0 (RETURN f)" 
    "f. RETURN o (uncurry f) = uncurry (RETURN oo f)"
    "f. (RETURN ooo uncurry) f = uncurry (RETURN ooo f)"
    by auto

  text ‹Composed precondition normalizer›  
  named_theorems fcomp_prenorm_simps ‹fcomp precondition-normalizer: Simplification theorems›

  text ‹Support for preconditions of the form _∈Domain R›, 
    where R› is the relation of the next more abstract level.›
  declare DomainI[fcomp_prenorm_simps]

  lemma auto_weaken_pre_init_hf: 
    assumes "x. PROTECT P x  P' x"  
    assumes "(f,h)  hfref P' R S"
    shows "(f,h)  hfref P R S"
    using assms
    by (auto simp: hfref_def)

  lemma auto_weaken_pre_init_f: 
    assumes "x. PROTECT P x  P' x"  
    assumes "(f,h)  fref P' R S"
    shows "(f,h)  fref P R S"
    using assms
    by (auto simp: fref_def)

  lemmas auto_weaken_pre_init = auto_weaken_pre_init_hf auto_weaken_pre_init_f  

  lemma auto_weaken_pre_uncurry_step:
    assumes "PROTECT f a  f'"
    shows "PROTECT (λ(x,y). f x y) (a,b)  f' b" 
    using assms
    by (auto simp: curry_def dest!: meta_eq_to_obj_eq intro!: eq_reflection)

  lemma auto_weaken_pre_uncurry_finish:  
    "PROTECT f x  f x" by (auto)

  lemma auto_weaken_pre_uncurry_start:
    assumes "P  P'"
    assumes "P'Q"
    shows "PQ"
    using assms by (auto)

  lemma auto_weaken_pre_comp_PRE_I:
    assumes "S x  P x"
    assumes "y. (y,x)R; P x; S x  Q x y"
    shows "comp_PRE R P Q S x"
    using assms by (auto simp: comp_PRE_def)

  lemma auto_weaken_pre_to_imp_nf:
    "(ABC) = (AB  C)"
    "((AB)C) = (ABC)"
    by auto

  lemma auto_weaken_pre_add_dummy_imp:
    "P  True  P" by simp


  text ‹Synthesis for hfref statements›  
  definition hfsynth_ID_R :: "('a  _  assn)  'a  bool" where
    [simp]: "hfsynth_ID_R _ _  True"

  lemma hfsynth_ID_R_D:
    fixes I :: "'a itself"
    assumes "hfsynth_ID_R R a"
    assumes "intf_of_assn R I"
    shows "a ::i I"
    by simp

  lemma hfsynth_hnr_from_hfI:
    assumes "x xi. P x  hfsynth_ID_R (fst R) x  hn_refine (emp * hn_ctxt (fst R) x xi) (f$xi) (emp * hn_ctxt (snd R) x xi) S (g$x)"
    shows "(f,g)  [P]a R  S"
    using assms
    unfolding hfref_def 
    by (auto simp: hn_ctxt_def)


  lemma hfsynth_ID_R_uncurry_unfold: 
    "hfsynth_ID_R (to_hnr_prod R S) (a,b)  hfsynth_ID_R R a  hfsynth_ID_R S b" 
    "hfsynth_ID_R (fst (hf_pres R k))  hfsynth_ID_R R"
    by (auto intro!: eq_reflection)

  ML signature SEPREF_RULES = sig
      (* Analysis of relations, both fref and fun_rel *)
      (* "R1→...→Rn→_" / "[_]f ((R1×rR2)...×rRn)"  ↦  "[R1,...,Rn]" *)
      val binder_rels: term -> term list 
      (* "_→...→_→S" / "[_]f _ → S"  ↦  "S" *)
      val body_rel: term -> term 
      (* Map →/fref to (precond,args,res). NONE if no/trivial precond. *)
      val analyze_rel: term -> term option * term list * term 
      (* Make trivial ("λ_. True") precond *)
      val mk_triv_precond: term list -> term 
      (* Make "[P]f ((R1×rR2)...×rRn) → S". Insert trivial precond if NONE. *)
      val mk_rel: term option * term list * term -> term 
      (* Map relation to (args,res) *)
      val strip_rel: term -> term list * term 

      (* Make hfprod (op *a) *)
      val mk_hfprod : term * term -> term
      val mk_hfprods : term list -> term

      (* Determine interface type of refinement assertion, using default fallback
        if necessary. Use named_thms intf_of_assn for configuration. *)
      val intf_of_assn : Proof.context -> term -> typ

      (*
        Convert a parametricity theorem in higher-order form to
        uncurried fref-form. For functions without arguments, 
        a unit-argument is added.

        TODO/FIXME: Currently this only works for higher-order theorems,
          i.e., theorems of the form (f,g)∈R1→…→Rn. 
          
          First-order theorems are silently treated as refinement theorems
          for functions with zero arguments, i.e., a unit-argument is added.
      *)
      val to_fref : Proof.context -> thm -> thm

      (* Convert a parametricity or fref theorem to first order form *)
      val to_foparam : Proof.context -> thm -> thm

      (* Convert schematic hfref goal to hnr-goal *)
      val prepare_hfref_synth_tac : Proof.context -> tactic'

      (* Convert theorem in hfref-form to hnr-form *)
      val to_hnr : Proof.context -> thm -> thm

      (* Convert theorem in hnr-form to hfref-form *)
      val to_hfref: Proof.context -> thm -> thm

      (* Convert theorem to given form, if not yet in this form *)
      val ensure_fref : Proof.context -> thm -> thm
      val ensure_fref_nres : Proof.context -> thm -> thm
      val ensure_hfref : Proof.context -> thm -> thm
      val ensure_hnr : Proof.context -> thm -> thm


      type hnr_analysis = {
        thm: thm,                     (* Original theorem, may be normalized *)
        precond: term,                (* Precondition, abstracted over abs-arguments *)
        prems : term list,            (* Premises not depending on arguments *)
        ahead: term * bool,           (* Abstract function, has leading RETURN *)
        chead: term * bool,           (* Concrete function, has leading return *)
        argrels: (term * bool) list,  (* Argument relations, preserved (keep-flag) *)
        result_rel: term              (* Result relation *)
      }
  
      val analyze_hnr: Proof.context -> thm -> hnr_analysis
      val pretty_hnr_analysis: Proof.context -> hnr_analysis -> Pretty.T
      val mk_hfref_thm: Proof.context -> hnr_analysis -> thm
  
  

      (* Simplify precondition of fref/hfref-theorem *)
      val simplify_precond: Proof.context -> thm -> thm

      (* Normalize hfref-theorem after composition *)
      val norm_fcomp_rule: Proof.context -> thm -> thm

      (* Replace "pure ?A" by "?A'" and is_pure constraint, then normalize *)
      val add_pure_constraints_rule: Proof.context -> thm -> thm

      (* Compose fref/hfref and fref theorem, to produce hfref theorem.
        The input theorems may also be in ho-param or hnr form, and
        are converted accordingly.
      *)
      val gen_compose : Proof.context -> thm -> thm -> thm

      (* FCOMP-attribute *)
      val fcomp_attrib: attribute context_parser
    end

    structure Sepref_Rules: SEPREF_RULES = struct

      local open Refine_Util Relators in
        fun binder_rels @{mpat "?F  ?G"} = F::binder_rels G
          | binder_rels @{mpat "fref _ ?F _"} = strip_prodrel_left F
          | binder_rels _ = []
    
        local 
          fun br_aux @{mpat "_  ?G"} = br_aux G
            | br_aux R = R
        in    
          fun body_rel @{mpat "fref _ _ ?G"} = G
            | body_rel R = br_aux R
        end
    
        fun strip_rel R = (binder_rels R, body_rel R)   
    
        fun analyze_rel @{mpat "fref (λ_. True) ?R ?S"} = (NONE,strip_prodrel_left R,S)
          | analyze_rel @{mpat "fref ?P ?R ?S"} = (SOME P,strip_prodrel_left R,S)
          | analyze_rel R = let
              val (args,res) = strip_rel R
            in
              (NONE,args,res)
            end
    
        fun mk_triv_precond Rs = absdummy (map rel_absT Rs |> list_prodT_left) @{term True}
    
        fun mk_rel (P,Rs,S) = let 
          val R = list_prodrel_left Rs 
    
          val P = case P of 
              SOME P => P 
            | NONE => mk_triv_precond Rs
    
        in 
          @{mk_term "fref ?P ?R ?S"} 
        end
      end


      fun mk_hfprod (a, b) = @{mk_term "?a*a?b"}
  
      local 
        fun mk_hfprods_rev [] = @{mk_term "unit_assnk"}
          | mk_hfprods_rev [Rk] = Rk
          | mk_hfprods_rev (Rkn::Rks) = mk_hfprod (mk_hfprods_rev Rks, Rkn)
      in
        val mk_hfprods = mk_hfprods_rev o rev
      end


      fun intf_of_assn ctxt t = let
        val orig_ctxt = ctxt
        val (t,ctxt) = yield_singleton (Variable.import_terms false) t ctxt

        val v = TVar (("T",0),Proof_Context.default_sort ctxt ("T",0)) |> Logic.mk_type
        val goal = @{mk_term "Trueprop (intf_of_assn ?t ?v)"}

        val i_of_assn_rls = 
          Named_Theorems_Rev.get ctxt @{named_theorems_rev intf_of_assn}
          @ @{thms intf_of_assn_fallback}

        fun tac ctxt = REPEAT_ALL_NEW (resolve_tac ctxt i_of_assn_rls)

        val thm = Goal.prove ctxt [] [] goal (fn {context,...} => ALLGOALS (tac context))
        val intf = case Thm.concl_of thm of
            @{mpat "Trueprop (intf_of_assn _ (?v ASp TYPE (_)))"} => v 
          | _ => raise THM("Intf_of_assn: Proved a different theorem?",~1,[thm])

        val intf = singleton (Variable.export_terms ctxt orig_ctxt) intf
          |> Logic.dest_type

      in
        intf
      end

      datatype rthm_type = 
        RT_HOPARAM    (* (_,_) ∈ _ → … → _ *)
      | RT_FREF       (* (_,_) ∈ [_]f _ → _ *)
      | RT_HNR        (* hn_refine _ _ _ _ _ *)
      | RT_HFREF      (* (_,_) ∈ [_]a _ → _ *)
      | RT_OTHER

      fun rthm_type thm =
        case Thm.concl_of thm |> HOLogic.dest_Trueprop of
          @{mpat "(_,_)  fref _ _ _"} => RT_FREF
        | @{mpat "(_,_)  hfref _ _ _"} => RT_HFREF
        | @{mpat "hn_refine _ _ _ _ _"} => RT_HNR
        | @{mpat "(_,_)  _"} => RT_HOPARAM (* TODO: Distinction between ho-param and fo-param *)
        | _ => RT_OTHER


      fun to_fref ctxt thm = let
        open Conv
      in  
        case Thm.concl_of thm |> HOLogic.dest_Trueprop of
          @{mpat "(_,_)__"} =>
            Local_Defs.unfold0 ctxt @{thms fref_param1} thm
            |> fconv_rule (repeat_conv (Refine_Util.ftop_conv (K (rewr_conv @{thm fref_nest})) ctxt))
            |> Local_Defs.unfold0 ctxt @{thms in_CURRY_conv}
        | @{mpat "(_,_)_"} => thm RS @{thm fref_param0I}   
        | _ => raise THM ("to_fref: Expected theorem of form (_,_)∈_",~1,[thm])
      end

      fun to_foparam ctxt thm = let
        val unf_thms = @{thms 
          split_tupled_all prod_rel_simp uncurry_apply cnv_conj_to_meta Product_Type.split}
      in
        case Thm.concl_of thm of
          @{mpat "Trueprop ((_,_)  fref _ _ _)"} =>
            (@{thm frefD} OF [thm])
            |> forall_intr_vars
            |> Local_Defs.unfold0 ctxt unf_thms
            |> Variable.gen_all ctxt
        | @{mpat "Trueprop ((_,_)  _)"} =>
            Parametricity.fo_rule thm
        | _ => raise THM("Expected parametricity or fref theorem",~1,[thm])
      end

      fun to_hnr ctxt thm =
        (thm RS @{thm hf2hnr})
        |> Local_Defs.unfold0 ctxt @{thms to_hnr_prod_fst_snd keep_drop_sels} (* Resolve fst and snd over *a and Rk, Rd *)
        |> Local_Defs.unfold0 ctxt @{thms hnr_uncurry_unfold} (* Resolve products for uncurried parameters *)
        |> Local_Defs.unfold0 ctxt @{thms uncurry_apply uncurry_APP assn_one_left split} (* Remove the uncurry modifiers, the emp-dummy, and unfold product cases *)
        |> Local_Defs.unfold0 ctxt @{thms hn_ctxt_ctxt_fix_conv} (* Remove duplicate hn_ctxt tagging *)
        |> Local_Defs.unfold0 ctxt @{thms all_to_meta imp_to_meta HOL.True_implies_equals HOL.implies_True_equals Pure.triv_forall_equality cnv_conj_to_meta} (* Convert to meta-level, remove vacuous condition *)
        |> Local_Defs.unfold0 ctxt (Named_Theorems.get ctxt @{named_theorems to_hnr_post}) (* Post-Processing *)
        |> Goal.norm_result ctxt
        |> Conv.fconv_rule Thm.eta_conversion

      (* Convert schematic hfref-goal to hn_refine goal *)  
      fun prepare_hfref_synth_tac ctxt = let
        val i_of_assn_rls = 
          Named_Theorems_Rev.get ctxt @{named_theorems_rev intf_of_assn}
          @ @{thms intf_of_assn_fallback}

        val to_hnr_post_rls = 
          Named_Theorems.get ctxt @{named_theorems to_hnr_post}

        val i_of_assn_tac = (
          REPEAT' (
            DETERM o dresolve_tac ctxt @{thms hfsynth_ID_R_D}
            THEN' DETERM o SOLVED' (REPEAT_ALL_NEW (resolve_tac ctxt i_of_assn_rls))
          )
        )
      in
        (* Note: To re-use the to_hnr infrastructure, we first work with
          $-tags on the abstract function, which are finally removed.
        *)
        resolve_tac ctxt @{thms hfsynth_hnr_from_hfI} THEN_ELSE' (
          SELECT_GOAL (
            unfold_tac ctxt @{thms to_hnr_prod_fst_snd keep_drop_sels hf_pres_fst} (* Distribute fst,snd over product and hf_pres *)
            THEN unfold_tac ctxt @{thms hnr_uncurry_unfold hfsynth_ID_R_uncurry_unfold} (* Curry parameters *)
            THEN unfold_tac ctxt @{thms uncurry_apply uncurry_APP assn_one_left split} (* Curry parameters (II) and remove emp assertion *)
            (*THEN unfold_tac ctxt @{thms hn_ctxt_ctxt_fix_conv} (* Remove duplicate hn_ctxt (Should not be necessary) *)*)
            THEN unfold_tac ctxt @{thms all_to_meta imp_to_meta HOL.True_implies_equals HOL.implies_True_equals Pure.triv_forall_equality cnv_conj_to_meta} (* Convert precondition to meta-level *)
            THEN ALLGOALS i_of_assn_tac (* Generate _::i_ premises*)
            THEN unfold_tac ctxt to_hnr_post_rls (* Postprocessing *)
            THEN unfold_tac ctxt @{thms APP_def} (* Get rid of $ - tags *)
          )
        ,
          K all_tac
        )
      end


      (************************************)  
      (* Analyze hnr *)
      structure Termtab2 = Table(
        type key = term * term 
        val ord = prod_ord Term_Ord.fast_term_ord Term_Ord.fast_term_ord);
  
      type hnr_analysis = {
        thm: thm,                     
        precond: term,                
        prems : term list,
        ahead: term * bool,           
        chead: term * bool,           
        argrels: (term * bool) list,  
        result_rel: term              
      }
  
    
      fun analyze_hnr (ctxt:Proof.context) thm = let
    
        (* Debug information: Stores string*term pairs, which are pretty-printed on error *)
        val dbg = Unsynchronized.ref []
        fun add_dbg msg ts = (
          dbg := (msg,ts) :: !dbg;
          ()
        )
        fun pretty_dbg (msg,ts) = Pretty.block [
          Pretty.str msg,
          Pretty.str ":",
          Pretty.brk 1,
          Pretty.list "[" "]" (map (Syntax.pretty_term ctxt) ts)
        ]
        fun pretty_dbgs l = map pretty_dbg l |> Pretty.fbreaks |> Pretty.block
    
        fun trace_dbg msg = Pretty.block [Pretty.str msg, Pretty.fbrk, pretty_dbgs (rev (!dbg))] |> Pretty.string_of |> tracing
    
        fun fail msg = (trace_dbg msg; raise THM(msg,~1,[thm])) 
        fun assert cond msg = cond orelse fail msg;
    
    
        (* Heads may have a leading return/RETURN.
          The following code strips off the leading return, unless it has the form
          "return x" for an argument x
        *)
        fun check_strip_leading args t f = (* Handle the case RETURN x, where x is an argument *)
          if Termtab.defined args f then (t,false) else (f,true)
    
        fun strip_leading_RETURN args (t as @{mpat "RETURN$(?f)"}) = check_strip_leading args t f
          | strip_leading_RETURN args (t as @{mpat "RETURN ?f"}) = check_strip_leading args t f
          | strip_leading_RETURN _ t = (t,false)
    
        fun strip_leading_return args (t as @{mpat "return$(?f)"}) = check_strip_leading args t f
            | strip_leading_return args (t as @{mpat "return ?f"}) = check_strip_leading args t f
            | strip_leading_return _ t = (t,false)
    
    
        (* The following code strips the arguments of the concrete or abstract
          function. It knows how to handle APP-tags ($), and stops at PR_CONST-tags.
    
          Moreover, it only strips actual arguments that occur in the 
          precondition-section of the hn_refine-statement. This ensures
          that non-arguments, like maxsize, are treated correctly.
        *)    
        fun strip_fun _ (t as @{mpat "PR_CONST _"}) = (t,[])
          | strip_fun s (t as @{mpat "?f$?x"}) = check_arg s t f x
          | strip_fun s (t as @{mpat "?f ?x"}) = check_arg s t f x
          | strip_fun _ f = (f,[])
        and check_arg s t f x = 
            if Termtab.defined s x then
              strip_fun s f |> apsnd (curry op :: x)
            else (t,[])  
    
        (* Arguments in the pre/postcondition are wrapped into hn_ctxt tags. 
          This function strips them off. *)    
        fun dest_hn_ctxt @{mpat "hn_ctxt ?R ?a ?c"} = ((a,c),R)
          | dest_hn_ctxt _ = fail "Invalid hn_ctxt parameter in pre or postcondition"
    
    
        fun dest_hn_refine @{mpat "(hn_refine ?G ?c ?G' ?R ?a)"} = (G,c,G',R,a) 
          | dest_hn_refine _ = fail "Conclusion is not a hn_refine statement"
    
        (*
          Strip separation conjunctions. Special case for "emp", which is ignored. 
        *)  
        fun is_emp @{mpat emp} = true | is_emp _ = false
  
        val strip_star' = Sepref_Basic.strip_star #> filter (not o is_emp)
  
        (* Compare Termtab2s for equality of keys *)  
        fun pairs_eq pairs1 pairs2 = 
                  Termtab2.forall (Termtab2.defined pairs1 o fst) pairs2
          andalso Termtab2.forall (Termtab2.defined pairs2 o fst) pairs1
    
    
        fun atomize_prem @{mpat "Trueprop ?p"} = p
          | atomize_prem _ = fail "Non-atomic premises"
    
        (* Make HOL conjunction list *)  
        fun mk_conjs [] = @{const True}
          | mk_conjs [p] = p
          | mk_conjs (p::ps) = HOLogic.mk_binop @{const_name "HOL.conj"} (p,mk_conjs ps)
    
    
        (***********************)      
        (* Start actual analysis *)
    
        val _ = add_dbg "thm" [Thm.prop_of thm]
        val prems = Thm.prems_of thm
        val concl = Thm.concl_of thm |> HOLogic.dest_Trueprop
        val (G,c,G',R,a) = dest_hn_refine concl
    
        val pre_pairs = G 
          |> strip_star'
          |> tap (add_dbg "precondition")
          |> map dest_hn_ctxt
          |> Termtab2.make
    
        val post_pairs = G' 
          |> strip_star'
          |> tap (add_dbg "postcondition")
          |> map dest_hn_ctxt
          |> Termtab2.make
    
        val _ = assert (pairs_eq pre_pairs post_pairs) 
          "Parameters in precondition do not match postcondition"
    
        val aa_set = pre_pairs |> Termtab2.keys |> map fst |> Termtab.make_set
        val ca_set = pre_pairs |> Termtab2.keys |> map snd |> Termtab.make_set
    
        val (a,leading_RETURN) = strip_leading_RETURN aa_set a
        val (c,leading_return) = strip_leading_return ca_set c
    
        val _ = add_dbg "stripped abstract term" [a]
        val _ = add_dbg "stripped concrete term" [c]
    
        val (ahead,aargs) = strip_fun aa_set a;
        val (chead,cargs) = strip_fun ca_set c;
    
        val _ = add_dbg "abstract head" [ahead]
        val _ = add_dbg "abstract args" aargs
        val _ = add_dbg "concrete head" [chead]
        val _ = add_dbg "concrete args" cargs
    
    
        val _ = assert (length cargs = length aargs) "Different number of abstract and concrete arguments";
    
        val _ = assert (not (has_duplicates op aconv aargs)) "Duplicate abstract arguments"
        val _ = assert (not (has_duplicates op aconv cargs)) "Duplicate concrete arguments"
    
        val argpairs = aargs ~~ cargs
        val ap_set = Termtab2.make_set argpairs
        val _ = assert (pairs_eq pre_pairs ap_set) "Arguments from pre/postcondition do not match operation's arguments"
    
        val pre_rels = map (the o (Termtab2.lookup pre_pairs)) argpairs
        val post_rels = map (the o (Termtab2.lookup post_pairs)) argpairs
    
        val _ = add_dbg "pre-rels" pre_rels
        val _ = add_dbg "post-rels" post_rels

        fun adjust_hf_pres @{mpat "snd (?Rk)"} = R
          | adjust_hf_pres t = t
          
        val post_rels = map adjust_hf_pres post_rels
    
        fun is_invalid R @{mpat "invalid_assn ?R'"} = R aconv R'
          | is_invalid _ @{mpat "snd (_d)"} = true
          | is_invalid _ _ = false
    
        fun is_keep (R,R') =
          if R aconv R' then true
          else if is_invalid R R' then false
          else fail "Mismatch between pre and post relation for argument"
    
        val keep = map is_keep (pre_rels ~~ post_rels)
    
        val argrels = pre_rels ~~ keep

        val aa_set = Termtab.make_set aargs
        val ca_set = Termtab.make_set cargs

        fun is_precond t =
          (exists_subterm (Termtab.defined ca_set) t andalso fail "Premise contains concrete argument")
          orelse exists_subterm (Termtab.defined aa_set) t

        val (preconds, prems) = split is_precond prems  
    
        val precond = 
          map atomize_prem preconds 
          |> mk_conjs
          |> fold lambda aargs
    
        val _ = add_dbg "precond" [precond]
        val _ = add_dbg "prems" prems
    
      in
        {
          thm = thm,
          precond = precond,
          prems = prems,
          ahead = (ahead,leading_RETURN),
          chead = (chead,leading_return),
          argrels = argrels,
          result_rel = R
        }
      end  
    
      fun pretty_hnr_analysis 
        ctxt 
        ({thm,precond,ahead,chead,argrels,result_rel,...}) 
        : Pretty.T =
      let  
        val _ = thm (* Suppress unused warning for thm *)

        fun pretty_argrel (R,k) = Pretty.block [
          Syntax.pretty_term ctxt R,
          if k then Pretty.str "k" else Pretty.str "d"
        ]
    
        val pretty_chead = case chead of 
          (t,false) => Syntax.pretty_term ctxt t 
        | (t,true) => Pretty.block [Pretty.str "return ", Syntax.pretty_term ctxt t]

        val pretty_ahead = case ahead of 
          (t,false) => Syntax.pretty_term ctxt t 
        | (t,true) => Pretty.block [Pretty.str "RETURN ", Syntax.pretty_term ctxt t]

      in
        Pretty.fbreaks [
          (*Display.pretty_thm ctxt thm,*)
          Pretty.block [ 
            Pretty.enclose "[" "]" [pretty_chead, pretty_ahead],
            Pretty.enclose "[" "]" [Syntax.pretty_term ctxt precond],
            Pretty.brk 1,
            Pretty.block (Pretty.separate " →" (map pretty_argrel argrels @ [Syntax.pretty_term ctxt result_rel]))
          ]
        ] |> Pretty.block
    
      end
    
    
      fun mk_hfref_thm 
        ctxt 
        ({thm,precond,prems,ahead,chead,argrels,result_rel}) = 
      let
    
        fun mk_keep (R,true) = @{mk_term "?Rk"}
          | mk_keep (R,false) = @{mk_term "?Rd"}
    
        (* TODO: Move, this is of general use! *)  
        fun mk_uncurry f = @{mk_term "uncurry ?f"}  
      
        (* Uncurry function for the given number of arguments. 
          For zero arguments, add a unit-parameter.
        *)
        fun rpt_uncurry n t =
          if n=0 then @{mk_term "uncurry0 ?t"}
          else if n=1 then t 
          else funpow (n-1) mk_uncurry t
      
        (* Rewrite uncurried lambda's to λ(_,_). _ form. Use top-down rewriting
          to correctly handle nesting to the left. 
    
          TODO: Combine with abstraction and  uncurry-procedure,
            and mark the deviation about uncurry as redundant 
            intermediate step to be eliminated.
        *)  
        fun rew_uncurry_lambda t = let
          val rr = map (Logic.dest_equals o Thm.prop_of) @{thms uncurry_def uncurry0_def}
          val thy = Proof_Context.theory_of ctxt
        in 
          Pattern.rewrite_term_top thy rr [] t 
        end  
    
        (* Shortcuts for simplification tactics *)
        fun gsimp_only ctxt sec = let
          val ss = put_simpset HOL_basic_ss ctxt |> sec
        in asm_full_simp_tac ss end
    
        fun simp_only ctxt thms = gsimp_only ctxt (fn ctxt => ctxt addsimps thms)
    
    
        (********************************)
        (* Build theorem statement *)
        (* ⟦prems⟧ ⟹ (chead,ahead) ∈ [precond] rels → R *)
    
        (* Uncurry precondition *)
        val num_args = length argrels
        val precond = precond
          |> rpt_uncurry num_args
          |> rew_uncurry_lambda (* Convert to nicer λ((...,_),_) - form*)

        (* Re-attach leading RETURN/return *)
        fun mk_RETURN (t,r) = if r then 
            let
              val T = funpow num_args range_type (fastype_of (fst ahead))
              val tRETURN = Const (@{const_name RETURN}, T --> Type(@{type_name nres},[T]))
            in
              Refine_Util.mk_compN num_args tRETURN t
            end  
          else t
    
        fun mk_return (t,r) = if r then 
            let
              val T = funpow num_args range_type (fastype_of (fst chead))
              val tRETURN = Const (@{const_name return}, T --> Type(@{type_name Heap},[T]))
            in
              Refine_Util.mk_compN num_args tRETURN t
            end  
          else t
          
        (* Hrmpf!: Gone for good from 2015→2016. Inserting ctxt-based substitute here. *)  
        fun certify_inst ctxt (instT, inst) =
         (map (apsnd (Thm.ctyp_of ctxt)) instT,
          map (apsnd (Thm.cterm_of ctxt)) inst);

        (*  
        fun mk_RETURN (t,r) = if r then @{mk_term "RETURN o ?t"} else t
        fun mk_return (t,r) = if r then @{mk_term "return o ?t"} else t
        *)
    
        (* Uncurry abstract and concrete function, append leading return *)
        val ahead = ahead |> mk_RETURN |> rpt_uncurry num_args  
        val chead = chead |> mk_return |> rpt_uncurry num_args 
    
        (* Add keep-flags and summarize argument relations to product *)
        val argrel = map mk_keep argrels |> rev (* TODO: Why this rev? *) |> mk_hfprods
    
        (* Produce final result statement *)
        val result = @{mk_term "Trueprop ((?chead,?ahead)  [?precond]a ?argrel  ?result_rel)"}
        val result = Logic.list_implies (prems,result)
    
        (********************************)
        (* Prove theorem *)
    
        (* Create context and import result statement and original theorem *)
        val orig_ctxt = ctxt
        (*val thy = Proof_Context.theory_of ctxt*)
        val (insts, ctxt) = Variable.import_inst true [result] ctxt
        val insts' = certify_inst ctxt insts
        val result = Term_Subst.instantiate insts result
        val thm = Thm.instantiate insts' thm
    
        (* Unfold APP tags. This is required as some APP-tags have also been unfolded by analysis *)
        val thm = Local_Defs.unfold0 ctxt @{thms APP_def} thm
    
        (* Tactic to prove the theorem. 
          A first step uses hfrefI to get a hnr-goal.
          This is then normalized in several consecutive steps, which 
            get rid of uncurrying. Finally, the original theorem is used for resolution,
            where the pre- and postcondition, and result relation are connected with 
            a consequence rule, to handle unfolded hn_ctxt-tags, re-ordered relations,
            and introduced unit-parameters (TODO: 
              Mark artificially introduced unit-parameter specially, it may get confused 
              with intentional unit-parameter, e.g., functional empty_set ()!)
    
          *)
        fun tac ctxt = 
                resolve_tac ctxt @{thms hfrefI}
          THEN' gsimp_only ctxt (fn c => c 
            addsimps @{thms uncurry_def hn_ctxt_def uncurry0_def
                            keep_drop_sels uc_hfprod_sel o_apply
                            APP_def}
            |> Splitter.add_split @{thm prod.split}
          ) 
    
          THEN' TRY o (
            REPEAT_ALL_NEW (match_tac ctxt @{thms allI impI})
            THEN' simp_only ctxt @{thms Product_Type.split prod.inject})
    
          THEN' TRY o REPEAT_ALL_NEW (ematch_tac ctxt @{thms conjE})
          THEN' TRY o hyp_subst_tac ctxt
          THEN' simp_only ctxt @{thms triv_forall_equality}
          THEN' (
            resolve_tac ctxt @{thms hn_refine_cons[rotated]} 
            THEN' (resolve_tac ctxt [thm] THEN_ALL_NEW assume_tac ctxt))
          THEN_ALL_NEW simp_only ctxt 
            @{thms hn_ctxt_def entt_refl pure_unit_rel_eq_empty
              mult_ac mult_1 mult_1_right keep_drop_sels}  
    
        (* Prove theorem *)  
        val result = Thm.cterm_of ctxt result
        val rthm = Goal.prove_internal ctxt [] result (fn _ => ALLGOALS (tac ctxt))
    
        (* Export statement to original context *)
        val rthm = singleton (Variable.export ctxt orig_ctxt) rthm
    
        (* Post-processing *)
        val rthm = Local_Defs.unfold0 ctxt (Named_Theorems.get ctxt @{named_theorems to_hfref_post}) rthm

      in
        rthm
      end
  
      fun to_hfref ctxt = analyze_hnr ctxt #> mk_hfref_thm ctxt




      (***********************************)
      (* Composition *)

      local