Session Dict_Construction

Theory Introduction

section ‹Dictionary Construction›

theory Introduction
imports Main
begin

subsection ‹Introduction›

text ‹
 Isabelle's logic features \emph{type classes}~\cite{haftmann2007typeclasses,wenzel1997typeclasses}.
 These are built into the kernel and are used extensively in theory developments.
 The existing \emph{code generator}, when targeting Standard ML, performs the well-known dictionary
 construction or \emph{dictionary translation}~\cite{haftmann2010codegeneration}.
 This works by replacing type classes with records, instances with values, and occurrences with
 explicit parameters.

 Haftmann and Nipkow give a pen-and-paper correctness proof of this construction
 \cite[§›4.1]{haftmann2010codegeneration}, based on a notion of \emph{higher-order rewrite
 systems.}
 The resulting theorem then states that any well-typed term is reduction-equivalent before and after
 class elimination.
 In this work, the dictionary construction is performed in a certified fashion, that is, the
 equivalence is a theorem inside the logic.
›

subsection ‹Encoding classes›

text ‹
  The choice of representation of a dictionary itself is straightforward: We model it as a
  @{command datatype}, along with functions returning values of that type. The alternative here
  would have been to use the @{command record} package. The obvious advantage is that we could
  easily model subclass relationships through record inheritance. However, records do not support
  multiple inheritance. Since records offer no advantage over datatypes in that regard, we opted for
  the more modern @{command datatype} package.
›

text ‹Consider the following example:›

class plus =
  fixes plus :: "'a  'a  'a"

text ‹
  This will get translated to a @{command datatype} with a single constructor taking a single
  argument:
›

datatype 'a dict_plus =
  mk_plus (param_plus: "'a  'a  'a")

text ‹A function using the @{class plus} constraint:›

definition double :: "'a::plus  'a" where
"double x = plus x x"

definition double' :: "'a dict_plus  'a  'a" where
"double' dict x = param_plus dict x x"

subsection ‹Encoding instances›

text ‹
  A more controversial design decision is how to represent dictionary certificates. For example,
  given a value of type @{typ "nat dict_plus"}, how do we know that this is a faithful representation
  of the @{class plus} instance for @{typ nat}?
›

text  Florian Haftmann proposed a ``shallow encoding''. It works by exploiting the internal treatment
    of constants with sort constraints in the Isabelle kernel. Constants themselves do not carry
    sort constraints, only their definitional equations. The fact that a constant only appears with
    these constraints on the surface of the system is a feature of type inference.

    Instead, we can instruct the system to ignore these constraints. However, any attempt at
    ``hiding'' the constraints behind a type definition ultimately does not work: The nonemptiness
    proof requires a witness of a valid dictionary for an arbitrary, but fixed type @{typ 'a}, which
    is of course not possible (see §›\ref{sec:impossibility} for details).

   The certificates contain the class axioms directly. For example, the @{class semigroup_add}
    class requires @{term "(a + b) + c = a + (b + c)"}.

    Translated into a definition, this would look as follows:

    @{term
      "cert_plus dict 
        (a b c. param_plus dict (param_plus dict a b) c = param_plus dict a (param_plus dict b c))"}

    Proving that instances satisfy this certificate is trivial.

    However, the equality proof of f'› and f› is impossible: they are simply not equal in general.
    Nothing would prevent someone from defining an alternative dictionary using multiplication
    instead of addition and the certificate would still hold; but obviously functions using
    @{const plus} on numbers would expect addition.

    Intuitively, this makes sense: the above notion of ``certificate'' establishes no connection
    between original instantiation and newly-generated dictionaries.

    Instead of proving equality, one would have to ``lift'' all existing theorems over the old
    constants to the new constants.

   In order for equality between new and old constants to hold, the certificate needs to capture
    that the dictionary corresponds exactly to the class constants. This is achieved by the
    representation below.
    It literally states that the fields of the dictionary are equal to the class constants.
    The condition of the resulting equation can only be instantiated with dictionaries corresponding
    to existing class instances. This constitutes a ‹closed world› assumption, i.e., callers of
    generated code may not invent own instantiations.
›

definition cert_plus :: "'a::plus dict_plus  bool" where
"cert_plus dict  (param_plus dict = plus)"

text ‹
  Based on that definition, we can prove that @{const double} and @{const double'} are equivalent:
›

lemma "cert_plus dict  double' dict = double"
unfolding cert_plus_def double'_def double_def
by auto

text ‹
  An unconditional equation can be obtained by specializing the theorem to a ground type and
  supplying a valid dictionary.
›

subsection ‹Implementation›

text ‹
  When translating a constant f›, we use existing mechanisms in Isabelle to obtain its
  ‹code graph›. The graph contains the code equations of all transitive dependencies (i.e.,
  other constants) of f›. In general, we have to re-define each of these dependencies. For that,
  we use the internal interface of the @{command function} package and feed it the code equations
  after performing the dictionary construction. In the standard case, where the user has not
  performed a custom code setup, the resulting function looks similar to its original definition.
  But the user may have also changed the implementation of a function significantly afterwards.
  This imposes some restrictions:

   The new constant needs to be proven terminating. We apply some heuristics to transfer the
    original termination proof to the new definition. This only works when the termination condition
    does not rely on class axioms. (See §›\ref{sec:termination} for details.)
   Pattern matching must be performed on datatypes, instead of the more general
    @{command code_datatype}s.
   The set of code equations must be exhaustive and non-overlapping.
›

end

Theory Impossibility

subsection ‹Impossibility of hiding sort constraints›
text_raw ‹\label{sec:impossibility}›

text ‹Coauthor of this section: Florian Haftmann›

theory Impossibility
imports Main
begin

axiomatization of_prop :: "prop  bool" where
of_prop_Trueprop [simp]: "of_prop (Trueprop P)  P" and
Trueprop_of_prop [simp]: "Trueprop (of_prop Q)  PROP Q"

text ‹A type satisfies the certificate if there is an instance of the class.›

definition is_sg :: "'a itself  bool" where
"is_sg TYPE('a) = of_prop OFCLASS('a, semigroup_add_class)"

text ‹We trick the parser into ignoring the sort constraint of @{const plus}.›

setup ‹Sign.add_const_constraint (@{const_name plus}, SOME @{typ "'a::{} => 'a  'a"})

definition sg :: "('a  'a  'a)  bool" where
"sg plus  plus = Groups.plus  is_sg TYPE('a)" for plus

text ‹Attempt: Define a type that contains all legal @{const plus} functions.›

typedef (overloaded) 'a Sg = "Collect sg :: ('a  'a  'a) set"
  morphisms the_plus Sg
  unfolding sg_def[abs_def]
  apply (simp add: is_sg_def)

text ‹We need to prove @{term "OFCLASS('a, semigroup_add_class)"} for arbitrary @{typ 'a}, which is
impossible.›

oops

end

Theory Dict_Construction

section ‹Setup›

theory Dict_Construction
imports Automatic_Refinement.Refine_Util
keywords "declassify" :: thy_decl
begin

definition set_of :: "('a  'b  bool)  ('a × 'b) set" where
"set_of P = {(x, y). P x y}"

lemma wfP_implies_wf_set_of: "wfP P  wf (set_of P)"
unfolding wfP_def set_of_def .

lemma wf_set_of_implies_wfP: "wf (set_of P)  wfP P"
unfolding wfP_def set_of_def .

lemma wf_simulate_simple:
  assumes "wf r"
  assumes "x y. (x, y)  r'  (g x, g y)  r"
  shows "wf r'"
using assms
by (metis in_inv_image wf_eq_minimal wf_inv_image)

lemma set_ofI: "P x y  (x, y)  set_of P"
unfolding set_of_def by simp

lemma set_ofD: "(x, y)  set_of P  P x y"
unfolding set_of_def by simp

lemma wfP_simulate_simple:
  assumes "wfP r"
  assumes "x y. r' x y  r (g x) (g y)"
  shows "wfP r'"
apply (rule wf_set_of_implies_wfP)
apply (rule wf_simulate_simple[where g = g])
apply (rule wfP_implies_wf_set_of)
apply (fact assms)
using assms(2) by (auto intro: set_ofI dest: set_ofD)

lemma wf_implies_dom: "wf (set_of R)  All (Wellfounded.accp R)"
apply (rule allI)
apply (rule accp_wfPD)
apply (rule wf_set_of_implies_wfP) .

lemma wfP_implies_dom: "wfP R  All (Wellfounded.accp R)"
by (metis wfP_implies_wf_set_of wf_implies_dom)

named_theorems dict_construction_specs

ML_file ‹dict_construction_util.ML›
ML_file ‹transfer_termination.ML›
ML_file ‹congruences.ML›
ML_file ‹side_conditions.ML›
ML_file ‹class_graph.ML›
ML_file ‹dict_construction.ML›

method_setup fo_cong_rule = Attrib.thm >> (fn thm => fn ctxt => SIMPLE_METHOD' (Dict_Construction_Util.fo_cong_tac ctxt thm)) "resolve congruence rule using first-order matching"

declare [[code drop: "(∧)"]]
lemma [code]: "True  p  p" "False  p  False" by auto

declare [[code drop: "(∨)"]]
lemma [code]: "True  p  True" "False  p  p" by auto

declare comp_cong[fundef_cong del]
declare fun.map_cong[fundef_cong]

end

File ‹dict_construction_util.ML›

infixr 5 ==>
infixr ===>
infix 1 CONTINUE_WITH CONTINUE_WITH_FW

signature DICT_CONSTRUCTION_UTIL = sig
  (* general *)
  val split_list3: ('a * 'b * 'c) list -> 'a list * 'b list * 'c list
  val symreltab_of_symtab: 'a Symtab.table Symtab.table -> 'a Symreltab.table
  val zip_symtabs: ('a -> 'b -> 'c) -> 'a Symtab.table -> 'b Symtab.table -> 'c Symtab.table
  val cat_options: 'a option list -> 'a list
  val partition: ('a -> bool) -> 'a list -> 'a list * 'a list
  val unappend: 'a list * 'b -> 'c list -> 'c list * 'c list
  val flat_right: ('a * 'b list) list -> ('a * 'b) list

  (* logic *)
  val ===> : term list * term -> term
  val ==> : term * term -> term
  val sortify: sort -> term -> term
  val sortify_typ: sort -> typ -> typ
  val typify: term -> term
  val typify_typ: typ -> typ
  val all_frees: term -> (string * typ) list
  val all_frees': term -> string list
  val all_tfrees: typ -> (string * sort) list

  (* printing *)
  val pretty_const: Proof.context -> string -> Pretty.T

  (* conversion/tactic *)
  val ANY: tactic list -> tactic
  val ANY': ('a -> tactic) list -> 'a -> tactic
  val CONTINUE_WITH: (int -> tactic) * (int -> tactic) list -> int -> thm -> thm Seq.seq
  val CONTINUE_WITH_FW: (int -> tactic) * (int -> tactic) list -> int -> thm -> thm Seq.seq
  val SOLVED: tactic -> tactic
  val TRY': ('a -> tactic) -> 'a -> tactic
  val descend_fun_conv: conv -> conv
  val lhs_conv: conv -> conv
  val rhs_conv: conv -> conv
  val rewr_lhs_head_conv: thm -> conv
  val rewr_rhs_head_conv: thm -> conv
  val conv_result: ('a -> thm) -> 'a -> term
  val changed_conv: ('a -> thm) -> 'a -> thm
  val maybe_induct_tac: thm list option -> term list list -> term list list -> Proof.context -> tactic
  val multi_induct_tac: thm list -> term list list -> term list list -> Proof.context -> tactic
  val print_tac': Proof.context -> string -> int -> tactic
  val fo_cong_tac: Proof.context -> thm -> int -> tactic

  (* theorem manipulation *)
  val contract: Proof.context -> thm -> thm
  val on_thms_complete: (unit -> 'a) -> thm list -> thm list

  (* theory *)
  val define_params_nosyn: term -> local_theory -> thm * local_theory
  val note_thm: binding -> thm -> local_theory -> thm * local_theory
  val note_thms: binding -> thm list -> local_theory -> thm list * local_theory

  (* timing *)
  val with_timeout: Time.time -> ('a -> 'a) -> 'a -> 'a

  (* debugging *)
  val debug: bool Config.T
  val if_debug: Proof.context -> (unit -> unit) -> unit
  val ALLGOALS': Proof.context -> (int -> tactic) -> tactic
  val prove': Proof.context -> string list -> term list -> term ->
    ({prems: thm list, context: Proof.context} -> tactic) -> thm
  val prove_common': Proof.context -> string list -> term list -> term list ->
    ({prems: thm list, context: Proof.context} -> tactic) -> thm list
end

structure Dict_Construction_Util : DICT_CONSTRUCTION_UTIL = struct

(* general *)

fun symreltab_of_symtab tab =
  Symtab.map (K Symtab.dest) tab |>
    Symtab.dest |>
    maps (fn (k, kvs) => map (apfst (pair k)) kvs) |>
    Symreltab.make

fun split_list3 [] = ([], [], [])
  | split_list3 ((x, y, z) :: rest) =
      let val (xs, ys, zs) = split_list3 rest in
        (x :: xs, y :: ys, z :: zs)
      end

fun zip_symtabs f t1 t2 =
  let
    open Symtab
    val ord = fast_string_ord
    fun aux acc [] [] = acc
      | aux acc ((k1, x) :: xs) ((k2, y) :: ys) =
        (case ord (k1, k2) of
           EQUAL   => aux (update_new (k1, f x y) acc) xs ys
         | LESS    => raise UNDEF k1
         | GREATER => raise UNDEF k2)
      | aux _ ((k, _) :: _) [] =
          raise UNDEF k
      | aux _ [] ((k, _) :: _) =
          raise UNDEF k
  in aux empty (dest t1) (dest t2) end

fun cat_options [] = []
  | cat_options (SOME x :: xs) = x :: cat_options xs
  | cat_options (NONE :: xs) = cat_options xs

fun partition f xs = (filter f xs, filter_out f xs)

fun unappend (xs, _) = chop (length xs)

fun flat_right [] = []
  | flat_right ((x, ys) :: rest) = map (pair x) ys @ flat_right rest

(* logic *)

fun x ==> y = Logic.mk_implies (x, y)
val op ===> = Library.foldr op ==>

fun sortify_typ sort (Type (tyco, args)) = Type (tyco, map (sortify_typ sort) args)
  | sortify_typ sort (TFree (name, _)) = TFree (name, sort)
  | sortify_typ _ (TVar _) = error "TVar encountered"

fun sortify sort (Const (name, typ)) = Const (name, sortify_typ sort typ)
  | sortify sort (Free (name, typ)) = Free (name, sortify_typ sort typ)
  | sortify sort (t $ u) = sortify sort t $ sortify sort u
  | sortify sort (Abs (name, typ, term)) = Abs (name, sortify_typ sort typ, sortify sort term)
  | sortify _ (Bound n) = Bound n
  | sortify _ (Var _) = error "Var encountered"

val typify_typ = sortify_typ @{sort type}
val typify = sortify @{sort type}

fun all_frees (Free (name, typ)) = [(name, typ)]
  | all_frees (t $ u) = union (op =) (all_frees t) (all_frees u)
  | all_frees (Abs (_, _, t)) = all_frees t
  | all_frees _ = []

val all_frees' = map fst o all_frees

fun all_tfrees (TFree (name, sort)) = [(name, sort)]
  | all_tfrees (Type (_, ts)) = fold (union (op =)) (map all_tfrees ts) []
  | all_tfrees _ = []

(* printing *)

fun pretty_const ctxt const =
  Syntax.pretty_term ctxt (Const (const, Sign.the_const_type (Proof_Context.theory_of ctxt) const))

(* conversion/tactic *)

fun ANY tacs = fold (curry op APPEND) tacs no_tac
fun ANY' tacs n = fold (curry op APPEND) (map (fn t => t n) tacs) no_tac
fun TRY' tac n = TRY (tac n)

fun descend_fun_conv cv =
  cv else_conv (fn ct =>
    case Thm.term_of ct of
      _ $ _ => Conv.fun_conv (descend_fun_conv cv) ct
    | _ => Conv.no_conv ct)

fun lhs_conv cv =
  cv |> Conv.arg1_conv |> Conv.arg_conv

fun rhs_conv cv =
  cv |> Conv.arg_conv |> Conv.arg_conv

fun rewr_lhs_head_conv thm =
  safe_mk_meta_eq thm |> Conv.rewr_conv |> descend_fun_conv |> lhs_conv

fun rewr_rhs_head_conv thm =
  safe_mk_meta_eq thm |> Conv.rewr_conv |> descend_fun_conv |> rhs_conv

fun conv_result cv ct =
  Thm.prop_of (cv ct) |> Logic.dest_equals |> snd

fun changed_conv cv = fn ct =>
  let
    val res = cv ct
    val (lhs, rhs) = Thm.prop_of res |> Logic.dest_equals
  in
    if lhs aconv rhs then
      raise CTERM ("no conversion", [])
    else
      res
  end

fun multi_induct_tac rules insts arbitrary ctxt =
  let
    val insts' = map (map (SOME o pair NONE o rpair false)) insts
    val arbitrary' = map (map dest_Free) arbitrary
  in
    DETERM (Induct.induct_tac ctxt false insts' arbitrary' [] (SOME rules) [] 1)
  end

fun maybe_induct_tac (SOME rules) insts arbitrary = multi_induct_tac rules insts arbitrary
  | maybe_induct_tac NONE _ _ = K all_tac

fun (tac CONTINUE_WITH tacs) i st =
  st |> (tac i THEN (fn st' =>
    let
      val n' = Thm.nprems_of st'
      val n = Thm.nprems_of st
      fun aux [] _ = all_tac
        | aux (tac :: tacs) i = tac i THEN aux tacs (i - 1)
    in
      if n' - n + 1 <> length tacs then
        raise THM ("CONTINUE_WITH: unexpected number of emerging subgoals", 0, [st'])
      else
        aux (rev tacs) (i + n' - n) st'
    end))

fun (tac CONTINUE_WITH_FW tacs) i st =
  st |> (tac i THEN (fn st' =>
    let
      val n' = Thm.nprems_of st'
      val n = Thm.nprems_of st
      fun aux [] _ st = all_tac st
        | aux (tac :: tacs) i st = st |>
            (tac i THEN (fn st' =>
             aux tacs (i + 1 + Thm.nprems_of st' - Thm.nprems_of st) st'))
    in
      if n' - n + 1 <> length tacs then
        raise THM ("unexpected number of emerging subgoals", 0, [st'])
      else
        aux tacs i st'
    end))

fun SOLVED tac = tac THEN ALLGOALS (K no_tac)

fun print_tac' ctxt str = SELECT_GOAL (print_tac ctxt str)

fun fo_cong_tac ctxt thm = SUBGOAL (fn (concl, i) =>
  let
    val lhs_of = HOLogic.dest_Trueprop #> HOLogic.dest_eq #> fst
    val concl_pat = lhs_of (Thm.concl_of thm) |> Thm.cterm_of ctxt
    val concl = lhs_of concl |> Thm.cterm_of ctxt

    val insts = Thm.first_order_match (concl_pat, concl)
  in
    resolve_tac ctxt [Drule.instantiate_normalize insts thm] i
  end handle Pattern.MATCH => no_tac)

(* theorem manipulation *)

fun contract ctxt thm =
  let
    val (((_, frees), [thm']), ctxt') = Variable.import true [thm] ctxt

    val prop = Thm.prop_of thm'
    val prems = Logic.strip_imp_prems prop
    val (lhs, rhs) =
      Logic.strip_imp_concl prop
      |> HOLogic.dest_Trueprop
      |> HOLogic.dest_eq

    fun used x =
      exists (exists_subterm (fn t => t = x)) prems

    val (f, xs) = strip_comb lhs
    val (g, ys) = strip_comb rhs

    fun loop [] ys = (0, (f, list_comb (g, rev ys)))
      | loop xs [] = (0, (list_comb (f, rev xs), g))
      | loop (x :: xs) (y :: ys) =
          if x = y andalso is_Free x andalso not (used x) then
            loop xs ys |> apfst (fn x => x + 1)
          else
            (0, (list_comb (f, rev (x :: xs)), list_comb (g, rev (y :: ys))))

    val (count, (lhs', rhs')) = loop (rev xs) (rev ys)

    val concl' = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs', rhs'))

    fun tac ctxt 0 = resolve_tac ctxt [thm] THEN_ALL_NEW (Method.assm_tac ctxt)
      | tac ctxt n = resolve_tac ctxt @{thms ext} THEN' tac ctxt (n - 1)

    val prop = prems ===> concl'
  in
    Goal.prove_future ctxt' [] [] prop (fn {context, ...} => HEADGOAL (tac context count))
    |> singleton (Variable.export ctxt' ctxt)
  end

fun on_thms_complete f thms =
  (Future.fork (fn () => (Thm.consolidate thms; f ())); thms)

(* theory *)

fun define_params_nosyn term =
  Specification.definition NONE [] [] ((Binding.empty, []), term)
  #>> snd #>> snd

fun note_thm binding thm =
  Local_Theory.note ((binding, []), [thm]) #>> snd #>> the_single

fun note_thms binding thms =
  Local_Theory.note ((binding, []), thms) #>> snd

(* timing *)

fun with_timeout time f x =
  case Exn.interruptible_capture (Timeout.apply time f) x of
    Exn.Res y => y
  | Exn.Exn (Timeout.TIMEOUT _) => x
  | Exn.Exn e => Exn.reraise e

(* debugging *)

val debug = Attrib.setup_config_bool @{binding "dict_construction_debug"} (K false)

fun if_debug ctxt f =
  if Config.get ctxt debug then f () else ()

fun ALLGOALS' ctxt = if Config.get ctxt debug then ALLGOALS else PARALLEL_ALLGOALS
fun prove' ctxt = if Config.get ctxt debug then Goal.prove ctxt else Goal.prove_future ctxt
fun prove_common' ctxt = Goal.prove_common ctxt (if Config.get ctxt debug then NONE else SOME ~1)

end

File ‹transfer_termination.ML›

signature TRANSFER_TERMINATION = sig
  val termination_tac: Function.info -> Function.info -> Proof.context -> int -> tactic
end

structure Transfer_Termination : TRANSFER_TERMINATION = struct

open Dict_Construction_Util

fun termination_tac ({R = new_R, ...}: Function.info) (old_info: Function.info) ctxt =
  let
    fun fallback_tac warn _ =
      (if warn then warning "Falling back to another termination proof" else ();
        Seq.empty)

    (*copied from BNF_Comp, in turn copied from Envir.expand_term_free*)
    fun expand_term_const defs =
      let
        val eqs = map ((fn ((x, U), u) => (x, (U, u))) o apfst dest_Const) defs;
        val get = fn Const (x, _) => AList.lookup (op =) eqs x | _ => NONE;
      in Envir.expand_term get end

    fun map_comp_bnf typ =
      let
        (* we start from a fresh lthy to avoid local hyps interfering with BNF *)
        val lthy =
          Proof_Context.theory_of ctxt
          |> Named_Target.theory_init
          |> Config.put BNF_Comp.typedef_threshold ~1
        (* we just pretend that they're all live here *)
        val live_As = all_tfrees typ
        fun flatten_tyargs Ass =
          live_As
          |> filter (fn T => exists (fn Ts => member (op =) Ts T) Ass)
        (* Dont_Inline would create new definitions, always *)
        val ((bnf, _), ((_, {map_unfolds, ...}), _)) =
          BNF_Comp.bnf_of_typ false BNF_Def.Do_Inline I flatten_tyargs live_As [] typ
            ((BNF_Comp.empty_comp_cache, BNF_Comp.empty_unfolds), lthy)
        val subst = map (Logic.dest_equals o Thm.prop_of) map_unfolds
        val t = BNF_Def.map_of_bnf bnf
      in
        (live_As, expand_term_const subst t)
      end

      val tac = case old_info of
        {R = old_R, totality = SOME totality, ...} =>
          let
            fun get_eq R =
              Inductive.the_inductive ctxt R
              |> snd |> #eqs
              |> the_single
              |> Local_Defs.abs_def_rule ctxt
            val (old_R_eq, new_R_eq) = apply2 get_eq (old_R, new_R)

            fun get_typ R =
              fastype_of R
              |> strip_type
              |> fst |> hd
              |> Type.legacy_freeze_type
            val (old_R_typ, new_R_typ) = apply2 get_typ (old_R, new_R)

            (* simple strategy: old_R and new_R are identical *)
            val simple_tac =
              let
                val totality' = Local_Defs.unfold ctxt [old_R_eq] totality
              in
                Local_Defs.unfold_tac ctxt [new_R_eq] THEN
                  HEADGOAL (SOLVED' (resolve_tac ctxt [totality']))
              end

            (* smart strategy: new_R can be simulated by old_R *)
            (* FIXME this is trigger-happy *)
            val smart_tac = Exn.interruptible_capture (fn st =>
              let
                val old_R_stripped =
                  Thm.prop_of old_R_eq
                  |> Logic.dest_equals |> snd
                  |> map_types (K dummyT)
                  |> Syntax.check_term ctxt

                val futile =
                  old_R_stripped |> exists_type (exists_subtype
                    (fn TFree (_, sort) => sort <> @{sort type}
                      | TVar (_, sort) => sort <> @{sort type}
                      | _ => false))

                fun costrip_prodT new_t old_t =
                  if Type.could_match (old_t, new_t) then
                    (0, new_t)
                  else
                    case costrip_prodT (snd (HOLogic.dest_prodT new_t)) old_t of
                      (n, stripped_t) => (n + 1, stripped_t)

                fun construct_inner_proj new_t old_t =
                  let
                    val (diff, stripped_t) = costrip_prodT new_t old_t
                    val (tfrees, f_head) = map_comp_bnf stripped_t
                    val f_args = map (K (Abs ("x", dummyT, Const (@{const_name undefined}, dummyT)))) tfrees
                    fun add_snd 0 = list_comb (map_types (K dummyT) f_head, f_args)
                      | add_snd n = Const (@{const_name comp}, dummyT) $ add_snd (n - 1) $ Const (@{const_name snd}, dummyT)
                  in
                    add_snd diff
                  end

                fun construct_outer_proj new_t old_t = case (new_t, old_t) of
                  (Type (@{type_name sum}, new_ts), Type (@{type_name sum}, old_ts)) =>
                    let
                      val ps = map2 construct_outer_proj new_ts old_ts
                    in list_comb (Const (@{const_name map_sum}, dummyT), ps) end
                | _ => construct_inner_proj new_t old_t

                val outer_proj = construct_outer_proj new_R_typ old_R_typ

                val old_R_typ_imported =
                  yield_singleton Variable.importT_terms old_R ctxt
                  |> fst |> get_typ

                val c =
                  outer_proj
                  |> Type.constraint (new_R_typ --> old_R_typ_imported)
                  |> Syntax.check_term ctxt
                  |> Thm.cterm_of ctxt

                val wf_simulate =
                  Drule.infer_instantiate ctxt [(("g", 0), c)] @{thm wf_simulate_simple}

                val old_wf = (@{thm wfP_implies_wf_set_of} OF [@{thm accp_wfPI} OF [totality]])

                val inner_tac =
                  match_tac ctxt @{thms wf_implies_dom} THEN'
                    match_tac ctxt [wf_simulate] CONTINUE_WITH_FW
                      [resolve_tac ctxt [old_wf],
                       match_tac ctxt @{thms set_ofI} THEN'
                         dmatch_tac ctxt @{thms set_ofD} THEN'
                         SELECT_GOAL (Local_Defs.unfold_tac ctxt [old_R_eq, new_R_eq]) THEN'
                         TRY'
                           (REPEAT_ALL_NEW (ematch_tac ctxt @{thms conjE exE}) THEN'
                             hyp_subst_tac_thin true ctxt THEN'
                             REPEAT_ALL_NEW (match_tac ctxt @{thms conjI exI}))]

                val unfold_tac =
                  Local_Defs.unfold_tac ctxt @{thms comp_apply id_apply prod.sel} THEN
                    auto_tac ctxt

                val tac = SOLVED (HEADGOAL inner_tac THEN unfold_tac)
              in
                if futile then
                  (warning "Termination relation has sort constraints; termination proof is unlikely to be automatic or may even be impossible";
                   Seq.empty)
                else
                  (tracing "Trying to re-use termination proof";
                   tac st)
              end)
            #> Exn.get_res
            #> the_default Seq.empty
          in
            simple_tac ORELSE smart_tac ORELSE fallback_tac true
          end
      | _ => fallback_tac false
  in SELECT_GOAL tac end

end

File ‹congruences.ML›

signature CONGRUENCES = sig
  type rule =
    {rule: thm,
     concl: term,
     prems: term list,
     proper: bool}

  type ctx = (string * typ) list * term list
  datatype ctx_tree =
    Tree of (term * (rule * (ctx * ctx_tree) list) option)

  val export_term_ctx: ctx -> term -> term

  val import_rule: Proof.context -> thm -> rule
  val import_term: Proof.context -> rule list -> term -> ctx_tree

  val fold_tree:
    (term -> 'a) ->
    (term -> rule -> (ctx * 'a) list -> 'a) ->
    ctx_tree -> 'a
end

structure Congruences: CONGRUENCES = struct

type rule =
  {rule: thm,
   concl: term,
   prems: term list,
   proper: bool}

type ctx = (string * typ) list * term list

fun export_term_ctx (fixes, assumes) =
  fold_rev (curry Logic.mk_implies) assumes
  #> fold_rev (Logic.all o Free) fixes

datatype ctx_tree =
  Tree of (term * (rule * (ctx * ctx_tree) list) option)

fun fold_tree atom cong t =
  let
    fun go (Tree (t, NONE)) = atom t
      | go (Tree (t, SOME (r, ctxs))) = cong t r (map (apsnd go) ctxs)
  in go t end

fun raw_import_rule {check: bool, proper: bool} ctxt thm =
  let
    val concl = Thm.concl_of thm
    val (lhs, rhs) = Logic.dest_equals concl
    val prems = Thm.prems_of thm
    val rule = {rule = thm, concl = concl, prems = prems, proper = proper}
  in
    if check then
      let
        val ((f_lhs, _), _) = strip_comb lhs |>> dest_Const
        val ((r_lhs, _), _) = strip_comb rhs |>> dest_Const
      in
        if f_lhs <> r_lhs then
          error ("invalid cong rule " ^ Syntax.string_of_term ctxt (Thm.prop_of thm))
        else
          rule
      end
    else
      rule
  end

val import_rule = raw_import_rule {check = true, proper = true}

fun mk_cong n = if n <= 1 then @{thm cong} else mk_cong (n - 1) OF [@{thm cong}]

fun cong_rule n =
  raw_import_rule {check = false, proper = false} @{context} (mk_cong n RS @{thm eq_reflection})

val ext_rule =
  raw_import_rule {check = false, proper = false} @{context} (@{thm ext} RS @{thm eq_reflection})

fun import_term ctxt rules t =
  let
    val thy = Proof_Context.theory_of ctxt

    val lhs_of = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop

    fun go ctxt t =
      let
        (* FIXME eventually, this should be the arity of fst (strip_comb t) *)
        val arity = length (snd (strip_comb t))

        val rules = rules @ [cong_rule arity, ext_rule]

        fun mk_branch subst t =
          let
            val ((params, impl), ctxt') = Variable.focus NONE t ctxt
            val (assms, concl) = Logic.strip_horn impl
            val assms = map (Envir.subst_term subst) assms
          in
            ((map #2 params, assms), (ctxt', lhs_of concl))
          end

        fun match concl =
          try (Pattern.match thy (concl, Logic.mk_equals (t, t))) (Vartab.empty, Vartab.empty)

        fun apply_rule (rule as {concl, prems, ...}) =
          case match concl of
            NONE => NONE
          | SOME subst =>
              SOME (rule, map (mk_branch subst o Envir.beta_norm o Envir.subst_term subst) prems)

        val is_atom = is_Const t orelse is_Free t
      in
        if is_atom then
          Tree (t, NONE)
        else
          case get_first apply_rule rules of
            NONE => error "this shouldn't happen"
          | SOME (rule, branches) => Tree (t, SOME (rule, map (apsnd (uncurry go)) branches))
      end
  in
    go ctxt t
  end

end

File ‹side_conditions.ML›

signature SIDE_CONDITIONS = sig
  type predicate =
    {f: term,
     index: int,
     inductive: Inductive.result,
     alt: thm option}

  val transform_predicate: morphism -> predicate -> predicate
  val get_predicate: Proof.context -> term -> predicate option
  val set_alt: term -> thm -> Context.generic -> Context.generic
  val is_total: Proof.context -> term -> bool

  val mk_side: thm list -> thm list option -> local_theory -> predicate list * local_theory

  val time_limit: real Config.T
end

structure Side_Conditions : SIDE_CONDITIONS = struct

open Dict_Construction_Util

val time_limit = Attrib.setup_config_real @{binding side_conditions_time_limit} (K 5.0)

val inductive_config =
  {quiet_mode = true, verbose = true, alt_name = Binding.empty, coind = false,
    no_elim = false, no_ind = false, skip_mono = false}

type predicate =
  {f: term,
   index: int,
   inductive: Inductive.result,
   alt: thm option}

fun transform_predicate phi {f, index, inductive, alt} =
  {f = Morphism.term phi f,
   index = index,
   inductive = Inductive.transform_result phi inductive,
   alt = Option.map (Morphism.thm phi) alt}

structure Predicates = Generic_Data
(
  type T = predicate Item_Net.T
  val empty = Item_Net.init (op aconv o apply2 #f) (single o #f)
  val merge = Item_Net.merge
  val extend = I
)

fun get_predicate ctxt t =
  Item_Net.retrieve (Predicates.get (Context.Proof ctxt)) t
  |> try hd
  |> Option.map (transform_predicate (Morphism.transfer_morphism (Proof_Context.theory_of ctxt)))

fun is_total ctxt t =
  let
    val SOME {alt = SOME alt, ...} = get_predicate ctxt t
    val (_, rhs) = Logic.dest_equals (Thm.prop_of alt)
  in rhs = @{term True} end

(* must be of the form [f_side ?x ?y = True] *)
fun set_alt t thm context =
  let
    val thm = safe_mk_meta_eq thm
    val (lhs, _) = Logic.dest_equals (Thm.prop_of thm)
    val {f, index, inductive, ...} = hd (Item_Net.retrieve (Predicates.get context) t)
    val pred = nth (#preds inductive) index
    val (arg_typs, _) = strip_type (fastype_of pred)
    val args =
      Name.invent_names (Variable.names_of (Context.proof_of context)) "x" arg_typs
      |> map Free
    val new_pred = {f = f, index = index, inductive = inductive, alt = SOME thm}
  in
    if Pattern.matches (Context.theory_of context) (lhs, list_comb (pred, args)) then
      Predicates.map (Item_Net.update new_pred) context
    else
      error "Alternative is not fully general"
  end

fun apply_simps ctxt clear thms t =
  let
    val ctxt' =
      Context_Position.not_really ctxt
      |> clear ? put_simpset HOL_ss
  in conv_result (Simplifier.asm_full_rewrite (ctxt' addsimps thms)) t end

fun apply_alts ctxt =
  Item_Net.content (Predicates.get (Context.Proof ctxt))
  |> map #alt
  |> cat_options
  |> apply_simps ctxt true

fun apply_intros ctxt =
  Item_Net.content (Predicates.get (Context.Proof ctxt))
  |> map #inductive
  |> maps #intrs
  |> apply_simps ctxt false

fun dest_head (Free (name, typ)) = (name, typ)
  | dest_head (Const (name, typ)) = (Long_Name.base_name name, typ)

val sideN = "_side"

fun mk_side simps inducts lthy =
  let
    val thy = Proof_Context.theory_of lthy

    val ((_, simps), names) =
      Variable.import true simps lthy
      ||> Variable.names_of

    val (lhss, rhss) =
      map (HOLogic.dest_eq o HOLogic.dest_Trueprop o Thm.prop_of) simps
      |> split_list

    val heads = map (`dest_head o (fst o strip_comb)) lhss

    fun mk_typ t = binder_types t ---> @{typ bool}

    val sides = map (apfst (suffix sideN) o apsnd mk_typ o fst) heads

    fun mk_pred_app pred (f, xs) =
      let
        val pred_typs = binder_types (fastype_of pred)
        val exp_param_count = length pred_typs

        val f_typs = take exp_param_count (binder_types (fastype_of f))

        val pred' =
          Envir.subst_term_types (fold (Sign.typ_match thy) (pred_typs ~~ f_typs) Vartab.empty) pred

        val diff = exp_param_count - length xs
      in
        if diff > 0 then
          let
            val bounds = map Bound (0 upto diff - 1)
            val alls = map (K ("x", dummyT)) (0 upto diff - 1)
            val prop = Logic.list_all (alls, HOLogic.mk_Trueprop (list_comb (pred', xs @ bounds)))
          in
            prop (* fishy *)
          end
        else
          HOLogic.mk_Trueprop (list_comb (pred', take exp_param_count xs))
      end

    fun mk_cond f xs =
      if is_Abs f then (* do not look this up in the Item_Net, it'll only end in tears *)
        NONE
      else
        case get_predicate lthy f of
          NONE =>
            (case find_index (equal f o snd) heads of
              ~1 => NONE (* in this case we don't know anything about f; it may be a constructor *)
            | index => SOME (mk_pred_app (Free (nth sides index)) (f, xs)))
        | SOME {index, inductive = {preds, ...}, ...} =>
            SOME (mk_pred_app (nth preds index) (f, xs))

    fun mk_atom f =
      (* in this branch, if f has a non-const-true predicate, it is most likely that there is a
         missing congruence rule *)
      the_list (mk_cond f [])

    fun mk_cong t _ cs =
      let
        val cs' = maps (fn (ctx, ts) => map (Congruences.export_term_ctx ctx) ts) (tl cs)
        val (f, xs) = strip_comb t
        val cs = mk_cond f xs
      in
        the_list cs @ cs'
      end

    val rules = map (Congruences.import_rule lthy) (Function.get_congs lthy)
    val premss =
      map (Congruences.import_term lthy rules) rhss
      |> map (Congruences.fold_tree mk_atom mk_cong)

    val concls =
      map Free sides ~~ map (snd o strip_comb) lhss
      |> map (HOLogic.mk_Trueprop o list_comb)

    val time = Time.fromReal (Config.get lthy time_limit)

    val intros =
      map Logic.list_implies (premss ~~ concls)
      |> Syntax.check_terms lthy
      |> map (apply_alts lthy o Thm.cterm_of lthy)
      |> Par_List.map (with_timeout time (apply_intros lthy o Thm.cterm_of lthy))

    val inds = map (rpair NoSyn o apfst Binding.name) (distinct op = sides)
    val (result, lthy') =
      Inductive.add_inductive inductive_config inds []
        (map (pair (Binding.empty, [])) intros) [] lthy

    fun mk_impartial_goal pred names =
      let
        val param_typs = binder_types (fastype_of pred)
        val (args, names) = fold_map (fn typ => apfst (Free o rpair typ) o Name.variant "x") param_typs names
        val goal = HOLogic.mk_Trueprop (list_comb (pred, args))
      in ((goal, args), names) end

    val ((props, instss), _) =
      fold_map mk_impartial_goal (#preds result) names
      |>> split_list
    val frees = flat instss |> map (fst o dest_Free)

    fun tactic {context = ctxt, ...} =
      let
        val simp_context =
          put_simpset HOL_ss (Context_Position.not_really ctxt) addsimps (#intrs result)
      in
       maybe_induct_tac inducts instss [] ctxt THEN
          PARALLEL_ALLGOALS (Nitpick_Util.DETERM_TIMEOUT time o asm_full_simp_tac simp_context)
      end

    val alts =
      try (Goal.prove_common lthy' NONE frees [] props) tactic
      |> Option.map (map (mk_eq o Thm.close_derivation ))

    val _ =
      if is_none alts then
        Pretty.str "Potentially underspecified function(s): " ::
          Pretty.commas (map (Syntax.pretty_term lthy o snd) (distinct op = heads))
        |> Pretty.block
        |> Pretty.string_of
        |> warning
      else
        ()

    fun mk_pred n t =
      {f = t, index = n, inductive = result,
       alt = Option.map (fn alts => nth alts n) alts}

    val preds = map_index (fn (n, (_, t)) => mk_pred n t) (distinct op = heads)

    val lthy'' =
      Local_Theory.declaration {pervasive = false, syntax = false}
        (fn phi => fold (Predicates.map o Item_Net.update o transform_predicate phi) preds) lthy'
  in
    (preds, lthy'')
  end

end

File ‹class_graph.ML›

signature CLASS_GRAPH =
sig
  type selector = typ -> term

  type node =
    {class: string,
     qname: string,
     selectors: selector Symtab.table,
     make: typ -> term,
     data_thms: thm list,
     cert: typ -> term,
     cert_thms: thm * thm * thm list}

  val dict_typ: node -> typ -> typ

  type edge =
    {super_selector: selector,
     subclass: thm}

  type path = edge list

  type ev

  val class_of: ev -> class
  val node_of: ev -> node
  val parents_of: ev -> (edge * ev) Symtab.table

  val find_path': ev -> (ev -> 'a option) -> (path * 'a) option
  val find_path: ev -> class -> path option
  val fold_path: path -> typ -> term -> term

  val ensure_class: class -> local_theory -> (ev * local_theory)

  val edges: local_theory -> class -> edge Symtab.table option
  val node: local_theory -> class -> node option
  val all_edges: local_theory -> edge Symreltab.table
  val all_nodes: local_theory -> node Symtab.table

  val pretty_ev: Proof.context -> ev -> Pretty.T

  (* utilities *)
  val mangle: string -> string
  val param_sorts: string -> class -> theory -> class list list
  val super_classes: class -> theory -> string list
end

structure Class_Graph: CLASS_GRAPH =
struct

open Dict_Construction_Util

val mangle =
  translate_string (fn x =>
    if x = "." then
      "_"
    else if x = "_" then
      "__"
    else
      x)

fun param_sorts tyco class thy =
  let val algebra = Sign.classes_of thy in
    Sorts.mg_domain algebra tyco [class] |> map (filter (Class.is_class thy))
  end

fun super_classes class thy =
  let val algebra = Sign.classes_of thy in
    Sorts.super_classes algebra class |>
      Sorts.minimize_sort algebra |>
      filter (Class.is_class thy) |>
      sort fast_string_ord
  end

type selector = typ -> term

type node =
  {class: string,
   qname: string,
   selectors: selector Symtab.table,
   make: typ -> term,
   data_thms: thm list,
   cert: typ -> term,
   cert_thms: thm * thm * thm list}

type edge =
  {super_selector: selector,
   subclass: thm}

type path = edge list

abstype ev = Evidence of class * node * (edge * ev) Symtab.table
with

fun class_of (Evidence (class, _, _)) = class
fun node_of (Evidence (_, node, _)) = node
fun parents_of (Evidence (_, _, tab)) = tab

fun mk_evidence class node tab = Evidence (class, node, tab)

fun find_path' ev is_goal =
  case is_goal ev of
    SOME a =>
      SOME ([], a)
  | NONE =>
    let
      fun f (_, (edge, ev)) = Option.map (apfst (cons edge)) (find_path' ev is_goal)
    in Symtab.get_first f (parents_of ev) end

fun find_path ev goal =
  find_path' ev (fn ev => if class_of ev = goal then SOME () else NONE) |> Option.map fst

fun pretty_ev ctxt (Evidence (class, {qname, ...}, tab)) =
  let
    val typ = @{typ 'a}
    fun mk_super ({super_selector, ...}, super_ev) = Pretty.block
      [Pretty.str "selector:",
       Pretty.brk 1,
       Syntax.pretty_term ctxt (super_selector typ),
       Pretty.fbrk,
       pretty_ev ctxt super_ev]
    val supers = Symtab.dest tab
      |> map (fn (_, super) => mk_super super)
      |> Pretty.big_list "super classes"
  in
    Pretty.block
      [Pretty.str "Evidence for ",
       Syntax.pretty_sort ctxt [class],
       Pretty.str ": ",
       Syntax.pretty_typ ctxt (Type (qname, [typ])),
       Pretty.str (" (qname = " ^ qname ^ ")"),
       Pretty.fbrk,
       supers]
  end

end

structure Classes = Generic_Data
(
  type T = (edge Symtab.table * node) Symtab.table
  val empty = Symtab.empty
  fun merge (t1, t2) =
    if Symtab.is_empty t1 andalso Symtab.is_empty t2 then
      Symtab.empty
    else
      error "merging not supported"
  val extend = I
)

fun node lthy class =
  Symtab.lookup (Classes.get (Context.Proof lthy)) class |> Option.map snd

fun edges lthy class =
  Symtab.lookup (Classes.get (Context.Proof lthy)) class |> Option.map fst

val all_nodes =
  Context.Proof #> Classes.get #> Symtab.map (K snd)

val all_edges =
  Context.Proof #> Classes.get #> Symtab.map (K fst) #> symreltab_of_symtab

fun dict_typ {qname, ...} typ =
  Type (qname, [typ])

fun fold_path path typ =
  fold (fn {super_selector = s, ...} => fn acc => s typ $ acc) path

fun mk_super_selector' qualified qname super_ev typ =
  let
    val {class = super_class, qname = super_qname, ...} = node_of super_ev
    val raw_name = mangle super_class ^ "__super"
    val name = if qualified then Long_Name.append qname raw_name else raw_name
  in (name, Type (qname, [typ]) --> Type (super_qname, [typ])) end

fun mk_node class info super_evs lthy =
  let
    fun print_info ctxt =
      Pretty.block [Pretty.str "Defining record for class ", Syntax.pretty_sort ctxt [class]]
      |> Pretty.writeln

    val name = mangle class ^ "__dict"
    val qname = Local_Theory.full_name lthy (Binding.name name)
    val tvar = @{typ 'a}
    val typ = Type (qname, [tvar])

    fun mk_field name ftyp = (Binding.name name, ftyp)

    val params = #params info
      |> map (fn (name', ftyp) =>
        let
          val ftyp' = typ_subst_atomic [(TFree ("'a", [class]), @{typ 'a})] ftyp
          val field_name = mangle name' ^ "__field"
          val field = mk_field field_name ftyp'
          fun sel tvar' =
            Const (Long_Name.append qname field_name,
                   typ_subst_atomic [(tvar, tvar')] (typ --> ftyp'))
        in (field, (name', sel)) end)
    val (fields, selectors) = split_list params

    val super_params = Symtab.dest super_evs |>
      map (fn (_, super_ev) =>
        let
          val {cert = raw_super_cert, qname = super_qname, ...} = node_of super_ev
          val (field_name, _) = mk_super_selector' false qname super_ev tvar
          val field = mk_field field_name (Type (super_qname, [tvar]))
          fun sel typ = Const (mk_super_selector' true qname super_ev typ)
          fun super_cert dict = raw_super_cert tvar $ (sel tvar $ dict)
          val raw_edge = (class_of super_ev, sel)
        in (field, raw_edge, super_cert) end)
    val (super_fields, raw_edges, super_certs) = split_list3 super_params

    val all_fields = super_fields @ fields

    fun make typ' =
      Const (Long_Name.append qname "Dict",
        typ_subst_atomic [(tvar, typ')] (map #2 all_fields ---> typ))

    val cert_name = name ^ "__cert"
    val cert_binding = Binding.name cert_name
    val cert_body =
      let
        fun local_param_eq ((_, typ), (name, sel)) dict =
          HOLogic.mk_eq (sel tvar $ dict, Const (name, typ))
      in
        map local_param_eq params @ super_certs
      end
    val cert_var_name = "dict"
    val cert_term =
      Abs (cert_var_name, typ,
        List.foldr HOLogic.mk_conj @{term True} (map (fn x => x (Bound 0)) cert_body))

    fun prove_thms (cert, cert_def) lthy =
      let
        val var = Free (cert_var_name, typ)
        fun tac ctxt = Local_Defs.unfold_tac ctxt [cert_def] THEN blast_tac ctxt 1
        fun prove prop =
          Goal.prove_future lthy [cert_var_name] [] prop (fn {context, ...} => tac context)

        fun mk_dest_props raw_prop =
          HOLogic.mk_Trueprop (cert $ var) ==> HOLogic.mk_Trueprop (raw_prop var)
        fun mk_intro_cond raw_prop =
          HOLogic.mk_Trueprop (raw_prop var)

        val dests =
          map (fn raw_prop => prove (mk_dest_props raw_prop)) cert_body
        val intro =
          prove (map mk_intro_cond cert_body ===> HOLogic.mk_Trueprop (cert $ var))

        val (dests', (intro', lthy')) =
          note_thms Binding.empty dests lthy ||> note_thm Binding.empty intro

        val (param_dests, super_dests) = chop (length params) dests'

        fun pre_edges phi =
          let
            fun mk_edge thm (sc, sel) =
              (sc, {super_selector = sel, subclass = Morphism.thm phi thm})
          in Symtab.make (map2 mk_edge super_dests raw_edges) end
      in
        ((param_dests, pre_edges, intro'), lthy')
      end

    val constructor =
      (((Binding.empty, Binding.name "Dict"), all_fields), NoSyn)
    val datatyp =
      (([(NONE, (@{typ 'a}, @{sort type}))], Binding.name name), NoSyn)

    val dtspec =
      (Ctr_Sugar.default_ctr_options,
       [(((datatyp, [constructor]), (Binding.empty, Binding.empty, Binding.empty)), [])])

    val (((raw_cert, raw_cert_def), (param_dests, pre_edges, intro)), (lthy', lthy)) = lthy
      |> tap print_info
      |> BNF_FP_Def_Sugar.co_datatypes BNF_Util.Least_FP BNF_LFP.construct_lfp dtspec
      (* FIXME ideally BNF would return a fp_sugar value right here so that I can avoid constructing
         long names by hand above *)
      |> (snd o Local_Theory.begin_nested)
      |> Local_Theory.define ((cert_binding, NoSyn), ((Thm.def_binding cert_binding, []), cert_term))
      |>> apsnd snd
      |> (fn (raw_cert, lthy) => prove_thms raw_cert lthy |>> pair raw_cert)
      ||> `Local_Theory.end_nested

    val phi = Proof_Context.export_morphism lthy lthy'
    fun cert typ = subst_TVars [(("'a", 0), typ)] (Morphism.term phi raw_cert)
    val cert_def = Morphism.thm phi raw_cert_def
    val edges = pre_edges phi
    val param_dests' = map (Morphism.thm phi) param_dests
    val intro' = Morphism.thm phi intro

    val data_thms =
      BNF_FP_Def_Sugar.fp_sugar_of lthy' qname
      |> the |> #fp_ctr_sugar |> #ctr_sugar |> #sel_thmss |> flat
      |> map safe_mk_meta_eq

    val node =
      {class = class,
       qname = qname,
       selectors = Symtab.make selectors,
       make = make,
       data_thms = data_thms,
       cert = cert,
       cert_thms = (cert_def, intro', param_dests')}
  in (node, edges, lthy') end

fun ensure_class class lthy =
  if not (Class.is_class (Proof_Context.theory_of lthy) class) then
    error ("not a proper class: " ^ class)
  else
    let
      val thy = Proof_Context.theory_of lthy
      val super_classes = super_classes class thy
      fun collect_super mk_node =
        let
          val (super_evs, lthy') = fold_map ensure_class super_classes lthy
          val raw_tab = Symtab.make (super_classes ~~ super_evs)
          val (node, edges, lthy'') = mk_node raw_tab lthy'
          val tab = zip_symtabs pair edges raw_tab
          val ev = mk_evidence class node tab
        in (ev, edges, lthy'') end
    in
      case Symtab.lookup (Classes.get (Context.Proof lthy)) class of
        SOME (edge_tab, node) =>
          if super_classes = Symtab.keys edge_tab then
            let val (ev, _, lthy') = collect_super (fn _ => fn lthy => (node, edge_tab, lthy)) in
              (ev, lthy')
            end
          else
            (* This happens when a new subclass relationship is established which subsumes or
               augments previous superclasses. *)
            error "class with different super classes"
      | NONE =>
          let
            val ax_info = Axclass.get_info thy class
            val (ev, edges, lthy') = collect_super (mk_node class ax_info)
            val upd = Symtab.update_new (class, (edges, node_of ev))
          in
            (ev, Local_Theory.declaration {pervasive = false, syntax = false} (K (Classes.map upd)) lthy')
          end
    end

end

File ‹dict_construction.ML›

signature DICT_CONSTRUCTION =
sig
  datatype cert_proof = Cert | Skip

  type const

  type 'a sccs = (string * 'a) list list

  val annotate_code_eqs: local_theory -> string list -> (const sccs * local_theory)
  val new_names: local_theory -> const sccs -> (string * const) sccs
  val symtab_of_sccs: 'a sccs -> 'a Symtab.table

  val axclass: class -> local_theory -> Class_Graph.node * local_theory
  val instance: (string * const) Symtab.table -> string -> class -> local_theory -> term * local_theory
  val term: term Symreltab.table -> (string * const) Symtab.table -> term -> local_theory -> (term * local_theory)
  val consts: (string * const) Symtab.table -> cert_proof -> (string * const) list -> local_theory -> local_theory

  (* certification *)

  type const_info =
    {fun_info: Function.info option,
     inducts: thm list option,
     base_thms: thm list,
     base_certs: thm list,
     simps: thm list,
     code_thms: thm list, (* old defining theorems *)
     congs: thm list option}

  type fun_target = (string * class) list * (term * term)

  type dict_thms =
    {base_thms: thm list,
     def_thm: thm}

  type dict_target = (string * class) list * (term * string * class)

  val prove_fun_cert: fun_target list -> const_info -> cert_proof -> local_theory -> thm list
  val prove_dict_cert: dict_target -> dict_thms -> local_theory -> thm

  val the_info: Proof.context -> string -> const_info

  (* utilities *)

  val normalizer_conv: Proof.context -> conv
  val cong_of_const: Proof.context -> string -> thm option
  val get_code_eqs: Proof.context -> string -> thm list
  val group_code_eqs: Proof.context -> string list ->
    (string * (((string * sort) list * typ) * ((term list * term) * thm option) list)) list list
end

structure Dict_Construction: DICT_CONSTRUCTION =
struct

open Class_Graph
open Dict_Construction_Util

(* FIXME copied from skip_proof.ML *)

val (_, make_thm_cterm) =
  Context.>>>
    (Context.map_theory_result (Thm.add_oracle (@{binding cert_oracle}, I)))

fun make_thm ctxt prop = make_thm_cterm (Thm.cterm_of ctxt prop)

fun cheat_tac ctxt i st =
  resolve_tac ctxt [make_thm ctxt (Var (("A", 0), propT))] i st

(** utilities **)

val normalizer_conv = Axclass.overload_conv

fun cong_of_const ctxt name =
  let
    val head =
      Thm.concl_of
      #> Logic.dest_equals #> fst
      #> strip_comb #> fst
      #> dest_Const #> fst
    fun applicable thm =
      try head thm = SOME name
  in
    Function_Context_Tree.get_function_congs ctxt
    |> filter applicable
    |> try hd
  end

fun group_code_eqs ctxt consts =
  let
    val thy = Proof_Context.theory_of ctxt
    val graph = #eqngr (Code_Preproc.obtain true { ctxt = ctxt, consts = consts, terms = [] })

    fun mk_eqs name = name
      |> Code_Preproc.cert graph
      |> Code.equations_of_cert thy ||> these
      ||> map (apsnd fst o apfst (apsnd fst o apfst (map fst)))
      |> pair name
  in
    map (map mk_eqs) (rev (Graph.strong_conn graph))
  end

fun get_code_eqs ctxt const =
  AList.lookup op = (flat (group_code_eqs ctxt [const])) const
  |> the |> snd
  |> map snd
  |> cat_options
  |> map (Conv.fconv_rule (normalizer_conv ctxt))

(** certification **)

datatype cert_proof = Cert | Skip

type const_info =
  {fun_info: Function.info option,
   inducts: thm list option,
   base_thms: thm list,
   base_certs: thm list,
   simps: thm list,
   code_thms: thm list,
   congs: thm list option}

fun map_const_info f1 f2 f3 f4 f5 f6 f7 {fun_info, inducts, base_thms, base_certs, simps, code_thms, congs} =
  {fun_info = f1 fun_info,
   inducts = f2 inducts,
   base_thms = f3 base_thms,
   base_certs = f4 base_certs,
   simps = f5 simps,
   code_thms = f6 code_thms,
   congs = f7 congs}

fun morph_const_info phi =
  map_const_info
    (Option.map (Function_Common.transform_function_data phi))
    (Option.map (map (Morphism.thm phi)))
    (map (Morphism.thm phi))
    (map (Morphism.thm phi))
    (map (Morphism.thm phi))
    I (* sic *)
    (Option.map (map (Morphism.thm phi)))

type fun_target = (string * class) list * (term * term)

type dict_thms =
  {base_thms: thm list,
   def_thm: thm}

type dict_target = (string * class) list * (term * string * class)

fun fun_cert_tac base_thms base_certs simps code_thms =
  SOLVED' o Subgoal.FOCUS (fn {prems, context = ctxt, concl, ...} =>
    let
      val _ =
        if_debug ctxt (fn () =>
          tracing ("Proving " ^ Syntax.string_of_term ctxt (Thm.term_of concl)))

      fun is_ih prem =
        Thm.prop_of prem |> Logic.strip_imp_concl |> HOLogic.dest_Trueprop |> can HOLogic.dest_eq

      val (ihs, certs) = partition is_ih prems
      val super_certs = all_edges ctxt |> Symreltab.dest |> map (#subclass o snd)
      val param_dests = all_nodes ctxt |> Symtab.dest |> maps (#3 o #cert_thms o snd)
      val congs = Function_Context_Tree.get_function_congs ctxt @ map safe_mk_meta_eq @{thms cong}

      val simp_context = (clear_simpset ctxt) addsimps (certs @ super_certs @ base_certs @ base_thms @ param_dests)
        addloop ("overload", CONVERSION o changed_conv o Axclass.overload_conv)

      val ihs = map (Simplifier.asm_full_simplify simp_context) ihs

      val ih_tac =
        resolve_tac ctxt ihs THEN_ALL_NEW
          (TRY' (SOLVED' (Simplifier.asm_full_simp_tac simp_context)))

      val unfold_new =
        ANY' (map (CONVERSION o rewr_lhs_head_conv) simps)

      val normalize =
        CONVERSION (normalizer_conv ctxt)

      val unfold_old =
        ANY' (map (CONVERSION o rewr_rhs_head_conv) code_thms)

      val simp =
        CONVERSION (lhs_conv (Simplifier.asm_full_rewrite simp_context))

      fun walk_congs i = i |>
        ((resolve_tac ctxt @{thms refl} ORELSE'
          SOLVED' (Simplifier.asm_full_simp_tac simp_context) ORELSE'
          ih_tac ORELSE'
          Method.assm_tac ctxt ORELSE'
          (resolve_tac ctxt @{thms meta_eq_to_obj_eq} THEN'
            fo_resolve_tac congs ctxt)) THEN_ALL_NEW walk_congs)

      val tacs =
        [unfold_new, normalize, unfold_old, simp, walk_congs]
    in
      EVERY' tacs 1
    end)

fun dict_cert_tac class def_thm base_thms =
  SOLVED' o Subgoal.FOCUS (fn {prems, context = ctxt, ...} =>
    let
      val (intro, sels) = case node ctxt class of
        SOME {cert_thms = (_, intro, _), data_thms = sels, ...} => (intro, sels)
      | NONE => error ("class " ^ class ^ " is not defined")

      val apply_intro =
        resolve_tac ctxt [intro]

      val unfold_dict =
        CONVERSION (Conv.rewr_conv def_thm |> Conv.arg_conv |> lhs_conv)

      val normalize =
        CONVERSION (normalizer_conv ctxt)

      val smash_sels =
        CONVERSION (lhs_conv (Conv.rewrs_conv sels))

      val solve =
        resolve_tac ctxt (@{thm HOL.refl} :: base_thms)

      val finally =
        resolve_tac ctxt prems

      val tacs =
        [apply_intro, unfold_dict, normalize, smash_sels, solve, finally]
    in
      EVERY (map (ALLGOALS' ctxt) tacs)
    end)

fun prepare_dicts classes names lthy =
  let
    val sorts = Symtab.make_list classes

    fun mk_dicts (param_name, (tvar, class)) =
      case node lthy class of
        NONE =>
          error ("unknown class " ^ class)
      | SOME {cert, qname, ...} =>
          let
            val sort = the (Symtab.lookup sorts tvar)
            val param = Free (param_name, Type (qname, [TFree (tvar, sort)]))
          in
            (param, HOLogic.mk_Trueprop (cert dummyT $ param))
          end

    val dict_names = Name.invent_names names "a" classes
    val names = fold Name.declare (map fst dict_names) names
    val (dict_params, prems) = split_list (map mk_dicts dict_names)
  in (dict_params, prems, names) end

fun prepare_fun_goal targets lthy =
  let
    fun mk_eq (classes, (lhs, rhs)) names =
      let
        val (lhs_name, _) = dest_Const lhs
        val (rhs_name, rhs_typ) = dest_Const rhs

        val (dict_params, prems, names) = prepare_dicts classes names lthy

        val param_names = fst (strip_type rhs_typ) |> map (K dummyT) |> Name.invent_names names "a"
        val names = fold Name.declare (map fst param_names) names
        val params = map Free param_names

        val lhs = list_comb (Const (lhs_name, dummyT), dict_params @ params)
        val rhs = list_comb (Const (rhs_name, dummyT), params)

        val eq = Const (@{const_name HOL.eq}, dummyT) $ lhs $ rhs

        val all_params = dict_params @ params
        val eq :: rest = Syntax.check_terms lthy (eq :: prems @ all_params)
        val (prems, all_params) = unappend (prems, all_params) rest

        val eq =
          if is_some (Axclass.inst_of_param (Proof_Context.theory_of lthy) rhs_name) then
            Thm.cterm_of lthy eq |> conv_result (Conv.arg_conv (normalizer_conv lthy))
          else
            eq

        val prop = prems ===> HOLogic.mk_Trueprop eq
      in ((all_params, prop), names) end
  in
    fold_map mk_eq targets Name.context
    |> fst
    |> split_list
  end

fun prepare_dict_goal (classes, (term, _, class)) lthy =
  let
    val cert = case node lthy class of
      NONE =>
        error ("unknown class " ^ class)
    | SOME {cert, ...} =>
        cert dummyT

    val names = Name.context
    val (dict_params, prems, _) = prepare_dicts classes names lthy
    val (term_name, _) = dest_Const term
    val dict = list_comb (Const (term_name, dummyT), dict_params)

    val prop = prems ===> HOLogic.mk_Trueprop (cert $ dict)
    val prop :: dict_params = Syntax.check_terms lthy (prop :: dict_params)
  in
    (dict_params, prop)
  end

fun prove_fun_cert targets {inducts, base_thms, base_certs, simps, code_thms, ...} proof lthy =
  let
    (* the props contain dictionary certs as prems
       we can't exclude them from the induction because the dicts are part of the function
       definition
       excluding them would mean that applying the induction rules becomes tricky or impossible
       proper fix would be if fun, akin to inductive, supported a "for" clause that marks parameters
       as "not changing" *)
    val (argss, props) = prepare_fun_goal targets lthy
    val frees = flat argss |> map (fst o dest_Free)

    (* we first prove the extensional variant (easier to prove), and then derive the
       contracted variant
       abs_def can't deal with premises, so we use our own version here *)
    val tac =
      case proof of
        Cert => fun_cert_tac base_thms base_certs simps code_thms
      | Skip => cheat_tac

    val long_thms =
       prove_common' lthy frees [] props (fn {context, ...} =>
          maybe_induct_tac inducts argss [] context THEN
            ALLGOALS' context (tac context))
  in
    map (contract lthy) long_thms
  end

fun prove_dict_cert target {base_thms, def_thm} lthy =
  let
    val (args, prop) = prepare_dict_goal target lthy
    val frees = map (fst o dest_Free) args
    val (_, (_, _, class)) = target
  in
    prove' lthy frees [] prop (fn {context, ...} =>
      dict_cert_tac class def_thm base_thms context 1)
  end

(** background data **)

type definitions =
  {instantiations: (term * thm) Symreltab.table, (* key: (class, tyco) *)
   constants: (string * (thm option * const_info)) Symtab.table (* key: constant name *) }

structure Definitions = Generic_Data
(
  type T = definitions
  val empty = {instantiations = Symreltab.empty, constants = Symtab.empty}
  val extend = I
  fun merge ({instantiations = i1, constants = c1}, {instantiations = i2, constants = c2}) =
    if Symreltab.is_empty i1 andalso Symtab.is_empty c1 andalso
       Symreltab.is_empty i2 andalso Symtab.is_empty c2 then
      empty
    else
      error "merging not supported"
)

fun map_definitions map_insts map_consts =
  Definitions.map (fn {instantiations, constants} =>
    {instantiations = map_insts instantiations,
     constants = map_consts constants})

fun the_info ctxt name =
  Symtab.lookup (#constants (Definitions.get (Context.Proof ctxt))) name
  |> the
  |> snd
  |> snd

fun add_instantiation (class, tyco) term cert =
  let
    fun upd phi =
      map_definitions
        (fn tab =>
          if Symreltab.defined tab (class, tyco) then
            error ("Duplicate instantiation " ^ quote tyco ^ " :: " ^ quote class)
          else
            tab
            |> Symreltab.update ((class, tyco), (Morphism.term phi term, Morphism.thm phi cert))) I
  in
    Local_Theory.declaration {pervasive = false, syntax = false} upd
  end

fun add_constant name name' (cert, info) lthy =
  let
    val qname = Local_Theory.full_name lthy (Binding.name name')
    fun upd phi =
      map_definitions I
        (fn tab =>
          if Symtab.defined tab name then
            error ("Duplicate constant " ^ quote name)
          else
            tab
            |> Symtab.update (name,
                (qname, (Option.map (Morphism.thm phi) cert, morph_const_info phi info))))
  in
    Local_Theory.declaration {pervasive = false, syntax = false} upd lthy
    |> Local_Theory.note ((Binding.empty, @{attributes [dict_construction_specs]}), #simps info)
    |> snd
  end

(** classes **)

fun axclass class =
  ensure_class class
  #>> node_of

(** grouping and annotating constants **)

datatype const =
  Fun of
    {dicts: ((string * class) * typ) list,
     certs: term list,
     param_typs: typ list,
     typ: typ, (* typified *)
     new_typ: typ,
     eqs: {params: term list, rhs: term, thm: thm} list,
     info: Function_Common.info option,
     cong: thm option} |
  Constructor |
  Classparam of
    {class: class,
     typ: typ, (* varified *)
     selector: term (* varified *)}

type 'a sccs = (string * 'a) list list

fun symtab_of_sccs x = Symtab.make (flat x)

fun raw_dict_params tparams lthy =
  let
    fun mk_dict tparam class lthy =
      let
        val (node, lthy') = axclass class lthy
        val targ = TFree (tparam, @{sort type})
        val typ = dict_typ node targ
        val cert = #cert node targ
      in ((((tparam, class), typ), cert), lthy') end
    fun mk_dicts (tparam, sort) = fold_map
      (mk_dict tparam)
      (filter (Class.is_class (Proof_Context.theory_of lthy)) sort)
   in fold_map mk_dicts tparams lthy |>> flat end

fun dict_params context dicts =
  let
    fun dict_param ((_, class), typ) =
      Name.variant (mangle class) #>> rpair typ #>> Free
  in
    fold_map dict_param dicts context
  end

fun get_sel class param typ lthy =
  let
    val ({selectors, ...}, lthy') = axclass class lthy
  in
    case Symtab.lookup selectors param of
      NONE => error ("unknown class parameter " ^ param)
    | SOME sel => (sel typ, lthy')
  end

fun annotate_const name ((tparams, typ), raw_eqs) lthy =
  if Code.is_constr (Proof_Context.theory_of lthy) name then
    ((name, Constructor), lthy)
  else if null raw_eqs then
    (* this detection is reliable, because code equations with overloaded heads are not allowed *)
    let
      val (_, class) = the_single tparams ||> the_single
      val (selector, thy') = get_sel class name (TVar (("'a", 0), @{sort type})) lthy
      val typ = range_type (fastype_of selector)
    in
      ((name, Classparam {class = class, typ = typ, selector = selector}), thy')
    end
  else
    let
      val info = try (Function.get_info lthy) (Const (name, typ))
      val cong = cong_of_const lthy name
      val ((raw_dicts, certs), lthy') = raw_dict_params tparams lthy |>> split_list
      val dict_typs = map snd raw_dicts
      val typ' = typify_typ typ
      fun mk_eq ((raw_params, rhs), SOME thm) =
            let
              val norm = normalizer_conv lthy'
              val transform = Thm.cterm_of lthy' #> conv_result norm #> typify
              val params = map transform raw_params
            in
              if has_duplicates (op =) (flat (map all_frees' params)) then
                (warning "ignoring code equation with non-linear pattern"; NONE)
              else
                SOME {params = params, rhs = rhs, thm = Conv.fconv_rule norm thm}
            end
        | mk_eq _ =
            error "no theorem"
      val const =
        Fun
          {dicts = raw_dicts, certs = certs, typ = typ', param_typs = binder_types typ',
           new_typ = dict_typs ---> typ', eqs = map_filter mk_eq raw_eqs, info = info, cong = cong}
    in ((name, const), lthy') end

fun annotate_code_eqs lthy consts =
  fold_map (fold_map (uncurry annotate_const)) (group_code_eqs lthy consts) lthy

(** instances and terms **)

fun mk_path [] _ _ lthy = (NONE, lthy)
  | mk_path ((class, term) :: rest) typ goal lthy =
    let
      val (ev, lthy') = ensure_class class lthy
    in
      case find_path ev goal of
        SOME path => (SOME (fold_path path typ term), lthy')
      | NONE =>      mk_path rest typ goal lthy'
    end

fun instance consts tyco class lthy =
  case Symreltab.lookup (#instantiations (Definitions.get (Context.Proof lthy))) (class, tyco) of
    SOME (inst, _) =>
      (inst, lthy)
  | NONE =>
      let
        val thy = Proof_Context.theory_of lthy
        val tparam_sorts = param_sorts tyco class thy

        fun print_info ctxt =
          let
            val tvars =
              Name.invent_list [] Name.aT (length tparam_sorts) ~~ tparam_sorts
              |> map TFree
          in
            [Pretty.str "Defining instance ", Syntax.pretty_typ ctxt (Type (tyco, tvars)),
              Pretty.str " :: ", Syntax.pretty_sort ctxt [class]]
            |> Pretty.block
            |> Pretty.writeln
          end

        val ({make, ...}, lthy) = axclass class lthy

        val name = mangle class ^ "__instance__" ^ mangle tyco
        val tparams = Name.invent_names Name.context Name.aT tparam_sorts
        val ((dict_params, _), lthy) = raw_dict_params tparams lthy
          |>> map fst
          |>> dict_params (Name.make_context [name])
        val dict_context = Symreltab.make (flat_right tparams ~~ dict_params)

        val {params, ...} = Axclass.get_info thy class

        val (super_fields, lthy) = fold_map
          (obtain_dict dict_context consts (Type (tyco, map TFree tparams)))
          (super_classes class thy)
          lthy

        val tparams' = map (TFree o rpair @{sort type} o fst) tparams
        val typ_inst = (TFree ("'a", [class]), Type (tyco, tparams'))

        fun mk_field (field, typ) =
          let
            val param = Axclass.param_of_inst thy (field, tyco)
            (* check: did we already define all required fields? *)
            (* if not: abort (else we would run into an infinite loop) *)
            val _ = case Symtab.lookup (#constants (Definitions.get (Context.Proof lthy))) param of
              NONE =>
                (* necessary for zero_nat *)
                if Code.is_constr thy param then
                  ()
                else
                  error ("cyclic dependency: " ^ param ^ " not yet defined in the definition of " ^ tyco ^ " :: " ^ class)
            | SOME _ => ()
          in
            term dict_context consts (Const (param, typ_subst_atomic [typ_inst] typ))
           end

        val (fields, lthy) = fold_map mk_field params lthy

        val rhs = list_comb (make (Type (tyco, tparams')), super_fields @ fields)
        val typ = map fastype_of dict_params ---> fastype_of rhs
        val head = Free (name, typ)
        val lhs = list_comb (head, dict_params)
        val term = Logic.mk_equals (lhs, rhs)

        val (def, (lthy', lthy)) = lthy
          |> tap print_info
          |> (snd o Local_Theory.begin_nested)
          |> define_params_nosyn term
          ||> `Local_Theory.end_nested
        val phi = Proof_Context.export_morphism lthy lthy'
        val def = Morphism.thm phi def

        val base_thms =
          Definitions.get (Context.Proof lthy') |> #constants |> Symtab.dest
          |> map (apsnd fst o snd)
          |> map_filter snd

        val target = (flat_right tparams, (Morphism.term phi head, tyco, class))
        val args = {base_thms = base_thms, def_thm = def}
        val thm = prove_dict_cert target args lthy'

        val const = Const (Local_Theory.full_name lthy' (Binding.name name), typ)
      in
        (const, add_instantiation (class, tyco) const thm lthy')
      end
and obtain_dict dict_context consts =
  let
    val dict_context' = Symreltab.dest dict_context
    fun for_class (Type (tyco, args)) class lthy =
          let
            val inst_param_sorts = param_sorts tyco class (Proof_Context.theory_of lthy)
            val (raw_inst, lthy') = instance consts tyco class lthy
            val (const_name, _) = dest_Const raw_inst
            val (inst_args, lthy'') = fold_map for_sort (inst_param_sorts ~~ args) lthy'
            val head = Sign.mk_const (Proof_Context.theory_of lthy'') (const_name, args)
          in
            (list_comb (head, flat inst_args), lthy'')
          end
      | for_class (TFree (name, _)) class lthy =
          let
            val available = map_filter
              (fn ((tp, class), term) => if tp = name then SOME (class, term) else NONE)
              dict_context'
            val (path, lthy') = mk_path available (TFree (name, @{sort type})) class lthy
          in
            case path of
              SOME term => (term, lthy')
            | NONE => error "no path found"
          end
      | for_class (TVar _) _ _ = error "unexpected type variable"
    and for_sort (sort, typ) =
          fold_map (for_class typ) sort
  in for_class end
and term dict_context consts term lthy =
  let
    fun traverse (t as Const (name, typ)) lthy =
        (case Symtab.lookup consts name of
          NONE => error ("unknown constant " ^ name)
        | SOME (_, Constructor) =>
            (typify t, lthy)
        | SOME (_, Classparam {class, typ = typ', selector}) =>
            let
              val subst = Sign.typ_match (Proof_Context.theory_of lthy) (typ', typ) Vartab.empty
              val (_, targ) = the (Vartab.lookup subst ("'a", 0))
              val (dict, lthy') = obtain_dict dict_context consts targ class lthy
            in
              (subst_TVars [(("'a", 0), targ)] selector $ dict, lthy')
            end
        | SOME (name', Fun {dicts = dicts, typ = typ', new_typ, ...}) =>
            let
              val subst = Type.raw_match (Logic.varifyT_global typ', typ) Vartab.empty
                |> Vartab.dest |> map (apsnd snd)
              fun lookup tparam = the (AList.lookup (op =) subst (tparam, 0))
              val (dicts, lthy') =
                fold_map (uncurry (obtain_dict dict_context consts o lookup)) (map fst dicts) lthy
              val typ = typ_subst_TVars subst (Logic.varifyT_global new_typ)
              val head =
                case Symtab.lookup (#constants (Definitions.get (Context.Proof lthy))) name of
                  NONE => Free (name', typ)
                | SOME (n, _) => Const (n, typ)
              val res = list_comb (head, dicts)
            in
              (res, lthy')
            end)
      | traverse (f $ x) lthy =
          let
            val (f', lthy')  = traverse f lthy
            val (x', lthy'') = traverse x lthy'
          in (f' $ x', lthy'') end
      | traverse (Abs (name, typ, term)) lthy =
          let
            val (term', lthy') = traverse term lthy
          in (Abs (name, typify_typ typ, term'), lthy') end
      | traverse (Free (name, typ)) lthy = (Free (name, typify_typ typ), lthy)
      | traverse (Var (name, typ)) lthy  = (Var (name, typify_typ typ), lthy)
      | traverse (Bound n) lthy = (Bound n, lthy)
  in
    traverse term lthy
  end

(** group of constants **)

fun new_names lthy consts =
  let
    val (all_names, all_consts) = split_list (flat consts)
    val all_frees = map (fn Fun {eqs, ...} => eqs | _ => []) all_consts |> flat
      |> map #params |> flat
      |> map all_frees' |> flat
    val context = fold Name.declare (all_names @ all_frees) (Variable.names_of lthy)

    fun new_name (name, const) context =
      let val (name', context') = Name.variant (mangle name) context in
        ((name, (name', const)), context')
      end
  in
    fst (fold_map (fold_map new_name) consts context)
  end

fun consts consts proof group lthy =
  let
    val fun_config = Function_Common.FunctionConfig
      {sequential=true, default=NONE, domintros=false, partials=false}
    fun pat_completeness_auto ctxt =
      Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt

    val all_names = map fst group

    val pretty_consts = map (pretty_const lthy) all_names |> Pretty.commas

    fun print_info msg =
      Pretty.str (msg ^ " ") :: pretty_consts
      |> Pretty.block
      |> Pretty.writeln

    val _ = print_info "Redefining constant(s)"

    fun process_eqs (name, Fun {dicts, param_typs, new_typ, eqs, info, cong, ...}) lthy =
          let
            val new_name = case Symtab.lookup consts name of
              NONE => error ("no new name for " ^ name)
            | SOME (n, _) => n

            val all_frees = map #params eqs |> flat |> map all_frees' |> flat
            val context = Name.make_context (all_names @ all_frees)
            val (dict_params, context') = dict_params context dicts

            fun adapt_params param_typs params =
              let
                val real_params = dict_params @ params
                val ext_params = drop (length params) param_typs
                  |> map typify_typ
                  |> Name.invent_names context' "e0" |> map Free
              in (real_params, ext_params) end

            fun mk_eq {params, rhs, thm} lthy =
              let
                val (real_params, ext_params) = adapt_params param_typs params
                val lhs' = list_comb (Free (new_name, new_typ), real_params @ ext_params)
                val (rhs', lthy') = term (Symreltab.make (map fst dicts ~~ dict_params)) consts rhs lthy
                val rhs'' = list_comb (rhs', ext_params)
              in
                ((HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs', rhs'')), thm), lthy')
              end

            val is_fun = length param_typs + length dicts > 0
          in
            fold_map mk_eq eqs lthy
            |>> rpair (new_typ, is_fun)
            |>> SOME
            |>> pair ((name, new_name, map fst dicts), {info = info, cong = cong})
          end
      | process_eqs (name, _) lthy =
          ((((name, name, []), {info = NONE, cong = NONE}), NONE), lthy)

    val (items, lthy') = fold_map process_eqs group lthy

    val ((metas, infos), ((eqs, code_thms), (new_typs, is_funs))) = items
      |> map_filter (fn (meta, eqs) => Option.map (pair meta) eqs)
      |> split_list
      ||> split_list ||> apfst (flat #> split_list #>> map typify)
      ||> apsnd split_list
      |>> split_list

    val _ = if_debug lthy (fn () =>
      if null code_thms then ()
      else
        map (Syntax.pretty_term lthy o Thm.prop_of) code_thms
        |> Pretty.big_list "Equations:"
        |> Pretty.string_of
        |> tracing)

    val is_fun =
      case distinct (op =) is_funs of
        [b] => b
      | [] => false
      | _ => error "unsupported feature: mixed non-function and function definitions"
    fun mk_binding (_, new_name, _) typ =
      (Binding.name new_name, SOME typ, NoSyn)
    val bindings = map2 mk_binding metas new_typs

    val {constants, instantiations} = Definitions.get (Context.Proof lthy')
    val base_thms = Symtab.dest constants |> map (apsnd fst o snd) |> map_filter snd
    val base_certs = Symreltab.dest instantiations |> map (snd o snd)

    val consts = Sign.consts_of (Proof_Context.theory_of lthy')

    fun prove_eq_fun (info as {simps = SOME simps, fs, inducts = SOME inducts, ...}) lthy =
      let
        fun mk_target (name, _, classes) new =
          (classes, (new, Const (Consts.the_const consts name)))
        val targets = map2 mk_target metas fs
        val args =
          {fun_info = SOME info, inducts = SOME inducts, simps = simps, base_thms = base_thms,
           base_certs = base_certs, code_thms = code_thms, congs = NONE}
      in
        (prove_fun_cert targets args proof lthy, args)
      end

    fun prove_eq_def defs lthy =
      let
        fun mk_target (name, _, classes) new =
          (classes, (new, Const (Consts.the_const consts name)))
        val targets = map2 mk_target metas (map (fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o Thm.prop_of) defs)
        val args =
          {fun_info = NONE, inducts = NONE, simps = defs,
           base_thms = base_thms, base_certs = base_certs, code_thms = code_thms, congs = NONE}
      in
        (prove_fun_cert targets args proof lthy, args)
      end

    fun add_constants ((((name, name', _), _), SOME _) :: xs) ((thm :: thms), info) =
          add_constant name name' (SOME thm, info) #> add_constants xs (thms, info)
      | add_constants ((((name, name', _), _), NONE) :: xs) (thms, info) =
          add_constant name name' (NONE, info) #> add_constants xs (thms, info)
      | add_constants [] _ =
          I

    fun prove_termination new_info ctxt =
      let
        val termination_ctxt =
          ctxt addsimps (@{thms equal} @ base_thms)
            addloop ("overload", CONVERSION o changed_conv o Axclass.overload_conv)
        val fallback_tac =
          Function_Common.termination_prover_tac true termination_ctxt

        val tac = case try hd (cat_options (map #info infos)) of
          SOME old_info => HEADGOAL (Transfer_Termination.termination_tac new_info old_info ctxt)
        | NONE => no_tac

      in Function.prove_termination NONE (tac ORELSE fallback_tac) ctxt end

    fun prove_cong data lthy =
      let
        fun rewr_cong thm cong =
          if Thm.nprems_of thm > 0 then
            (warning "No fundef_cong rule can be derived; this will likely not work later"; NONE)
          else
            (print_info "Porting fundef_cong rule for ";
             SOME (Local_Defs.fold lthy [thm] cong))

        val congs' =
          map2 (Option.mapPartial o rewr_cong) (fst data) (map #cong infos)
          |> cat_options

        fun add_congs phi =
          fold Function_Context_Tree.add_function_cong (map (Morphism.thm phi) congs')

        val data' =
          apsnd (map_const_info I I I I I I (K (SOME congs'))) data
      in
        (data', Local_Theory.declaration {pervasive = false, syntax = false} add_congs lthy)
      end

    fun mk_fun lthy =
      let
        val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) eqs
        val (info, lthy') =
          Function.add_function bindings specs fun_config pat_completeness_auto lthy
          |-> prove_termination
        val simps = the (#simps info)
        val (_, lthy'') =
          (* [simp del] is required because otherwise non-matching function definitions
             (e.g. divmod_nat) make the simplifier loop
             separate step because otherwise we'll get tons of warnings because the psimp rules
             are not added to the simpset *)
          Local_Theory.note ((Binding.empty, @{attributes [simp del]}), simps) lthy'
        fun prove_eq phi =
          prove_eq_fun (Function_Common.transform_function_data phi info)
      in
        (((simps, #inducts info), prove_eq), lthy'')
      end

    fun mk_def lthy =
      let
        val (defs, lthy') = fold_map define_params_nosyn eqs lthy
        fun prove_eq phi = prove_eq_def (map (Morphism.thm phi) defs)
      in
        (((defs, NONE), prove_eq), lthy')
      end
  in
    if null eqs then
      lthy'
    else
      let
        (* the redefinition itself doesn't have a sort constraint, but the equality prop may have
           one; hence the proof needs to happen after exiting the local theory target
           conceptually, everything happening locally would be great, but the type checker won't
           allow us to add sort constraints to TFrees after they have been declared *)
        val ((side, prove_eq), (lthy', lthy)) = lthy'
          |> (snd o Local_Theory.begin_nested)
          |> (if is_fun then mk_fun else mk_def)
          |-> (fn ((simps, inducts), prove_eq) =>
                apfst (rpair prove_eq) o Side_Conditions.mk_side simps inducts)
          ||> `Local_Theory.end_nested
        val phi = Proof_Context.export_morphism lthy lthy'
      in
        lthy'
        |> `(prove_eq phi)
        |>> apfst (on_thms_complete (fn () => print_info "Proved equivalence for"))
        |-> prove_cong
        |-> add_constants items
      end
  end

fun const_raw (binding, raw_consts) proof lthy =
  let
    val _ =
      if proof = Skip then
        warning "Skipping certificate proofs"
      else ()

    val (name, _) = Syntax.read_terms lthy raw_consts |> map dest_Const |> split_list

    val (eqs, lthy) = annotate_code_eqs lthy name
    val tab = symtab_of_sccs (new_names lthy eqs)

    val lthy = fold (consts tab proof) eqs lthy

    val {instantiations, constants} = Definitions.get (Context.Proof lthy)
    val thms =
      map (snd o snd) (Symreltab.dest instantiations) @
        map_filter (fst o snd o snd) (Symtab.dest constants)
  in
    snd (Local_Theory.note (binding, thms) lthy)
  end

(** setup **)

val parse_flags =
  Scan.optional (Args.parens (Parse.reserved "skip" >> K Skip)) Cert

val _ =
  Outer_Syntax.local_theory
    @{command_keyword "declassify"}
    "redefines a constant after applying the dictionary construction"
    (parse_flags -- Parse_Spec.opt_thm_name ":" -- Scan.repeat1 Parse.const >>
        (fn ((flags, def_binding), consts) => const_raw (def_binding, consts) flags))

end

Theory Termination

section ‹Termination heuristics›
text_raw ‹\label{sec:termination}›

theory Termination
  imports "../Dict_Construction"
begin

text ‹
  As indicated in the introduction, the newly-defined functions must be proven terminating. In
  general, we cannot reuse the original termination proof, as the following example illustrates:
›

fun f :: "nat  nat" where
"f 0 = 0" |
"f (Suc n) = f n"

lemma [code]: "f x = f x" ..

text ‹
  The invocation of @{theory_text declassify f›} would fail, because @{const f}'s code equations
  are not terminating.

  Hence, in the general case where users have modified the code equations, we need to fall back
  to an (automated) attempt to prove termination.

  In the remainder of this section, we will illustrate the special case where the user has not
  modified the code equations, i.e., the original termination proof should ``morally'' be still
  applicable. For this, we will perform the dictionary construction manually.
›

― ‹Some ML incantations to ensure that the dictionary types are present›
local_setup Class_Graph.ensure_class @{class plus} #> snd›
local_setup Class_Graph.ensure_class @{class zero} #> snd›

fun sum_list :: "'a::{plus,zero} list  'a" where
"sum_list [] = 0" |
"sum_list (x # xs) = x + sum_list xs"

text ‹
  The above function carries two distinct class constraints, which are translated into two
  dictionary parameters:
›

function sum_list' where
"sum_list' d_plus d_zero [] = Groups_zero__class_zero__field d_zero" |
"sum_list' d_plus d_zero (x # xs) = Groups_plus__class_plus__field d_plus x (sum_list' d_plus d_zero xs)"
by pat_completeness auto

text ‹
  Now, we need to carry out the termination proof of @{const sum_list'}. The @{theory_text function}
  package analyzes the function definition and discovers one recursive call. In pseudo-notation:

  @{text [display] ‹(d_plus, d_zero, x # xs) ↝ (d_plus, d_zero, xs)›}

  The result of this analysis is captured in the inductive predicate @{const sum_list'_rel}. Its
  introduction rules look as follows:
›

thm sum_list'_rel.intros
― ‹@{thm sum_list'_rel.intros}

text ‹Compare this to the relation for @{const sum_list}:›

thm sum_list_rel.intros
― ‹@{thm sum_list_rel.intros}

text ‹
  Except for the additional (unchanging) dictionary arguments, these relations are more or less
  equivalent to each other. There is an important difference, though: @{const sum_list_rel} has
  sort constraints, @{const sum_list'_rel} does not. (This will become important later on.)
›

context
  notes [[show_sorts]]
begin

term sum_list_rel
― ‹@{typ 'a::{plus,zero} list  'a::{plus,zero} list  bool›}

term sum_list'_rel
― ‹@{typ 'a::type Groups_plus__dict × 'a::type Groups_zero__dict × 'a::type list  'a::type Groups_plus__dict × 'a::type Groups_zero__dict × 'a::type list  bool›}

end

text ‹
  Let us know discuss the rough concept of the termination proof for @{const sum_list'}. The goal is
  to show that @{const sum_list'_rel} is well-founded. Usually, this is proved by specifying a
  ‹measure function› that
   maps the arguments to natural numbers
   decreases for each recursive call.
›
text ‹
  Here, however, we want to instead show that each recursive call in @{const sum_list'} has a
  corresponding recursive call in @{const sum_list}. In other words, we want to show that the
  existing proof of well-foundedness of @{const sum_list_rel} can be lifted to a proof of
  well-foundedness of @{const sum_list'_rel}. This is what the theorem
  @{thm [source=true] wfP_simulate_simple} states:

  @{thm [display=true] wfP_simulate_simple}

  Given any well-founded relation r› and a function g› that maps function arguments from r'› to
  r›, we can deduce that r'› is also well-founded.

  For our example, we need to provide a function g› of type
  @{typ 'b Groups_plus__dict × 'b Groups_zero__dict × 'b list  'a list›}.
  Because the dictionary parameters are not changing, they can safely be dropped by g›.
  However, because of the sort constraint in @{const sum_list_rel}, the term @{term "snd  snd"}
  is not a well-typed instantiation for g›.

  Instead (this is where the heuristic comes in), we assume that the original function
  @{const sum_list} is parametric, i.e., termination does not depend on the elements of the list
  passed to it, but only on the structure of the list. Additionally, we assume that all involved
  type classes have at least one instantiation.

  With this in mind, we can use @{term "map (λ_. undefined)  snd  snd"} as g›:
›

thm wfP_simulate_simple[where
  r = sum_list_rel and
  r' = sum_list'_rel and
  g = "map (λ_. undefined)  snd  snd"]

text ‹
  Finally, we can prove the termination of @{const sum_list'}.
›

termination sum_list'
proof -
  have "wfP sum_list'_rel"
  proof (rule wfP_simulate_simple)
    ― ‹We first need to obtain the well-foundedness theorem for @{const sum_list_rel} from the ML
        guts of the @{theory_text function} package.›
    show "wfP sum_list_rel"
      apply (rule accp_wfPI)
      apply (tactic ‹resolve_tac @{context} [Function.get_info @{context} @{term sum_list} |> #totality |> the] 1)
      done

    define g :: "'b Groups_plus__dict × 'b Groups_zero__dict × 'b list  'c::{plus,zero} list" where
      "g = map (λ_. undefined)  snd  snd"

    ― ‹Prove the simulation of @{const sum_list'_rel} by @{const sum_list_rel} by rule induction.›
    show "sum_list_rel (g x) (g y)" if "sum_list'_rel x y" for x y
      using that
      proof (induction x y rule: sum_list'_rel.induct)
        case (1 d_plus d_zero x xs)
        show ?case
          ― ‹Unfold the constituent parts of @{term g}:›
          apply (simp only: g_def comp_apply snd_conv list.map)
          ― ‹Use the corresponding introduction rule of @{const sum_list_rel} and hope for the best:›
          apply (rule sum_list_rel.intros(1))
          done
      qed
  qed

  ― ‹This is the goal that the @{theory_text function} package expects.›
  then show "x. sum_list'_dom x"
    by (rule wfP_implies_dom)
qed

text ‹This can be automated with a special tactic:›

experiment
begin

termination sum_list'
  apply (tactic Transfer_Termination.termination_tac
      (Function.get_info @{context} @{term sum_list'})
      (Function.get_info @{context} @{term sum_list})
      @{context}
      1; fail)
  done

end

text ‹
  A similar technique can be used for making functions defined in locales executable when, for some
  reason, the definition of a ``defs'' locale is not feasible.
›

locale foo =
  fixes A :: "nat"
  assumes "A > 0"
begin

fun f where
"f 0 = A" |
"f (Suc n) = Suc (f n)"

― ‹We carry out this proof in the locale for simplicity; a real implementation would probably
    have to set up a local theory properly.›
lemma f_total: "wfP f_rel"
apply (rule accp_wfPI)
apply (tactic ‹resolve_tac @{context} [Function.get_info @{context} @{term f} |> #totality |> the] 1)
done

end

― ‹The dummy interpretation serves the same purpose as the assumption that class constraints have
    at least one instantiation.›
interpretation dummy: foo 1 by standard simp

function f' where
"f' A 0 = A" |
"f' A (Suc n) = Suc (f' A n)"
by pat_completeness auto

termination f'
  apply (rule wfP_implies_dom)
  apply (rule wfP_simulate_simple[where g = "snd"])
   apply (rule dummy.f_total)
  subgoal for x y
    apply (induction x y rule: f'_rel.induct)
    subgoal
     apply (simp only: snd_conv)
     apply (rule dummy.f_rel.intros)
     done
    done
  done

text ‹Automatic:›

experiment
begin

termination f'
  apply (tactic Transfer_Termination.termination_tac
      (Function.get_info @{context} @{term f'})
      (Function.get_info @{context} @{term dummy.f})
      @{context}
      1; fail)
  done

end

end

Theory Test_Dict_Construction

section ‹Test cases for dictionary construction›

theory Test_Dict_Construction
imports
  Dict_Construction
  "HOL-Library.ListVector"
begin

subsection ‹Code equations with different number of explicit arguments›

lemma [code]: "fold f [] = id" "fold f (x # xs) s = fold f xs (f x s)" "fold f [x, y] u  f y (f x u)"
by auto

experiment begin

  declassify valid: fold
  thm valid
  lemma "List_fold = fold" by (rule valid)

end

subsection ‹Complex class hierarchies›

local_setup Class_Graph.ensure_class @{class zero} #> snd›
local_setup Class_Graph.ensure_class @{class plus} #> snd›

experiment begin

  local_setup Class_Graph.ensure_class @{class comm_monoid_add} #> snd›
  local_setup Class_Graph.ensure_class @{class ring} #> snd›

  typ "nat Rings_ring__dict"

end

text ‹Check that Class_Graph› does not leak out of locales›

ML@{assert} (is_none (Class_Graph.node @{context} @{class ring}))


subsection ‹Instances with non-trivial arity›

fun f :: "'a::plus  'a" where
"f x = x + x"

definition g :: "'a::{plus,zero} list  'a list" where
"g x = f x"

datatype natt = Z | S natt

instantiation natt :: "{zero,plus}" begin
  definition zero_natt where
  "zero_natt = Z"

  fun plus_natt where
  "plus_natt Z x = x" |
  "plus_natt (S m) n = S (plus_natt m n)"

  instance ..
end

definition h :: "natt list" where
"h = g [Z,S Z]"

experiment begin

(* FIXME problem with smart_tac *)
declassify valid: h
thm valid
lemma "Test__Dict__Construction_h = h" by (fact valid)

MLDict_Construction.the_info @{context} @{const_name plus_natt_inst.plus_natt}

end

text ‹Check that @{command declassify} does not leak out of locales›

ML‹
  can (Dict_Construction.the_info @{context}) @{const_name plus_natt_inst.plus_natt}
  |> not |> @{assert}


subsection ‹[@{attribute fundef_cong}] rules›

datatype 'a seq = Cons 'a "'a seq" | Nil

experiment begin

declassify map_seq

text ‹Check presence of derived [@{attribute fundef_cong}] rule›

MLDict_Construction.the_info @{context} @{const_name map_seq}
  |> #fun_info
  |> the
  |> #fs
  |> the_single
  |> dest_Const
  |> fst
  |> Dict_Construction.cong_of_const @{context}
  |> the
›

end


subsection ‹Mutual recursion›

fun odd :: "nat  bool" and even where
"odd 0  False" |
"even 0  True" |
"odd (Suc n)  even n" |
"even (Suc n)  odd n"

experiment begin

declassify valid: odd even
thm valid

end

datatype 'a bin_tree = Leaf | Node 'a "'a bin_tree" "'a bin_tree"

experiment begin

declassify valid: map_bin_tree rel_bin_tree
thm valid

end

datatype 'v env = Env "'v list"
datatype v = Closure "v env"

context
  notes is_measure_trivial[where f = "size_env size", measure_function]
begin

(* FIXME order is important! *)
fun test_v :: "v  bool" and test_w :: "v env  bool" where
"test_v (Closure env)  test_w env" |
"test_w (Env vs)  list_all test_v vs"

fun test_v1 :: "v  'a::{one,monoid_add}" and test_w1 :: "v env  'a" where
"test_v1 (Closure env) = 1 + test_w1 env" |
"test_w1 (Env vs) = sum_list (map test_v1 vs)"

end

experiment begin

declassify valid: test_w test_v
thm valid

end

experiment begin

(* FIXME derive fundef_cong rule for sum_list *)
declassify valid: test_w1 test_v1
thm valid

end


subsection ‹Non-trivial code dependencies; code equations where the head is not fully general›

definition "c  0 :: nat"
definition "d x  if x = 0 then 0 else x"

lemma contrived[code]: "c = d 0" unfolding c_def d_def by simp

experiment begin

declassify valid: c
thm valid
lemma "Test__Dict__Construction_c = c" by (fact valid)

end


subsection ‹Pattern matching on @{term "0::nat"}

definition j where "j (n::nat) = (0::nat)"

lemma [code]: "j 0 = 0" "j (Suc n) = j n"
unfolding j_def by auto

fun k where
"k 0 = (0::nat)" |
"k (Suc n) = k n"

lemma f_code[code]: "k n = 0"
by (induct n) simp+

experiment begin

declassify valid: j k
thm valid
lemma
  "Test__Dict__Construction_j = j"
  "Test__Dict__Construction_k = k"
by (fact valid)+

end


subsection ‹Complex termination arguments›

fun fac :: "nat  nat" where
"fac n = (if n  1 then 1 else n * fac (n - 1))"

experiment begin

declassify valid: fac

end


subsection ‹Combination of various things›

experiment begin

declassify valid: sum_list

end


subsection ‹Interaction with the code generator›

declassify h
export_code Test__Dict__Construction_h in SML


end

Theory Test_Side_Conditions

subsection ‹Contrived side conditions›

theory Test_Side_Conditions
imports Dict_Construction
begin

ML fun assert_alt_total ctxt term = @{assert} (Side_Conditions.is_total ctxt term)

fun head where
"head (x # _) = x"

local_setup ‹snd o Side_Conditions.mk_side @{thms head.simps} NONE›

lemma head_side_eq: "head_side xs  xs  []"
by (cases xs) (auto intro: head_side.intros elim: head_side.cases)

declaration ‹K (Side_Conditions.set_alt @{term head} @{thm head_side_eq})

fun map where
"map f [] = []" |
"map f (x # xs) = f x # map f xs"

local_setup ‹snd o Side_Conditions.mk_side @{thms map.simps} (SOME @{thms map.induct})
thm map_side.intros

ML assert_alt_total @{context} @{term map}

experiment begin

  text ‹Functions that use partial functions always in their domain are processed correctly.›

  fun tail where
  "tail (_ # xs) = xs"

  local_setup ‹snd o Side_Conditions.mk_side @{thms tail.simps} NONE›

  lemma tail_side_eq: "tail_side xs  xs  []"
  by (cases xs) (auto intro: tail_side.intros elim: tail_side.cases)

  declaration ‹K (Side_Conditions.set_alt @{term tail} @{thm tail_side_eq})

  function map' where
  "map' f xs = (if xs = [] then [] else f (head xs) # map' f (tail xs))"
  by auto

  termination
    apply (relation "measure (size  snd)")
    apply rule
    subgoal for f xs by (cases xs) auto
    done

  local_setup ‹snd o Side_Conditions.mk_side @{thms map'.simps} (SOME @{thms map'.induct})
  thm map'_side.intros

  ML assert_alt_total @{context} @{term map'}

end

lemma map_cong:
  assumes "xs = ys" "x. x  set ys  f x = g x"
  shows "map f xs = map g ys"
unfolding assms(1)
using assms(2)
by (induction ys) auto

definition map_head where
"map_head xs = map head xs"

experiment begin

  declare map_cong[fundef_cong]

  local_setup ‹snd o Side_Conditions.mk_side @{thms map_head_def} NONE›
  thm map_head_side.intros

  lemma "map_head_side xs  (x  set xs. x  [])"
  by (auto intro: map_head_side.intros elim: map_head_side.cases)

  definition map_head' where
  "map_head' xss = map (map head) xss"

  local_setup ‹snd o Side_Conditions.mk_side @{thms map_head'_def} NONE›
  thm map_head'_side.intros

  lemma "map_head'_side xss  (xs  set xss. x  set xs. x  [])"
  by (auto intro: map_head'_side.intros elim: map_head'_side.cases)

end

experiment begin

  local_setup ‹snd o Side_Conditions.mk_side @{thms map_head_def} NONE›
  term map_head_side
  thm map_head_side.intros

  lemma "¬ map_head_side xs"
  by (auto elim: map_head_side.cases)

end


definition head_known where
"head_known xs = head (3 # xs)"

local_setup ‹snd o Side_Conditions.mk_side @{thms head_known_def} NONE›
thm head_known_side.intros

MLassert_alt_total @{context} @{term head_known}

fun odd :: "nat  bool" and even where
"odd 0  False" |
"even 0  True" |
"odd (Suc n)  even n" |
"even (Suc n)  odd n"

local_setup ‹snd o Side_Conditions.mk_side @{thms odd.simps even.simps} (SOME @{thms odd_even.induct})
thm odd_side_even_side.intros

MLassert_alt_total @{context} @{term odd}
MLassert_alt_total @{context} @{term even}

definition odd_known where
"odd_known = odd (Suc 0)"

local_setup ‹snd o Side_Conditions.mk_side @{thms odd_known_def} NONE›
thm odd_known_side.intros

MLassert_alt_total @{context} @{term odd_known}

end

Theory Test_Lazy_Case

subsection ‹Interaction with Lazy_Case›

theory Test_Lazy_Case
imports
  Dict_Construction
  Lazy_Case.Lazy_Case
  Show.Show_Instances
begin

datatype 'a tree = Node | Fork 'a "'a tree list"

lemma map_tree[code]:
  "map_tree f t = (case t of Node  Node | Fork x ts  Fork (f x) (map (map_tree f) ts))"
by (induction t) auto

experiment begin

(* FIXME proper qualified path *)
text ‹
  Dictionary construction of @{const map_tree} requires the [@{attribute fundef_cong}] rule of
  @{const Test_Lazy_Case.tree.case_lazy}.
›

declassify valid: map_tree
thm valid

lemma "Test__Lazy__Case_tree_map__tree = map_tree" by (fact valid)

end


definition i :: "(unit × (bool list × string × nat option) list) option  string" where
"i = show"

experiment begin

text ‹This currently requires @{theory Lazy_Case.Lazy_Case} because of @{const divmod_nat}.›

(* FIXME get rid of Lazy_Case dependency *)
declassify valid: i
thm valid

lemma "Test__Lazy__Case_i = i" by (fact valid)

end

end