# Theory Introduction

section ‹Dictionary Construction›

theory Introduction
imports Main
begin

subsection ‹Introduction›

text ‹
Isabelle's logic features \emph{type classes}~\cite{haftmann2007typeclasses,wenzel1997typeclasses}.
These are built into the kernel and are used extensively in theory developments.
The existing \emph{code generator}, when targeting Standard ML, performs the well-known dictionary
construction or \emph{dictionary translation}~\cite{haftmann2010codegeneration}.
This works by replacing type classes with records, instances with values, and occurrences with
explicit parameters.

Haftmann and Nipkow give a pen-and-paper correctness proof of this construction
\cite[‹§›4.1]{haftmann2010codegeneration}, based on a notion of \emph{higher-order rewrite
systems.}
The resulting theorem then states that any well-typed term is reduction-equivalent before and after
class elimination.
In this work, the dictionary construction is performed in a certified fashion, that is, the
equivalence is a theorem inside the logic.
›

subsection ‹Encoding classes›

text ‹
The choice of representation of a dictionary itself is straightforward: We model it as a
@{command datatype}, along with functions returning values of that type. The alternative here
would have been to use the @{command record} package. The obvious advantage is that we could
easily model subclass relationships through record inheritance. However, records do not support
multiple inheritance. Since records offer no advantage over datatypes in that regard, we opted for
the more modern @{command datatype} package.
›

text ‹Consider the following example:›

class plus =
fixes plus :: "'a ⇒ 'a ⇒ 'a"

text ‹
This will get translated to a @{command datatype} with a single constructor taking a single
argument:
›

datatype 'a dict_plus =
mk_plus (param_plus: "'a ⇒ 'a ⇒ 'a")

text ‹A function using the @{class plus} constraint:›

definition double :: "'a::plus ⇒ 'a" where
"double x = plus x x"

definition double' :: "'a dict_plus ⇒ 'a ⇒ 'a" where
"double' dict x = param_plus dict x x"

subsection ‹Encoding instances›

text ‹
A more controversial design decision is how to represent dictionary certificates. For example,
given a value of type @{typ "nat dict_plus"}, how do we know that this is a faithful representation
of the @{class plus} instance for @{typ nat}?
›

text ‹
▪ Florian Haftmann proposed a shallow encoding''. It works by exploiting the internal treatment
of constants with sort constraints in the Isabelle kernel. Constants themselves do not carry
sort constraints, only their definitional equations. The fact that a constant only appears with
these constraints on the surface of the system is a feature of type inference.

Instead, we can instruct the system to ignore these constraints. However, any attempt at
hiding'' the constraints behind a type definition ultimately does not work: The nonemptiness
proof requires a witness of a valid dictionary for an arbitrary, but fixed type @{typ 'a}, which
is of course not possible (see ‹§›\ref{sec:impossibility} for details).

▪ The certificates contain the class axioms directly. For example, the @{class semigroup_add}
class requires @{term "(a + b) + c = a + (b + c)"}.

Translated into a definition, this would look as follows:

@{term
"cert_plus dict ⟷
(∀a b c. param_plus dict (param_plus dict a b) c = param_plus dict a (param_plus dict b c))"}

Proving that instances satisfy this certificate is trivial.

However, the equality proof of ‹f'› and ‹f› is impossible: they are simply not equal in general.
Nothing would prevent someone from defining an alternative dictionary using multiplication
instead of addition and the certificate would still hold; but obviously functions using
@{const plus} on numbers would expect addition.

Intuitively, this makes sense: the above notion of certificate'' establishes no connection
between original instantiation and newly-generated dictionaries.

Instead of proving equality, one would have to lift'' all existing theorems over the old
constants to the new constants.

▪ In order for equality between new and old constants to hold, the certificate needs to capture
that the dictionary corresponds exactly to the class constants. This is achieved by the
representation below.
It literally states that the fields of the dictionary are equal to the class constants.
The condition of the resulting equation can only be instantiated with dictionaries corresponding
to existing class instances. This constitutes a ∗‹closed world› assumption, i.e., callers of
generated code may not invent own instantiations.
›

definition cert_plus :: "'a::plus dict_plus ⇒ bool" where
"cert_plus dict ⟷ (param_plus dict = plus)"

text ‹
Based on that definition, we can prove that @{const double} and @{const double'} are equivalent:
›

lemma "cert_plus dict ⟹ double' dict = double"
unfolding cert_plus_def double'_def double_def
by auto

text ‹
An unconditional equation can be obtained by specializing the theorem to a ground type and
supplying a valid dictionary.
›

subsection ‹Implementation›

text ‹
When translating a constant ‹f›, we use existing mechanisms in Isabelle to obtain its
∗‹code graph›. The graph contains the code equations of all transitive dependencies (i.e.,
other constants) of ‹f›. In general, we have to re-define each of these dependencies. For that,
we use the internal interface of the @{command function} package and feed it the code equations
after performing the dictionary construction. In the standard case, where the user has not
performed a custom code setup, the resulting function looks similar to its original definition.
But the user may have also changed the implementation of a function significantly afterwards.
This imposes some restrictions:

▪ The new constant needs to be proven terminating. We apply some heuristics to transfer the
original termination proof to the new definition. This only works when the termination condition
does not rely on class axioms. (See ‹§›\ref{sec:termination} for details.)
▪ Pattern matching must be performed on datatypes, instead of the more general
@{command code_datatype}s.
▪ The set of code equations must be exhaustive and non-overlapping.
›

end

# Theory Impossibility

subsection ‹Impossibility of hiding sort constraints›
text_raw ‹\label{sec:impossibility}›

text ‹Coauthor of this section: Florian Haftmann›

theory Impossibility
imports Main
begin

axiomatization of_prop :: "prop ⇒ bool" where
of_prop_Trueprop [simp]: "of_prop (Trueprop P) ⟷ P" and
Trueprop_of_prop [simp]: "Trueprop (of_prop Q) ≡ PROP Q"

text ‹A type satisfies the certificate if there is an instance of the class.›

definition is_sg :: "'a itself ⇒ bool" where
"is_sg TYPE('a) = of_prop OFCLASS('a, semigroup_add_class)"

text ‹We trick the parser into ignoring the sort constraint of @{const plus}.›

setup ‹Sign.add_const_constraint (@{const_name plus}, SOME @{typ "'a::{} => 'a ⇒ 'a"})›

definition sg :: "('a ⇒ 'a ⇒ 'a) ⇒ bool" where
"sg plus ⟷ plus = Groups.plus ∧ is_sg TYPE('a)" for plus

text ‹Attempt: Define a type that contains all legal @{const plus} functions.›

typedef (overloaded) 'a Sg = "Collect sg :: ('a ⇒ 'a ⇒ 'a) set"
morphisms the_plus Sg
unfolding sg_def[abs_def]

text ‹We need to prove @{term "OFCLASS('a, semigroup_add_class)"} for arbitrary @{typ 'a}, which is
impossible.›

oops

end


# Theory Dict_Construction

section ‹Setup›

theory Dict_Construction
imports Automatic_Refinement.Refine_Util
keywords "declassify" :: thy_decl
begin

definition set_of :: "('a ⇒ 'b ⇒ bool) ⇒ ('a × 'b) set" where
"set_of P = {(x, y). P x y}"

lemma wfP_implies_wf_set_of: "wfP P ⟹ wf (set_of P)"
unfolding wfP_def set_of_def .

lemma wf_set_of_implies_wfP: "wf (set_of P) ⟹ wfP P"
unfolding wfP_def set_of_def .

lemma wf_simulate_simple:
assumes "wf r"
assumes "⋀x y. (x, y) ∈ r' ⟹ (g x, g y) ∈ r"
shows "wf r'"
using assms
by (metis in_inv_image wf_eq_minimal wf_inv_image)

lemma set_ofI: "P x y ⟹ (x, y) ∈ set_of P"
unfolding set_of_def by simp

lemma set_ofD: "(x, y) ∈ set_of P ⟹ P x y"
unfolding set_of_def by simp

lemma wfP_simulate_simple:
assumes "wfP r"
assumes "⋀x y. r' x y ⟹ r (g x) (g y)"
shows "wfP r'"
apply (rule wf_set_of_implies_wfP)
apply (rule wf_simulate_simple[where g = g])
apply (rule wfP_implies_wf_set_of)
apply (fact assms)
using assms(2) by (auto intro: set_ofI dest: set_ofD)

lemma wf_implies_dom: "wf (set_of R) ⟹ All (Wellfounded.accp R)"
apply (rule allI)
apply (rule accp_wfPD)
apply (rule wf_set_of_implies_wfP) .

lemma wfP_implies_dom: "wfP R ⟹ All (Wellfounded.accp R)"
by (metis wfP_implies_wf_set_of wf_implies_dom)

named_theorems dict_construction_specs

ML_file ‹dict_construction_util.ML›
ML_file ‹transfer_termination.ML›
ML_file ‹congruences.ML›
ML_file ‹side_conditions.ML›
ML_file ‹class_graph.ML›
ML_file ‹dict_construction.ML›

method_setup fo_cong_rule = ‹
Attrib.thm >> (fn thm => fn ctxt => SIMPLE_METHOD' (Dict_Construction_Util.fo_cong_tac ctxt thm))
› "resolve congruence rule using first-order matching"

declare [[code drop: "(∧)"]]
lemma [code]: "True ∧ p ⟷ p" "False ∧ p ⟷ False" by auto

declare [[code drop: "(∨)"]]
lemma [code]: "True ∨ p ⟷ True" "False ∨ p ⟷ p" by auto

declare comp_cong[fundef_cong del]
declare fun.map_cong[fundef_cong]

end

# File ‹dict_construction_util.ML›

infixr 5 ==>
infixr ===>
infix 1 CONTINUE_WITH CONTINUE_WITH_FW

signature DICT_CONSTRUCTION_UTIL = sig
(* general *)
val split_list3: ('a * 'b * 'c) list -> 'a list * 'b list * 'c list
val symreltab_of_symtab: 'a Symtab.table Symtab.table -> 'a Symreltab.table
val zip_symtabs: ('a -> 'b -> 'c) -> 'a Symtab.table -> 'b Symtab.table -> 'c Symtab.table
val cat_options: 'a option list -> 'a list
val partition: ('a -> bool) -> 'a list -> 'a list * 'a list
val unappend: 'a list * 'b -> 'c list -> 'c list * 'c list
val flat_right: ('a * 'b list) list -> ('a * 'b) list

(* logic *)
val ===> : term list * term -> term
val ==> : term * term -> term
val sortify: sort -> term -> term
val sortify_typ: sort -> typ -> typ
val typify: term -> term
val typify_typ: typ -> typ
val all_frees: term -> (string * typ) list
val all_frees': term -> string list
val all_tfrees: typ -> (string * sort) list

(* printing *)
val pretty_const: Proof.context -> string -> Pretty.T

(* conversion/tactic *)
val ANY: tactic list -> tactic
val ANY': ('a -> tactic) list -> 'a -> tactic
val CONTINUE_WITH: (int -> tactic) * (int -> tactic) list -> int -> thm -> thm Seq.seq
val CONTINUE_WITH_FW: (int -> tactic) * (int -> tactic) list -> int -> thm -> thm Seq.seq
val SOLVED: tactic -> tactic
val TRY': ('a -> tactic) -> 'a -> tactic
val descend_fun_conv: conv -> conv
val lhs_conv: conv -> conv
val rhs_conv: conv -> conv
val conv_result: ('a -> thm) -> 'a -> term
val changed_conv: ('a -> thm) -> 'a -> thm
val maybe_induct_tac: thm list option -> term list list -> term list list -> Proof.context -> tactic
val multi_induct_tac: thm list -> term list list -> term list list -> Proof.context -> tactic
val print_tac': Proof.context -> string -> int -> tactic
val fo_cong_tac: Proof.context -> thm -> int -> tactic

(* theorem manipulation *)
val contract: Proof.context -> thm -> thm
val on_thms_complete: (unit -> 'a) -> thm list -> thm list

(* theory *)
val define_params_nosyn: term -> local_theory -> thm * local_theory
val note_thm: binding -> thm -> local_theory -> thm * local_theory
val note_thms: binding -> thm list -> local_theory -> thm list * local_theory

(* timing *)
val with_timeout: Time.time -> ('a -> 'a) -> 'a -> 'a

(* debugging *)
val debug: bool Config.T
val if_debug: Proof.context -> (unit -> unit) -> unit
val ALLGOALS': Proof.context -> (int -> tactic) -> tactic
val prove': Proof.context -> string list -> term list -> term ->
({prems: thm list, context: Proof.context} -> tactic) -> thm
val prove_common': Proof.context -> string list -> term list -> term list ->
({prems: thm list, context: Proof.context} -> tactic) -> thm list
end

structure Dict_Construction_Util : DICT_CONSTRUCTION_UTIL = struct

(* general *)

fun symreltab_of_symtab tab =
Symtab.map (K Symtab.dest) tab |>
Symtab.dest |>
maps (fn (k, kvs) => map (apfst (pair k)) kvs) |>
Symreltab.make

fun split_list3 [] = ([], [], [])
| split_list3 ((x, y, z) :: rest) =
let val (xs, ys, zs) = split_list3 rest in
(x :: xs, y :: ys, z :: zs)
end

fun zip_symtabs f t1 t2 =
let
open Symtab
val ord = fast_string_ord
fun aux acc [] [] = acc
| aux acc ((k1, x) :: xs) ((k2, y) :: ys) =
(case ord (k1, k2) of
EQUAL   => aux (update_new (k1, f x y) acc) xs ys
| LESS    => raise UNDEF k1
| GREATER => raise UNDEF k2)
| aux _ ((k, _) :: _) [] =
raise UNDEF k
| aux _ [] ((k, _) :: _) =
raise UNDEF k
in aux empty (dest t1) (dest t2) end

fun cat_options [] = []
| cat_options (SOME x :: xs) = x :: cat_options xs
| cat_options (NONE :: xs) = cat_options xs

fun partition f xs = (filter f xs, filter_out f xs)

fun unappend (xs, _) = chop (length xs)

fun flat_right [] = []
| flat_right ((x, ys) :: rest) = map (pair x) ys @ flat_right rest

(* logic *)

fun x ==> y = Logic.mk_implies (x, y)
val op ===> = Library.foldr op ==>

fun sortify_typ sort (Type (tyco, args)) = Type (tyco, map (sortify_typ sort) args)
| sortify_typ sort (TFree (name, _)) = TFree (name, sort)
| sortify_typ _ (TVar _) = error "TVar encountered"

fun sortify sort (Const (name, typ)) = Const (name, sortify_typ sort typ)
| sortify sort (Free (name, typ)) = Free (name, sortify_typ sort typ)
| sortify sort (t $u) = sortify sort t$ sortify sort u
| sortify sort (Abs (name, typ, term)) = Abs (name, sortify_typ sort typ, sortify sort term)
| sortify _ (Bound n) = Bound n
| sortify _ (Var _) = error "Var encountered"

val typify_typ = sortify_typ @{sort type}
val typify = sortify @{sort type}

fun all_frees (Free (name, typ)) = [(name, typ)]
| all_frees (t $u) = union (op =) (all_frees t) (all_frees u) | all_frees (Abs (_, _, t)) = all_frees t | all_frees _ = [] val all_frees' = map fst o all_frees fun all_tfrees (TFree (name, sort)) = [(name, sort)] | all_tfrees (Type (_, ts)) = fold (union (op =)) (map all_tfrees ts) [] | all_tfrees _ = [] (* printing *) fun pretty_const ctxt const = Syntax.pretty_term ctxt (Const (const, Sign.the_const_type (Proof_Context.theory_of ctxt) const)) (* conversion/tactic *) fun ANY tacs = fold (curry op APPEND) tacs no_tac fun ANY' tacs n = fold (curry op APPEND) (map (fn t => t n) tacs) no_tac fun TRY' tac n = TRY (tac n) fun descend_fun_conv cv = cv else_conv (fn ct => case Thm.term_of ct of _$ _ => Conv.fun_conv (descend_fun_conv cv) ct
| _ => Conv.no_conv ct)

fun lhs_conv cv =
cv |> Conv.arg1_conv |> Conv.arg_conv

fun rhs_conv cv =
cv |> Conv.arg_conv |> Conv.arg_conv

safe_mk_meta_eq thm |> Conv.rewr_conv |> descend_fun_conv |> lhs_conv

safe_mk_meta_eq thm |> Conv.rewr_conv |> descend_fun_conv |> rhs_conv

fun conv_result cv ct =
Thm.prop_of (cv ct) |> Logic.dest_equals |> snd

fun changed_conv cv = fn ct =>
let
val res = cv ct
val (lhs, rhs) = Thm.prop_of res |> Logic.dest_equals
in
if lhs aconv rhs then
raise CTERM ("no conversion", [])
else
res
end

fun multi_induct_tac rules insts arbitrary ctxt =
let
val insts' = map (map (SOME o pair NONE o rpair false)) insts
val arbitrary' = map (map dest_Free) arbitrary
in
DETERM (Induct.induct_tac ctxt false insts' arbitrary' [] (SOME rules) [] 1)
end

fun maybe_induct_tac (SOME rules) insts arbitrary = multi_induct_tac rules insts arbitrary
| maybe_induct_tac NONE _ _ = K all_tac

fun (tac CONTINUE_WITH tacs) i st =
st |> (tac i THEN (fn st' =>
let
val n' = Thm.nprems_of st'
val n = Thm.nprems_of st
fun aux [] _ = all_tac
| aux (tac :: tacs) i = tac i THEN aux tacs (i - 1)
in
if n' - n + 1 <> length tacs then
raise THM ("CONTINUE_WITH: unexpected number of emerging subgoals", 0, [st'])
else
aux (rev tacs) (i + n' - n) st'
end))

fun (tac CONTINUE_WITH_FW tacs) i st =
st |> (tac i THEN (fn st' =>
let
val n' = Thm.nprems_of st'
val n = Thm.nprems_of st
fun aux [] _ st = all_tac st
| aux (tac :: tacs) i st = st |>
(tac i THEN (fn st' =>
aux tacs (i + 1 + Thm.nprems_of st' - Thm.nprems_of st) st'))
in
if n' - n + 1 <> length tacs then
raise THM ("unexpected number of emerging subgoals", 0, [st'])
else
aux tacs i st'
end))

fun SOLVED tac = tac THEN ALLGOALS (K no_tac)

fun print_tac' ctxt str = SELECT_GOAL (print_tac ctxt str)

fun fo_cong_tac ctxt thm = SUBGOAL (fn (concl, i) =>
let
val lhs_of = HOLogic.dest_Trueprop #> HOLogic.dest_eq #> fst
val concl_pat = lhs_of (Thm.concl_of thm) |> Thm.cterm_of ctxt
val concl = lhs_of concl |> Thm.cterm_of ctxt

val insts = Thm.first_order_match (concl_pat, concl)
in
resolve_tac ctxt [Drule.instantiate_normalize insts thm] i
end handle Pattern.MATCH => no_tac)

(* theorem manipulation *)

fun contract ctxt thm =
let
val (((_, frees), [thm']), ctxt') = Variable.import true [thm] ctxt

val prop = Thm.prop_of thm'
val prems = Logic.strip_imp_prems prop
val (lhs, rhs) =
Logic.strip_imp_concl prop
|> HOLogic.dest_Trueprop
|> HOLogic.dest_eq

fun used x =
exists (exists_subterm (fn t => t = x)) prems

val (f, xs) = strip_comb lhs
val (g, ys) = strip_comb rhs

fun loop [] ys = (0, (f, list_comb (g, rev ys)))
| loop xs [] = (0, (list_comb (f, rev xs), g))
| loop (x :: xs) (y :: ys) =
if x = y andalso is_Free x andalso not (used x) then
loop xs ys |> apfst (fn x => x + 1)
else
(0, (list_comb (f, rev (x :: xs)), list_comb (g, rev (y :: ys))))

val (count, (lhs', rhs')) = loop (rev xs) (rev ys)

val concl' = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs', rhs'))

fun tac ctxt 0 = resolve_tac ctxt [thm] THEN_ALL_NEW (Method.assm_tac ctxt)
| tac ctxt n = resolve_tac ctxt @{thms ext} THEN' tac ctxt (n - 1)

val prop = prems ===> concl'
in
Goal.prove_future ctxt' [] [] prop (fn {context, ...} => HEADGOAL (tac context count))
|> singleton (Variable.export ctxt' ctxt)
end

fun on_thms_complete f thms =
(Future.fork (fn () => (Thm.consolidate thms; f ())); thms)

(* theory *)

fun define_params_nosyn term =
Specification.definition NONE [] [] ((Binding.empty, []), term)
#>> snd #>> snd

fun note_thm binding thm =
Local_Theory.note ((binding, []), [thm]) #>> snd #>> the_single

fun note_thms binding thms =
Local_Theory.note ((binding, []), thms) #>> snd

(* timing *)

fun with_timeout time f x =
case Exn.interruptible_capture (Timeout.apply time f) x of
Exn.Res y => y
| Exn.Exn (Timeout.TIMEOUT _) => x
| Exn.Exn e => Exn.reraise e

(* debugging *)

val debug = Attrib.setup_config_bool @{binding "dict_construction_debug"} (K false)

fun if_debug ctxt f =
if Config.get ctxt debug then f () else ()

fun ALLGOALS' ctxt = if Config.get ctxt debug then ALLGOALS else PARALLEL_ALLGOALS
fun prove' ctxt = if Config.get ctxt debug then Goal.prove ctxt else Goal.prove_future ctxt
fun prove_common' ctxt = Goal.prove_common ctxt (if Config.get ctxt debug then NONE else SOME ~1)

end

# File ‹transfer_termination.ML›

signature TRANSFER_TERMINATION = sig
val termination_tac: Function.info -> Function.info -> Proof.context -> int -> tactic
end

structure Transfer_Termination : TRANSFER_TERMINATION = struct

open Dict_Construction_Util

fun termination_tac ({R = new_R, ...}: Function.info) (old_info: Function.info) ctxt =
let
fun fallback_tac warn _ =
(if warn then warning "Falling back to another termination proof" else ();
Seq.empty)

(*copied from BNF_Comp, in turn copied from Envir.expand_term_free*)
fun expand_term_const defs =
let
val eqs = map ((fn ((x, U), u) => (x, (U, u))) o apfst dest_Const) defs;
val get = fn Const (x, _) => AList.lookup (op =) eqs x | _ => NONE;
in Envir.expand_term get end

fun map_comp_bnf typ =
let
(* we start from a fresh lthy to avoid local hyps interfering with BNF *)
val lthy =
Proof_Context.theory_of ctxt
|> Named_Target.theory_init
|> Config.put BNF_Comp.typedef_threshold ~1
(* we just pretend that they're all live here *)
val live_As = all_tfrees typ
fun flatten_tyargs Ass =
live_As
|> filter (fn T => exists (fn Ts => member (op =) Ts T) Ass)
(* Dont_Inline would create new definitions, always *)
val ((bnf, _), ((_, {map_unfolds, ...}), _)) =
BNF_Comp.bnf_of_typ false BNF_Def.Do_Inline I flatten_tyargs live_As [] typ
((BNF_Comp.empty_comp_cache, BNF_Comp.empty_unfolds), lthy)
val subst = map (Logic.dest_equals o Thm.prop_of) map_unfolds
val t = BNF_Def.map_of_bnf bnf
in
(live_As, expand_term_const subst t)
end

val tac = case old_info of
{R = old_R, totality = SOME totality, ...} =>
let
fun get_eq R =
Inductive.the_inductive ctxt R
|> snd |> #eqs
|> the_single
|> Local_Defs.abs_def_rule ctxt
val (old_R_eq, new_R_eq) = apply2 get_eq (old_R, new_R)

fun get_typ R =
fastype_of R
|> strip_type
|> fst |> hd
|> Type.legacy_freeze_type
val (old_R_typ, new_R_typ) = apply2 get_typ (old_R, new_R)

(* simple strategy: old_R and new_R are identical *)
val simple_tac =
let
val totality' = Local_Defs.unfold ctxt [old_R_eq] totality
in
Local_Defs.unfold_tac ctxt [new_R_eq] THEN
end

(* smart strategy: new_R can be simulated by old_R *)
(* FIXME this is trigger-happy *)
val smart_tac = Exn.interruptible_capture (fn st =>
let
val old_R_stripped =
Thm.prop_of old_R_eq
|> Logic.dest_equals |> snd
|> map_types (K dummyT)
|> Syntax.check_term ctxt

val futile =
old_R_stripped |> exists_type (exists_subtype
(fn TFree (_, sort) => sort <> @{sort type}
| TVar (_, sort) => sort <> @{sort type}
| _ => false))

fun costrip_prodT new_t old_t =
if Type.could_match (old_t, new_t) then
(0, new_t)
else
case costrip_prodT (snd (HOLogic.dest_prodT new_t)) old_t of
(n, stripped_t) => (n + 1, stripped_t)

fun construct_inner_proj new_t old_t =
let
val (diff, stripped_t) = costrip_prodT new_t old_t
val (tfrees, f_head) = map_comp_bnf stripped_t
val f_args = map (K (Abs ("x", dummyT, Const (@{const_name undefined}, dummyT)))) tfrees
| add_snd n = Const (@{const_name comp}, dummyT) $add_snd (n - 1)$ Const (@{const_name snd}, dummyT)
in
end

fun construct_outer_proj new_t old_t = case (new_t, old_t) of
(Type (@{type_name sum}, new_ts), Type (@{type_name sum}, old_ts)) =>
let
val ps = map2 construct_outer_proj new_ts old_ts
in list_comb (Const (@{const_name map_sum}, dummyT), ps) end
| _ => construct_inner_proj new_t old_t

val outer_proj = construct_outer_proj new_R_typ old_R_typ

val old_R_typ_imported =
yield_singleton Variable.importT_terms old_R ctxt
|> fst |> get_typ

val c =
outer_proj
|> Type.constraint (new_R_typ --> old_R_typ_imported)
|> Syntax.check_term ctxt
|> Thm.cterm_of ctxt

val wf_simulate =
Drule.infer_instantiate ctxt [(("g", 0), c)] @{thm wf_simulate_simple}

val old_wf = (@{thm wfP_implies_wf_set_of} OF [@{thm accp_wfPI} OF [totality]])

val inner_tac =
match_tac ctxt @{thms wf_implies_dom} THEN'
match_tac ctxt [wf_simulate] CONTINUE_WITH_FW
[resolve_tac ctxt [old_wf],
match_tac ctxt @{thms set_ofI} THEN'
dmatch_tac ctxt @{thms set_ofD} THEN'
SELECT_GOAL (Local_Defs.unfold_tac ctxt [old_R_eq, new_R_eq]) THEN'
TRY'
(REPEAT_ALL_NEW (ematch_tac ctxt @{thms conjE exE}) THEN'
hyp_subst_tac_thin true ctxt THEN'
REPEAT_ALL_NEW (match_tac ctxt @{thms conjI exI}))]

val unfold_tac =
Local_Defs.unfold_tac ctxt @{thms comp_apply id_apply prod.sel} THEN
auto_tac ctxt

val tac = SOLVED (HEADGOAL inner_tac THEN unfold_tac)
in
if futile then
(warning "Termination relation has sort constraints; termination proof is unlikely to be automatic or may even be impossible";
Seq.empty)
else
(tracing "Trying to re-use termination proof";
tac st)
end)
#> Exn.get_res
#> the_default Seq.empty
in
simple_tac ORELSE smart_tac ORELSE fallback_tac true
end
| _ => fallback_tac false
in SELECT_GOAL tac end

end

# File ‹congruences.ML›

signature CONGRUENCES = sig
type rule =
{rule: thm,
concl: term,
prems: term list,
proper: bool}

type ctx = (string * typ) list * term list
datatype ctx_tree =
Tree of (term * (rule * (ctx * ctx_tree) list) option)

val export_term_ctx: ctx -> term -> term

val import_rule: Proof.context -> thm -> rule
val import_term: Proof.context -> rule list -> term -> ctx_tree

val fold_tree:
(term -> 'a) ->
(term -> rule -> (ctx * 'a) list -> 'a) ->
ctx_tree -> 'a
end

structure Congruences: CONGRUENCES = struct

type rule =
{rule: thm,
concl: term,
prems: term list,
proper: bool}

type ctx = (string * typ) list * term list

fun export_term_ctx (fixes, assumes) =
fold_rev (curry Logic.mk_implies) assumes
#> fold_rev (Logic.all o Free) fixes

datatype ctx_tree =
Tree of (term * (rule * (ctx * ctx_tree) list) option)

fun fold_tree atom cong t =
let
fun go (Tree (t, NONE)) = atom t
| go (Tree (t, SOME (r, ctxs))) = cong t r (map (apsnd go) ctxs)
in go t end

fun raw_import_rule {check: bool, proper: bool} ctxt thm =
let
val concl = Thm.concl_of thm
val (lhs, rhs) = Logic.dest_equals concl
val prems = Thm.prems_of thm
val rule = {rule = thm, concl = concl, prems = prems, proper = proper}
in
if check then
let
val ((f_lhs, _), _) = strip_comb lhs |>> dest_Const
val ((r_lhs, _), _) = strip_comb rhs |>> dest_Const
in
if f_lhs <> r_lhs then
error ("invalid cong rule " ^ Syntax.string_of_term ctxt (Thm.prop_of thm))
else
rule
end
else
rule
end

val import_rule = raw_import_rule {check = true, proper = true}

fun mk_cong n = if n <= 1 then @{thm cong} else mk_cong (n - 1) OF [@{thm cong}]

fun cong_rule n =
raw_import_rule {check = false, proper = false} @{context} (mk_cong n RS @{thm eq_reflection})

val ext_rule =
raw_import_rule {check = false, proper = false} @{context} (@{thm ext} RS @{thm eq_reflection})

fun import_term ctxt rules t =
let
val thy = Proof_Context.theory_of ctxt

val lhs_of = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop

fun go ctxt t =
let
(* FIXME eventually, this should be the arity of fst (strip_comb t) *)
val arity = length (snd (strip_comb t))

val rules = rules @ [cong_rule arity, ext_rule]

fun mk_branch subst t =
let
val ((params, impl), ctxt') = Variable.focus NONE t ctxt
val (assms, concl) = Logic.strip_horn impl
val assms = map (Envir.subst_term subst) assms
in
((map #2 params, assms), (ctxt', lhs_of concl))
end

fun match concl =
try (Pattern.match thy (concl, Logic.mk_equals (t, t))) (Vartab.empty, Vartab.empty)

fun apply_rule (rule as {concl, prems, ...}) =
case match concl of
NONE => NONE
| SOME subst =>
SOME (rule, map (mk_branch subst o Envir.beta_norm o Envir.subst_term subst) prems)

val is_atom = is_Const t orelse is_Free t
in
if is_atom then
Tree (t, NONE)
else
case get_first apply_rule rules of
NONE => error "this shouldn't happen"
| SOME (rule, branches) => Tree (t, SOME (rule, map (apsnd (uncurry go)) branches))
end
in
go ctxt t
end

end

# File ‹side_conditions.ML›

signature SIDE_CONDITIONS = sig
type predicate =
{f: term,
index: int,
inductive: Inductive.result,
alt: thm option}

val transform_predicate: morphism -> predicate -> predicate
val get_predicate: Proof.context -> term -> predicate option
val set_alt: term -> thm -> Context.generic -> Context.generic
val is_total: Proof.context -> term -> bool

val mk_side: thm list -> thm list option -> local_theory -> predicate list * local_theory

val time_limit: real Config.T
end

structure Side_Conditions : SIDE_CONDITIONS = struct

open Dict_Construction_Util

val time_limit = Attrib.setup_config_real @{binding side_conditions_time_limit} (K 5.0)

val inductive_config =
{quiet_mode = true, verbose = true, alt_name = Binding.empty, coind = false,
no_elim = false, no_ind = false, skip_mono = false}

type predicate =
{f: term,
index: int,
inductive: Inductive.result,
alt: thm option}

fun transform_predicate phi {f, index, inductive, alt} =
{f = Morphism.term phi f,
index = index,
inductive = Inductive.transform_result phi inductive,
alt = Option.map (Morphism.thm phi) alt}

structure Predicates = Generic_Data
(
type T = predicate Item_Net.T
val empty = Item_Net.init (op aconv o apply2 #f) (single o #f)
val merge = Item_Net.merge
val extend = I
)

fun get_predicate ctxt t =
Item_Net.retrieve (Predicates.get (Context.Proof ctxt)) t
|> try hd
|> Option.map (transform_predicate (Morphism.transfer_morphism (Proof_Context.theory_of ctxt)))

fun is_total ctxt t =
let
val SOME {alt = SOME alt, ...} = get_predicate ctxt t
val (_, rhs) = Logic.dest_equals (Thm.prop_of alt)
in rhs = @{term True} end

(* must be of the form [f_side ?x ?y = True] *)
fun set_alt t thm context =
let
val thm = safe_mk_meta_eq thm
val (lhs, _) = Logic.dest_equals (Thm.prop_of thm)
val {f, index, inductive, ...} = hd (Item_Net.retrieve (Predicates.get context) t)
val pred = nth (#preds inductive) index
val (arg_typs, _) = strip_type (fastype_of pred)
val args =
Name.invent_names (Variable.names_of (Context.proof_of context)) "x" arg_typs
|> map Free
val new_pred = {f = f, index = index, inductive = inductive, alt = SOME thm}
in
if Pattern.matches (Context.theory_of context) (lhs, list_comb (pred, args)) then
Predicates.map (Item_Net.update new_pred) context
else
error "Alternative is not fully general"
end

fun apply_simps ctxt clear thms t =
let
val ctxt' =
Context_Position.not_really ctxt
|> clear ? put_simpset HOL_ss
in conv_result (Simplifier.asm_full_rewrite (ctxt' addsimps thms)) t end

fun apply_alts ctxt =
Item_Net.content (Predicates.get (Context.Proof ctxt))
|> map #alt
|> cat_options
|> apply_simps ctxt true

fun apply_intros ctxt =
Item_Net.content (Predicates.get (Context.Proof ctxt))
|> map #inductive
|> maps #intrs
|> apply_simps ctxt false

fun dest_head (Free (name, typ)) = (name, typ)
| dest_head (Const (name, typ)) = (Long_Name.base_name name, typ)

val sideN = "_side"

fun mk_side simps inducts lthy =
let
val thy = Proof_Context.theory_of lthy

val ((_, simps), names) =
Variable.import true simps lthy
||> Variable.names_of

val (lhss, rhss) =
map (HOLogic.dest_eq o HOLogic.dest_Trueprop o Thm.prop_of) simps
|> split_list

fun mk_typ t = binder_types t ---> @{typ bool}

val sides = map (apfst (suffix sideN) o apsnd mk_typ o fst) heads

fun mk_pred_app pred (f, xs) =
let
val pred_typs = binder_types (fastype_of pred)
val exp_param_count = length pred_typs

val f_typs = take exp_param_count (binder_types (fastype_of f))

val pred' =
Envir.subst_term_types (fold (Sign.typ_match thy) (pred_typs ~~ f_typs) Vartab.empty) pred

val diff = exp_param_count - length xs
in
if diff > 0 then
let
val bounds = map Bound (0 upto diff - 1)
val alls = map (K ("x", dummyT)) (0 upto diff - 1)
val prop = Logic.list_all (alls, HOLogic.mk_Trueprop (list_comb (pred', xs @ bounds)))
in
prop (* fishy *)
end
else
HOLogic.mk_Trueprop (list_comb (pred', take exp_param_count xs))
end

fun mk_cond f xs =
if is_Abs f then (* do not look this up in the Item_Net, it'll only end in tears *)
NONE
else
case get_predicate lthy f of
NONE =>
(case find_index (equal f o snd) heads of
~1 => NONE (* in this case we don't know anything about f; it may be a constructor *)
| index => SOME (mk_pred_app (Free (nth sides index)) (f, xs)))
| SOME {index, inductive = {preds, ...}, ...} =>
SOME (mk_pred_app (nth preds index) (f, xs))

fun mk_atom f =
(* in this branch, if f has a non-const-true predicate, it is most likely that there is a
missing congruence rule *)
the_list (mk_cond f [])

fun mk_cong t _ cs =
let
val cs' = maps (fn (ctx, ts) => map (Congruences.export_term_ctx ctx) ts) (tl cs)
val (f, xs) = strip_comb t
val cs = mk_cond f xs
in
the_list cs @ cs'
end

val rules = map (Congruences.import_rule lthy) (Function.get_congs lthy)
val premss =
map (Congruences.import_term lthy rules) rhss
|> map (Congruences.fold_tree mk_atom mk_cong)

val concls =
map Free sides ~~ map (snd o strip_comb) lhss
|> map (HOLogic.mk_Trueprop o list_comb)

val time = Time.fromReal (Config.get lthy time_limit)

val intros =
map Logic.list_implies (premss ~~ concls)
|> Syntax.check_terms lthy
|> map (apply_alts lthy o Thm.cterm_of lthy)
|> Par_List.map (with_timeout time (apply_intros lthy o Thm.cterm_of lthy))

val inds = map (rpair NoSyn o apfst Binding.name) (distinct op = sides)
val (result, lthy') =
(map (pair (Binding.empty, [])) intros) [] lthy

fun mk_impartial_goal pred names =
let
val param_typs = binder_types (fastype_of pred)
val (args, names) = fold_map (fn typ => apfst (Free o rpair typ) o Name.variant "x") param_typs names
val goal = HOLogic.mk_Trueprop (list_comb (pred, args))
in ((goal, args), names) end

val ((props, instss), _) =
fold_map mk_impartial_goal (#preds result) names
|>> split_list
val frees = flat instss |> map (fst o dest_Free)

fun tactic {context = ctxt, ...} =
let
val simp_context =
put_simpset HOL_ss (Context_Position.not_really ctxt) addsimps (#intrs result)
in
maybe_induct_tac inducts instss [] ctxt THEN
PARALLEL_ALLGOALS (Nitpick_Util.DETERM_TIMEOUT time o asm_full_simp_tac simp_context)
end

val alts =
try (Goal.prove_common lthy' NONE frees [] props) tactic
|> Option.map (map (mk_eq o Thm.close_derivation ⌂))

val _ =
if is_none alts then
Pretty.str "Potentially underspecified function(s): " ::
Pretty.commas (map (Syntax.pretty_term lthy o snd) (distinct op = heads))
|> Pretty.block
|> Pretty.string_of
|> warning
else
()

fun mk_pred n t =
{f = t, index = n, inductive = result,
alt = Option.map (fn alts => nth alts n) alts}

val preds = map_index (fn (n, (_, t)) => mk_pred n t) (distinct op = heads)

val lthy'' =
Local_Theory.declaration {pervasive = false, syntax = false}
(fn phi => fold (Predicates.map o Item_Net.update o transform_predicate phi) preds) lthy'
in
(preds, lthy'')
end

end

# File ‹class_graph.ML›

signature CLASS_GRAPH =
sig
type selector = typ -> term

type node =
{class: string,
qname: string,
selectors: selector Symtab.table,
make: typ -> term,
data_thms: thm list,
cert: typ -> term,
cert_thms: thm * thm * thm list}

val dict_typ: node -> typ -> typ

type edge =
{super_selector: selector,
subclass: thm}

type path = edge list

type ev

val class_of: ev -> class
val node_of: ev -> node
val parents_of: ev -> (edge * ev) Symtab.table

val find_path': ev -> (ev -> 'a option) -> (path * 'a) option
val find_path: ev -> class -> path option
val fold_path: path -> typ -> term -> term

val ensure_class: class -> local_theory -> (ev * local_theory)

val edges: local_theory -> class -> edge Symtab.table option
val node: local_theory -> class -> node option
val all_edges: local_theory -> edge Symreltab.table
val all_nodes: local_theory -> node Symtab.table

val pretty_ev: Proof.context -> ev -> Pretty.T

(* utilities *)
val mangle: string -> string
val param_sorts: string -> class -> theory -> class list list
val super_classes: class -> theory -> string list
end

structure Class_Graph: CLASS_GRAPH =
struct

open Dict_Construction_Util

val mangle =
translate_string (fn x =>
if x = "." then
"_"
else if x = "_" then
"__"
else
x)

fun param_sorts tyco class thy =
let val algebra = Sign.classes_of thy in
Sorts.mg_domain algebra tyco [class] |> map (filter (Class.is_class thy))
end

fun super_classes class thy =
let val algebra = Sign.classes_of thy in
Sorts.super_classes algebra class |>
Sorts.minimize_sort algebra |>
filter (Class.is_class thy) |>
sort fast_string_ord
end

type selector = typ -> term

type node =
{class: string,
qname: string,
selectors: selector Symtab.table,
make: typ -> term,
data_thms: thm list,
cert: typ -> term,
cert_thms: thm * thm * thm list}

type edge =
{super_selector: selector,
subclass: thm}

type path = edge list

abstype ev = Evidence of class * node * (edge * ev) Symtab.table
with

fun class_of (Evidence (class, _, _)) = class
fun node_of (Evidence (_, node, _)) = node
fun parents_of (Evidence (_, _, tab)) = tab

fun mk_evidence class node tab = Evidence (class, node, tab)

fun find_path' ev is_goal =
case is_goal ev of
SOME a =>
SOME ([], a)
| NONE =>
let
fun f (_, (edge, ev)) = Option.map (apfst (cons edge)) (find_path' ev is_goal)
in Symtab.get_first f (parents_of ev) end

fun find_path ev goal =
find_path' ev (fn ev => if class_of ev = goal then SOME () else NONE) |> Option.map fst

fun pretty_ev ctxt (Evidence (class, {qname, ...}, tab)) =
let
val typ = @{typ 'a}
fun mk_super ({super_selector, ...}, super_ev) = Pretty.block
[Pretty.str "selector:",
Pretty.brk 1,
Syntax.pretty_term ctxt (super_selector typ),
Pretty.fbrk,
pretty_ev ctxt super_ev]
val supers = Symtab.dest tab
|> map (fn (_, super) => mk_super super)
|> Pretty.big_list "super classes"
in
Pretty.block
[Pretty.str "Evidence for ",
Syntax.pretty_sort ctxt [class],
Pretty.str ": ",
Syntax.pretty_typ ctxt (Type (qname, [typ])),
Pretty.str (" (qname = " ^ qname ^ ")"),
Pretty.fbrk,
supers]
end

end

structure Classes = Generic_Data
(
type T = (edge Symtab.table * node) Symtab.table
val empty = Symtab.empty
fun merge (t1, t2) =
if Symtab.is_empty t1 andalso Symtab.is_empty t2 then
Symtab.empty
else
error "merging not supported"
val extend = I
)

fun node lthy class =
Symtab.lookup (Classes.get (Context.Proof lthy)) class |> Option.map snd

fun edges lthy class =
Symtab.lookup (Classes.get (Context.Proof lthy)) class |> Option.map fst

val all_nodes =
Context.Proof #> Classes.get #> Symtab.map (K snd)

val all_edges =
Context.Proof #> Classes.get #> Symtab.map (K fst) #> symreltab_of_symtab

fun dict_typ {qname, ...} typ =
Type (qname, [typ])

fun fold_path path typ =
fold (fn {super_selector = s, ...} => fn acc => s typ $acc) path fun mk_super_selector' qualified qname super_ev typ = let val {class = super_class, qname = super_qname, ...} = node_of super_ev val raw_name = mangle super_class ^ "__super" val name = if qualified then Long_Name.append qname raw_name else raw_name in (name, Type (qname, [typ]) --> Type (super_qname, [typ])) end fun mk_node class info super_evs lthy = let fun print_info ctxt = Pretty.block [Pretty.str "Defining record for class ", Syntax.pretty_sort ctxt [class]] |> Pretty.writeln val name = mangle class ^ "__dict" val qname = Local_Theory.full_name lthy (Binding.name name) val tvar = @{typ 'a} val typ = Type (qname, [tvar]) fun mk_field name ftyp = (Binding.name name, ftyp) val params = #params info |> map (fn (name', ftyp) => let val ftyp' = typ_subst_atomic [(TFree ("'a", [class]), @{typ 'a})] ftyp val field_name = mangle name' ^ "__field" val field = mk_field field_name ftyp' fun sel tvar' = Const (Long_Name.append qname field_name, typ_subst_atomic [(tvar, tvar')] (typ --> ftyp')) in (field, (name', sel)) end) val (fields, selectors) = split_list params val super_params = Symtab.dest super_evs |> map (fn (_, super_ev) => let val {cert = raw_super_cert, qname = super_qname, ...} = node_of super_ev val (field_name, _) = mk_super_selector' false qname super_ev tvar val field = mk_field field_name (Type (super_qname, [tvar])) fun sel typ = Const (mk_super_selector' true qname super_ev typ) fun super_cert dict = raw_super_cert tvar$ (sel tvar $dict) val raw_edge = (class_of super_ev, sel) in (field, raw_edge, super_cert) end) val (super_fields, raw_edges, super_certs) = split_list3 super_params val all_fields = super_fields @ fields fun make typ' = Const (Long_Name.append qname "Dict", typ_subst_atomic [(tvar, typ')] (map #2 all_fields ---> typ)) val cert_name = name ^ "__cert" val cert_binding = Binding.name cert_name val cert_body = let fun local_param_eq ((_, typ), (name, sel)) dict = HOLogic.mk_eq (sel tvar$ dict, Const (name, typ))
in
map local_param_eq params @ super_certs
end
val cert_var_name = "dict"
val cert_term =
Abs (cert_var_name, typ,
List.foldr HOLogic.mk_conj @{term True} (map (fn x => x (Bound 0)) cert_body))

fun prove_thms (cert, cert_def) lthy =
let
val var = Free (cert_var_name, typ)
fun tac ctxt = Local_Defs.unfold_tac ctxt [cert_def] THEN blast_tac ctxt 1
fun prove prop =
Goal.prove_future lthy [cert_var_name] [] prop (fn {context, ...} => tac context)

fun mk_dest_props raw_prop =
HOLogic.mk_Trueprop (cert $var) ==> HOLogic.mk_Trueprop (raw_prop var) fun mk_intro_cond raw_prop = HOLogic.mk_Trueprop (raw_prop var) val dests = map (fn raw_prop => prove (mk_dest_props raw_prop)) cert_body val intro = prove (map mk_intro_cond cert_body ===> HOLogic.mk_Trueprop (cert$ var))

val (dests', (intro', lthy')) =
note_thms Binding.empty dests lthy ||> note_thm Binding.empty intro

val (param_dests, super_dests) = chop (length params) dests'

fun pre_edges phi =
let
fun mk_edge thm (sc, sel) =
(sc, {super_selector = sel, subclass = Morphism.thm phi thm})
in Symtab.make (map2 mk_edge super_dests raw_edges) end
in
((param_dests, pre_edges, intro'), lthy')
end

val constructor =
(((Binding.empty, Binding.name "Dict"), all_fields), NoSyn)
val datatyp =
(([(NONE, (@{typ 'a}, @{sort type}))], Binding.name name), NoSyn)

val dtspec =
(Ctr_Sugar.default_ctr_options,
[(((datatyp, [constructor]), (Binding.empty, Binding.empty, Binding.empty)), [])])

val (((raw_cert, raw_cert_def), (param_dests, pre_edges, intro)), (lthy', lthy)) = lthy
|> tap print_info
|> BNF_FP_Def_Sugar.co_datatypes BNF_Util.Least_FP BNF_LFP.construct_lfp dtspec
(* FIXME ideally BNF would return a fp_sugar value right here so that I can avoid constructing
long names by hand above *)
|> (snd o Local_Theory.begin_nested)
|> Local_Theory.define ((cert_binding, NoSyn), ((Thm.def_binding cert_binding, []), cert_term))
|>> apsnd snd
|> (fn (raw_cert, lthy) => prove_thms raw_cert lthy |>> pair raw_cert)
||> Local_Theory.end_nested

val phi = Proof_Context.export_morphism lthy lthy'
fun cert typ = subst_TVars [(("'a", 0), typ)] (Morphism.term phi raw_cert)
val cert_def = Morphism.thm phi raw_cert_def
val edges = pre_edges phi
val param_dests' = map (Morphism.thm phi) param_dests
val intro' = Morphism.thm phi intro

val data_thms =
BNF_FP_Def_Sugar.fp_sugar_of lthy' qname
|> the |> #fp_ctr_sugar |> #ctr_sugar |> #sel_thmss |> flat
|> map safe_mk_meta_eq

val node =
{class = class,
qname = qname,
selectors = Symtab.make selectors,
make = make,
data_thms = data_thms,
cert = cert,
cert_thms = (cert_def, intro', param_dests')}
in (node, edges, lthy') end

fun ensure_class class lthy =
if not (Class.is_class (Proof_Context.theory_of lthy) class) then
error ("not a proper class: " ^ class)
else
let
val thy = Proof_Context.theory_of lthy
val super_classes = super_classes class thy
fun collect_super mk_node =
let
val (super_evs, lthy') = fold_map ensure_class super_classes lthy
val raw_tab = Symtab.make (super_classes ~~ super_evs)
val (node, edges, lthy'') = mk_node raw_tab lthy'
val tab = zip_symtabs pair edges raw_tab
val ev = mk_evidence class node tab
in (ev, edges, lthy'') end
in
case Symtab.lookup (Classes.get (Context.Proof lthy)) class of
SOME (edge_tab, node) =>
if super_classes = Symtab.keys edge_tab then
let val (ev, _, lthy') = collect_super (fn _ => fn lthy => (node, edge_tab, lthy)) in
(ev, lthy')
end
else
(* This happens when a new subclass relationship is established which subsumes or
augments previous superclasses. *)
error "class with different super classes"
| NONE =>
let
val ax_info = Axclass.get_info thy class
val (ev, edges, lthy') = collect_super (mk_node class ax_info)
val upd = Symtab.update_new (class, (edges, node_of ev))
in
(ev, Local_Theory.declaration {pervasive = false, syntax = false} (K (Classes.map upd)) lthy')
end
end

end


# File ‹dict_construction.ML›

signature DICT_CONSTRUCTION =
sig
datatype cert_proof = Cert | Skip

type const

type 'a sccs = (string * 'a) list list

val annotate_code_eqs: local_theory -> string list -> (const sccs * local_theory)
val new_names: local_theory -> const sccs -> (string * const) sccs
val symtab_of_sccs: 'a sccs -> 'a Symtab.table

val axclass: class -> local_theory -> Class_Graph.node * local_theory
val instance: (string * const) Symtab.table -> string -> class -> local_theory -> term * local_theory
val term: term Symreltab.table -> (string * const) Symtab.table -> term -> local_theory -> (term * local_theory)
val consts: (string * const) Symtab.table -> cert_proof -> (string * const) list -> local_theory -> local_theory

(* certification *)

type const_info =
{fun_info: Function.info option,
inducts: thm list option,
base_thms: thm list,
base_certs: thm list,
simps: thm list,
code_thms: thm list, (* old defining theorems *)
congs: thm list option}

type fun_target = (string * class) list * (term * term)

type dict_thms =
{base_thms: thm list,
def_thm: thm}

type dict_target = (string * class) list * (term * string * class)

val prove_fun_cert: fun_target list -> const_info -> cert_proof -> local_theory -> thm list
val prove_dict_cert: dict_target -> dict_thms -> local_theory -> thm

val the_info: Proof.context -> string -> const_info

(* utilities *)

val normalizer_conv: Proof.context -> conv
val cong_of_const: Proof.context -> string -> thm option
val get_code_eqs: Proof.context -> string -> thm list
val group_code_eqs: Proof.context -> string list ->
(string * (((string * sort) list * typ) * ((term list * term) * thm option) list)) list list
end

structure Dict_Construction: DICT_CONSTRUCTION =
struct

open Class_Graph
open Dict_Construction_Util

(* FIXME copied from skip_proof.ML *)

val (_, make_thm_cterm) =
Context.>>>

fun make_thm ctxt prop = make_thm_cterm (Thm.cterm_of ctxt prop)

fun cheat_tac ctxt i st =
resolve_tac ctxt [make_thm ctxt (Var (("A", 0), propT))] i st

(** utilities **)

fun cong_of_const ctxt name =
let
Thm.concl_of
#> Logic.dest_equals #> fst
#> strip_comb #> fst
#> dest_Const #> fst
fun applicable thm =
try head thm = SOME name
in
Function_Context_Tree.get_function_congs ctxt
|> filter applicable
|> try hd
end

fun group_code_eqs ctxt consts =
let
val thy = Proof_Context.theory_of ctxt
val graph = #eqngr (Code_Preproc.obtain true { ctxt = ctxt, consts = consts, terms = [] })

fun mk_eqs name = name
|> Code_Preproc.cert graph
|> Code.equations_of_cert thy ||> these
||> map (apsnd fst o apfst (apsnd fst o apfst (map fst)))
|> pair name
in
map (map mk_eqs) (rev (Graph.strong_conn graph))
end

fun get_code_eqs ctxt const =
AList.lookup op = (flat (group_code_eqs ctxt [const])) const
|> the |> snd
|> map snd
|> cat_options
|> map (Conv.fconv_rule (normalizer_conv ctxt))

(** certification **)

datatype cert_proof = Cert | Skip

type const_info =
{fun_info: Function.info option,
inducts: thm list option,
base_thms: thm list,
base_certs: thm list,
simps: thm list,
code_thms: thm list,
congs: thm list option}

fun map_const_info f1 f2 f3 f4 f5 f6 f7 {fun_info, inducts, base_thms, base_certs, simps, code_thms, congs} =
{fun_info = f1 fun_info,
inducts = f2 inducts,
base_thms = f3 base_thms,
base_certs = f4 base_certs,
simps = f5 simps,
code_thms = f6 code_thms,
congs = f7 congs}

fun morph_const_info phi =
map_const_info
(Option.map (Function_Common.transform_function_data phi))
(Option.map (map (Morphism.thm phi)))
(map (Morphism.thm phi))
(map (Morphism.thm phi))
(map (Morphism.thm phi))
I (* sic *)
(Option.map (map (Morphism.thm phi)))

type fun_target = (string * class) list * (term * term)

type dict_thms =
{base_thms: thm list,
def_thm: thm}

type dict_target = (string * class) list * (term * string * class)

fun fun_cert_tac base_thms base_certs simps code_thms =
SOLVED' o Subgoal.FOCUS (fn {prems, context = ctxt, concl, ...} =>
let
val _ =
if_debug ctxt (fn () =>
tracing ("Proving " ^ Syntax.string_of_term ctxt (Thm.term_of concl)))

fun is_ih prem =
Thm.prop_of prem |> Logic.strip_imp_concl |> HOLogic.dest_Trueprop |> can HOLogic.dest_eq

val (ihs, certs) = partition is_ih prems
val super_certs = all_edges ctxt |> Symreltab.dest |> map (#subclass o snd)
val param_dests = all_nodes ctxt |> Symtab.dest |> maps (#3 o #cert_thms o snd)
val congs = Function_Context_Tree.get_function_congs ctxt @ map safe_mk_meta_eq @{thms cong}

val simp_context = (clear_simpset ctxt) addsimps (certs @ super_certs @ base_certs @ base_thms @ param_dests)

val ihs = map (Simplifier.asm_full_simplify simp_context) ihs

val ih_tac =
resolve_tac ctxt ihs THEN_ALL_NEW
(TRY' (SOLVED' (Simplifier.asm_full_simp_tac simp_context)))

val unfold_new =
ANY' (map (CONVERSION o rewr_lhs_head_conv) simps)

val normalize =
CONVERSION (normalizer_conv ctxt)

val unfold_old =
ANY' (map (CONVERSION o rewr_rhs_head_conv) code_thms)

val simp =
CONVERSION (lhs_conv (Simplifier.asm_full_rewrite simp_context))

fun walk_congs i = i |>
((resolve_tac ctxt @{thms refl} ORELSE'
SOLVED' (Simplifier.asm_full_simp_tac simp_context) ORELSE'
ih_tac ORELSE'
Method.assm_tac ctxt ORELSE'
(resolve_tac ctxt @{thms meta_eq_to_obj_eq} THEN'
fo_resolve_tac congs ctxt)) THEN_ALL_NEW walk_congs)

val tacs =
[unfold_new, normalize, unfold_old, simp, walk_congs]
in
EVERY' tacs 1
end)

fun dict_cert_tac class def_thm base_thms =
SOLVED' o Subgoal.FOCUS (fn {prems, context = ctxt, ...} =>
let
val (intro, sels) = case node ctxt class of
SOME {cert_thms = (_, intro, _), data_thms = sels, ...} => (intro, sels)
| NONE => error ("class " ^ class ^ " is not defined")

val apply_intro =
resolve_tac ctxt [intro]

val unfold_dict =
CONVERSION (Conv.rewr_conv def_thm |> Conv.arg_conv |> lhs_conv)

val normalize =
CONVERSION (normalizer_conv ctxt)

val smash_sels =
CONVERSION (lhs_conv (Conv.rewrs_conv sels))

val solve =
resolve_tac ctxt (@{thm HOL.refl} :: base_thms)

val finally =
resolve_tac ctxt prems

val tacs =
[apply_intro, unfold_dict, normalize, smash_sels, solve, finally]
in
EVERY (map (ALLGOALS' ctxt) tacs)
end)

fun prepare_dicts classes names lthy =
let
val sorts = Symtab.make_list classes

fun mk_dicts (param_name, (tvar, class)) =
case node lthy class of
NONE =>
error ("unknown class " ^ class)
| SOME {cert, qname, ...} =>
let
val sort = the (Symtab.lookup sorts tvar)
val param = Free (param_name, Type (qname, [TFree (tvar, sort)]))
in
(param, HOLogic.mk_Trueprop (cert dummyT $param)) end val dict_names = Name.invent_names names "a" classes val names = fold Name.declare (map fst dict_names) names val (dict_params, prems) = split_list (map mk_dicts dict_names) in (dict_params, prems, names) end fun prepare_fun_goal targets lthy = let fun mk_eq (classes, (lhs, rhs)) names = let val (lhs_name, _) = dest_Const lhs val (rhs_name, rhs_typ) = dest_Const rhs val (dict_params, prems, names) = prepare_dicts classes names lthy val param_names = fst (strip_type rhs_typ) |> map (K dummyT) |> Name.invent_names names "a" val names = fold Name.declare (map fst param_names) names val params = map Free param_names val lhs = list_comb (Const (lhs_name, dummyT), dict_params @ params) val rhs = list_comb (Const (rhs_name, dummyT), params) val eq = Const (@{const_name HOL.eq}, dummyT)$ lhs $rhs val all_params = dict_params @ params val eq :: rest = Syntax.check_terms lthy (eq :: prems @ all_params) val (prems, all_params) = unappend (prems, all_params) rest val eq = if is_some (Axclass.inst_of_param (Proof_Context.theory_of lthy) rhs_name) then Thm.cterm_of lthy eq |> conv_result (Conv.arg_conv (normalizer_conv lthy)) else eq val prop = prems ===> HOLogic.mk_Trueprop eq in ((all_params, prop), names) end in fold_map mk_eq targets Name.context |> fst |> split_list end fun prepare_dict_goal (classes, (term, _, class)) lthy = let val cert = case node lthy class of NONE => error ("unknown class " ^ class) | SOME {cert, ...} => cert dummyT val names = Name.context val (dict_params, prems, _) = prepare_dicts classes names lthy val (term_name, _) = dest_Const term val dict = list_comb (Const (term_name, dummyT), dict_params) val prop = prems ===> HOLogic.mk_Trueprop (cert$ dict)
val prop :: dict_params = Syntax.check_terms lthy (prop :: dict_params)
in
(dict_params, prop)
end

fun prove_fun_cert targets {inducts, base_thms, base_certs, simps, code_thms, ...} proof lthy =
let
(* the props contain dictionary certs as prems
we can't exclude them from the induction because the dicts are part of the function
definition
excluding them would mean that applying the induction rules becomes tricky or impossible
proper fix would be if fun, akin to inductive, supported a "for" clause that marks parameters
as "not changing" *)
val (argss, props) = prepare_fun_goal targets lthy
val frees = flat argss |> map (fst o dest_Free)

(* we first prove the extensional variant (easier to prove), and then derive the
contracted variant
abs_def can't deal with premises, so we use our own version here *)
val tac =
case proof of
Cert => fun_cert_tac base_thms base_certs simps code_thms
| Skip => cheat_tac

val long_thms =
prove_common' lthy frees [] props (fn {context, ...} =>
maybe_induct_tac inducts argss [] context THEN
ALLGOALS' context (tac context))
in
map (contract lthy) long_thms
end

fun prove_dict_cert target {base_thms, def_thm} lthy =
let
val (args, prop) = prepare_dict_goal target lthy
val frees = map (fst o dest_Free) args
val (_, (_, _, class)) = target
in
prove' lthy frees [] prop (fn {context, ...} =>
dict_cert_tac class def_thm base_thms context 1)
end

(** background data **)

type definitions =
{instantiations: (term * thm) Symreltab.table, (* key: (class, tyco) *)
constants: (string * (thm option * const_info)) Symtab.table (* key: constant name *) }

structure Definitions = Generic_Data
(
type T = definitions
val empty = {instantiations = Symreltab.empty, constants = Symtab.empty}
val extend = I
fun merge ({instantiations = i1, constants = c1}, {instantiations = i2, constants = c2}) =
if Symreltab.is_empty i1 andalso Symtab.is_empty c1 andalso
Symreltab.is_empty i2 andalso Symtab.is_empty c2 then
empty
else
error "merging not supported"
)

fun map_definitions map_insts map_consts =
Definitions.map (fn {instantiations, constants} =>
{instantiations = map_insts instantiations,
constants = map_consts constants})

fun the_info ctxt name =
Symtab.lookup (#constants (Definitions.get (Context.Proof ctxt))) name
|> the
|> snd
|> snd

fun add_instantiation (class, tyco) term cert =
let
fun upd phi =
map_definitions
(fn tab =>
if Symreltab.defined tab (class, tyco) then
error ("Duplicate instantiation " ^ quote tyco ^ " :: " ^ quote class)
else
tab
|> Symreltab.update ((class, tyco), (Morphism.term phi term, Morphism.thm phi cert))) I
in
Local_Theory.declaration {pervasive = false, syntax = false} upd
end

fun add_constant name name' (cert, info) lthy =
let
val qname = Local_Theory.full_name lthy (Binding.name name')
fun upd phi =
map_definitions I
(fn tab =>
if Symtab.defined tab name then
error ("Duplicate constant " ^ quote name)
else
tab
|> Symtab.update (name,
(qname, (Option.map (Morphism.thm phi) cert, morph_const_info phi info))))
in
Local_Theory.declaration {pervasive = false, syntax = false} upd lthy
|> Local_Theory.note ((Binding.empty, @{attributes [dict_construction_specs]}), #simps info)
|> snd
end

(** classes **)

fun axclass class =
ensure_class class
#>> node_of

(** grouping and annotating constants **)

datatype const =
Fun of
{dicts: ((string * class) * typ) list,
certs: term list,
param_typs: typ list,
typ: typ, (* typified *)
new_typ: typ,
eqs: {params: term list, rhs: term, thm: thm} list,
info: Function_Common.info option,
cong: thm option} |
Constructor |
Classparam of
{class: class,
typ: typ, (* varified *)
selector: term (* varified *)}

type 'a sccs = (string * 'a) list list

fun symtab_of_sccs x = Symtab.make (flat x)

fun raw_dict_params tparams lthy =
let
fun mk_dict tparam class lthy =
let
val (node, lthy') = axclass class lthy
val targ = TFree (tparam, @{sort type})
val typ = dict_typ node targ
val cert = #cert node targ
in ((((tparam, class), typ), cert), lthy') end
fun mk_dicts (tparam, sort) = fold_map
(mk_dict tparam)
(filter (Class.is_class (Proof_Context.theory_of lthy)) sort)
in fold_map mk_dicts tparams lthy |>> flat end

fun dict_params context dicts =
let
fun dict_param ((_, class), typ) =
Name.variant (mangle class) #>> rpair typ #>> Free
in
fold_map dict_param dicts context
end

fun get_sel class param typ lthy =
let
val ({selectors, ...}, lthy') = axclass class lthy
in
case Symtab.lookup selectors param of
NONE => error ("unknown class parameter " ^ param)
| SOME sel => (sel typ, lthy')
end

fun annotate_const name ((tparams, typ), raw_eqs) lthy =
if Code.is_constr (Proof_Context.theory_of lthy) name then
((name, Constructor), lthy)
else if null raw_eqs then
(* this detection is reliable, because code equations with overloaded heads are not allowed *)
let
val (_, class) = the_single tparams ||> the_single
val (selector, thy') = get_sel class name (TVar (("'a", 0), @{sort type})) lthy
val typ = range_type (fastype_of selector)
in
((name, Classparam {class = class, typ = typ, selector = selector}), thy')
end
else
let
val info = try (Function.get_info lthy) (Const (name, typ))
val cong = cong_of_const lthy name
val ((raw_dicts, certs), lthy') = raw_dict_params tparams lthy |>> split_list
val dict_typs = map snd raw_dicts
val typ' = typify_typ typ
fun mk_eq ((raw_params, rhs), SOME thm) =
let
val norm = normalizer_conv lthy'
val transform = Thm.cterm_of lthy' #> conv_result norm #> typify
val params = map transform raw_params
in
if has_duplicates (op =) (flat (map all_frees' params)) then
(warning "ignoring code equation with non-linear pattern"; NONE)
else
SOME {params = params, rhs = rhs, thm = Conv.fconv_rule norm thm}
end
| mk_eq _ =
error "no theorem"
val const =
Fun
{dicts = raw_dicts, certs = certs, typ = typ', param_typs = binder_types typ',
new_typ = dict_typs ---> typ', eqs = map_filter mk_eq raw_eqs, info = info, cong = cong}
in ((name, const), lthy') end

fun annotate_code_eqs lthy consts =
fold_map (fold_map (uncurry annotate_const)) (group_code_eqs lthy consts) lthy

(** instances and terms **)

fun mk_path [] _ _ lthy = (NONE, lthy)
| mk_path ((class, term) :: rest) typ goal lthy =
let
val (ev, lthy') = ensure_class class lthy
in
case find_path ev goal of
SOME path => (SOME (fold_path path typ term), lthy')
| NONE =>      mk_path rest typ goal lthy'
end

fun instance consts tyco class lthy =
case Symreltab.lookup (#instantiations (Definitions.get (Context.Proof lthy))) (class, tyco) of
SOME (inst, _) =>
(inst, lthy)
| NONE =>
let
val thy = Proof_Context.theory_of lthy
val tparam_sorts = param_sorts tyco class thy

fun print_info ctxt =
let
val tvars =
Name.invent_list [] Name.aT (length tparam_sorts) ~~ tparam_sorts
|> map TFree
in
[Pretty.str "Defining instance ", Syntax.pretty_typ ctxt (Type (tyco, tvars)),
Pretty.str " :: ", Syntax.pretty_sort ctxt [class]]
|> Pretty.block
|> Pretty.writeln
end

val ({make, ...}, lthy) = axclass class lthy

val name = mangle class ^ "__instance__" ^ mangle tyco
val tparams = Name.invent_names Name.context Name.aT tparam_sorts
val ((dict_params, _), lthy) = raw_dict_params tparams lthy
|>> map fst
|>> dict_params (Name.make_context [name])
val dict_context = Symreltab.make (flat_right tparams ~~ dict_params)

val {params, ...} = Axclass.get_info thy class

val (super_fields, lthy) = fold_map
(obtain_dict dict_context consts (Type (tyco, map TFree tparams)))
(super_classes class thy)
lthy

val tparams' = map (TFree o rpair @{sort type} o fst) tparams
val typ_inst = (TFree ("'a", [class]), Type (tyco, tparams'))

fun mk_field (field, typ) =
let
val param = Axclass.param_of_inst thy (field, tyco)
(* if not: abort (else we would run into an infinite loop) *)
val _ = case Symtab.lookup (#constants (Definitions.get (Context.Proof lthy))) param of
NONE =>
(* necessary for zero_nat *)
if Code.is_constr thy param then
()
else
error ("cyclic dependency: " ^ param ^ " not yet defined in the definition of " ^ tyco ^ " :: " ^ class)
| SOME _ => ()
in
term dict_context consts (Const (param, typ_subst_atomic [typ_inst] typ))
end

val (fields, lthy) = fold_map mk_field params lthy

val rhs = list_comb (make (Type (tyco, tparams')), super_fields @ fields)
val typ = map fastype_of dict_params ---> fastype_of rhs
val head = Free (name, typ)
val lhs = list_comb (head, dict_params)
val term = Logic.mk_equals (lhs, rhs)

val (def, (lthy', lthy)) = lthy
|> tap print_info
|> (snd o Local_Theory.begin_nested)
|> define_params_nosyn term
||> Local_Theory.end_nested
val phi = Proof_Context.export_morphism lthy lthy'
val def = Morphism.thm phi def

val base_thms =
Definitions.get (Context.Proof lthy') |> #constants |> Symtab.dest
|> map (apsnd fst o snd)
|> map_filter snd

val target = (flat_right tparams, (Morphism.term phi head, tyco, class))
val args = {base_thms = base_thms, def_thm = def}
val thm = prove_dict_cert target args lthy'

val const = Const (Local_Theory.full_name lthy' (Binding.name name), typ)
in
(const, add_instantiation (class, tyco) const thm lthy')
end
and obtain_dict dict_context consts =
let
val dict_context' = Symreltab.dest dict_context
fun for_class (Type (tyco, args)) class lthy =
let
val inst_param_sorts = param_sorts tyco class (Proof_Context.theory_of lthy)
val (raw_inst, lthy') = instance consts tyco class lthy
val (const_name, _) = dest_Const raw_inst
val (inst_args, lthy'') = fold_map for_sort (inst_param_sorts ~~ args) lthy'
val head = Sign.mk_const (Proof_Context.theory_of lthy'') (const_name, args)
in
end
| for_class (TFree (name, _)) class lthy =
let
val available = map_filter
(fn ((tp, class), term) => if tp = name then SOME (class, term) else NONE)
dict_context'
val (path, lthy') = mk_path available (TFree (name, @{sort type})) class lthy
in
case path of
SOME term => (term, lthy')
| NONE => error "no path found"
end
| for_class (TVar _) _ _ = error "unexpected type variable"
and for_sort (sort, typ) =
fold_map (for_class typ) sort
in for_class end
and term dict_context consts term lthy =
let
fun traverse (t as Const (name, typ)) lthy =
(case Symtab.lookup consts name of
NONE => error ("unknown constant " ^ name)
| SOME (_, Constructor) =>
(typify t, lthy)
| SOME (_, Classparam {class, typ = typ', selector}) =>
let
val subst = Sign.typ_match (Proof_Context.theory_of lthy) (typ', typ) Vartab.empty
val (_, targ) = the (Vartab.lookup subst ("'a", 0))
val (dict, lthy') = obtain_dict dict_context consts targ class lthy
in
(subst_TVars [(("'a", 0), targ)] selector $dict, lthy') end | SOME (name', Fun {dicts = dicts, typ = typ', new_typ, ...}) => let val subst = Type.raw_match (Logic.varifyT_global typ', typ) Vartab.empty |> Vartab.dest |> map (apsnd snd) fun lookup tparam = the (AList.lookup (op =) subst (tparam, 0)) val (dicts, lthy') = fold_map (uncurry (obtain_dict dict_context consts o lookup)) (map fst dicts) lthy val typ = typ_subst_TVars subst (Logic.varifyT_global new_typ) val head = case Symtab.lookup (#constants (Definitions.get (Context.Proof lthy))) name of NONE => Free (name', typ) | SOME (n, _) => Const (n, typ) val res = list_comb (head, dicts) in (res, lthy') end) | traverse (f$ x) lthy =
let
val (f', lthy')  = traverse f lthy
val (x', lthy'') = traverse x lthy'
in (f' \$ x', lthy'') end
| traverse (Abs (name, typ, term)) lthy =
let
val (term', lthy') = traverse term lthy
in (Abs (name, typify_typ typ, term'), lthy') end
| traverse (Free (name, typ)) lthy = (Free (name, typify_typ typ), lthy)
| traverse (Var (name, typ)) lthy  = (Var (name, typify_typ typ), lthy)
| traverse (Bound n) lthy = (Bound n, lthy)
in
traverse term lthy
end

(** group of constants **)

fun new_names lthy consts =
let
val (all_names, all_consts) = split_list (flat consts)
val all_frees = map (fn Fun {eqs, ...} => eqs | _ => []) all_consts |> flat
|> map #params |> flat
|> map all_frees' |> flat
val context = fold Name.declare (all_names @ all_frees) (Variable.names_of lthy)

fun new_name (name, const) context =
let val (name', context') = Name.variant (mangle name) context in
((name, (name', const)), context')
end
in
fst (fold_map (fold_map new_name) consts context)
end

fun consts consts proof group lthy =
let
val fun_config = Function_Common.FunctionConfig
{sequential=true, default=NONE, domintros=false, partials=false}
fun pat_completeness_auto ctxt =
Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt

val all_names = map fst group

val pretty_consts = map (pretty_const lthy) all_names |> Pretty.commas

fun print_info msg =
Pretty.str (msg ^ " ") :: pretty_consts
|> Pretty.block
|> Pretty.writeln

val _ = print_info "Redefining constant(s)"

fun process_eqs (name, Fun {dicts, param_typs, new_typ, eqs, info, cong, ...}) lthy =
let
val new_name = case Symtab.lookup consts name of
NONE => error ("no new name for " ^ name)
| SOME (n, _) => n

val all_frees = map #params eqs |> flat |> map all_frees' |> flat
val context = Name.make_context (all_names @ all_frees)
val (dict_params, context') = dict_params context dicts

let
val real_params = dict_params @ params
val ext_params = drop (length params) param_typs
|> map typify_typ
|> Name.invent_names context' "e0" |> map Free
in (real_params, ext_params) end

fun mk_eq {params, rhs, thm} lthy =
let
val (real_params, ext_params) = adapt_params param_typs params
val lhs' = list_comb (Free (new_name, new_typ), real_params @ ext_params)
val (rhs', lthy') = term (Symreltab.make (map fst dicts ~~ dict_params)) consts rhs lthy
val rhs'' = list_comb (rhs', ext_params)
in
((HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs', rhs'')), thm), lthy')
end

val is_fun = length param_typs + length dicts > 0
in
fold_map mk_eq eqs lthy
|>> rpair (new_typ, is_fun)
|>> SOME
|>> pair ((name, new_name, map fst dicts), {info = info, cong = cong})
end
| process_eqs (name, _) lthy =
((((name, name, []), {info = NONE, cong = NONE}), NONE), lthy)

val (items, lthy') = fold_map process_eqs group lthy

val ((metas, infos), ((eqs, code_thms), (new_typs, is_funs))) = items
|> map_filter (fn (meta, eqs) => Option.map (pair meta) eqs)
|> split_list
||> split_list ||> apfst (flat #> split_list #>> map typify)
||> apsnd split_list
|>> split_list

val _ = if_debug lthy (fn () =>
if null code_thms then ()
else
map (Syntax.pretty_term lthy o Thm.prop_of) code_thms
|> Pretty.big_list "Equations:"
|> Pretty.string_of
|> tracing)

val is_fun =
case distinct (op =) is_funs of
[b] => b
| [] => false
| _ => error "unsupported feature: mixed non-function and function definitions"
fun mk_binding (_, new_name, _) typ =
(Binding.name new_name, SOME typ, NoSyn)
val bindings = map2 mk_binding metas new_typs

val {constants, instantiations} = Definitions.get (Context.Proof lthy')
val base_thms = Symtab.dest constants |> map (apsnd fst o snd) |> map_filter snd
val base_certs = Symreltab.dest instantiations |> map (snd o snd)

val consts = Sign.consts_of (Proof_Context.theory_of lthy')

fun prove_eq_fun (info as {simps = SOME simps, fs, inducts = SOME inducts, ...}) lthy =
let
fun mk_target (name, _, classes) new =
(classes, (new, Const (Consts.the_const consts name)))
val targets = map2 mk_target metas fs
val args =
{fun_info = SOME info, inducts = SOME inducts, simps = simps, base_thms = base_thms,
base_certs = base_certs, code_thms = code_thms, congs = NONE}
in
(prove_fun_cert targets args proof lthy, args)
end

fun prove_eq_def defs lthy =
let
fun mk_target (name, _, classes) new =
(classes, (new, Const (Consts.the_const consts name)))
val targets = map2 mk_target metas (map (fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o Thm.prop_of) defs)
val args =
{fun_info = NONE, inducts = NONE, simps = defs,
base_thms = base_thms, base_certs = base_certs, code_thms = code_thms, congs = NONE}
in
(prove_fun_cert targets args proof lthy, args)
end

fun add_constants ((((name, name', _), _), SOME _) :: xs) ((thm :: thms), info) =
| add_constants ((((name, name', _), _), NONE) :: xs) (thms, info) =
I

fun prove_termination new_info ctxt =
let
val termination_ctxt =
ctxt addsimps (@{thms equal} @ base_thms)
val fallback_tac =
Function_Common.termination_prover_tac true termination_ctxt

val tac = case try hd (cat_options (map #info infos)) of
SOME old_info => HEADGOAL (Transfer_Termination.termination_tac new_info old_info ctxt)
| NONE => no_tac

in Function.prove_termination NONE (tac ORELSE fallback_tac) ctxt end

fun prove_cong data lthy =
let
fun rewr_cong thm cong =
if Thm.nprems_of thm > 0 then
(warning "No fundef_cong rule can be derived; this will likely not work later"; NONE)
else
(print_info "Porting fundef_cong rule for ";
SOME (Local_Defs.fold lthy [thm] cong))

val congs' =
map2 (Option.mapPartial o rewr_cong) (fst data) (map #cong infos)
|> cat_options

fold Function_Context_Tree.add_function_cong (map (Morphism.thm phi) congs')

val data' =
apsnd (map_const_info I I I I I I (K (SOME congs'))) data
in
(data', Local_Theory.declaration {pervasive = false, syntax = false} add_congs lthy)
end

fun mk_fun lthy =
let
val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) eqs
val (info, lthy') =
Function.add_function bindings specs fun_config pat_completeness_auto lthy
|-> prove_termination
val simps = the (#simps info)
val (_, lthy'') =
(* [simp del] is required because otherwise non-matching function definitions
(e.g. divmod_nat) make the simplifier loop
separate step because otherwise we'll get tons of warnings because the psimp rules
are not added to the simpset *)
Local_Theory.note ((Binding.empty, @{attributes [simp del]}), simps) lthy'
fun prove_eq phi =
prove_eq_fun (Function_Common.transform_function_data phi info)
in
(((simps, #inducts info), prove_eq), lthy'')
end

fun mk_def lthy =
let
val (defs, lthy') = fold_map define_params_nosyn eqs lthy
fun prove_eq phi = prove_eq_def (map (Morphism.thm phi) defs)
in
(((defs, NONE), prove_eq), lthy')
end
in
if null eqs then
lthy'
else
let
(* the redefinition itself doesn't have a sort constraint, but the equality prop may have
one; hence the proof needs to happen after exiting the local theory target
conceptually, everything happening locally would be great, but the type checker won't
allow us to add sort constraints to TFrees after they have been declared *)
val ((side, prove_eq), (lthy', lthy)) = lthy'
|> (snd o Local_Theory.begin_nested)
|> (if is_fun then mk_fun else mk_def)
|-> (fn ((simps, inducts), prove_eq) =>
apfst (rpair prove_eq) o Side_Conditions.mk_side simps inducts)
||> Local_Theory.end_nested
val phi = Proof_Context.export_morphism lthy lthy'
in
lthy'
|> (prove_eq phi)
|>> apfst (on_thms_complete (fn () => print_info "Proved equivalence for"))
|-> prove_cong
end
end

fun const_raw (binding, raw_consts) proof lthy =
let
val _ =
if proof = Skip then
warning "Skipping certificate proofs"
else ()

val (name, _) = Syntax.read_terms lthy raw_consts |> map dest_Const |> split_list

val (eqs, lthy) = annotate_code_eqs lthy name
val tab = symtab_of_sccs (new_names lthy eqs)

val lthy = fold (consts tab proof) eqs lthy

val {instantiations, constants} = Definitions.get (Context.Proof lthy)
val thms =
map (snd o snd) (Symreltab.dest instantiations) @
map_filter (fst o snd o snd) (Symtab.dest constants)
in
snd (Local_Theory.note (binding, thms) lthy)
end

(** setup **)

val parse_flags =
Scan.optional (Args.parens (Parse.reserved "skip" >> K Skip)) Cert

val _ =
Outer_Syntax.local_theory
@{command_keyword "declassify"}
"redefines a constant after applying the dictionary construction"
(parse_flags -- Parse_Spec.opt_thm_name ":" -- Scan.repeat1 Parse.const >>
(fn ((flags, def_binding), consts) => const_raw (def_binding, consts) flags))

end

# Theory Termination

section ‹Termination heuristics›
text_raw ‹\label{sec:termination}›

theory Termination
imports "../Dict_Construction"
begin

text ‹
As indicated in the introduction, the newly-defined functions must be proven terminating. In
general, we cannot reuse the original termination proof, as the following example illustrates:
›

fun f :: "nat ⇒ nat" where
"f 0 = 0" |
"f (Suc n) = f n"

lemma [code]: "f x = f x" ..

text ‹
The invocation of @{theory_text ‹declassify f›} would fail, because @{const f}'s code equations
are not terminating.

Hence, in the general case where users have modified the code equations, we need to fall back
to an (automated) attempt to prove termination.

In the remainder of this section, we will illustrate the special case where the user has not
modified the code equations, i.e., the original termination proof should morally'' be still
applicable. For this, we will perform the dictionary construction manually.
›

― ‹Some ML incantations to ensure that the dictionary types are present›
local_setup ‹Class_Graph.ensure_class @{class plus} #> snd›
local_setup ‹Class_Graph.ensure_class @{class zero} #> snd›

fun sum_list :: "'a::{plus,zero} list ⇒ 'a" where
"sum_list [] = 0" |
"sum_list (x # xs) = x + sum_list xs"

text ‹
The above function carries two distinct class constraints, which are translated into two
dictionary parameters:
›

function sum_list' where
"sum_list' d_plus d_zero [] = Groups_zero__class_zero__field d_zero" |
"sum_list' d_plus d_zero (x # xs) = Groups_plus__class_plus__field d_plus x (sum_list' d_plus d_zero xs)"
by pat_completeness auto

text ‹
Now, we need to carry out the termination proof of @{const sum_list'}. The @{theory_text function}
package analyzes the function definition and discovers one recursive call. In pseudo-notation:

@{text [display] ‹(d_plus, d_zero, x # xs) ↝ (d_plus, d_zero, xs)›}

The result of this analysis is captured in the inductive predicate @{const sum_list'_rel}. Its
introduction rules look as follows:
›

thm sum_list'_rel.intros
― ‹@{thm sum_list'_rel.intros}›

text ‹Compare this to the relation for @{const sum_list}:›

thm sum_list_rel.intros
― ‹@{thm sum_list_rel.intros}›

text ‹
Except for the additional (unchanging) dictionary arguments, these relations are more or less
equivalent to each other. There is an important difference, though: @{const sum_list_rel} has
sort constraints, @{const sum_list'_rel} does not. (This will become important later on.)
›

context
notes [[show_sorts]]
begin

term sum_list_rel
― ‹@{typ ‹'a::{plus,zero} list ⇒ 'a::{plus,zero} list ⇒ bool›}›

term sum_list'_rel
― ‹@{typ ‹'a::type Groups_plus__dict × 'a::type Groups_zero__dict × 'a::type list ⇒ 'a::type Groups_plus__dict × 'a::type Groups_zero__dict × 'a::type list ⇒ bool›}›

end

text ‹
Let us know discuss the rough concept of the termination proof for @{const sum_list'}. The goal is
to show that @{const sum_list'_rel} is well-founded. Usually, this is proved by specifying a
∗‹measure function› that
▸ maps the arguments to natural numbers
▸ decreases for each recursive call.
›
text ‹
Here, however, we want to instead show that each recursive call in @{const sum_list'} has a
corresponding recursive call in @{const sum_list}. In other words, we want to show that the
existing proof of well-foundedness of @{const sum_list_rel} can be lifted to a proof of
well-foundedness of @{const sum_list'_rel}. This is what the theorem
@{thm [source=true] wfP_simulate_simple} states:

@{thm [display=true] wfP_simulate_simple}

Given any well-founded relation ‹r› and a function ‹g› that maps function arguments from ‹r'› to
‹r›, we can deduce that ‹r'› is also well-founded.

For our example, we need to provide a function ‹g› of type
@{typ ‹'b Groups_plus__dict × 'b Groups_zero__dict × 'b list ⇒ 'a list›}.
Because the dictionary parameters are not changing, they can safely be dropped by ‹g›.
However, because of the sort constraint in @{const sum_list_rel}, the term @{term "snd ∘ snd"}
is not a well-typed instantiation for ‹g›.

Instead (this is where the heuristic comes in), we assume that the original function
@{const sum_list} is parametric, i.e., termination does not depend on the elements of the list
passed to it, but only on the structure of the list. Additionally, we assume that all involved
type classes have at least one instantiation.

With this in mind, we can use @{term "map (λ_. undefined) ∘ snd ∘ snd"} as ‹g›:
›

thm wfP_simulate_simple[where
r = sum_list_rel and
r' = sum_list'_rel and
g = "map (λ_. undefined) ∘ snd ∘ snd"]

text ‹
Finally, we can prove the termination of @{const sum_list'}.
›

termination sum_list'
proof -
have "wfP sum_list'_rel"
proof (rule wfP_simulate_simple)
― ‹We first need to obtain the well-foundedness theorem for @{const sum_list_rel} from the ML
guts of the @{theory_text function} package.›
show "wfP sum_list_rel"
apply (rule accp_wfPI)
apply (tactic ‹resolve_tac @{context} [Function.get_info @{context} @{term sum_list} |> #totality |> the] 1›)
done

define g :: "'b Groups_plus__dict × 'b Groups_zero__dict × 'b list ⇒ 'c::{plus,zero} list" where
"g = map (λ_. undefined) ∘ snd ∘ snd"

― ‹Prove the simulation of @{const sum_list'_rel} by @{const sum_list_rel} by rule induction.›
show "sum_list_rel (g x) (g y)" if "sum_list'_rel x y" for x y
using that
proof (induction x y rule: sum_list'_rel.induct)
case (1 d_plus d_zero x xs)
show ?case
― ‹Unfold the constituent parts of @{term g}:›
apply (simp only: g_def comp_apply snd_conv list.map)
― ‹Use the corresponding introduction rule of @{const sum_list_rel} and hope for the best:›
apply (rule sum_list_rel.intros(1))
done
qed
qed

― ‹This is the goal that the @{theory_text function} package expects.›
then show "∀x. sum_list'_dom x"
by (rule wfP_implies_dom)
qed

text ‹This can be automated with a special tactic:›

experiment
begin

termination sum_list'
apply (tactic ‹
Transfer_Termination.termination_tac
(Function.get_info @{context} @{term sum_list'})
(Function.get_info @{context} @{term sum_list})
@{context}
1›; fail)
done

end

text ‹
A similar technique can be used for making functions defined in locales executable when, for some
reason, the definition of a defs'' locale is not feasible.
›

locale foo =
fixes A :: "nat"
assumes "A > 0"
begin

fun f where
"f 0 = A" |
"f (Suc n) = Suc (f n)"

― ‹We carry out this proof in the locale for simplicity; a real implementation would probably
have to set up a local theory properly.›
lemma f_total: "wfP f_rel"
apply (rule accp_wfPI)
apply (tactic ‹resolve_tac @{context} [Function.get_info @{context} @{term f} |> #totality |> the] 1›)
done

end

― ‹The dummy interpretation serves the same purpose as the assumption that class constraints have
at least one instantiation.›
interpretation dummy: foo 1 by standard simp

function f' where
"f' A 0 = A" |
"f' A (Suc n) = Suc (f' A n)"
by pat_completeness auto

termination f'
apply (rule wfP_implies_dom)
apply (rule wfP_simulate_simple[where g = "snd"])
apply (rule dummy.f_total)
subgoal for x y
apply (induction x y rule: f'_rel.induct)
subgoal
apply (simp only: snd_conv)
apply (rule dummy.f_rel.intros)
done
done
done

text ‹Automatic:›

experiment
begin

termination f'
apply (tactic ‹
Transfer_Termination.termination_tac
(Function.get_info @{context} @{term f'})
(Function.get_info @{context} @{term dummy.f})
@{context}
1›; fail)
done

end

end

# Theory Test_Dict_Construction

section ‹Test cases for dictionary construction›

theory Test_Dict_Construction
imports
Dict_Construction
"HOL-Library.ListVector"
begin

subsection ‹Code equations with different number of explicit arguments›

lemma [code]: "fold f [] = id" "fold f (x # xs) s = fold f xs (f x s)" "fold f [x, y] u ≡ f y (f x u)"
by auto

experiment begin

declassify valid: fold
thm valid
lemma "List_fold = fold" by (rule valid)

end

subsection ‹Complex class hierarchies›

local_setup ‹Class_Graph.ensure_class @{class zero} #> snd›
local_setup ‹Class_Graph.ensure_class @{class plus} #> snd›

experiment begin

local_setup ‹Class_Graph.ensure_class @{class comm_monoid_add} #> snd›
local_setup ‹Class_Graph.ensure_class @{class ring} #> snd›

typ "nat Rings_ring__dict"

end

text ‹Check that ‹Class_Graph› does not leak out of locales›

ML‹@{assert} (is_none (Class_Graph.node @{context} @{class ring}))›

subsection ‹Instances with non-trivial arity›

fun f :: "'a::plus ⇒ 'a" where
"f x = x + x"

definition g :: "'a::{plus,zero} list ⇒ 'a list" where
"g x = f x"

datatype natt = Z | S natt

instantiation natt :: "{zero,plus}" begin
definition zero_natt where
"zero_natt = Z"

fun plus_natt where
"plus_natt Z x = x" |
"plus_natt (S m) n = S (plus_natt m n)"

instance ..
end

definition h :: "natt list" where
"h = g [Z,S Z]"

experiment begin

(* FIXME problem with smart_tac *)
declassify valid: h
thm valid
lemma "Test__Dict__Construction_h = h" by (fact valid)

ML‹Dict_Construction.the_info @{context} @{const_name plus_natt_inst.plus_natt}›

end

text ‹Check that @{command declassify} does not leak out of locales›

ML‹
can (Dict_Construction.the_info @{context}) @{const_name plus_natt_inst.plus_natt}
|> not |> @{assert}
›

subsection ‹[@{attribute fundef_cong}] rules›

datatype 'a seq = Cons 'a "'a seq" | Nil

experiment begin

declassify map_seq

text ‹Check presence of derived [@{attribute fundef_cong}] rule›

ML‹
Dict_Construction.the_info @{context} @{const_name map_seq}
|> #fun_info
|> the
|> #fs
|> the_single
|> dest_Const
|> fst
|> Dict_Construction.cong_of_const @{context}
|> the
›

end

subsection ‹Mutual recursion›

fun odd :: "nat ⇒ bool" and even where
"odd 0 ⟷ False" |
"even 0 ⟷ True" |
"odd (Suc n) ⟷ even n" |
"even (Suc n) ⟷ odd n"

experiment begin

declassify valid: odd even
thm valid

end

datatype 'a bin_tree = Leaf | Node 'a "'a bin_tree" "'a bin_tree"

experiment begin

declassify valid: map_bin_tree rel_bin_tree
thm valid

end

datatype 'v env = Env "'v list"
datatype v = Closure "v env"

context
notes is_measure_trivial[where f = "size_env size", measure_function]
begin

(* FIXME order is important! *)
fun test_v :: "v ⇒ bool" and test_w :: "v env ⇒ bool" where
"test_v (Closure env) ⟷ test_w env" |
"test_w (Env vs) ⟷ list_all test_v vs"

fun test_v1 :: "v ⇒ 'a::{one,monoid_add}" and test_w1 :: "v env ⇒ 'a" where
"test_v1 (Closure env) = 1 + test_w1 env" |
"test_w1 (Env vs) = sum_list (map test_v1 vs)"

end

experiment begin

declassify valid: test_w test_v
thm valid

end

experiment begin

(* FIXME derive fundef_cong rule for sum_list *)
declassify valid: test_w1 test_v1
thm valid

end

subsection ‹Non-trivial code dependencies; code equations where the head is not fully general›

definition "c ≡ 0 :: nat"
definition "d x ≡ if x = 0 then 0 else x"

lemma contrived[code]: "c = d 0" unfolding c_def d_def by simp

experiment begin

declassify valid: c
thm valid
lemma "Test__Dict__Construction_c = c" by (fact valid)

end

subsection ‹Pattern matching on @{term "0::nat"}›

definition j where "j (n::nat) = (0::nat)"

lemma [code]: "j 0 = 0" "j (Suc n) = j n"
unfolding j_def by auto

fun k where
"k 0 = (0::nat)" |
"k (Suc n) = k n"

lemma f_code[code]: "k n = 0"
by (induct n) simp+

experiment begin

declassify valid: j k
thm valid
lemma
"Test__Dict__Construction_j = j"
"Test__Dict__Construction_k = k"
by (fact valid)+

end

subsection ‹Complex termination arguments›

fun fac :: "nat ⇒ nat" where
"fac n = (if n ≤ 1 then 1 else n * fac (n - 1))"

experiment begin

declassify valid: fac

end

subsection ‹Combination of various things›

experiment begin

declassify valid: sum_list

end

subsection ‹Interaction with the code generator›

declassify h
export_code Test__Dict__Construction_h in SML

end

# Theory Test_Side_Conditions

subsection ‹Contrived side conditions›

theory Test_Side_Conditions
imports Dict_Construction
begin

ML ‹
fun assert_alt_total ctxt term = @{assert} (Side_Conditions.is_total ctxt term)
›

"head (x # _) = x"

local_setup ‹snd o Side_Conditions.mk_side @{thms head.simps} NONE›

fun map where
"map f [] = []" |
"map f (x # xs) = f x # map f xs"

local_setup ‹snd o Side_Conditions.mk_side @{thms map.simps} (SOME @{thms map.induct})›
thm map_side.intros

ML ‹assert_alt_total @{context} @{term map}›

experiment begin

text ‹Functions that use partial functions always in their domain are processed correctly.›

fun tail where
"tail (_ # xs) = xs"

local_setup ‹snd o Side_Conditions.mk_side @{thms tail.simps} NONE›

lemma tail_side_eq: "tail_side xs ⟷ xs ≠ []"
by (cases xs) (auto intro: tail_side.intros elim: tail_side.cases)

declaration ‹K (Side_Conditions.set_alt @{term tail} @{thm tail_side_eq})›

function map' where
"map' f xs = (if xs = [] then [] else f (head xs) # map' f (tail xs))"
by auto

termination
apply (relation "measure (size ∘ snd)")
apply rule
subgoal for f xs by (cases xs) auto
done

local_setup ‹snd o Side_Conditions.mk_side @{thms map'.simps} (SOME @{thms map'.induct})›
thm map'_side.intros

ML ‹assert_alt_total @{context} @{term map'}›

end

lemma map_cong:
assumes "xs = ys" "⋀x. x ∈ set ys ⟹ f x = g x"
shows "map f xs = map g ys"
unfolding assms(1)
using assms(2)
by (induction ys) auto

experiment begin

declare map_cong[fundef_cong]

local_setup ‹snd o Side_Conditions.mk_side @{thms map_head_def} NONE›

lemma "map_head_side xs ⟷ (∀x ∈ set xs. x ≠ [])"

local_setup ‹snd o Side_Conditions.mk_side @{thms map_head'_def} NONE›

lemma "map_head'_side xss ⟷ (∀xs ∈ set xss. ∀x ∈ set xs. x ≠ [])"

end

experiment begin

local_setup ‹snd o Side_Conditions.mk_side @{thms map_head_def} NONE›

end

local_setup ‹snd o Side_Conditions.mk_side @{thms head_known_def} NONE›

fun odd :: "nat ⇒ bool" and even where
"odd 0 ⟷ False" |
"even 0 ⟷ True" |
"odd (Suc n) ⟷ even n" |
"even (Suc n) ⟷ odd n"

local_setup ‹snd o Side_Conditions.mk_side @{thms odd.simps even.simps} (SOME @{thms odd_even.induct})›
thm odd_side_even_side.intros

ML‹assert_alt_total @{context} @{term odd}›
ML‹assert_alt_total @{context} @{term even}›

definition odd_known where
"odd_known = odd (Suc 0)"

local_setup ‹snd o Side_Conditions.mk_side @{thms odd_known_def} NONE›
thm odd_known_side.intros

ML‹assert_alt_total @{context} @{term odd_known}›

end

# Theory Test_Lazy_Case

subsection ‹Interaction with ‹Lazy_Case››

theory Test_Lazy_Case
imports
Dict_Construction
Lazy_Case.Lazy_Case
Show.Show_Instances
begin

datatype 'a tree = Node | Fork 'a "'a tree list"

lemma map_tree[code]:
"map_tree f t = (case t of Node ⇒ Node | Fork x ts ⇒ Fork (f x) (map (map_tree f) ts))"
by (induction t) auto

experiment begin

(* FIXME proper qualified path *)
text ‹
Dictionary construction of @{const map_tree} requires the [@{attribute fundef_cong}] rule of
@{const Test_Lazy_Case.tree.case_lazy}.
›

declassify valid: map_tree
thm valid

lemma "Test__Lazy__Case_tree_map__tree = map_tree" by (fact valid)

end

definition i :: "(unit × (bool list × string × nat option) list) option ⇒ string" where
"i = show"

experiment begin

text ‹This currently requires @{theory Lazy_Case.Lazy_Case} because of @{const divmod_nat}.›

(* FIXME get rid of Lazy_Case dependency *)
declassify valid: i
thm valid

lemma "Test__Lazy__Case_i = i" by (fact valid)

end

end`