Theory Introduction
section ‹Dictionary Construction›
theory Introduction
imports Main
begin
subsection ‹Introduction›
text ‹
Isabelle's logic features \emph{type classes}~\cite{haftmann2007typeclasses,wenzel1997typeclasses}.
These are built into the kernel and are used extensively in theory developments.
The existing \emph{code generator}, when targeting Standard ML, performs the well-known dictionary
construction or \emph{dictionary translation}~\cite{haftmann2010codegeneration}.
This works by replacing type classes with records, instances with values, and occurrences with
explicit parameters.
Haftmann and Nipkow give a pen-and-paper correctness proof of this construction
\cite[‹§›4.1]{haftmann2010codegeneration}, based on a notion of \emph{higher-order rewrite
systems.}
The resulting theorem then states that any well-typed term is reduction-equivalent before and after
class elimination.
In this work, the dictionary construction is performed in a certified fashion, that is, the
equivalence is a theorem inside the logic.
›
subsection ‹Encoding classes›
text ‹
The choice of representation of a dictionary itself is straightforward: We model it as a
@{command datatype}, along with functions returning values of that type. The alternative here
would have been to use the @{command record} package. The obvious advantage is that we could
easily model subclass relationships through record inheritance. However, records do not support
multiple inheritance. Since records offer no advantage over datatypes in that regard, we opted for
the more modern @{command datatype} package.
›
text ‹Consider the following example:›
class plus =
fixes plus :: "'a ⇒ 'a ⇒ 'a"
text ‹
This will get translated to a @{command datatype} with a single constructor taking a single
argument:
›
datatype 'a dict_plus =
mk_plus (param_plus: "'a ⇒ 'a ⇒ 'a")
text ‹A function using the @{class plus} constraint:›
definition double :: "'a::plus ⇒ 'a" where
"double x = plus x x"
definition double' :: "'a dict_plus ⇒ 'a ⇒ 'a" where
"double' dict x = param_plus dict x x"
subsection ‹Encoding instances›
text ‹
A more controversial design decision is how to represent dictionary certificates. For example,
given a value of type @{typ "nat dict_plus"}, how do we know that this is a faithful representation
of the @{class plus} instance for @{typ nat}?
›
text ‹
▪ Florian Haftmann proposed a ``shallow encoding''. It works by exploiting the internal treatment
of constants with sort constraints in the Isabelle kernel. Constants themselves do not carry
sort constraints, only their definitional equations. The fact that a constant only appears with
these constraints on the surface of the system is a feature of type inference.
Instead, we can instruct the system to ignore these constraints. However, any attempt at
``hiding'' the constraints behind a type definition ultimately does not work: The nonemptiness
proof requires a witness of a valid dictionary for an arbitrary, but fixed type @{typ 'a}, which
is of course not possible (see ‹§›\ref{sec:impossibility} for details).
▪ The certificates contain the class axioms directly. For example, the @{class semigroup_add}
class requires @{term "(a + b) + c = a + (b + c)"}.
Translated into a definition, this would look as follows:
@{term
"cert_plus dict ⟷
(∀a b c. param_plus dict (param_plus dict a b) c = param_plus dict a (param_plus dict b c))"}
Proving that instances satisfy this certificate is trivial.
However, the equality proof of ‹f'› and ‹f› is impossible: they are simply not equal in general.
Nothing would prevent someone from defining an alternative dictionary using multiplication
instead of addition and the certificate would still hold; but obviously functions using
@{const plus} on numbers would expect addition.
Intuitively, this makes sense: the above notion of ``certificate'' establishes no connection
between original instantiation and newly-generated dictionaries.
Instead of proving equality, one would have to ``lift'' all existing theorems over the old
constants to the new constants.
▪ In order for equality between new and old constants to hold, the certificate needs to capture
that the dictionary corresponds exactly to the class constants. This is achieved by the
representation below.
It literally states that the fields of the dictionary are equal to the class constants.
The condition of the resulting equation can only be instantiated with dictionaries corresponding
to existing class instances. This constitutes a ∗‹closed world› assumption, i.e., callers of
generated code may not invent own instantiations.
›
definition cert_plus :: "'a::plus dict_plus ⇒ bool" where
"cert_plus dict ⟷ (param_plus dict = plus)"
text ‹
Based on that definition, we can prove that @{const double} and @{const double'} are equivalent:
›
lemma "cert_plus dict ⟹ double' dict = double"
unfolding cert_plus_def double'_def double_def
by auto
text ‹
An unconditional equation can be obtained by specializing the theorem to a ground type and
supplying a valid dictionary.
›
subsection ‹Implementation›
text ‹
When translating a constant ‹f›, we use existing mechanisms in Isabelle to obtain its
∗‹code graph›. The graph contains the code equations of all transitive dependencies (i.e.,
other constants) of ‹f›. In general, we have to re-define each of these dependencies. For that,
we use the internal interface of the @{command function} package and feed it the code equations
after performing the dictionary construction. In the standard case, where the user has not
performed a custom code setup, the resulting function looks similar to its original definition.
But the user may have also changed the implementation of a function significantly afterwards.
This imposes some restrictions:
▪ The new constant needs to be proven terminating. We apply some heuristics to transfer the
original termination proof to the new definition. This only works when the termination condition
does not rely on class axioms. (See ‹§›\ref{sec:termination} for details.)
▪ Pattern matching must be performed on datatypes, instead of the more general
@{command code_datatype}s.
▪ The set of code equations must be exhaustive and non-overlapping.
›
end
Theory Impossibility
subsection ‹Impossibility of hiding sort constraints›
text_raw ‹\label{sec:impossibility}›
text ‹Coauthor of this section: Florian Haftmann›
theory Impossibility
imports Main
begin
axiomatization of_prop :: "prop ⇒ bool" where
of_prop_Trueprop [simp]: "of_prop (Trueprop P) ⟷ P" and
Trueprop_of_prop [simp]: "Trueprop (of_prop Q) ≡ PROP Q"
text ‹A type satisfies the certificate if there is an instance of the class.›
definition is_sg :: "'a itself ⇒ bool" where
"is_sg TYPE('a) = of_prop OFCLASS('a, semigroup_add_class)"
text ‹We trick the parser into ignoring the sort constraint of @{const plus}.›
setup ‹Sign.add_const_constraint (@{const_name plus}, SOME @{typ "'a::{} => 'a ⇒ 'a"})›
definition sg :: "('a ⇒ 'a ⇒ 'a) ⇒ bool" where
"sg plus ⟷ plus = Groups.plus ∧ is_sg TYPE('a)" for plus
text ‹Attempt: Define a type that contains all legal @{const plus} functions.›
typedef (overloaded) 'a Sg = "Collect sg :: ('a ⇒ 'a ⇒ 'a) set"
morphisms the_plus Sg
unfolding sg_def[abs_def]
apply (simp add: is_sg_def)
text ‹We need to prove @{term "OFCLASS('a, semigroup_add_class)"} for arbitrary @{typ 'a}, which is
impossible.›
oops
end
Theory Dict_Construction
section ‹Setup›
theory Dict_Construction
imports Automatic_Refinement.Refine_Util
keywords "declassify" :: thy_decl
begin
definition set_of :: "('a ⇒ 'b ⇒ bool) ⇒ ('a × 'b) set" where
"set_of P = {(x, y). P x y}"
lemma wfP_implies_wf_set_of: "wfP P ⟹ wf (set_of P)"
unfolding wfP_def set_of_def .
lemma wf_set_of_implies_wfP: "wf (set_of P) ⟹ wfP P"
unfolding wfP_def set_of_def .
lemma wf_simulate_simple:
assumes "wf r"
assumes "⋀x y. (x, y) ∈ r' ⟹ (g x, g y) ∈ r"
shows "wf r'"
using assms
by (metis in_inv_image wf_eq_minimal wf_inv_image)
lemma set_ofI: "P x y ⟹ (x, y) ∈ set_of P"
unfolding set_of_def by simp
lemma set_ofD: "(x, y) ∈ set_of P ⟹ P x y"
unfolding set_of_def by simp
lemma wfP_simulate_simple:
assumes "wfP r"
assumes "⋀x y. r' x y ⟹ r (g x) (g y)"
shows "wfP r'"
apply (rule wf_set_of_implies_wfP)
apply (rule wf_simulate_simple[where g = g])
apply (rule wfP_implies_wf_set_of)
apply (fact assms)
using assms(2) by (auto intro: set_ofI dest: set_ofD)
lemma wf_implies_dom: "wf (set_of R) ⟹ All (Wellfounded.accp R)"
apply (rule allI)
apply (rule accp_wfPD)
apply (rule wf_set_of_implies_wfP) .
lemma wfP_implies_dom: "wfP R ⟹ All (Wellfounded.accp R)"
by (metis wfP_implies_wf_set_of wf_implies_dom)
named_theorems dict_construction_specs
ML_file ‹dict_construction_util.ML›
ML_file ‹transfer_termination.ML›
ML_file ‹congruences.ML›
ML_file ‹side_conditions.ML›
ML_file ‹class_graph.ML›
ML_file ‹dict_construction.ML›
method_setup fo_cong_rule = ‹
Attrib.thm >> (fn thm => fn ctxt => SIMPLE_METHOD' (Dict_Construction_Util.fo_cong_tac ctxt thm))
› "resolve congruence rule using first-order matching"
declare [[code drop: "(∧)"]]
lemma [code]: "True ∧ p ⟷ p" "False ∧ p ⟷ False" by auto
declare [[code drop: "(∨)"]]
lemma [code]: "True ∨ p ⟷ True" "False ∨ p ⟷ p" by auto
declare comp_cong[fundef_cong del]
declare fun.map_cong[fundef_cong]
end
File ‹dict_construction_util.ML›
infixr 5 ==>
infixr ===>
infix 1 CONTINUE_WITH CONTINUE_WITH_FW
signature DICT_CONSTRUCTION_UTIL = sig
val split_list3: ('a * 'b * 'c) list -> 'a list * 'b list * 'c list
val symreltab_of_symtab: 'a Symtab.table Symtab.table -> 'a Symreltab.table
val zip_symtabs: ('a -> 'b -> 'c) -> 'a Symtab.table -> 'b Symtab.table -> 'c Symtab.table
val cat_options: 'a option list -> 'a list
val partition: ('a -> bool) -> 'a list -> 'a list * 'a list
val unappend: 'a list * 'b -> 'c list -> 'c list * 'c list
val flat_right: ('a * 'b list) list -> ('a * 'b) list
val ===> : term list * term -> term
val ==> : term * term -> term
val sortify: sort -> term -> term
val sortify_typ: sort -> typ -> typ
val typify: term -> term
val typify_typ: typ -> typ
val all_frees: term -> (string * typ) list
val all_frees': term -> string list
val all_tfrees: typ -> (string * sort) list
val pretty_const: Proof.context -> string -> Pretty.T
val ANY: tactic list -> tactic
val ANY': ('a -> tactic) list -> 'a -> tactic
val CONTINUE_WITH: (int -> tactic) * (int -> tactic) list -> int -> thm -> thm Seq.seq
val CONTINUE_WITH_FW: (int -> tactic) * (int -> tactic) list -> int -> thm -> thm Seq.seq
val SOLVED: tactic -> tactic
val TRY': ('a -> tactic) -> 'a -> tactic
val descend_fun_conv: conv -> conv
val lhs_conv: conv -> conv
val rhs_conv: conv -> conv
val rewr_lhs_head_conv: thm -> conv
val rewr_rhs_head_conv: thm -> conv
val conv_result: ('a -> thm) -> 'a -> term
val changed_conv: ('a -> thm) -> 'a -> thm
val maybe_induct_tac: thm list option -> term list list -> term list list -> Proof.context -> tactic
val multi_induct_tac: thm list -> term list list -> term list list -> Proof.context -> tactic
val print_tac': Proof.context -> string -> int -> tactic
val fo_cong_tac: Proof.context -> thm -> int -> tactic
val contract: Proof.context -> thm -> thm
val on_thms_complete: (unit -> 'a) -> thm list -> thm list
val define_params_nosyn: term -> local_theory -> thm * local_theory
val note_thm: binding -> thm -> local_theory -> thm * local_theory
val note_thms: binding -> thm list -> local_theory -> thm list * local_theory
val with_timeout: Time.time -> ('a -> 'a) -> 'a -> 'a
val debug: bool Config.T
val if_debug: Proof.context -> (unit -> unit) -> unit
val ALLGOALS': Proof.context -> (int -> tactic) -> tactic
val prove': Proof.context -> string list -> term list -> term ->
({prems: thm list, context: Proof.context} -> tactic) -> thm
val prove_common': Proof.context -> string list -> term list -> term list ->
({prems: thm list, context: Proof.context} -> tactic) -> thm list
end
structure Dict_Construction_Util : DICT_CONSTRUCTION_UTIL = struct
fun symreltab_of_symtab tab =
Symtab.map (K Symtab.dest) tab |>
Symtab.dest |>
maps (fn (k, kvs) => map (apfst (pair k)) kvs) |>
Symreltab.make
fun split_list3 [] = ([], [], [])
| split_list3 ((x, y, z) :: rest) =
let val (xs, ys, zs) = split_list3 rest in
(x :: xs, y :: ys, z :: zs)
end
fun zip_symtabs f t1 t2 =
let
open Symtab
val ord = fast_string_ord
fun aux acc [] [] = acc
| aux acc ((k1, x) :: xs) ((k2, y) :: ys) =
(case ord (k1, k2) of
EQUAL => aux (update_new (k1, f x y) acc) xs ys
| LESS => raise UNDEF k1
| GREATER => raise UNDEF k2)
| aux _ ((k, _) :: _) [] =
raise UNDEF k
| aux _ [] ((k, _) :: _) =
raise UNDEF k
in aux empty (dest t1) (dest t2) end
fun cat_options [] = []
| cat_options (SOME x :: xs) = x :: cat_options xs
| cat_options (NONE :: xs) = cat_options xs
fun partition f xs = (filter f xs, filter_out f xs)
fun unappend (xs, _) = chop (length xs)
fun flat_right [] = []
| flat_right ((x, ys) :: rest) = map (pair x) ys @ flat_right rest
fun x ==> y = Logic.mk_implies (x, y)
val op ===> = Library.foldr op ==>
fun sortify_typ sort (Type (tyco, args)) = Type (tyco, map (sortify_typ sort) args)
| sortify_typ sort (TFree (name, _)) = TFree (name, sort)
| sortify_typ _ (TVar _) = error "TVar encountered"
fun sortify sort (Const (name, typ)) = Const (name, sortify_typ sort typ)
| sortify sort (Free (name, typ)) = Free (name, sortify_typ sort typ)
| sortify sort (t $ u) = sortify sort t $ sortify sort u
| sortify sort (Abs (name, typ, term)) = Abs (name, sortify_typ sort typ, sortify sort term)
| sortify _ (Bound n) = Bound n
| sortify _ (Var _) = error "Var encountered"
val typify_typ = sortify_typ @{sort type}
val typify = sortify @{sort type}
fun all_frees (Free (name, typ)) = [(name, typ)]
| all_frees (t $ u) = union (op =) (all_frees t) (all_frees u)
| all_frees (Abs (_, _, t)) = all_frees t
| all_frees _ = []
val all_frees' = map fst o all_frees
fun all_tfrees (TFree (name, sort)) = [(name, sort)]
| all_tfrees (Type (_, ts)) = fold (union (op =)) (map all_tfrees ts) []
| all_tfrees _ = []
fun pretty_const ctxt const =
Syntax.pretty_term ctxt (Const (const, Sign.the_const_type (Proof_Context.theory_of ctxt) const))
fun ANY tacs = fold (curry op APPEND) tacs no_tac
fun ANY' tacs n = fold (curry op APPEND) (map (fn t => t n) tacs) no_tac
fun TRY' tac n = TRY (tac n)
fun descend_fun_conv cv =
cv else_conv (fn ct =>
case Thm.term_of ct of
_ $ _ => Conv.fun_conv (descend_fun_conv cv) ct
| _ => Conv.no_conv ct)
fun lhs_conv cv =
cv |> Conv.arg1_conv |> Conv.arg_conv
fun rhs_conv cv =
cv |> Conv.arg_conv |> Conv.arg_conv
fun rewr_lhs_head_conv thm =
safe_mk_meta_eq thm |> Conv.rewr_conv |> descend_fun_conv |> lhs_conv
fun rewr_rhs_head_conv thm =
safe_mk_meta_eq thm |> Conv.rewr_conv |> descend_fun_conv |> rhs_conv
fun conv_result cv ct =
Thm.prop_of (cv ct) |> Logic.dest_equals |> snd
fun changed_conv cv = fn ct =>
let
val res = cv ct
val (lhs, rhs) = Thm.prop_of res |> Logic.dest_equals
in
if lhs aconv rhs then
raise CTERM ("no conversion", [])
else
res
end
fun multi_induct_tac rules insts arbitrary ctxt =
let
val insts' = map (map (SOME o pair NONE o rpair false)) insts
val arbitrary' = map (map dest_Free) arbitrary
in
DETERM (Induct.induct_tac ctxt false insts' arbitrary' [] (SOME rules) [] 1)
end
fun maybe_induct_tac (SOME rules) insts arbitrary = multi_induct_tac rules insts arbitrary
| maybe_induct_tac NONE _ _ = K all_tac
fun (tac CONTINUE_WITH tacs) i st =
st |> (tac i THEN (fn st' =>
let
val n' = Thm.nprems_of st'
val n = Thm.nprems_of st
fun aux [] _ = all_tac
| aux (tac :: tacs) i = tac i THEN aux tacs (i - 1)
in
if n' - n + 1 <> length tacs then
raise THM ("CONTINUE_WITH: unexpected number of emerging subgoals", 0, [st'])
else
aux (rev tacs) (i + n' - n) st'
end))
fun (tac CONTINUE_WITH_FW tacs) i st =
st |> (tac i THEN (fn st' =>
let
val n' = Thm.nprems_of st'
val n = Thm.nprems_of st
fun aux [] _ st = all_tac st
| aux (tac :: tacs) i st = st |>
(tac i THEN (fn st' =>
aux tacs (i + 1 + Thm.nprems_of st' - Thm.nprems_of st) st'))
in
if n' - n + 1 <> length tacs then
raise THM ("unexpected number of emerging subgoals", 0, [st'])
else
aux tacs i st'
end))
fun SOLVED tac = tac THEN ALLGOALS (K no_tac)
fun print_tac' ctxt str = SELECT_GOAL (print_tac ctxt str)
fun fo_cong_tac ctxt thm = SUBGOAL (fn (concl, i) =>
let
val lhs_of = HOLogic.dest_Trueprop #> HOLogic.dest_eq #> fst
val concl_pat = lhs_of (Thm.concl_of thm) |> Thm.cterm_of ctxt
val concl = lhs_of concl |> Thm.cterm_of ctxt
val insts = Thm.first_order_match (concl_pat, concl)
in
resolve_tac ctxt [Drule.instantiate_normalize insts thm] i
end handle Pattern.MATCH => no_tac)
fun contract ctxt thm =
let
val (((_, frees), [thm']), ctxt') = Variable.import true [thm] ctxt
val prop = Thm.prop_of thm'
val prems = Logic.strip_imp_prems prop
val (lhs, rhs) =
Logic.strip_imp_concl prop
|> HOLogic.dest_Trueprop
|> HOLogic.dest_eq
fun used x =
exists (exists_subterm (fn t => t = x)) prems
val (f, xs) = strip_comb lhs
val (g, ys) = strip_comb rhs
fun loop [] ys = (0, (f, list_comb (g, rev ys)))
| loop xs [] = (0, (list_comb (f, rev xs), g))
| loop (x :: xs) (y :: ys) =
if x = y andalso is_Free x andalso not (used x) then
loop xs ys |> apfst (fn x => x + 1)
else
(0, (list_comb (f, rev (x :: xs)), list_comb (g, rev (y :: ys))))
val (count, (lhs', rhs')) = loop (rev xs) (rev ys)
val concl' = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs', rhs'))
fun tac ctxt 0 = resolve_tac ctxt [thm] THEN_ALL_NEW (Method.assm_tac ctxt)
| tac ctxt n = resolve_tac ctxt @{thms ext} THEN' tac ctxt (n - 1)
val prop = prems ===> concl'
in
Goal.prove_future ctxt' [] [] prop (fn {context, ...} => HEADGOAL (tac context count))
|> singleton (Variable.export ctxt' ctxt)
end
fun on_thms_complete f thms =
(Future.fork (fn () => (Thm.consolidate thms; f ())); thms)
fun define_params_nosyn term =
Specification.definition NONE [] [] ((Binding.empty, []), term)
#>> snd #>> snd
fun note_thm binding thm =
Local_Theory.note ((binding, []), [thm]) #>> snd #>> the_single
fun note_thms binding thms =
Local_Theory.note ((binding, []), thms) #>> snd
fun with_timeout time f x =
case Exn.interruptible_capture (Timeout.apply time f) x of
Exn.Res y => y
| Exn.Exn (Timeout.TIMEOUT _) => x
| Exn.Exn e => Exn.reraise e
val debug = Attrib.setup_config_bool @{binding "dict_construction_debug"} (K false)
fun if_debug ctxt f =
if Config.get ctxt debug then f () else ()
fun ALLGOALS' ctxt = if Config.get ctxt debug then ALLGOALS else PARALLEL_ALLGOALS
fun prove' ctxt = if Config.get ctxt debug then Goal.prove ctxt else Goal.prove_future ctxt
fun prove_common' ctxt = Goal.prove_common ctxt (if Config.get ctxt debug then NONE else SOME ~1)
end
File ‹transfer_termination.ML›
signature TRANSFER_TERMINATION = sig
val termination_tac: Function.info -> Function.info -> Proof.context -> int -> tactic
end
structure Transfer_Termination : TRANSFER_TERMINATION = struct
open Dict_Construction_Util
fun termination_tac ({R = new_R, ...}: Function.info) (old_info: Function.info) ctxt =
let
fun fallback_tac warn _ =
(if warn then warning "Falling back to another termination proof" else ();
Seq.empty)
fun expand_term_const defs =
let
val eqs = map ((fn ((x, U), u) => (x, (U, u))) o apfst dest_Const) defs;
val get = fn Const (x, _) => AList.lookup (op =) eqs x | _ => NONE;
in Envir.expand_term get end
fun map_comp_bnf typ =
let
val lthy =
Proof_Context.theory_of ctxt
|> Named_Target.theory_init
|> Config.put BNF_Comp.typedef_threshold ~1
val live_As = all_tfrees typ
fun flatten_tyargs Ass =
live_As
|> filter (fn T => exists (fn Ts => member (op =) Ts T) Ass)
val ((bnf, _), ((_, {map_unfolds, ...}), _)) =
BNF_Comp.bnf_of_typ false BNF_Def.Do_Inline I flatten_tyargs live_As [] typ
((BNF_Comp.empty_comp_cache, BNF_Comp.empty_unfolds), lthy)
val subst = map (Logic.dest_equals o Thm.prop_of) map_unfolds
val t = BNF_Def.map_of_bnf bnf
in
(live_As, expand_term_const subst t)
end
val tac = case old_info of
{R = old_R, totality = SOME totality, ...} =>
let
fun get_eq R =
Inductive.the_inductive ctxt R
|> snd |> #eqs
|> the_single
|> Local_Defs.abs_def_rule ctxt
val (old_R_eq, new_R_eq) = apply2 get_eq (old_R, new_R)
fun get_typ R =
fastype_of R
|> strip_type
|> fst |> hd
|> Type.legacy_freeze_type
val (old_R_typ, new_R_typ) = apply2 get_typ (old_R, new_R)
val simple_tac =
let
val totality' = Local_Defs.unfold ctxt [old_R_eq] totality
in
Local_Defs.unfold_tac ctxt [new_R_eq] THEN
HEADGOAL (SOLVED' (resolve_tac ctxt [totality']))
end
val smart_tac = Exn.interruptible_capture (fn st =>
let
val old_R_stripped =
Thm.prop_of old_R_eq
|> Logic.dest_equals |> snd
|> map_types (K dummyT)
|> Syntax.check_term ctxt
val futile =
old_R_stripped |> exists_type (exists_subtype
(fn TFree (_, sort) => sort <> @{sort type}
| TVar (_, sort) => sort <> @{sort type}
| _ => false))
fun costrip_prodT new_t old_t =
if Type.could_match (old_t, new_t) then
(0, new_t)
else
case costrip_prodT (snd (HOLogic.dest_prodT new_t)) old_t of
(n, stripped_t) => (n + 1, stripped_t)
fun construct_inner_proj new_t old_t =
let
val (diff, stripped_t) = costrip_prodT new_t old_t
val (tfrees, f_head) = map_comp_bnf stripped_t
val f_args = map (K (Abs ("x", dummyT, Const (@{const_name undefined}, dummyT)))) tfrees
fun add_snd 0 = list_comb (map_types (K dummyT) f_head, f_args)
| add_snd n = Const (@{const_name comp}, dummyT) $ add_snd (n - 1) $ Const (@{const_name snd}, dummyT)
in
add_snd diff
end
fun construct_outer_proj new_t old_t = case (new_t, old_t) of
(Type (@{type_name sum}, new_ts), Type (@{type_name sum}, old_ts)) =>
let
val ps = map2 construct_outer_proj new_ts old_ts
in list_comb (Const (@{const_name map_sum}, dummyT), ps) end
| _ => construct_inner_proj new_t old_t
val outer_proj = construct_outer_proj new_R_typ old_R_typ
val old_R_typ_imported =
yield_singleton Variable.importT_terms old_R ctxt
|> fst |> get_typ
val c =
outer_proj
|> Type.constraint (new_R_typ --> old_R_typ_imported)
|> Syntax.check_term ctxt
|> Thm.cterm_of ctxt
val wf_simulate =
Drule.infer_instantiate ctxt [(("g", 0), c)] @{thm wf_simulate_simple}
val old_wf = (@{thm wfP_implies_wf_set_of} OF [@{thm accp_wfPI} OF [totality]])
val inner_tac =
match_tac ctxt @{thms wf_implies_dom} THEN'
match_tac ctxt [wf_simulate] CONTINUE_WITH_FW
[resolve_tac ctxt [old_wf],
match_tac ctxt @{thms set_ofI} THEN'
dmatch_tac ctxt @{thms set_ofD} THEN'
SELECT_GOAL (Local_Defs.unfold_tac ctxt [old_R_eq, new_R_eq]) THEN'
TRY'
(REPEAT_ALL_NEW (ematch_tac ctxt @{thms conjE exE}) THEN'
hyp_subst_tac_thin true ctxt THEN'
REPEAT_ALL_NEW (match_tac ctxt @{thms conjI exI}))]
val unfold_tac =
Local_Defs.unfold_tac ctxt @{thms comp_apply id_apply prod.sel} THEN
auto_tac ctxt
val tac = SOLVED (HEADGOAL inner_tac THEN unfold_tac)
in
if futile then
(warning "Termination relation has sort constraints; termination proof is unlikely to be automatic or may even be impossible";
Seq.empty)
else
(tracing "Trying to re-use termination proof";
tac st)
end)
#> Exn.get_res
#> the_default Seq.empty
in
simple_tac ORELSE smart_tac ORELSE fallback_tac true
end
| _ => fallback_tac false
in SELECT_GOAL tac end
end
File ‹congruences.ML›
signature CONGRUENCES = sig
type rule =
{rule: thm,
concl: term,
prems: term list,
proper: bool}
type ctx = (string * typ) list * term list
datatype ctx_tree =
Tree of (term * (rule * (ctx * ctx_tree) list) option)
val export_term_ctx: ctx -> term -> term
val import_rule: Proof.context -> thm -> rule
val import_term: Proof.context -> rule list -> term -> ctx_tree
val fold_tree:
(term -> 'a) ->
(term -> rule -> (ctx * 'a) list -> 'a) ->
ctx_tree -> 'a
end
structure Congruences: CONGRUENCES = struct
type rule =
{rule: thm,
concl: term,
prems: term list,
proper: bool}
type ctx = (string * typ) list * term list
fun export_term_ctx (fixes, assumes) =
fold_rev (curry Logic.mk_implies) assumes
#> fold_rev (Logic.all o Free) fixes
datatype ctx_tree =
Tree of (term * (rule * (ctx * ctx_tree) list) option)
fun fold_tree atom cong t =
let
fun go (Tree (t, NONE)) = atom t
| go (Tree (t, SOME (r, ctxs))) = cong t r (map (apsnd go) ctxs)
in go t end
fun raw_import_rule {check: bool, proper: bool} ctxt thm =
let
val concl = Thm.concl_of thm
val (lhs, rhs) = Logic.dest_equals concl
val prems = Thm.prems_of thm
val rule = {rule = thm, concl = concl, prems = prems, proper = proper}
in
if check then
let
val ((f_lhs, _), _) = strip_comb lhs |>> dest_Const
val ((r_lhs, _), _) = strip_comb rhs |>> dest_Const
in
if f_lhs <> r_lhs then
error ("invalid cong rule " ^ Syntax.string_of_term ctxt (Thm.prop_of thm))
else
rule
end
else
rule
end
val import_rule = raw_import_rule {check = true, proper = true}
fun mk_cong n = if n <= 1 then @{thm cong} else mk_cong (n - 1) OF [@{thm cong}]
fun cong_rule n =
raw_import_rule {check = false, proper = false} @{context} (mk_cong n RS @{thm eq_reflection})
val ext_rule =
raw_import_rule {check = false, proper = false} @{context} (@{thm ext} RS @{thm eq_reflection})
fun import_term ctxt rules t =
let
val thy = Proof_Context.theory_of ctxt
val lhs_of = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop
fun go ctxt t =
let
val arity = length (snd (strip_comb t))
val rules = rules @ [cong_rule arity, ext_rule]
fun mk_branch subst t =
let
val ((params, impl), ctxt') = Variable.focus NONE t ctxt
val (assms, concl) = Logic.strip_horn impl
val assms = map (Envir.subst_term subst) assms
in
((map #2 params, assms), (ctxt', lhs_of concl))
end
fun match concl =
try (Pattern.match thy (concl, Logic.mk_equals (t, t))) (Vartab.empty, Vartab.empty)
fun apply_rule (rule as {concl, prems, ...}) =
case match concl of
NONE => NONE
| SOME subst =>
SOME (rule, map (mk_branch subst o Envir.beta_norm o Envir.subst_term subst) prems)
val is_atom = is_Const t orelse is_Free t
in
if is_atom then
Tree (t, NONE)
else
case get_first apply_rule rules of
NONE => error "this shouldn't happen"
| SOME (rule, branches) => Tree (t, SOME (rule, map (apsnd (uncurry go)) branches))
end
in
go ctxt t
end
end
File ‹side_conditions.ML›
signature SIDE_CONDITIONS = sig
type predicate =
{f: term,
index: int,
inductive: Inductive.result,
alt: thm option}
val transform_predicate: morphism -> predicate -> predicate
val get_predicate: Proof.context -> term -> predicate option
val set_alt: term -> thm -> Context.generic -> Context.generic
val is_total: Proof.context -> term -> bool
val mk_side: thm list -> thm list option -> local_theory -> predicate list * local_theory
val time_limit: real Config.T
end
structure Side_Conditions : SIDE_CONDITIONS = struct
open Dict_Construction_Util
val time_limit = Attrib.setup_config_real @{binding side_conditions_time_limit} (K 5.0)
val inductive_config =
{quiet_mode = true, verbose = true, alt_name = Binding.empty, coind = false,
no_elim = false, no_ind = false, skip_mono = false}
type predicate =
{f: term,
index: int,
inductive: Inductive.result,
alt: thm option}
fun transform_predicate phi {f, index, inductive, alt} =
{f = Morphism.term phi f,
index = index,
inductive = Inductive.transform_result phi inductive,
alt = Option.map (Morphism.thm phi) alt}
structure Predicates = Generic_Data
(
type T = predicate Item_Net.T
val empty = Item_Net.init (op aconv o apply2 #f) (single o #f)
val merge = Item_Net.merge
val extend = I
)
fun get_predicate ctxt t =
Item_Net.retrieve (Predicates.get (Context.Proof ctxt)) t
|> try hd
|> Option.map (transform_predicate (Morphism.transfer_morphism (Proof_Context.theory_of ctxt)))
fun is_total ctxt t =
let
val SOME {alt = SOME alt, ...} = get_predicate ctxt t
val (_, rhs) = Logic.dest_equals (Thm.prop_of alt)
in rhs = @{term True} end
fun set_alt t thm context =
let
val thm = safe_mk_meta_eq thm
val (lhs, _) = Logic.dest_equals (Thm.prop_of thm)
val {f, index, inductive, ...} = hd (Item_Net.retrieve (Predicates.get context) t)
val pred = nth (#preds inductive) index
val (arg_typs, _) = strip_type (fastype_of pred)
val args =
Name.invent_names (Variable.names_of (Context.proof_of context)) "x" arg_typs
|> map Free
val new_pred = {f = f, index = index, inductive = inductive, alt = SOME thm}
in
if Pattern.matches (Context.theory_of context) (lhs, list_comb (pred, args)) then
Predicates.map (Item_Net.update new_pred) context
else
error "Alternative is not fully general"
end
fun apply_simps ctxt clear thms t =
let
val ctxt' =
Context_Position.not_really ctxt
|> clear ? put_simpset HOL_ss
in conv_result (Simplifier.asm_full_rewrite (ctxt' addsimps thms)) t end
fun apply_alts ctxt =
Item_Net.content (Predicates.get (Context.Proof ctxt))
|> map #alt
|> cat_options
|> apply_simps ctxt true
fun apply_intros ctxt =
Item_Net.content (Predicates.get (Context.Proof ctxt))
|> map #inductive
|> maps #intrs
|> apply_simps ctxt false
fun dest_head (Free (name, typ)) = (name, typ)
| dest_head (Const (name, typ)) = (Long_Name.base_name name, typ)
val sideN = "_side"
fun mk_side simps inducts lthy =
let
val thy = Proof_Context.theory_of lthy
val ((_, simps), names) =
Variable.import true simps lthy
||> Variable.names_of
val (lhss, rhss) =
map (HOLogic.dest_eq o HOLogic.dest_Trueprop o Thm.prop_of) simps
|> split_list
val heads = map (`dest_head o (fst o strip_comb)) lhss
fun mk_typ t = binder_types t ---> @{typ bool}
val sides = map (apfst (suffix sideN) o apsnd mk_typ o fst) heads
fun mk_pred_app pred (f, xs) =
let
val pred_typs = binder_types (fastype_of pred)
val exp_param_count = length pred_typs
val f_typs = take exp_param_count (binder_types (fastype_of f))
val pred' =
Envir.subst_term_types (fold (Sign.typ_match thy) (pred_typs ~~ f_typs) Vartab.empty) pred
val diff = exp_param_count - length xs
in
if diff > 0 then
let
val bounds = map Bound (0 upto diff - 1)
val alls = map (K ("x", dummyT)) (0 upto diff - 1)
val prop = Logic.list_all (alls, HOLogic.mk_Trueprop (list_comb (pred', xs @ bounds)))
in
prop
end
else
HOLogic.mk_Trueprop (list_comb (pred', take exp_param_count xs))
end
fun mk_cond f xs =
if is_Abs f then
NONE
else
case get_predicate lthy f of
NONE =>
(case find_index (equal f o snd) heads of
~1 => NONE
| index => SOME (mk_pred_app (Free (nth sides index)) (f, xs)))
| SOME {index, inductive = {preds, ...}, ...} =>
SOME (mk_pred_app (nth preds index) (f, xs))
fun mk_atom f =
the_list (mk_cond f [])
fun mk_cong t _ cs =
let
val cs' = maps (fn (ctx, ts) => map (Congruences.export_term_ctx ctx) ts) (tl cs)
val (f, xs) = strip_comb t
val cs = mk_cond f xs
in
the_list cs @ cs'
end
val rules = map (Congruences.import_rule lthy) (Function.get_congs lthy)
val premss =
map (Congruences.import_term lthy rules) rhss
|> map (Congruences.fold_tree mk_atom mk_cong)
val concls =
map Free sides ~~ map (snd o strip_comb) lhss
|> map (HOLogic.mk_Trueprop o list_comb)
val time = Time.fromReal (Config.get lthy time_limit)
val intros =
map Logic.list_implies (premss ~~ concls)
|> Syntax.check_terms lthy
|> map (apply_alts lthy o Thm.cterm_of lthy)
|> Par_List.map (with_timeout time (apply_intros lthy o Thm.cterm_of lthy))
val inds = map (rpair NoSyn o apfst Binding.name) (distinct op = sides)
val (result, lthy') =
Inductive.add_inductive inductive_config inds []
(map (pair (Binding.empty, [])) intros) [] lthy
fun mk_impartial_goal pred names =
let
val param_typs = binder_types (fastype_of pred)
val (args, names) = fold_map (fn typ => apfst (Free o rpair typ) o Name.variant "x") param_typs names
val goal = HOLogic.mk_Trueprop (list_comb (pred, args))
in ((goal, args), names) end
val ((props, instss), _) =
fold_map mk_impartial_goal (#preds result) names
|>> split_list
val frees = flat instss |> map (fst o dest_Free)
fun tactic {context = ctxt, ...} =
let
val simp_context =
put_simpset HOL_ss (Context_Position.not_really ctxt) addsimps (#intrs result)
in
maybe_induct_tac inducts instss [] ctxt THEN
PARALLEL_ALLGOALS (Nitpick_Util.DETERM_TIMEOUT time o asm_full_simp_tac simp_context)
end
val alts =
try (Goal.prove_common lthy' NONE frees [] props) tactic
|> Option.map (map (mk_eq o Thm.close_derivation ⌂))
val _ =
if is_none alts then
Pretty.str "Potentially underspecified function(s): " ::
Pretty.commas (map (Syntax.pretty_term lthy o snd) (distinct op = heads))
|> Pretty.block
|> Pretty.string_of
|> warning
else
()
fun mk_pred n t =
{f = t, index = n, inductive = result,
alt = Option.map (fn alts => nth alts n) alts}
val preds = map_index (fn (n, (_, t)) => mk_pred n t) (distinct op = heads)
val lthy'' =
Local_Theory.declaration {pervasive = false, syntax = false}
(fn phi => fold (Predicates.map o Item_Net.update o transform_predicate phi) preds) lthy'
in
(preds, lthy'')
end
end
File ‹class_graph.ML›
signature CLASS_GRAPH =
sig
type selector = typ -> term
type node =
{class: string,
qname: string,
selectors: selector Symtab.table,
make: typ -> term,
data_thms: thm list,
cert: typ -> term,
cert_thms: thm * thm * thm list}
val dict_typ: node -> typ -> typ
type edge =
{super_selector: selector,
subclass: thm}
type path = edge list
type ev
val class_of: ev -> class
val node_of: ev -> node
val parents_of: ev -> (edge * ev) Symtab.table
val find_path': ev -> (ev -> 'a option) -> (path * 'a) option
val find_path: ev -> class -> path option
val fold_path: path -> typ -> term -> term
val ensure_class: class -> local_theory -> (ev * local_theory)
val edges: local_theory -> class -> edge Symtab.table option
val node: local_theory -> class -> node option
val all_edges: local_theory -> edge Symreltab.table
val all_nodes: local_theory -> node Symtab.table
val pretty_ev: Proof.context -> ev -> Pretty.T
val mangle: string -> string
val param_sorts: string -> class -> theory -> class list list
val super_classes: class -> theory -> string list
end
structure Class_Graph: CLASS_GRAPH =
struct
open Dict_Construction_Util
val mangle =
translate_string (fn x =>
if x = "." then
"_"
else if x = "_" then
"__"
else
x)
fun param_sorts tyco class thy =
let val algebra = Sign.classes_of thy in
Sorts.mg_domain algebra tyco [class] |> map (filter (Class.is_class thy))
end
fun super_classes class thy =
let val algebra = Sign.classes_of thy in
Sorts.super_classes algebra class |>
Sorts.minimize_sort algebra |>
filter (Class.is_class thy) |>
sort fast_string_ord
end
type selector = typ -> term
type node =
{class: string,
qname: string,
selectors: selector Symtab.table,
make: typ -> term,
data_thms: thm list,
cert: typ -> term,
cert_thms: thm * thm * thm list}
type edge =
{super_selector: selector,
subclass: thm}
type path = edge list
abstype ev = Evidence of class * node * (edge * ev) Symtab.table
with
fun class_of (Evidence (class, _, _)) = class
fun node_of (Evidence (_, node, _)) = node
fun parents_of (Evidence (_, _, tab)) = tab
fun mk_evidence class node tab = Evidence (class, node, tab)
fun find_path' ev is_goal =
case is_goal ev of
SOME a =>
SOME ([], a)
| NONE =>
let
fun f (_, (edge, ev)) = Option.map (apfst (cons edge)) (find_path' ev is_goal)
in Symtab.get_first f (parents_of ev) end
fun find_path ev goal =
find_path' ev (fn ev => if class_of ev = goal then SOME () else NONE) |> Option.map fst
fun pretty_ev ctxt (Evidence (class, {qname, ...}, tab)) =
let
val typ = @{typ 'a}
fun mk_super ({super_selector, ...}, super_ev) = Pretty.block
[Pretty.str "selector:",
Pretty.brk 1,
Syntax.pretty_term ctxt (super_selector typ),
Pretty.fbrk,
pretty_ev ctxt super_ev]
val supers = Symtab.dest tab
|> map (fn (_, super) => mk_super super)
|> Pretty.big_list "super classes"
in
Pretty.block
[Pretty.str "Evidence for ",
Syntax.pretty_sort ctxt [class],
Pretty.str ": ",
Syntax.pretty_typ ctxt (Type (qname, [typ])),
Pretty.str (" (qname = " ^ qname ^ ")"),
Pretty.fbrk,
supers]
end
end
structure Classes = Generic_Data
(
type T = (edge Symtab.table * node) Symtab.table
val empty = Symtab.empty
fun merge (t1, t2) =
if Symtab.is_empty t1 andalso Symtab.is_empty t2 then
Symtab.empty
else
error "merging not supported"
val extend = I
)
fun node lthy class =
Symtab.lookup (Classes.get (Context.Proof lthy)) class |> Option.map snd
fun edges lthy class =
Symtab.lookup (Classes.get (Context.Proof lthy)) class |> Option.map fst
val all_nodes =
Context.Proof #> Classes.get #> Symtab.map (K snd)
val all_edges =
Context.Proof #> Classes.get #> Symtab.map (K fst) #> symreltab_of_symtab
fun dict_typ {qname, ...} typ =
Type (qname, [typ])
fun fold_path path typ =
fold (fn {super_selector = s, ...} => fn acc => s typ $ acc) path
fun mk_super_selector' qualified qname super_ev typ =
let
val {class = super_class, qname = super_qname, ...} = node_of super_ev
val raw_name = mangle super_class ^ "__super"
val name = if qualified then Long_Name.append qname raw_name else raw_name
in (name, Type (qname, [typ]) --> Type (super_qname, [typ])) end
fun mk_node class info super_evs lthy =
let
fun print_info ctxt =
Pretty.block [Pretty.str "Defining record for class ", Syntax.pretty_sort ctxt [class]]
|> Pretty.writeln
val name = mangle class ^ "__dict"
val qname = Local_Theory.full_name lthy (Binding.name name)
val tvar = @{typ 'a}
val typ = Type (qname, [tvar])
fun mk_field name ftyp = (Binding.name name, ftyp)
val params = #params info
|> map (fn (name', ftyp) =>
let
val ftyp' = typ_subst_atomic [(TFree ("'a", [class]), @{typ 'a})] ftyp
val field_name = mangle name' ^ "__field"
val field = mk_field field_name ftyp'
fun sel tvar' =
Const (Long_Name.append qname field_name,
typ_subst_atomic [(tvar, tvar')] (typ --> ftyp'))
in (field, (name', sel)) end)
val (fields, selectors) = split_list params
val super_params = Symtab.dest super_evs |>
map (fn (_, super_ev) =>
let
val {cert = raw_super_cert, qname = super_qname, ...} = node_of super_ev
val (field_name, _) = mk_super_selector' false qname super_ev tvar
val field = mk_field field_name (Type (super_qname, [tvar]))
fun sel typ = Const (mk_super_selector' true qname super_ev typ)
fun super_cert dict = raw_super_cert tvar $ (sel tvar $ dict)
val raw_edge = (class_of super_ev, sel)
in (field, raw_edge, super_cert) end)
val (super_fields, raw_edges, super_certs) = split_list3 super_params
val all_fields = super_fields @ fields
fun make typ' =
Const (Long_Name.append qname "Dict",
typ_subst_atomic [(tvar, typ')] (map #2 all_fields ---> typ))
val cert_name = name ^ "__cert"
val cert_binding = Binding.name cert_name
val cert_body =
let
fun local_param_eq ((_, typ), (name, sel)) dict =
HOLogic.mk_eq (sel tvar $ dict, Const (name, typ))
in
map local_param_eq params @ super_certs
end
val cert_var_name = "dict"
val cert_term =
Abs (cert_var_name, typ,
List.foldr HOLogic.mk_conj @{term True} (map (fn x => x (Bound 0)) cert_body))
fun prove_thms (cert, cert_def) lthy =
let
val var = Free (cert_var_name, typ)
fun tac ctxt = Local_Defs.unfold_tac ctxt [cert_def] THEN blast_tac ctxt 1
fun prove prop =
Goal.prove_future lthy [cert_var_name] [] prop (fn {context, ...} => tac context)
fun mk_dest_props raw_prop =
HOLogic.mk_Trueprop (cert $ var) ==> HOLogic.mk_Trueprop (raw_prop var)
fun mk_intro_cond raw_prop =
HOLogic.mk_Trueprop (raw_prop var)
val dests =
map (fn raw_prop => prove (mk_dest_props raw_prop)) cert_body
val intro =
prove (map mk_intro_cond cert_body ===> HOLogic.mk_Trueprop (cert $ var))
val (dests', (intro', lthy')) =
note_thms Binding.empty dests lthy ||> note_thm Binding.empty intro
val (param_dests, super_dests) = chop (length params) dests'
fun pre_edges phi =
let
fun mk_edge thm (sc, sel) =
(sc, {super_selector = sel, subclass = Morphism.thm phi thm})
in Symtab.make (map2 mk_edge super_dests raw_edges) end
in
((param_dests, pre_edges, intro'), lthy')
end
val constructor =
(((Binding.empty, Binding.name "Dict"), all_fields), NoSyn)
val datatyp =
(([(NONE, (@{typ 'a}, @{sort type}))], Binding.name name), NoSyn)
val dtspec =
(Ctr_Sugar.default_ctr_options,
[(((datatyp, [constructor]), (Binding.empty, Binding.empty, Binding.empty)), [])])
val (((raw_cert, raw_cert_def), (param_dests, pre_edges, intro)), (lthy', lthy)) = lthy
|> tap print_info
|> BNF_FP_Def_Sugar.co_datatypes BNF_Util.Least_FP BNF_LFP.construct_lfp dtspec
|> (snd o Local_Theory.begin_nested)
|> Local_Theory.define ((cert_binding, NoSyn), ((Thm.def_binding cert_binding, []), cert_term))
|>> apsnd snd
|> (fn (raw_cert, lthy) => prove_thms raw_cert lthy |>> pair raw_cert)
||> `Local_Theory.end_nested
val phi = Proof_Context.export_morphism lthy lthy'
fun cert typ = subst_TVars [(("'a", 0), typ)] (Morphism.term phi raw_cert)
val cert_def = Morphism.thm phi raw_cert_def
val edges = pre_edges phi
val param_dests' = map (Morphism.thm phi) param_dests
val intro' = Morphism.thm phi intro
val data_thms =
BNF_FP_Def_Sugar.fp_sugar_of lthy' qname
|> the |> #fp_ctr_sugar |> #ctr_sugar |> #sel_thmss |> flat
|> map safe_mk_meta_eq
val node =
{class = class,
qname = qname,
selectors = Symtab.make selectors,
make = make,
data_thms = data_thms,
cert = cert,
cert_thms = (cert_def, intro', param_dests')}
in (node, edges, lthy') end
fun ensure_class class lthy =
if not (Class.is_class (Proof_Context.theory_of lthy) class) then
error ("not a proper class: " ^ class)
else
let
val thy = Proof_Context.theory_of lthy
val super_classes = super_classes class thy
fun collect_super mk_node =
let
val (super_evs, lthy') = fold_map ensure_class super_classes lthy
val raw_tab = Symtab.make (super_classes ~~ super_evs)
val (node, edges, lthy'') = mk_node raw_tab lthy'
val tab = zip_symtabs pair edges raw_tab
val ev = mk_evidence class node tab
in (ev, edges, lthy'') end
in
case Symtab.lookup (Classes.get (Context.Proof lthy)) class of
SOME (edge_tab, node) =>
if super_classes = Symtab.keys edge_tab then
let val (ev, _, lthy') = collect_super (fn _ => fn lthy => (node, edge_tab, lthy)) in
(ev, lthy')
end
else
error "class with different super classes"
| NONE =>
let
val ax_info = Axclass.get_info thy class
val (ev, edges, lthy') = collect_super (mk_node class ax_info)
val upd = Symtab.update_new (class, (edges, node_of ev))
in
(ev, Local_Theory.declaration {pervasive = false, syntax = false} (K (Classes.map upd)) lthy')
end
end
end
File ‹dict_construction.ML›
signature DICT_CONSTRUCTION =
sig
datatype cert_proof = Cert | Skip
type const
type 'a sccs = (string * 'a) list list
val annotate_code_eqs: local_theory -> string list -> (const sccs * local_theory)
val new_names: local_theory -> const sccs -> (string * const) sccs
val symtab_of_sccs: 'a sccs -> 'a Symtab.table
val axclass: class -> local_theory -> Class_Graph.node * local_theory
val instance: (string * const) Symtab.table -> string -> class -> local_theory -> term * local_theory
val term: term Symreltab.table -> (string * const) Symtab.table -> term -> local_theory -> (term * local_theory)
val consts: (string * const) Symtab.table -> cert_proof -> (string * const) list -> local_theory -> local_theory
type const_info =
{fun_info: Function.info option,
inducts: thm list option,
base_thms: thm list,
base_certs: thm list,
simps: thm list,
code_thms: thm list,
congs: thm list option}
type fun_target = (string * class) list * (term * term)
type dict_thms =
{base_thms: thm list,
def_thm: thm}
type dict_target = (string * class) list * (term * string * class)
val prove_fun_cert: fun_target list -> const_info -> cert_proof -> local_theory -> thm list
val prove_dict_cert: dict_target -> dict_thms -> local_theory -> thm
val the_info: Proof.context -> string -> const_info
val normalizer_conv: Proof.context -> conv
val cong_of_const: Proof.context -> string -> thm option
val get_code_eqs: Proof.context -> string -> thm list
val group_code_eqs: Proof.context -> string list ->
(string * (((string * sort) list * typ) * ((term list * term) * thm option) list)) list list
end
structure Dict_Construction: DICT_CONSTRUCTION =
struct
open Class_Graph
open Dict_Construction_Util
val (_, make_thm_cterm) =
Context.>>>
(Context.map_theory_result (Thm.add_oracle (@{binding cert_oracle}, I)))
fun make_thm ctxt prop = make_thm_cterm (Thm.cterm_of ctxt prop)
fun cheat_tac ctxt i st =
resolve_tac ctxt [make_thm ctxt (Var (("A", 0), propT))] i st
val normalizer_conv = Axclass.overload_conv
fun cong_of_const ctxt name =
let
val head =
Thm.concl_of
#> Logic.dest_equals #> fst
#> strip_comb #> fst
#> dest_Const #> fst
fun applicable thm =
try head thm = SOME name
in
Function_Context_Tree.get_function_congs ctxt
|> filter applicable
|> try hd
end
fun group_code_eqs ctxt consts =
let
val thy = Proof_Context.theory_of ctxt
val graph = #eqngr (Code_Preproc.obtain true { ctxt = ctxt, consts = consts, terms = [] })
fun mk_eqs name = name
|> Code_Preproc.cert graph
|> Code.equations_of_cert thy ||> these
||> map (apsnd fst o apfst (apsnd fst o apfst (map fst)))
|> pair name
in
map (map mk_eqs) (rev (Graph.strong_conn graph))
end
fun get_code_eqs ctxt const =
AList.lookup op = (flat (group_code_eqs ctxt [const])) const
|> the |> snd
|> map snd
|> cat_options
|> map (Conv.fconv_rule (normalizer_conv ctxt))
datatype cert_proof = Cert | Skip
type const_info =
{fun_info: Function.info option,
inducts: thm list option,
base_thms: thm list,
base_certs: thm list,
simps: thm list,
code_thms: thm list,
congs: thm list option}
fun map_const_info f1 f2 f3 f4 f5 f6 f7 {fun_info, inducts, base_thms, base_certs, simps, code_thms, congs} =
{fun_info = f1 fun_info,
inducts = f2 inducts,
base_thms = f3 base_thms,
base_certs = f4 base_certs,
simps = f5 simps,
code_thms = f6 code_thms,
congs = f7 congs}
fun morph_const_info phi =
map_const_info
(Option.map (Function_Common.transform_function_data phi))
(Option.map (map (Morphism.thm phi)))
(map (Morphism.thm phi))
(map (Morphism.thm phi))
(map (Morphism.thm phi))
I
(Option.map (map (Morphism.thm phi)))
type fun_target = (string * class) list * (term * term)
type dict_thms =
{base_thms: thm list,
def_thm: thm}
type dict_target = (string * class) list * (term * string * class)
fun fun_cert_tac base_thms base_certs simps code_thms =
SOLVED' o Subgoal.FOCUS (fn {prems, context = ctxt, concl, ...} =>
let
val _ =
if_debug ctxt (fn () =>
tracing ("Proving " ^ Syntax.string_of_term ctxt (Thm.term_of concl)))
fun is_ih prem =
Thm.prop_of prem |> Logic.strip_imp_concl |> HOLogic.dest_Trueprop |> can HOLogic.dest_eq
val (ihs, certs) = partition is_ih prems
val super_certs = all_edges ctxt |> Symreltab.dest |> map (#subclass o snd)
val param_dests = all_nodes ctxt |> Symtab.dest |> maps (#3 o #cert_thms o snd)
val congs = Function_Context_Tree.get_function_congs ctxt @ map safe_mk_meta_eq @{thms cong}
val simp_context = (clear_simpset ctxt) addsimps (certs @ super_certs @ base_certs @ base_thms @ param_dests)
addloop ("overload", CONVERSION o changed_conv o Axclass.overload_conv)
val ihs = map (Simplifier.asm_full_simplify simp_context) ihs
val ih_tac =
resolve_tac ctxt ihs THEN_ALL_NEW
(TRY' (SOLVED' (Simplifier.asm_full_simp_tac simp_context)))
val unfold_new =
ANY' (map (CONVERSION o rewr_lhs_head_conv) simps)
val normalize =
CONVERSION (normalizer_conv ctxt)
val unfold_old =
ANY' (map (CONVERSION o rewr_rhs_head_conv) code_thms)
val simp =
CONVERSION (lhs_conv (Simplifier.asm_full_rewrite simp_context))
fun walk_congs i = i |>
((resolve_tac ctxt @{thms refl} ORELSE'
SOLVED' (Simplifier.asm_full_simp_tac simp_context) ORELSE'
ih_tac ORELSE'
Method.assm_tac ctxt ORELSE'
(resolve_tac ctxt @{thms meta_eq_to_obj_eq} THEN'
fo_resolve_tac congs ctxt)) THEN_ALL_NEW walk_congs)
val tacs =
[unfold_new, normalize, unfold_old, simp, walk_congs]
in
EVERY' tacs 1
end)
fun dict_cert_tac class def_thm base_thms =
SOLVED' o Subgoal.FOCUS (fn {prems, context = ctxt, ...} =>
let
val (intro, sels) = case node ctxt class of
SOME {cert_thms = (_, intro, _), data_thms = sels, ...} => (intro, sels)
| NONE => error ("class " ^ class ^ " is not defined")
val apply_intro =
resolve_tac ctxt [intro]
val unfold_dict =
CONVERSION (Conv.rewr_conv def_thm |> Conv.arg_conv |> lhs_conv)
val normalize =
CONVERSION (normalizer_conv ctxt)
val smash_sels =
CONVERSION (lhs_conv (Conv.rewrs_conv sels))
val solve =
resolve_tac ctxt (@{thm HOL.refl} :: base_thms)
val finally =
resolve_tac ctxt prems
val tacs =
[apply_intro, unfold_dict, normalize, smash_sels, solve, finally]
in
EVERY (map (ALLGOALS' ctxt) tacs)
end)
fun prepare_dicts classes names lthy =
let
val sorts = Symtab.make_list classes
fun mk_dicts (param_name, (tvar, class)) =
case node lthy class of
NONE =>
error ("unknown class " ^ class)
| SOME {cert, qname, ...} =>
let
val sort = the (Symtab.lookup sorts tvar)
val param = Free (param_name, Type (qname, [TFree (tvar, sort)]))
in
(param, HOLogic.mk_Trueprop (cert dummyT $ param))
end
val dict_names = Name.invent_names names "a" classes
val names = fold Name.declare (map fst dict_names) names
val (dict_params, prems) = split_list (map mk_dicts dict_names)
in (dict_params, prems, names) end
fun prepare_fun_goal targets lthy =
let
fun mk_eq (classes, (lhs, rhs)) names =
let
val (lhs_name, _) = dest_Const lhs
val (rhs_name, rhs_typ) = dest_Const rhs
val (dict_params, prems, names) = prepare_dicts classes names lthy
val param_names = fst (strip_type rhs_typ) |> map (K dummyT) |> Name.invent_names names "a"
val names = fold Name.declare (map fst param_names) names
val params = map Free param_names
val lhs = list_comb (Const (lhs_name, dummyT), dict_params @ params)
val rhs = list_comb (Const (rhs_name, dummyT), params)
val eq = Const (@{const_name HOL.eq}, dummyT) $ lhs $ rhs
val all_params = dict_params @ params
val eq :: rest = Syntax.check_terms lthy (eq :: prems @ all_params)
val (prems, all_params) = unappend (prems, all_params) rest
val eq =
if is_some (Axclass.inst_of_param (Proof_Context.theory_of lthy) rhs_name) then
Thm.cterm_of lthy eq |> conv_result (Conv.arg_conv (normalizer_conv lthy))
else
eq
val prop = prems ===> HOLogic.mk_Trueprop eq
in ((all_params, prop), names) end
in
fold_map mk_eq targets Name.context
|> fst
|> split_list
end
fun prepare_dict_goal (classes, (term, _, class)) lthy =
let
val cert = case node lthy class of
NONE =>
error ("unknown class " ^ class)
| SOME {cert, ...} =>
cert dummyT
val names = Name.context
val (dict_params, prems, _) = prepare_dicts classes names lthy
val (term_name, _) = dest_Const term
val dict = list_comb (Const (term_name, dummyT), dict_params)
val prop = prems ===> HOLogic.mk_Trueprop (cert $ dict)
val prop :: dict_params = Syntax.check_terms lthy (prop :: dict_params)
in
(dict_params, prop)
end
fun prove_fun_cert targets {inducts, base_thms, base_certs, simps, code_thms, ...} proof lthy =
let
val (argss, props) = prepare_fun_goal targets lthy
val frees = flat argss |> map (fst o dest_Free)
val tac =
case proof of
Cert => fun_cert_tac base_thms base_certs simps code_thms
| Skip => cheat_tac
val long_thms =
prove_common' lthy frees [] props (fn {context, ...} =>
maybe_induct_tac inducts argss [] context THEN
ALLGOALS' context (tac context))
in
map (contract lthy) long_thms
end
fun prove_dict_cert target {base_thms, def_thm} lthy =
let
val (args, prop) = prepare_dict_goal target lthy
val frees = map (fst o dest_Free) args
val (_, (_, _, class)) = target
in
prove' lthy frees [] prop (fn {context, ...} =>
dict_cert_tac class def_thm base_thms context 1)
end
type definitions =
{instantiations: (term * thm) Symreltab.table,
constants: (string * (thm option * const_info)) Symtab.table }
structure Definitions = Generic_Data
(
type T = definitions
val empty = {instantiations = Symreltab.empty, constants = Symtab.empty}
val extend = I
fun merge ({instantiations = i1, constants = c1}, {instantiations = i2, constants = c2}) =
if Symreltab.is_empty i1 andalso Symtab.is_empty c1 andalso
Symreltab.is_empty i2 andalso Symtab.is_empty c2 then
empty
else
error "merging not supported"
)
fun map_definitions map_insts map_consts =
Definitions.map (fn {instantiations, constants} =>
{instantiations = map_insts instantiations,
constants = map_consts constants})
fun the_info ctxt name =
Symtab.lookup (#constants (Definitions.get (Context.Proof ctxt))) name
|> the
|> snd
|> snd
fun add_instantiation (class, tyco) term cert =
let
fun upd phi =
map_definitions
(fn tab =>
if Symreltab.defined tab (class, tyco) then
error ("Duplicate instantiation " ^ quote tyco ^ " :: " ^ quote class)
else
tab
|> Symreltab.update ((class, tyco), (Morphism.term phi term, Morphism.thm phi cert))) I
in
Local_Theory.declaration {pervasive = false, syntax = false} upd
end
fun add_constant name name' (cert, info) lthy =
let
val qname = Local_Theory.full_name lthy (Binding.name name')
fun upd phi =
map_definitions I
(fn tab =>
if Symtab.defined tab name then
error ("Duplicate constant " ^ quote name)
else
tab
|> Symtab.update (name,
(qname, (Option.map (Morphism.thm phi) cert, morph_const_info phi info))))
in
Local_Theory.declaration {pervasive = false, syntax = false} upd lthy
|> Local_Theory.note ((Binding.empty, @{attributes [dict_construction_specs]}), #simps info)
|> snd
end
fun axclass class =
ensure_class class
#>> node_of
datatype const =
Fun of
{dicts: ((string * class) * typ) list,
certs: term list,
param_typs: typ list,
typ: typ,
new_typ: typ,
eqs: {params: term list, rhs: term, thm: thm} list,
info: Function_Common.info option,
cong: thm option} |
Constructor |
Classparam of
{class: class,
typ: typ,
selector: term }
type 'a sccs = (string * 'a) list list
fun symtab_of_sccs x = Symtab.make (flat x)
fun raw_dict_params tparams lthy =
let
fun mk_dict tparam class lthy =
let
val (node, lthy') = axclass class lthy
val targ = TFree (tparam, @{sort type})
val typ = dict_typ node targ
val cert = #cert node targ
in ((((tparam, class), typ), cert), lthy') end
fun mk_dicts (tparam, sort) = fold_map
(mk_dict tparam)
(filter (Class.is_class (Proof_Context.theory_of lthy)) sort)
in fold_map mk_dicts tparams lthy |>> flat end
fun dict_params context dicts =
let
fun dict_param ((_, class), typ) =
Name.variant (mangle class) #>> rpair typ #>> Free
in
fold_map dict_param dicts context
end
fun get_sel class param typ lthy =
let
val ({selectors, ...}, lthy') = axclass class lthy
in
case Symtab.lookup selectors param of
NONE => error ("unknown class parameter " ^ param)
| SOME sel => (sel typ, lthy')
end
fun annotate_const name ((tparams, typ), raw_eqs) lthy =
if Code.is_constr (Proof_Context.theory_of lthy) name then
((name, Constructor), lthy)
else if null raw_eqs then
let
val (_, class) = the_single tparams ||> the_single
val (selector, thy') = get_sel class name (TVar (("'a", 0), @{sort type})) lthy
val typ = range_type (fastype_of selector)
in
((name, Classparam {class = class, typ = typ, selector = selector}), thy')
end
else
let
val info = try (Function.get_info lthy) (Const (name, typ))
val cong = cong_of_const lthy name
val ((raw_dicts, certs), lthy') = raw_dict_params tparams lthy |>> split_list
val dict_typs = map snd raw_dicts
val typ' = typify_typ typ
fun mk_eq ((raw_params, rhs), SOME thm) =
let
val norm = normalizer_conv lthy'
val transform = Thm.cterm_of lthy' #> conv_result norm #> typify
val params = map transform raw_params
in
if has_duplicates (op =) (flat (map all_frees' params)) then
(warning "ignoring code equation with non-linear pattern"; NONE)
else
SOME {params = params, rhs = rhs, thm = Conv.fconv_rule norm thm}
end
| mk_eq _ =
error "no theorem"
val const =
Fun
{dicts = raw_dicts, certs = certs, typ = typ', param_typs = binder_types typ',
new_typ = dict_typs ---> typ', eqs = map_filter mk_eq raw_eqs, info = info, cong = cong}
in ((name, const), lthy') end
fun annotate_code_eqs lthy consts =
fold_map (fold_map (uncurry annotate_const)) (group_code_eqs lthy consts) lthy
fun mk_path [] _ _ lthy = (NONE, lthy)
| mk_path ((class, term) :: rest) typ goal lthy =
let
val (ev, lthy') = ensure_class class lthy
in
case find_path ev goal of
SOME path => (SOME (fold_path path typ term), lthy')
| NONE => mk_path rest typ goal lthy'
end
fun instance consts tyco class lthy =
case Symreltab.lookup (#instantiations (Definitions.get (Context.Proof lthy))) (class, tyco) of
SOME (inst, _) =>
(inst, lthy)
| NONE =>
let
val thy = Proof_Context.theory_of lthy
val tparam_sorts = param_sorts tyco class thy
fun print_info ctxt =
let
val tvars =
Name.invent_list [] Name.aT (length tparam_sorts) ~~ tparam_sorts
|> map TFree
in
[Pretty.str "Defining instance ", Syntax.pretty_typ ctxt (Type (tyco, tvars)),
Pretty.str " :: ", Syntax.pretty_sort ctxt [class]]
|> Pretty.block
|> Pretty.writeln
end
val ({make, ...}, lthy) = axclass class lthy
val name = mangle class ^ "__instance__" ^ mangle tyco
val tparams = Name.invent_names Name.context Name.aT tparam_sorts
val ((dict_params, _), lthy) = raw_dict_params tparams lthy
|>> map fst
|>> dict_params (Name.make_context [name])
val dict_context = Symreltab.make (flat_right tparams ~~ dict_params)
val {params, ...} = Axclass.get_info thy class
val (super_fields, lthy) = fold_map
(obtain_dict dict_context consts (Type (tyco, map TFree tparams)))
(super_classes class thy)
lthy
val tparams' = map (TFree o rpair @{sort type} o fst) tparams
val typ_inst = (TFree ("'a", [class]), Type (tyco, tparams'))
fun mk_field (field, typ) =
let
val param = Axclass.param_of_inst thy (field, tyco)
val _ = case Symtab.lookup (#constants (Definitions.get (Context.Proof lthy))) param of
NONE =>
if Code.is_constr thy param then
()
else
error ("cyclic dependency: " ^ param ^ " not yet defined in the definition of " ^ tyco ^ " :: " ^ class)
| SOME _ => ()
in
term dict_context consts (Const (param, typ_subst_atomic [typ_inst] typ))
end
val (fields, lthy) = fold_map mk_field params lthy
val rhs = list_comb (make (Type (tyco, tparams')), super_fields @ fields)
val typ = map fastype_of dict_params ---> fastype_of rhs
val head = Free (name, typ)
val lhs = list_comb (head, dict_params)
val term = Logic.mk_equals (lhs, rhs)
val (def, (lthy', lthy)) = lthy
|> tap print_info
|> (snd o Local_Theory.begin_nested)
|> define_params_nosyn term
||> `Local_Theory.end_nested
val phi = Proof_Context.export_morphism lthy lthy'
val def = Morphism.thm phi def
val base_thms =
Definitions.get (Context.Proof lthy') |> #constants |> Symtab.dest
|> map (apsnd fst o snd)
|> map_filter snd
val target = (flat_right tparams, (Morphism.term phi head, tyco, class))
val args = {base_thms = base_thms, def_thm = def}
val thm = prove_dict_cert target args lthy'
val const = Const (Local_Theory.full_name lthy' (Binding.name name), typ)
in
(const, add_instantiation (class, tyco) const thm lthy')
end
and obtain_dict dict_context consts =
let
val dict_context' = Symreltab.dest dict_context
fun for_class (Type (tyco, args)) class lthy =
let
val inst_param_sorts = param_sorts tyco class (Proof_Context.theory_of lthy)
val (raw_inst, lthy') = instance consts tyco class lthy
val (const_name, _) = dest_Const raw_inst
val (inst_args, lthy'') = fold_map for_sort (inst_param_sorts ~~ args) lthy'
val head = Sign.mk_const (Proof_Context.theory_of lthy'') (const_name, args)
in
(list_comb (head, flat inst_args), lthy'')
end
| for_class (TFree (name, _)) class lthy =
let
val available = map_filter
(fn ((tp, class), term) => if tp = name then SOME (class, term) else NONE)
dict_context'
val (path, lthy') = mk_path available (TFree (name, @{sort type})) class lthy
in
case path of
SOME term => (term, lthy')
| NONE => error "no path found"
end
| for_class (TVar _) _ _ = error "unexpected type variable"
and for_sort (sort, typ) =
fold_map (for_class typ) sort
in for_class end
and term dict_context consts term lthy =
let
fun traverse (t as Const (name, typ)) lthy =
(case Symtab.lookup consts name of
NONE => error ("unknown constant " ^ name)
| SOME (_, Constructor) =>
(typify t, lthy)
| SOME (_, Classparam {class, typ = typ', selector}) =>
let
val subst = Sign.typ_match (Proof_Context.theory_of lthy) (typ', typ) Vartab.empty
val (_, targ) = the (Vartab.lookup subst ("'a", 0))
val (dict, lthy') = obtain_dict dict_context consts targ class lthy
in
(subst_TVars [(("'a", 0), targ)] selector $ dict, lthy')
end
| SOME (name', Fun {dicts = dicts, typ = typ', new_typ, ...}) =>
let
val subst = Type.raw_match (Logic.varifyT_global typ', typ) Vartab.empty
|> Vartab.dest |> map (apsnd snd)
fun lookup tparam = the (AList.lookup (op =) subst (tparam, 0))
val (dicts, lthy') =
fold_map (uncurry (obtain_dict dict_context consts o lookup)) (map fst dicts) lthy
val typ = typ_subst_TVars subst (Logic.varifyT_global new_typ)
val head =
case Symtab.lookup (#constants (Definitions.get (Context.Proof lthy))) name of
NONE => Free (name', typ)
| SOME (n, _) => Const (n, typ)
val res = list_comb (head, dicts)
in
(res, lthy')
end)
| traverse (f $ x) lthy =
let
val (f', lthy') = traverse f lthy
val (x', lthy'') = traverse x lthy'
in (f' $ x', lthy'') end
| traverse (Abs (name, typ, term)) lthy =
let
val (term', lthy') = traverse term lthy
in (Abs (name, typify_typ typ, term'), lthy') end
| traverse (Free (name, typ)) lthy = (Free (name, typify_typ typ), lthy)
| traverse (Var (name, typ)) lthy = (Var (name, typify_typ typ), lthy)
| traverse (Bound n) lthy = (Bound n, lthy)
in
traverse term lthy
end
fun new_names lthy consts =
let
val (all_names, all_consts) = split_list (flat consts)
val all_frees = map (fn Fun {eqs, ...} => eqs | _ => []) all_consts |> flat
|> map #params |> flat
|> map all_frees' |> flat
val context = fold Name.declare (all_names @ all_frees) (Variable.names_of lthy)
fun new_name (name, const) context =
let val (name', context') = Name.variant (mangle name) context in
((name, (name', const)), context')
end
in
fst (fold_map (fold_map new_name) consts context)
end
fun consts consts proof group lthy =
let
val fun_config = Function_Common.FunctionConfig
{sequential=true, default=NONE, domintros=false, partials=false}
fun pat_completeness_auto ctxt =
Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt
val all_names = map fst group
val pretty_consts = map (pretty_const lthy) all_names |> Pretty.commas
fun print_info msg =
Pretty.str (msg ^ " ") :: pretty_consts
|> Pretty.block
|> Pretty.writeln
val _ = print_info "Redefining constant(s)"
fun process_eqs (name, Fun {dicts, param_typs, new_typ, eqs, info, cong, ...}) lthy =
let
val new_name = case Symtab.lookup consts name of
NONE => error ("no new name for " ^ name)
| SOME (n, _) => n
val all_frees = map #params eqs |> flat |> map all_frees' |> flat
val context = Name.make_context (all_names @ all_frees)
val (dict_params, context') = dict_params context dicts
fun adapt_params param_typs params =
let
val real_params = dict_params @ params
val ext_params = drop (length params) param_typs
|> map typify_typ
|> Name.invent_names context' "e0" |> map Free
in (real_params, ext_params) end
fun mk_eq {params, rhs, thm} lthy =
let
val (real_params, ext_params) = adapt_params param_typs params
val lhs' = list_comb (Free (new_name, new_typ), real_params @ ext_params)
val (rhs', lthy') = term (Symreltab.make (map fst dicts ~~ dict_params)) consts rhs lthy
val rhs'' = list_comb (rhs', ext_params)
in
((HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs', rhs'')), thm), lthy')
end
val is_fun = length param_typs + length dicts > 0
in
fold_map mk_eq eqs lthy
|>> rpair (new_typ, is_fun)
|>> SOME
|>> pair ((name, new_name, map fst dicts), {info = info, cong = cong})
end
| process_eqs (name, _) lthy =
((((name, name, []), {info = NONE, cong = NONE}), NONE), lthy)
val (items, lthy') = fold_map process_eqs group lthy
val ((metas, infos), ((eqs, code_thms), (new_typs, is_funs))) = items
|> map_filter (fn (meta, eqs) => Option.map (pair meta) eqs)
|> split_list
||> split_list ||> apfst (flat #> split_list #>> map typify)
||> apsnd split_list
|>> split_list
val _ = if_debug lthy (fn () =>
if null code_thms then ()
else
map (Syntax.pretty_term lthy o Thm.prop_of) code_thms
|> Pretty.big_list "Equations:"
|> Pretty.string_of
|> tracing)
val is_fun =
case distinct (op =) is_funs of
[b] => b
| [] => false
| _ => error "unsupported feature: mixed non-function and function definitions"
fun mk_binding (_, new_name, _) typ =
(Binding.name new_name, SOME typ, NoSyn)
val bindings = map2 mk_binding metas new_typs
val {constants, instantiations} = Definitions.get (Context.Proof lthy')
val base_thms = Symtab.dest constants |> map (apsnd fst o snd) |> map_filter snd
val base_certs = Symreltab.dest instantiations |> map (snd o snd)
val consts = Sign.consts_of (Proof_Context.theory_of lthy')
fun prove_eq_fun (info as {simps = SOME simps, fs, inducts = SOME inducts, ...}) lthy =
let
fun mk_target (name, _, classes) new =
(classes, (new, Const (Consts.the_const consts name)))
val targets = map2 mk_target metas fs
val args =
{fun_info = SOME info, inducts = SOME inducts, simps = simps, base_thms = base_thms,
base_certs = base_certs, code_thms = code_thms, congs = NONE}
in
(prove_fun_cert targets args proof lthy, args)
end
fun prove_eq_def defs lthy =
let
fun mk_target (name, _, classes) new =
(classes, (new, Const (Consts.the_const consts name)))
val targets = map2 mk_target metas (map (fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o Thm.prop_of) defs)
val args =
{fun_info = NONE, inducts = NONE, simps = defs,
base_thms = base_thms, base_certs = base_certs, code_thms = code_thms, congs = NONE}
in
(prove_fun_cert targets args proof lthy, args)
end
fun add_constants ((((name, name', _), _), SOME _) :: xs) ((thm :: thms), info) =
add_constant name name' (SOME thm, info) #> add_constants xs (thms, info)
| add_constants ((((name, name', _), _), NONE) :: xs) (thms, info) =
add_constant name name' (NONE, info) #> add_constants xs (thms, info)
| add_constants [] _ =
I
fun prove_termination new_info ctxt =
let
val termination_ctxt =
ctxt addsimps (@{thms equal} @ base_thms)
addloop ("overload", CONVERSION o changed_conv o Axclass.overload_conv)
val fallback_tac =
Function_Common.termination_prover_tac true termination_ctxt
val tac = case try hd (cat_options (map #info infos)) of
SOME old_info => HEADGOAL (Transfer_Termination.termination_tac new_info old_info ctxt)
| NONE => no_tac
in Function.prove_termination NONE (tac ORELSE fallback_tac) ctxt end
fun prove_cong data lthy =
let
fun rewr_cong thm cong =
if Thm.nprems_of thm > 0 then
(warning "No fundef_cong rule can be derived; this will likely not work later"; NONE)
else
(print_info "Porting fundef_cong rule for ";
SOME (Local_Defs.fold lthy [thm] cong))
val congs' =
map2 (Option.mapPartial o rewr_cong) (fst data) (map #cong infos)
|> cat_options
fun add_congs phi =
fold Function_Context_Tree.add_function_cong (map (Morphism.thm phi) congs')
val data' =
apsnd (map_const_info I I I I I I (K (SOME congs'))) data
in
(data', Local_Theory.declaration {pervasive = false, syntax = false} add_congs lthy)
end
fun mk_fun lthy =
let
val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) eqs
val (info, lthy') =
Function.add_function bindings specs fun_config pat_completeness_auto lthy
|-> prove_termination
val simps = the (#simps info)
val (_, lthy'') =
Local_Theory.note ((Binding.empty, @{attributes [simp del]}), simps) lthy'
fun prove_eq phi =
prove_eq_fun (Function_Common.transform_function_data phi info)
in
(((simps, #inducts info), prove_eq), lthy'')
end
fun mk_def lthy =
let
val (defs, lthy') = fold_map define_params_nosyn eqs lthy
fun prove_eq phi = prove_eq_def (map (Morphism.thm phi) defs)
in
(((defs, NONE), prove_eq), lthy')
end
in
if null eqs then
lthy'
else
let
val ((side, prove_eq), (lthy', lthy)) = lthy'
|> (snd o Local_Theory.begin_nested)
|> (if is_fun then mk_fun else mk_def)
|-> (fn ((simps, inducts), prove_eq) =>
apfst (rpair prove_eq) o Side_Conditions.mk_side simps inducts)
||> `Local_Theory.end_nested
val phi = Proof_Context.export_morphism lthy lthy'
in
lthy'
|> `(prove_eq phi)
|>> apfst (on_thms_complete (fn () => print_info "Proved equivalence for"))
|-> prove_cong
|-> add_constants items
end
end
fun const_raw (binding, raw_consts) proof lthy =
let
val _ =
if proof = Skip then
warning "Skipping certificate proofs"
else ()
val (name, _) = Syntax.read_terms lthy raw_consts |> map dest_Const |> split_list
val (eqs, lthy) = annotate_code_eqs lthy name
val tab = symtab_of_sccs (new_names lthy eqs)
val lthy = fold (consts tab proof) eqs lthy
val {instantiations, constants} = Definitions.get (Context.Proof lthy)
val thms =
map (snd o snd) (Symreltab.dest instantiations) @
map_filter (fst o snd o snd) (Symtab.dest constants)
in
snd (Local_Theory.note (binding, thms) lthy)
end
val parse_flags =
Scan.optional (Args.parens (Parse.reserved "skip" >> K Skip)) Cert
val _ =
Outer_Syntax.local_theory
@{command_keyword "declassify"}
"redefines a constant after applying the dictionary construction"
(parse_flags -- Parse_Spec.opt_thm_name ":" -- Scan.repeat1 Parse.const >>
(fn ((flags, def_binding), consts) => const_raw (def_binding, consts) flags))
end
Theory Termination
section ‹Termination heuristics›
text_raw ‹\label{sec:termination}›
theory Termination
imports "../Dict_Construction"
begin
text ‹
As indicated in the introduction, the newly-defined functions must be proven terminating. In
general, we cannot reuse the original termination proof, as the following example illustrates:
›
fun f :: "nat ⇒ nat" where
"f 0 = 0" |
"f (Suc n) = f n"
lemma [code]: "f x = f x" ..
text ‹
The invocation of @{theory_text ‹declassify f›} would fail, because @{const f}'s code equations
are not terminating.
Hence, in the general case where users have modified the code equations, we need to fall back
to an (automated) attempt to prove termination.
In the remainder of this section, we will illustrate the special case where the user has not
modified the code equations, i.e., the original termination proof should ``morally'' be still
applicable. For this, we will perform the dictionary construction manually.
›
local_setup ‹Class_Graph.ensure_class @{class plus} #> snd›
local_setup ‹Class_Graph.ensure_class @{class zero} #> snd›
fun sum_list :: "'a::{plus,zero} list ⇒ 'a" where
"sum_list [] = 0" |
"sum_list (x # xs) = x + sum_list xs"
text ‹
The above function carries two distinct class constraints, which are translated into two
dictionary parameters:
›
function sum_list' where
"sum_list' d_plus d_zero [] = Groups_zero__class_zero__field d_zero" |
"sum_list' d_plus d_zero (x # xs) = Groups_plus__class_plus__field d_plus x (sum_list' d_plus d_zero xs)"
by pat_completeness auto
text ‹
Now, we need to carry out the termination proof of @{const sum_list'}. The @{theory_text function}
package analyzes the function definition and discovers one recursive call. In pseudo-notation:
@{text [display] ‹(d_plus, d_zero, x # xs) ↝ (d_plus, d_zero, xs)›}
The result of this analysis is captured in the inductive predicate @{const sum_list'_rel}. Its
introduction rules look as follows:
›
thm sum_list'_rel.intros
text ‹Compare this to the relation for @{const sum_list}:›
thm sum_list_rel.intros
text ‹
Except for the additional (unchanging) dictionary arguments, these relations are more or less
equivalent to each other. There is an important difference, though: @{const sum_list_rel} has
sort constraints, @{const sum_list'_rel} does not. (This will become important later on.)
›
context
notes [[show_sorts]]
begin
term sum_list_rel
term sum_list'_rel
end
text ‹
Let us know discuss the rough concept of the termination proof for @{const sum_list'}. The goal is
to show that @{const sum_list'_rel} is well-founded. Usually, this is proved by specifying a
∗‹measure function› that
▸ maps the arguments to natural numbers
▸ decreases for each recursive call.
›
text ‹
Here, however, we want to instead show that each recursive call in @{const sum_list'} has a
corresponding recursive call in @{const sum_list}. In other words, we want to show that the
existing proof of well-foundedness of @{const sum_list_rel} can be lifted to a proof of
well-foundedness of @{const sum_list'_rel}. This is what the theorem
@{thm [source=true] wfP_simulate_simple} states:
@{thm [display=true] wfP_simulate_simple}
Given any well-founded relation ‹r› and a function ‹g› that maps function arguments from ‹r'› to
‹r›, we can deduce that ‹r'› is also well-founded.
For our example, we need to provide a function ‹g› of type
@{typ ‹'b Groups_plus__dict × 'b Groups_zero__dict × 'b list ⇒ 'a list›}.
Because the dictionary parameters are not changing, they can safely be dropped by ‹g›.
However, because of the sort constraint in @{const sum_list_rel}, the term @{term "snd ∘ snd"}
is not a well-typed instantiation for ‹g›.
Instead (this is where the heuristic comes in), we assume that the original function
@{const sum_list} is parametric, i.e., termination does not depend on the elements of the list
passed to it, but only on the structure of the list. Additionally, we assume that all involved
type classes have at least one instantiation.
With this in mind, we can use @{term "map (λ_. undefined) ∘ snd ∘ snd"} as ‹g›:
›
thm wfP_simulate_simple[where
r = sum_list_rel and
r' = sum_list'_rel and
g = "map (λ_. undefined) ∘ snd ∘ snd"]
text ‹
Finally, we can prove the termination of @{const sum_list'}.
›
termination sum_list'
proof -
have "wfP sum_list'_rel"
proof (rule wfP_simulate_simple)
show "wfP sum_list_rel"
apply (rule accp_wfPI)
apply (tactic ‹resolve_tac @{context} [Function.get_info @{context} @{term sum_list} |> #totality |> the] 1›)
done
define g :: "'b Groups_plus__dict × 'b Groups_zero__dict × 'b list ⇒ 'c::{plus,zero} list" where
"g = map (λ_. undefined) ∘ snd ∘ snd"
show "sum_list_rel (g x) (g y)" if "sum_list'_rel x y" for x y
using that
proof (induction x y rule: sum_list'_rel.induct)
case (1 d_plus d_zero x xs)
show ?case
apply (simp only: g_def comp_apply snd_conv list.map)
apply (rule sum_list_rel.intros(1))
done
qed
qed
then show "∀x. sum_list'_dom x"
by (rule wfP_implies_dom)
qed
text ‹This can be automated with a special tactic:›
experiment
begin
termination sum_list'
apply (tactic ‹
Transfer_Termination.termination_tac
(Function.get_info @{context} @{term sum_list'})
(Function.get_info @{context} @{term sum_list})
@{context}
1›; fail)
done
end
text ‹
A similar technique can be used for making functions defined in locales executable when, for some
reason, the definition of a ``defs'' locale is not feasible.
›
locale foo =
fixes A :: "nat"
assumes "A > 0"
begin
fun f where
"f 0 = A" |
"f (Suc n) = Suc (f n)"
lemma f_total: "wfP f_rel"
apply (rule accp_wfPI)
apply (tactic ‹resolve_tac @{context} [Function.get_info @{context} @{term f} |> #totality |> the] 1›)
done
end
interpretation dummy: foo 1 by standard simp
function f' where
"f' A 0 = A" |
"f' A (Suc n) = Suc (f' A n)"
by pat_completeness auto
termination f'
apply (rule wfP_implies_dom)
apply (rule wfP_simulate_simple[where g = "snd"])
apply (rule dummy.f_total)
subgoal for x y
apply (induction x y rule: f'_rel.induct)
subgoal
apply (simp only: snd_conv)
apply (rule dummy.f_rel.intros)
done
done
done
text ‹Automatic:›
experiment
begin
termination f'
apply (tactic ‹
Transfer_Termination.termination_tac
(Function.get_info @{context} @{term f'})
(Function.get_info @{context} @{term dummy.f})
@{context}
1›; fail)
done
end
end
Theory Test_Dict_Construction
section ‹Test cases for dictionary construction›
theory Test_Dict_Construction
imports
Dict_Construction
"HOL-Library.ListVector"
begin
subsection ‹Code equations with different number of explicit arguments›
lemma [code]: "fold f [] = id" "fold f (x # xs) s = fold f xs (f x s)" "fold f [x, y] u ≡ f y (f x u)"
by auto
experiment begin
declassify valid: fold
thm valid
lemma "List_fold = fold" by (rule valid)
end
subsection ‹Complex class hierarchies›
local_setup ‹Class_Graph.ensure_class @{class zero} #> snd›
local_setup ‹Class_Graph.ensure_class @{class plus} #> snd›
experiment begin
local_setup ‹Class_Graph.ensure_class @{class comm_monoid_add} #> snd›
local_setup ‹Class_Graph.ensure_class @{class ring} #> snd›
typ "nat Rings_ring__dict"
end
text ‹Check that ‹Class_Graph› does not leak out of locales›
ML‹@{assert} (is_none (Class_Graph.node @{context} @{class ring}))›
subsection ‹Instances with non-trivial arity›
fun f :: "'a::plus ⇒ 'a" where
"f x = x + x"
definition g :: "'a::{plus,zero} list ⇒ 'a list" where
"g x = f x"
datatype natt = Z | S natt
instantiation natt :: "{zero,plus}" begin
definition zero_natt where
"zero_natt = Z"
fun plus_natt where
"plus_natt Z x = x" |
"plus_natt (S m) n = S (plus_natt m n)"
instance ..
end
definition h :: "natt list" where
"h = g [Z,S Z]"
experiment begin
declassify valid: h
thm valid
lemma "Test__Dict__Construction_h = h" by (fact valid)
ML‹Dict_Construction.the_info @{context} @{const_name plus_natt_inst.plus_natt}›
end
text ‹Check that @{command declassify} does not leak out of locales›
ML‹
can (Dict_Construction.the_info @{context}) @{const_name plus_natt_inst.plus_natt}
|> not |> @{assert}
›
subsection ‹[@{attribute fundef_cong}] rules›
datatype 'a seq = Cons 'a "'a seq" | Nil
experiment begin
declassify map_seq
text ‹Check presence of derived [@{attribute fundef_cong}] rule›
ML‹
Dict_Construction.the_info @{context} @{const_name map_seq}
|> #fun_info
|> the
|> #fs
|> the_single
|> dest_Const
|> fst
|> Dict_Construction.cong_of_const @{context}
|> the
›
end
subsection ‹Mutual recursion›
fun odd :: "nat ⇒ bool" and even where
"odd 0 ⟷ False" |
"even 0 ⟷ True" |
"odd (Suc n) ⟷ even n" |
"even (Suc n) ⟷ odd n"
experiment begin
declassify valid: odd even
thm valid
end
datatype 'a bin_tree = Leaf | Node 'a "'a bin_tree" "'a bin_tree"
experiment begin
declassify valid: map_bin_tree rel_bin_tree
thm valid
end
datatype 'v env = Env "'v list"
datatype v = Closure "v env"
context
notes is_measure_trivial[where f = "size_env size", measure_function]
begin
fun test_v :: "v ⇒ bool" and test_w :: "v env ⇒ bool" where
"test_v (Closure env) ⟷ test_w env" |
"test_w (Env vs) ⟷ list_all test_v vs"
fun test_v1 :: "v ⇒ 'a::{one,monoid_add}" and test_w1 :: "v env ⇒ 'a" where
"test_v1 (Closure env) = 1 + test_w1 env" |
"test_w1 (Env vs) = sum_list (map test_v1 vs)"
end
experiment begin
declassify valid: test_w test_v
thm valid
end
experiment begin
declassify valid: test_w1 test_v1
thm valid
end
subsection ‹Non-trivial code dependencies; code equations where the head is not fully general›
definition "c ≡ 0 :: nat"
definition "d x ≡ if x = 0 then 0 else x"
lemma contrived[code]: "c = d 0" unfolding c_def d_def by simp
experiment begin
declassify valid: c
thm valid
lemma "Test__Dict__Construction_c = c" by (fact valid)
end
subsection ‹Pattern matching on @{term "0::nat"}›
definition j where "j (n::nat) = (0::nat)"
lemma [code]: "j 0 = 0" "j (Suc n) = j n"
unfolding j_def by auto
fun k where
"k 0 = (0::nat)" |
"k (Suc n) = k n"
lemma f_code[code]: "k n = 0"
by (induct n) simp+
experiment begin
declassify valid: j k
thm valid
lemma
"Test__Dict__Construction_j = j"
"Test__Dict__Construction_k = k"
by (fact valid)+
end
subsection ‹Complex termination arguments›
fun fac :: "nat ⇒ nat" where
"fac n = (if n ≤ 1 then 1 else n * fac (n - 1))"
experiment begin
declassify valid: fac
end
subsection ‹Combination of various things›
experiment begin
declassify valid: sum_list
end
subsection ‹Interaction with the code generator›
declassify h
export_code Test__Dict__Construction_h in SML
end
Theory Test_Side_Conditions
subsection ‹Contrived side conditions›
theory Test_Side_Conditions
imports Dict_Construction
begin
ML ‹
fun assert_alt_total ctxt term = @{assert} (Side_Conditions.is_total ctxt term)
›
fun head where
"head (x # _) = x"
local_setup ‹snd o Side_Conditions.mk_side @{thms head.simps} NONE›
lemma head_side_eq: "head_side xs ⟷ xs ≠ []"
by (cases xs) (auto intro: head_side.intros elim: head_side.cases)
declaration ‹K (Side_Conditions.set_alt @{term head} @{thm head_side_eq})›
fun map where
"map f [] = []" |
"map f (x # xs) = f x # map f xs"
local_setup ‹snd o Side_Conditions.mk_side @{thms map.simps} (SOME @{thms map.induct})›
thm map_side.intros
ML ‹assert_alt_total @{context} @{term map}›
experiment begin
text ‹Functions that use partial functions always in their domain are processed correctly.›
fun tail where
"tail (_ # xs) = xs"
local_setup ‹snd o Side_Conditions.mk_side @{thms tail.simps} NONE›
lemma tail_side_eq: "tail_side xs ⟷ xs ≠ []"
by (cases xs) (auto intro: tail_side.intros elim: tail_side.cases)
declaration ‹K (Side_Conditions.set_alt @{term tail} @{thm tail_side_eq})›
function map' where
"map' f xs = (if xs = [] then [] else f (head xs) # map' f (tail xs))"
by auto
termination
apply (relation "measure (size ∘ snd)")
apply rule
subgoal for f xs by (cases xs) auto
done
local_setup ‹snd o Side_Conditions.mk_side @{thms map'.simps} (SOME @{thms map'.induct})›
thm map'_side.intros
ML ‹assert_alt_total @{context} @{term map'}›
end
lemma map_cong:
assumes "xs = ys" "⋀x. x ∈ set ys ⟹ f x = g x"
shows "map f xs = map g ys"
unfolding assms(1)
using assms(2)
by (induction ys) auto
definition map_head where
"map_head xs = map head xs"
experiment begin
declare map_cong[fundef_cong]
local_setup ‹snd o Side_Conditions.mk_side @{thms map_head_def} NONE›
thm map_head_side.intros
lemma "map_head_side xs ⟷ (∀x ∈ set xs. x ≠ [])"
by (auto intro: map_head_side.intros elim: map_head_side.cases)
definition map_head' where
"map_head' xss = map (map head) xss"
local_setup ‹snd o Side_Conditions.mk_side @{thms map_head'_def} NONE›
thm map_head'_side.intros
lemma "map_head'_side xss ⟷ (∀xs ∈ set xss. ∀x ∈ set xs. x ≠ [])"
by (auto intro: map_head'_side.intros elim: map_head'_side.cases)
end
experiment begin
local_setup ‹snd o Side_Conditions.mk_side @{thms map_head_def} NONE›
term map_head_side
thm map_head_side.intros
lemma "¬ map_head_side xs"
by (auto elim: map_head_side.cases)
end
definition head_known where
"head_known xs = head (3 # xs)"
local_setup ‹snd o Side_Conditions.mk_side @{thms head_known_def} NONE›
thm head_known_side.intros
ML‹assert_alt_total @{context} @{term head_known}›
fun odd :: "nat ⇒ bool" and even where
"odd 0 ⟷ False" |
"even 0 ⟷ True" |
"odd (Suc n) ⟷ even n" |
"even (Suc n) ⟷ odd n"
local_setup ‹snd o Side_Conditions.mk_side @{thms odd.simps even.simps} (SOME @{thms odd_even.induct})›
thm odd_side_even_side.intros
ML‹assert_alt_total @{context} @{term odd}›
ML‹assert_alt_total @{context} @{term even}›
definition odd_known where
"odd_known = odd (Suc 0)"
local_setup ‹snd o Side_Conditions.mk_side @{thms odd_known_def} NONE›
thm odd_known_side.intros
ML‹assert_alt_total @{context} @{term odd_known}›
end
Theory Test_Lazy_Case
subsection ‹Interaction with ‹Lazy_Case››
theory Test_Lazy_Case
imports
Dict_Construction
Lazy_Case.Lazy_Case
Show.Show_Instances
begin
datatype 'a tree = Node | Fork 'a "'a tree list"
lemma map_tree[code]:
"map_tree f t = (case t of Node ⇒ Node | Fork x ts ⇒ Fork (f x) (map (map_tree f) ts))"
by (induction t) auto
experiment begin
text ‹
Dictionary construction of @{const map_tree} requires the [@{attribute fundef_cong}] rule of
@{const Test_Lazy_Case.tree.case_lazy}.
›
declassify valid: map_tree
thm valid
lemma "Test__Lazy__Case_tree_map__tree = map_tree" by (fact valid)
end
definition i :: "(unit × (bool list × string × nat option) list) option ⇒ string" where
"i = show"
experiment begin
text ‹This currently requires @{theory Lazy_Case.Lazy_Case} because of @{const divmod_nat}.›
declassify valid: i
thm valid
lemma "Test__Lazy__Case_i = i" by (fact valid)
end
end