# Theory Tagged_Prod_Sum

chapter ‹Tagged Sum-of-Products Representation›
text ‹
This theory sets up a version of the sum-of-products representation that includes constructor and
selector names. For an example of a type class that uses this representation see Derive\_Show.
›

theory Tagged_Prod_Sum
imports Main
begin

context begin

qualified datatype ('a, 'b) prod = Prod "string option" "string option" 'a 'b
qualified datatype ('a, 'b) sum = Inl "string option" 'a | Inr "string option" 'b

qualified definition fst where "fst p = (case p of (Prod _ _ a _) ⇒ a)"
qualified definition snd where "snd p = (case p of (Prod _ _ _ b) ⇒ b)"
qualified definition sel_name_fst where "sel_name_fst p = (case p of (Prod s _ _ _) ⇒ s)"
qualified definition sel_name_snd where "sel_name_snd p = (case p of (Prod _ s _ _) ⇒ s)"

qualified definition constr_name where "constr_name x = (case x of (Inl s _) ⇒ s | (Inr s _) ⇒ s)"

end

lemma measure_tagged_fst[measure_function]: "is_measure f ⟹ is_measure (λ p. f (Tagged_Prod_Sum.fst p))"
by (rule is_measure_trivial)

lemma measure_tagged_snd[measure_function]: "is_measure f ⟹ is_measure (λ p. f (Tagged_Prod_Sum.snd p))"
by (rule is_measure_trivial)

lemma size_tagged_prod_simp:
"Tagged_Prod_Sum.prod.size_prod f g p = f (Tagged_Prod_Sum.fst p) + g (Tagged_Prod_Sum.snd p) + Suc 0"
apply (induct p)

lemma size_tagged_sum_simp:
"Tagged_Prod_Sum.sum.size_sum f g x = (case x of Tagged_Prod_Sum.Inl _ a ⇒ f a + Suc 0 | Tagged_Prod_Sum.Inr _ b ⇒ g b + Suc 0)"
apply (induct x)
by auto

lemma size_tagged_prod_measure:
"is_measure f ⟹ is_measure g ⟹ is_measure (Tagged_Prod_Sum.prod.size_prod f g)"
by (rule is_measure_trivial)

lemma size_tagged_sum_measure:
"is_measure f ⟹ is_measure g ⟹ is_measure (Tagged_Prod_Sum.sum.size_sum f g)"
by (rule is_measure_trivial)

end

# Theory Derive

chapter ‹Derive›
text ‹
This theory includes the Isabelle/ML code needed for the derivation and exports the two keywords
\texttt{derive\_generic} and \texttt{derive\_generic\_setup}.
›

theory Derive
imports Main Tagged_Prod_Sum
keywords "derive_generic" "derive_generic_setup" :: thy_goal
begin

context begin

qualified definition iso :: "('a ⇒ 'b) ⇒ ('b ⇒ 'a) ⇒ bool" where
"iso from to = ((∀ a. to (from a) = a) ∧ (∀ b . from (to b) = b))"

lemma iso_intro: "(⋀a. to (from a) = a) ⟹ (⋀b. from (to b) = b) ⟹ iso from to"
unfolding iso_def by simp

end

ML_file ‹derive_util.ML›
ML_file ‹derive_laws.ML›
ML_file ‹derive_setup.ML›
ML_file ‹derive.ML›

end

# File ‹derive_util.ML›

signature DERIVE_UTIL =
sig
type ctr_info = (string * (string * typ list) list) list

type rep_type_info =
{repname : string,
rep_type : typ,
tFrees_mapping : (typ * typ) list,
from_info : Function_Common.info option,
to_info : Function_Common.info option}

type comb_type_info =
{combname : string,
combname_full : string,
comb_type : typ,
ctr_type : typ,
inConst : term,
inConst_free : term,
inConst_type : typ,
rep_type_instantiated : typ}

type type_info =
{tname : string,
tfrees : (typ * sort) list,
mutual_tnames : string list,
mutual_Ts : typ list,
mutual_ctrs : ctr_info,
mutual_sels : (string * string list list) list,
is_rec : bool,
is_mutually_rec : bool,
rep_info : rep_type_info,
comb_info : comb_type_info option,
iso_thm : thm option}

type class_info =
{classname : string,
class : sort,
params : (class * (string * typ)) list option,
class_law : thm option,
class_law_const : term option,
ops : term list option,
transfer_law : (string * thm list) list option,
axioms : thm list option,
axioms_def : thm option,
class_def : thm option,
equivalence_thm : thm option}

type instance_info =
{defs : thm list}

val is_typeT : typ -> bool
val insert_application : term -> term -> term
val add_tvars : string -> string list -> string
val replace_tfree : string list -> string -> string -> string
val ctrs_arguments : ctr_info -> typ list
val collect_tfrees : ctr_info -> (typ * sort) list
val collect_tfree_names : ctr_info -> string list
val combs_to_list : term -> term list
val get_tvar : typ list -> typ
val not_instantiated : theory -> string -> class -> bool
(* version of add_fun that doesn't throw away info *)
val add_fun' : (binding * typ option * mixfix) list ->
Specification.multi_specs -> Function_Common.function_config ->
local_theory -> (Function_Common.info * Proof.context)
val add_conversion_info : Function_Common.info -> Function_Common.info -> type_info -> type_info
val add_iso_info : thm option -> type_info -> type_info
val has_class_law : string -> theory -> bool
val zero_tvarsT : typ -> typ
val zero_tvars : term -> term
val get_superclasses : sort -> string -> theory -> string list
val tagged_function_termination_tac : Proof.context -> Function.info * local_theory
val get_mapping_function : Proof.context -> typ -> term
val is_polymorphic : typ -> bool

(* determines all mutual recursive types of a given BNF-least-fixpoint-type *)
val mutual_recursive_types : string -> Proof.context -> string list * typ list
val freeify_tvars : typ -> typ
(* delivers a full type from a type name by instantiating the type-variables of that
type with different variables of a given sort, also returns the chosen variables
as second component *)
val typ_and_vs_of_typname : theory -> string -> sort -> typ * (string * sort) list

val constr_terms : Proof.context -> string -> term list
end

structure Derive_Util : DERIVE_UTIL =
struct

type ctr_info = (string * (string * typ list) list) list

type rep_type_info =
{repname : string,
rep_type : typ,
tFrees_mapping : (typ * typ) list,
from_info : Function_Common.info option,
to_info : Function_Common.info option}

type comb_type_info =
{combname : string,
combname_full : string,
comb_type : typ,
ctr_type : typ,
inConst : term,
inConst_free : term,
inConst_type : typ,
rep_type_instantiated : typ}

type type_info =
{tname : string,
tfrees : (typ * sort) list,
mutual_tnames : string list,
mutual_Ts : typ list,
mutual_ctrs : ctr_info,
mutual_sels : (string * string list list) list,
is_rec : bool,
is_mutually_rec : bool,
rep_info : rep_type_info,
comb_info : comb_type_info option,
iso_thm : thm option}

type class_info =
{classname : string,
class : sort,
params : (class * (string * typ)) list option,
class_law : thm option,
class_law_const : term option,
ops : term list option,
transfer_law : (string * thm list) list option,
axioms : thm list option,
axioms_def : thm option,
class_def : thm option,
equivalence_thm : thm option}

type instance_info =
{defs : thm list}

val is_typeT = fn (Type _) => true | _ => false

fun insert_application (t1 $t2) t3 = insert_application t1 (insert_application t2 t3) | insert_application t t3 = t$ t3

let
fun zip_tvars [] = "" |
zip_tvars [x] = x |
zip_tvars (x::xs) = x ^ ", " ^ (zip_tvars xs)
in
case tvar_names of [] => tname | xs => "(" ^ zip_tvars xs ^ ") " ^ tname
end

(* replace tfree by replacement_name if it occurs in tfree_names *)
fun replace_tfree tfree_names replacement_name tfree =
(case List.find (curry (op =) tfree) tfree_names
of SOME _ => replacement_name
| NONE => tfree)

(* Operations on constructor information *)
val ctrs_arguments = (map (fn l => map snd (snd l))) #> flat #> flat
fun collect_tfrees ctrs = map (fn (t,s) => (TFree (t,s),s))
fun collect_tfree_names ctrs = fold Term.add_tfree_namesT (ctrs_arguments ctrs) []

fun not_instantiated thy tname class =
null (Thm.thynames_of_arity thy (tname, class))

fun combs_to_list t =
let
fun
combs_to_list_aux (t1 $t2) = t2 :: (combs_to_list_aux t1) | combs_to_list_aux t = [t] in rev (combs_to_list_aux t) end fun get_tvar ts = case ts of [] => TFree ("'a", \<^sort>‹type›) | (t::ts) => case t of T as TFree _ => T | Type (_,xs) => get_tvar (xs@ts) | _ => get_tvar ts fun add_fun' binding specs config lthy = let fun pat_completeness_auto ctxt = Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt fun prove_termination lthy = Function.prove_termination NONE (Function_Common.termination_prover_tac false lthy) lthy in lthy |> (Function.add_function binding specs config) pat_completeness_auto |> snd |> prove_termination end fun add_conversion_info from_info to_info (ty_info : type_info) = let val {tname, uses_metadata, tfrees, mutual_tnames, mutual_Ts, mutual_ctrs, mutual_sels, is_rec, is_mutually_rec, rep_info, comb_info, iso_thm} = ty_info val {repname, rep_type, tFrees_mapping, ...} = rep_info in {tname = tname, uses_metadata = uses_metadata, tfrees = tfrees, mutual_tnames = mutual_tnames, mutual_Ts = mutual_Ts, mutual_ctrs = mutual_ctrs, mutual_sels = mutual_sels, is_rec = is_rec, is_mutually_rec = is_mutually_rec, rep_info = {repname = repname, rep_type = rep_type, tFrees_mapping = tFrees_mapping, from_info = SOME from_info, to_info = SOME to_info} , comb_info = comb_info, iso_thm = iso_thm} : type_info end fun add_iso_info iso_thm (ty_info : type_info) = let val {tname, uses_metadata, tfrees, mutual_tnames, mutual_Ts, mutual_ctrs, mutual_sels, is_rec, is_mutually_rec, rep_info, comb_info, ...} = ty_info in {tname = tname, uses_metadata = uses_metadata, tfrees = tfrees, mutual_tnames = mutual_tnames, mutual_Ts = mutual_Ts, mutual_ctrs = mutual_ctrs, mutual_sels = mutual_sels, is_rec = is_rec, is_mutually_rec = is_mutually_rec, rep_info = rep_info, comb_info = comb_info, iso_thm = iso_thm} : type_info end fun has_class_law classname thy = let val class = Syntax.parse_sort (Proof_Context.init_global thy) classname |> hd in is_some (Class.rules thy class |> fst) end fun zero_tvarsT (Type (s,ts)) = Type (s, map zero_tvarsT ts) | zero_tvarsT (TVar ((n,_),s)) = TVar ((n,0),s) | zero_tvarsT T = T fun zero_tvars t = map_types zero_tvarsT t fun unique [] = [] | unique (x::xs) = let fun remove (_,[]) = [] | remove (x,y::ys) = if x = y then remove(x,ys) else y::remove(x,ys) in x::unique(remove(x,xs)) end fun get_superclasses class classname thy = let val all_classes = (Class.these_params thy class) |> map (snd #> fst) val superclasses = filter (curry (op =) classname #> not) all_classes in unique superclasses end fun tagged_function_termination_tac ctxt = let val prod_simp_thm = @{thm size_tagged_prod_simp} fun measure_tac ctxt = Function_Relation.relation_infer_tac ctxt ((Const (\<^const_name>‹measure›,dummyT))$ (Const (\<^const_name>‹size›,dummyT)))
fun prove_termination ctxt = auto_tac (Simplifier.add_simp prod_simp_thm ctxt)
in
Function.prove_termination NONE ((HEADGOAL (measure_tac ctxt)) THEN (prove_termination ctxt)) ctxt
end

fun get_mapping_function lthy T =
let
val map_thms = BNF_GFP_Rec_Sugar.map_thms_of_type lthy T
val map_const = Thm.full_prop_of (hd map_thms) |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> fst
|> strip_comb |> fst |> dest_Const |> apsnd (K dummyT) |> Const
in
map_const
end

fun is_polymorphic T = not (null (Term.add_tfreesT T []))

(* Code copied from generator_aux.ML in AFP entry Deriving by Sternagel and Thiemann *)

fun typ_and_vs_of_typname thy typ_name sort =
let
val ar = Sign.arity_number thy typ_name
val sorts = map (K sort) (1 upto ar)
val ty_vars = Name.invent_names (Name.make_context [typ_name]) "a" sorts
val ty = Type (typ_name,map TFree ty_vars)
in (ty,ty_vars) end

val freeify_tvars = map_type_tvar (TFree o apfst fst)

fun mutual_recursive_types tyco lthy =
(case BNF_FP_Def_Sugar.fp_sugar_of lthy tyco of
SOME sugar =>
if Sign.arity_number (Proof_Context.theory_of lthy) tyco -
BNF_Def.live_of_bnf (#fp_bnf sugar) > 0
then error "only datatypes without dead type parameters are supported"
else if #fp sugar = BNF_Util.Least_FP then
sugar |> #fp_res |> #Ts |> (map (fst o dest_Type))
||> map freeify_tvars
else error "only least fixpoints are supported"
| NONE => error ("type " ^ quote tyco ^ " does not appear to be a new style datatype"))

(* Code copied from bnf_access.ML in AFP entry Deriving by Sternagel and Thiemann *)

fun constr_terms lthy = BNF_FP_Def_Sugar.fp_sugar_of lthy
#> the #> #fp_ctr_sugar #> #ctr_sugar #> #ctrs

end

structure Type_Data = Theory_Data
(
type T = Derive_Util.type_info Symreltab.table;
val empty = Symreltab.empty;
val extend = I;
fun merge data : T = Symreltab.merge (K true) data;
);

structure Class_Data = Theory_Data
(
type T = Derive_Util.class_info Symtab.table;
val empty = Symtab.empty;
val extend = I;
fun merge data : T = Symtab.merge (K true) data;
);

structure Instance_Data = Theory_Data
(
type T = Derive_Util.instance_info Symreltab.table;
val empty = Symreltab.empty;
val extend = I;
fun merge data : T = Symreltab.merge (K true) data;
);

# File ‹derive_laws.ML›

open Derive_Util

signature DERIVE_LAWS =
sig
(* prove the iso theorem*)
val prove_isomorphism : type_info -> Proof.context -> thm option * Proof.context
val prove_instance_tac : typ -> class_info -> instance_info -> type_info -> Proof.context -> tactic
val prove_equivalence_law : class_info -> instance_info -> Proof.context -> thm
val prove_combinator_instance : (thm list list -> local_theory -> Proof.context) -> local_theory -> Proof.state
end

structure Derive_Laws : DERIVE_LAWS =
struct

fun get_class_info thy classname = Symtab.lookup (Class_Data.get thy) classname

fun prove_isomorphism type_info lthy =
let
val tname_short = Long_Name.base_name (#tname type_info)
val from_info = the (#from_info (#rep_info type_info))
val to_info = the (#to_info (#rep_info type_info))

val from_f = hd (#fs from_info)
val to_f = hd (#fs to_info)

val from_induct = hd (the (#inducts from_info))
val to_induct = hd (the (#inducts to_info))

val iso_thm =
HOLogic.mk_Trueprop (Const (\<^const_name>‹Derive.iso›, dummyT) $from_f$ to_f)
|> Syntax.check_term lthy
val induct_tac_to = (Induct_Tacs.induct_tac lthy [[SOME "b"]] (SOME [to_induct]) 2)
val induct_tac_from = (Induct_Tacs.induct_tac lthy [[SOME "a"]] (SOME [from_induct]) 1)

val iso_thm_proved = Goal.prove lthy [] [] iso_thm
(fn _ => (resolve_tac lthy [@{thm Derive.iso_intro}] 1) THEN
induct_tac_to THEN induct_tac_from THEN
(ALLGOALS (asm_full_simp_tac lthy)))

val ((_,thms),lthy') = Local_Theory.note ((Binding.name ("conversion_iso_" ^ tname_short),[]), [iso_thm_proved]) lthy
val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy')
val thm = singleton (Proof_Context.export lthy' ctxt_thy) (hd thms)
in
(SOME thm, lthy')
end

fun prove_equivalence_law cl_info inst_info ctxt =
let
val class = #class cl_info
val classname = #classname cl_info
val class_law = the (#class_law cl_info)
val op_defs = #defs inst_info
val axioms = the (#axioms cl_info)
val axioms_def = the_list (#axioms_def cl_info)
val class_def = the (#class_def cl_info)

val ops_raw = map (Thm.full_prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq #> fst #> strip_comb #> fst) op_defs
val ops = map (dest_Const #> apsnd (K dummyT) #> Const) ops_raw
val class_law_const = Thm.full_prop_of class_law |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> fst |> strip_comb |> fst
val class_law_const_dummy = dest_Const class_law_const |> apsnd (K dummyT) |> Const
val axioms_thms = map (Local_Defs.unfold ctxt (class_def :: axioms_def)) axioms
val superclasses = get_superclasses class classname (Proof_Context.theory_of ctxt)
val superclass_laws = map (get_class_info (Proof_Context.theory_of ctxt) #> the #> #equivalence_thm #> the) superclasses

val t = list_comb (class_law_const_dummy,ops) |> HOLogic.mk_Trueprop |> Syntax.check_term ctxt
in
Goal.prove ctxt [] [] t
(fn _ => (Local_Defs.unfold_tac ctxt [class_law])
end

fun prove_instance_tac T cl_info inst_info ty_info ctxt =
let
val transfer_thms = the (#transfer_law cl_info) |> map snd |> flat
val iso_thm = the (#iso_thm ty_info)
val ops = map (Thm.full_prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq #> fst #> strip_comb #> fst) (#defs inst_info)
val op_defs = #defs inst_info
val class_law = the (#class_law cl_info)
val equivalence_thm = the (#equivalence_thm cl_info)

val class_law_const = Thm.full_prop_of class_law |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> fst |> strip_comb |> fst

val ops = map (dest_Const #> apsnd (K dummyT) #> Const) ops
val class_law_const_dummy = dest_Const class_law_const |> apsnd (K dummyT) |> Const

val class_law_inst =
HOLogic.mk_Trueprop (list_comb (class_law_const_dummy, ops))
|> singleton (Type_Infer_Context.infer_types ctxt)
|> (fn t => subst_TVars ([(Term.add_tvar_names t [] |> hd,T)]) t )

val transfer_thm_inst = (hd transfer_thms) OF [iso_thm,equivalence_thm]

val class_law_inst_proved = Goal.prove ctxt [] [] class_law_inst
(fn _ => (Local_Defs.unfold_tac ctxt op_defs)
THEN (Proof_Context.fact_tac ctxt [transfer_thm_inst] 1))

val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of ctxt)
val class_law_inst_def = singleton (Proof_Context.export ctxt ctxt_thy) class_law_inst_proved
val class_law_unfolded = Local_Defs.unfold ctxt [class_law] class_law_inst_def

val class_tac = Class.intro_classes_tac ctxt []
THEN (ALLGOALS (Method.insert_tac ctxt [class_law_unfolded]))
THEN (ALLGOALS (blast_tac ctxt))
in
class_tac
end

fun prove_combinator_instance after_qed lthy =
let
fun class_tac thms ctxt = Class.intro_classes_tac ctxt []
THEN (ALLGOALS (Method.insert_tac ctxt thms))
THEN (ALLGOALS (blast_tac ctxt))
fun prove_class_laws_manually st ctxt =
let
val thm = #goal (Proof.simple_goal st)
val goal = (Class.intro_classes_tac ctxt []) thm |> Seq.hd
val goals = Thm.prems_of goal
val goals_formatted = (map single (goals ~~ (replicate (length goals) [])))
fun prove_class thms =
let
val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of ctxt)
val thms' = (Proof_Context.export ctxt ctxt_thy) (flat thms)
in
Class.prove_instantiation_exit (class_tac thms') ctxt
end
fun after_qed' thms _ =  prove_class thms |> Named_Target.theory_init |> after_qed []
val st' = Proof.theorem NONE after_qed' goals_formatted ctxt
in
st'
end

val st = Class.instantiation_instance I lthy
in
prove_class_laws_manually st lthy
end

end

# File ‹derive_setup.ML›

open Derive_Util

signature DERIVE_SETUP =
sig
val prove_class_transfer : string -> theory -> Proof.state
val define_class_law : string -> Proof.context -> (thm * thm * thm option * term list * local_theory)
end

structure Derive_Setup : DERIVE_SETUP =
struct

fun get_class_info thy classname = Symtab.lookup (Class_Data.get thy) classname

fun
replace_superclasses lthy (s $t) = replace_superclasses lthy s$ replace_superclasses lthy t |
replace_superclasses lthy (Const (n,T)) =
let
val is_class = Long_Name.base_name n
val class = Syntax.parse_sort lthy is_class handle ERROR _ => []
in
if null class then Const (n,T) else
let
val class_data = get_class_info (Proof_Context.theory_of lthy) (hd class)
in
if is_some class_data then the (#class_law_const (the class_data)) else Const (n,T)
end
end |
replace_superclasses _ t = t

fun
contains_axioms cn (s $t) = contains_axioms cn s orelse contains_axioms cn t | contains_axioms cn (Const (n,_)) = let val is_class = Long_Name.base_name n in if is_class = cn ^ "_axioms" then true else false end | contains_axioms _ _ = false fun define_class_law classname lthy = let val class_def = Proof_Context.get_thm lthy ("class." ^ classname ^ "_def") val has_axioms = contains_axioms classname (class_def |> Thm.full_prop_of |> Logic.unvarify_global |> Logic.dest_equals |> snd) val (axioms_def,(vars,class_law)) = class_def |> (if has_axioms then let val axioms_def = Proof_Context.get_thm lthy ("class." ^ classname ^ "_axioms_def") in Local_Defs.unfold lthy [axioms_def] #> pair (SOME axioms_def) end else (pair NONE)) ||> Thm.full_prop_of ||> Logic.unvarify_global ||> Logic.dest_equals ||> apfst (strip_comb #> snd) ||> apsnd (replace_superclasses lthy) val class_law_name = classname ^ "_class_law" val class_law_lhs = list_comb ((Free (class_law_name,(map (dest_Free #> snd) vars) ---> \<^typ>‹bool›)),vars) val class_law_eq = HOLogic.Trueprop$ HOLogic.mk_eq (class_law_lhs,class_law)
val ((_,(_,class_law_thm)),lthy') = Specification.definition NONE [] [] ((Binding.empty, []), class_law_eq) lthy

val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy')
val class_law_thm_export = singleton (Proof_Context.export lthy' ctxt_thy) class_law_thm
in
(class_law_thm_export,class_def,axioms_def,vars,lthy')
end

fun transfer_op lthy from to var =
let
fun convert_arg (T,i) =
case T of (TFree (_,_)) => from $(Bound i) | _ => Bound i fun abstract [] t = t | abstract (x::xs) t = (Abs (x, dummyT, abstract xs t)) val (v,T) = dest_Free var val (binders,body) = strip_type T val argnames = Name.invent_names (Variable.names_of lthy) "x" binders |> map fst val args_converted = map convert_arg (binders ~~ (List.tabulate (length binders,fn n => (length binders)-(n+1)))) val op_call = list_comb ((Free (v,T)),args_converted) val op_converted = case body of (TFree (_,_)) => to$ op_call |
_ => op_call
in
abstract argnames op_converted
end

fun prove_class_transfer classname thy =
let
fun add_info info thy = Class_Data.put (Symtab.update ((#classname info),info) (Class_Data.get thy)) thy
val class = Syntax.parse_sort (Proof_Context.init_global thy) classname
val classname_full = hd class
val axioms = Axclass.get_info thy classname_full |> #axioms
val (class_law,class_def,axioms_def,vars,lthy) = define_class_law classname (Named_Target.theory_init thy)
(* Exit so that class law is defined properly before the next step
FIXME use begin / end block instead (?) *)
val thy' = Local_Theory.exit_global lthy
val lthy' = Named_Target.theory_init thy'

val tfree_dt = get_tvar (map (dest_Free #> snd) vars)
val tfree_rep = let val (n,s) = tfree_dt |> dest_TFree in Name.invent_names (Name.make_context [n]) "'a" [s] end |> hd |> TFree
val from = Free ("from",tfree_rep --> tfree_dt)
val to   = Free ("to",tfree_dt --> tfree_rep)

val class_law_const = Thm.full_prop_of class_law |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> fst |> strip_comb |> fst
val class_law_const_dummy = dest_Const class_law_const |> apsnd (K dummyT) |> Const
val class_law_var = (Term.add_tvars class_law_const []) |> hd |> fst
val class_law_const_dt = subst_vars ([(class_law_var,tfree_dt)],[]) class_law_const
val class_law_const_rep = subst_vars ([(class_law_var,tfree_rep)],[]) class_law_const

val assm_iso = HOLogic.mk_Trueprop (Const (\<^const_name>‹Derive.iso›,dummyT) $from$ to)
val assm_class = HOLogic.mk_Trueprop (list_comb (class_law_const_dt,vars))
val vars_transfer = map (transfer_op lthy' from to) vars
val transfer_concl = HOLogic.mk_Trueprop (list_comb (class_law_const_rep,vars_transfer))
val transfer_term = Logic.mk_implies (assm_iso, (Logic.mk_implies (assm_class, transfer_concl)))
val transfer_term_inf = Type_Infer_Context.infer_types lthy' [transfer_term] |> hd

fun after_qed thms lthy =
(fold_map (fn lthy => fn thm => (Local_Theory.note ((Binding.name (classname ^ "_transfer"),[]), lthy) thm))
thms lthy)
|> (fn (thms,lthy) =>
Local_Theory.background_theory
(add_info {classname = classname_full, class = class, params = NONE, class_law = SOME class_law, class_law_const = SOME class_law_const_dummy, ops = SOME vars, transfer_law = SOME thms, axioms = SOME axioms, axioms_def = axioms_def, class_def = SOME class_def, equivalence_thm = NONE})
lthy)
|> Local_Theory.exit
in
Proof.theorem NONE after_qed [[(transfer_term_inf, [])]] lthy'
end

val _ =
Outer_Syntax.command \<^command_keyword>‹derive_generic_setup› "prepare a class for derivation"
(Parse.name >> (fn c =>
Toplevel.theory_to_proof (fn thy => if has_class_law c thy
then prove_class_transfer c thy
else error ("Class " ^ c ^ " has no associated laws, no need to call derive_setup"))))

end

# File ‹derive.ML›

open Derive_Util

signature DERIVE =
sig
(* Adds functions that convert to and from a product-sum representation *)
val define_prod_sum_conv : type_info -> bool -> Proof.context -> (Function_Common.info * Function_Common.info * Proof.context)
(* define product-sum-representation type synonym *)
val define_rep_type : string list -> ctr_info -> bool -> local_theory -> rep_type_info * local_theory
(* define Mu-combinator type *)
val define_combinator_type : string list -> (typ * class list) list -> rep_type_info
-> local_theory -> comb_type_info option * local_theory
(* instantiate a typeclass *)
val generate_instance : string -> sort -> bool -> theory -> Proof.state

val add_inst_info : string -> string -> thm list -> theory -> theory
end

structure Derive : DERIVE =
struct

fun get_type_info thy tname constr_names =
Symreltab.lookup (Type_Data.get thy) (tname, Bool.toString constr_names)
fun get_class_info thy classname = Symtab.lookup (Class_Data.get thy) classname
fun get_inst_info thy classname tname = Symreltab.lookup (Instance_Data.get thy) (classname, tname)
fun add_params_cl_info (cl_info : class_info) params =
{classname = (#classname cl_info), class = (#class cl_info), params = SOME params, class_law = (#class_law cl_info), class_law_const = (#class_law_const cl_info), ops = (#ops cl_info), transfer_law = (#transfer_law cl_info), axioms = (#axioms cl_info), axioms_def = (#axioms_def cl_info), class_def = (#class_def cl_info), equivalence_thm = (#equivalence_thm cl_info)}
fun mk_inst_info defs =
{defs = defs}
fun add_equivalence_cl_info (cl_info : class_info) equivalence_thm =
{classname = (#classname cl_info), class = (#class cl_info), params = (#params cl_info), class_law = (#class_law cl_info), class_law_const = (#class_law_const cl_info), ops = (#ops cl_info), transfer_law = (#transfer_law cl_info), axioms = (#axioms cl_info), axioms_def = (#axioms_def cl_info), class_def = (#class_def cl_info), equivalence_thm = SOME equivalence_thm}

fun make_rep T lthy conv_func (btype,bname) =
let
val term =
case btype of
(TFree _) => conv_func $(Free (bname, T)) | (Type (_, _)) => if is_polymorphic btype then (get_mapping_function lthy btype)$ conv_func $(Free (bname, dummyT)) else Free (bname, btype) | _ => Free (bname, btype) in hd (Type_Infer_Context.infer_types lthy [term]) end fun from_rep T lthy conv_func inner_term = let val term = case T of (TFree _) => conv_func$ inner_term |
(Type (_, _)) => if is_polymorphic T then (get_mapping_function lthy T) $conv_func$ inner_term
else inner_term |
_ => inner_term
in hd (Type_Infer_Context.infer_types lthy [term])
end

(* Generate instance for Mu-combinator type *)
fun instantiate_combinator_type (ty_info : type_info) (cl_info : class_info) constr_names lthy =
let
val tname = #tname ty_info
val ctr_type =  #rep_type_instantiated (the (#comb_info ty_info))
val rep_type = #rep_type (#rep_info ty_info)
val comb_type_name = #combname (the (#comb_info ty_info))
val comb_type_name_full = #combname_full (the (#comb_info ty_info))
val class = #class cl_info
val params = the (#params cl_info)

val sum_type_name = if constr_names then \<^type_name>‹Tagged_Prod_Sum.sum› else \<^type_name>‹Sum_Type.sum›
val prod_type_name = if constr_names then \<^type_name>‹Tagged_Prod_Sum.prod› else \<^type_name>‹Product_Type.prod›

val _ = ("Generating instance for type " ^ quote comb_type_name) |> writeln

fun define_modular_sum_prod def vars opname opT opT_var is_sum lthy =
let
fun get_tvars vars =
if null vars
then ((TVar (("'a",0),class)),(TVar (("'b",0),class)))
else
let
val var = hd vars
val varTname = if is_typeT var then var |> dest_Type |> fst else ""
in
if is_sum then
(if varTname = sum_type_name
then var |> dest_Type |> snd |> (fn Ts => (hd Ts, hd (tl Ts)))
else get_tvars ((var |> dest_Type |> snd)@(tl vars)))
else
(if varTname = prod_type_name
then var |> dest_Type |> snd |> (fn Ts => (hd Ts, hd (tl Ts)))
else get_tvars ((var |> dest_Type |> snd)@(tl vars)))
end
fun replace_tfree tfree replacement T =
if T = tfree then replacement else T
fun replace_op_call opname T replacement t =
if t = Const (opname,T) then replacement else t
fun is_TFree (TFree _) = true | is_TFree _ = false
fun remove_constraints T = if is_TFree T then dest_TFree T |> apsnd (K \<^sort>‹type›) |> TFree else T
val varTs = map (dest_Var #> snd) vars
val (left,right) = get_tvars (varTs @ [strip_type opT_var |> snd])
val op_tfree = Term.add_tfreesT opT [] |> hd |> TFree
val left_opT = map_atyps (replace_tfree op_tfree left) opT
val right_opT = map_atyps (replace_tfree op_tfree right) opT
val opname_left = (Long_Name.base_name opname) ^ "_left"
val opname_right = (Long_Name.base_name opname) ^ "_right"
val var_left = (Var ((opname_left,0),left_opT))
val var_right = (Var ((opname_right,0),right_opT))
val def_left = map_aterms (replace_op_call opname left_opT var_left) def
val def_right = map_aterms (replace_op_call opname right_opT var_right) def_left
val return_type = strip_type opT_var |> snd
val def_name = (Long_Name.base_name opname) ^ (if is_sum then "_sum_modular" else "_prod_modular")
val eq_head = Free (def_name, ([left_opT,right_opT]@varTs) ---> return_type)
|> Logic.unvarify_types_global
val args = map Logic.unvarify_global ([var_left,var_right] @ vars)
val eq = HOLogic.Trueprop $HOLogic.mk_eq ((list_comb (eq_head,args)), Logic.unvarify_global def_right) val eq' = map_types (map_atyps remove_constraints) eq val ((_,(_,def_thm)),lthy') = Specification.definition NONE [] [] ((Binding.empty, []), eq') lthy val left = left |> dest_TVar |> (fn ((s,_),_) => TFree (s,\<^sort>‹type›)) val right = right |> dest_TVar |> (fn ((s,_),_) => TFree (s,\<^sort>‹type›)) val def_const = Thm.hyps_of def_thm |> hd |> Logic.dest_equals |> snd val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy') val def_thm = singleton (Proof_Context.export lthy' ctxt_thy) def_thm in (left,right,def_const,def_thm,lthy') end fun op_instance opname opT T = let fun replace_tfree tfree replacement T = if T = tfree then replacement else T val op_tfree = Term.add_tfreesT opT [] |> hd |> TFree val opT_new = map_atyps (replace_tfree op_tfree T) opT in Const (opname,opT_new) end fun sum_prod_instance term lvar rvar lT rT = let fun replace_vars var1 var2 replacement1 replacement2 T = if T = var1 then replacement1 else if T = var2 then replacement2 else T in map_types (map_atyps (replace_vars lvar rvar lT rT)) term end fun define_modular_instance T (lvs,rvs,sum_term) (lvp,rvp,prod_term) opname opT = if is_typeT T then let val (tname,Ts) = dest_Type T in if tname = sum_type_name then (sum_prod_instance sum_term lvs rvs (hd Ts) (hd (tl Ts)))$ (define_modular_instance (hd Ts) (lvs,rvs,sum_term) (lvp,rvp,prod_term) opname opT)
$(define_modular_instance (hd (tl Ts)) (lvs,rvs,sum_term) (lvp,rvp,prod_term) opname opT) else if tname = prod_type_name then (sum_prod_instance prod_term lvp rvp (hd Ts) (hd (tl Ts)))$ (define_modular_instance (hd Ts) (lvs,rvs,sum_term) (lvp,rvp,prod_term) opname opT)
$(define_modular_instance (hd (tl Ts)) (lvs,rvs,sum_term) (lvp,rvp,prod_term) opname opT) else (op_instance opname opT T) end else (op_instance opname opT T) fun change_constraints constrT term = let val constraint = Term.add_tfreesT constrT [] |> hd |> snd val unconstr_tfrees = Term.add_tfrees term [] |> map TFree val constr_names = Name.invent_names (Variable.names_of lthy) "a" (replicate (length unconstr_tfrees) constraint) val constr_tfrees = map TFree constr_names in subst_atomic_types (unconstr_tfrees ~~ constr_tfrees) term end fun define_operation_rec T (opname,t) lthy = let fun get_comb_params [] = [] | get_comb_params (ty::tys) = (case ty of Type (n,tys') => if n = comb_type_name_full then tys' else get_comb_params (tys@tys') | _ => get_comb_params tys) fun make_arg inConst ctr_type x = let val T = dest_Free x |> snd in if T = ctr_type then inConst$ x else (x |> dest_Free |> apsnd (K dummyT) |> Free)
end

val short_opname = Long_Name.base_name opname
val fun_name = short_opname ^ "_" ^ comb_type_name
val prod_def_name = short_opname ^ "_prod_def"
val prod_thm = Proof_Context.get_thm lthy prod_def_name
val prod_hd = (Thm.full_prop_of prod_thm) |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> fst
val prod_const = prod_hd |> combs_to_list |> hd
val prod_opT = prod_const |> dest_Const |> snd
val prod_vars_raw = prod_hd |> combs_to_list |> tl
val prod_def = (Thm.full_prop_of prod_thm) |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> snd
val (lvp,rvp,prod_def_term,prod_def_thm,lthy') = define_modular_sum_prod prod_def prod_vars_raw opname t prod_opT false lthy

val sum_def_name = short_opname ^ "_sum_def"
val sum_thm = Proof_Context.get_thm lthy sum_def_name
val sum_hd = (Thm.full_prop_of sum_thm) |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> fst
val sum_const = sum_hd |> combs_to_list |> hd
val sum_opT = sum_const |> dest_Const |> snd
val sum_vars_raw = sum_hd |> combs_to_list |> tl
val sum_def = (Thm.full_prop_of sum_thm) |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> snd
val (lvs,rvs,sum_def_term,sum_def_thm,lthy'')  = define_modular_sum_prod sum_def sum_vars_raw opname t sum_opT true lthy'

val (binders, body) = (strip_type t)
val tvar = get_tvar (body :: binders)
val comb_params = get_comb_params [ctr_type]
val T_params = dest_Type T |> snd
val ctr_type' = typ_subst_atomic (comb_params ~~ T_params) ctr_type
val body' = typ_subst_atomic [(tvar,dummyT)] body
val binders' = map (typ_subst_atomic [(tvar,ctr_type')]) binders
val opT = (replicate (length binders) dummyT) ---> body'
val vars = (Name.invent_names (Variable.names_of lthy) "x" binders')
val xs = map Free vars
val inConst_name = Long_Name.qualify comb_type_name_full "In"
val inConst = Const (inConst_name, ctr_type' --> dummyT)
val xs_lhs = map (make_arg inConst ctr_type') xs

val modular_instance = define_modular_instance rep_type (lvs,rvs,sum_def_term) (lvp,rvp,prod_def_term) opname t
val xs_modular = map (apsnd (K dummyT) #> Free) vars
val modular_folded_name = short_opname ^ "_modular_folded"
val modularT = typ_subst_atomic [(tvar,rep_type)] t
val modular_instance_eq = HOLogic.Trueprop $HOLogic.mk_eq (Free (modular_folded_name, modularT),modular_instance) val modular_eq_constr = change_constraints t modular_instance_eq val ((_,(_,modular_folded_thm)),lthy''') = Specification.definition NONE [] [] ((Binding.empty, []), modular_eq_constr) lthy'' val modular_unfolded = Local_Defs.unfold lthy''' [prod_def_thm,sum_def_thm] modular_folded_thm |> Thm.full_prop_of |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> snd |> map_types (K dummyT) val lhs = list_comb (Free (fun_name, opT), xs_lhs) val rhs = if is_polymorphic body then inConst$ list_comb (modular_unfolded,xs_modular)
else list_comb (modular_unfolded,xs_modular)
val eq = HOLogic.Trueprop $HOLogic.mk_eq (lhs, rhs) in if xs = [] then Specification.definition NONE [] [] ((Binding.empty, []), eq) lthy |> snd else if constr_names then Function.add_function [(Binding.name fun_name, NONE, NoSyn)] [((Binding.empty_atts, eq), [], [])] Function_Fun.fun_config (fn ctxt => Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt) lthy |> snd |> tagged_function_termination_tac |> snd else Function_Fun.add_fun [(Binding.name fun_name, NONE, NoSyn)] [((Binding.empty_atts, eq), [], [])] Function_Fun.fun_config lthy end fun instantiate_comb_type lthy = let val thy = Local_Theory.exit_global lthy val (T,xs) = Derive_Util.typ_and_vs_of_typname thy comb_type_name_full class val filtered_params = params |> filter (fn (c,_) => not_instantiated thy comb_type_name_full c andalso not_instantiated thy tname c) fun define_operations_rec_aux _ [] lthy = lthy | define_operations_rec_aux ty (p::ps) lthy = define_operations_rec_aux ty ps (define_operation_rec ty p lthy) in Class.instantiation ([comb_type_name_full], xs, class) thy |> (define_operations_rec_aux T (map snd filtered_params)) end in instantiate_comb_type lthy end fun define_operation lthy tname T (opname,t) = let val from_name = "from_" ^ Long_Name.base_name tname val from_term = (Proof_Context.read_const {proper = true, strict = true} lthy from_name) |> dest_Const |> fst val from_func = Const (from_term, dummyT) val to_name = "to_" ^ Long_Name.base_name tname val to_term = (Proof_Context.read_const {proper = true, strict = true} lthy to_name) |> dest_Const |> fst val to_func = Const (to_term, dummyT) val (binders, body) = (strip_type t) val tvar = get_tvar (body :: binders) val body' = typ_subst_atomic [(tvar,dummyT)] body val binders' = map (typ_subst_atomic [(tvar,T)]) binders val newT = binders' ---> body' val vars = (Name.invent_names (Variable.names_of lthy) "x" binders') val names = map fst vars val xs = map Free vars val lhs = list_comb (Const (opname, newT), xs) val rhs_inner = list_comb (Const (opname, dummyT), map (make_rep T lthy from_func) (binders ~~ names)) val rhs = from_rep body lthy to_func rhs_inner val eq = HOLogic.Trueprop$ HOLogic.mk_eq (lhs, rhs)
val ((_,(_,thm)),lthy') = (Specification.definition NONE [] [] ((Binding.empty, []), eq) lthy)
in
(thm,lthy')
end

fun define_operations ps ty lthy =
let
fun
define_operations_aux [] _ thms lthy = (thms,lthy) |
define_operations_aux (p::ps) (tname,T) thms lthy =
if not_instantiated (Proof_Context.theory_of lthy) tname (fst p)
then
let
val (thm,lthy') = define_operation lthy tname T (snd p)
in define_operations_aux ps (tname,T) (thm :: thms) lthy'
end
else
let
val params = #defs (the_default {defs=[]} (get_inst_info (Proof_Context.theory_of lthy) (fst p) tname))
val names = map (Thm.full_prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq #> fst #> strip_comb #> fst #> dest_Const #> fst) params
val thm = the (AList.lookup (op =) (names ~~ params) (fst (snd p)))
in define_operations_aux ps (tname,T) (thm :: thms) lthy
end
in
define_operations_aux ps ty [] lthy
end

fun abstract_over_vars vars t =
let
val varnames = map dest_Free vars |> map fst
val varmapping = varnames ~~ (List.tabulate (length vars, I))
val increment_bounds = map (fn (v,n) => (v,n + 1))
fun
insert_bounds varmapping (Free (s,T)) =
(case AList.lookup (op =) varmapping s of
NONE => Free (s,T) |
SOME i => Bound i) |
insert_bounds varmapping (Abs (x,T,t)) = Abs (x,T,insert_bounds (increment_bounds varmapping) t) |
insert_bounds varmapping (s $t) = (insert_bounds varmapping s)$ (insert_bounds varmapping t) |
insert_bounds _ t = t
fun
abstract [] t = insert_bounds varmapping t |
abstract (x::xs) t = (Const (\<^const_name>‹Pure.all›,dummyT)) $(Abs (x, dummyT, abstract xs t)) in if null vars then t else (abstract varnames t) end (* Adds functions that convert a type to and from its product-sum representation *) fun define_prod_sum_conv (ty_info : type_info) constr_names lthy = let val tnames = #mutual_tnames ty_info val Ts = #mutual_Ts ty_info val ctrs = #mutual_ctrs ty_info val sels = #mutual_sels ty_info val is_recursive = #is_rec ty_info val is_mutually_recursive = #is_mutually_rec ty_info val _ = map (fn tyco => "Generating conversions for type " ^ quote tyco) tnames |> cat_lines |> writeln (* Functions to deal with tagged products and sums *) val str_optT = \<^typ>‹string option› val none_str_opt = Const (\<^const_name>‹None›, str_optT) fun some_str s = Const (\<^const_name>‹Some›, \<^typ>‹string› --> str_optT)$ HOLogic.mk_string s
val dummy_str_opt = (Term.dummy_pattern \<^typ>‹string option›)
val sum_type_name = if constr_names then \<^type_name>‹Tagged_Prod_Sum.sum› else \<^type_name>‹Sum_Type.sum›
val prod_type_name = if constr_names then \<^type_name>‹Tagged_Prod_Sum.prod› else \<^type_name>‹Product_Type.prod›
val prod_constr_name = if constr_names then \<^const_name>‹Tagged_Prod_Sum.Prod› else \<^const_name>‹Product_Type.Pair›
val inl_constr_name = if constr_names then \<^const_name>‹Tagged_Prod_Sum.Inl› else \<^const_name>‹Sum_Type.Inl›
val inr_constr_name = if constr_names then \<^const_name>‹Tagged_Prod_Sum.Inr› else \<^const_name>‹Sum_Type.Inr›
fun mk_tagged_prodT (T1, T2) = Type (prod_type_name, [T1, T2])
fun mk_tagged_sumT LT RT = Type (sum_type_name, [LT, RT])
fun tagged_pair_const T1 T2 = Const (prod_constr_name, str_optT --> str_optT --> T1 --> T2 --> mk_tagged_prodT (T1, T2));
fun mk_tagged_prod ((t1,s1), (t2,s2)) =
let val T1 = fastype_of t1
val T2 = fastype_of t2
val S1 = if s1 = "" then none_str_opt else some_str s1
val S2 = if s2 = "" then none_str_opt else some_str s2
in
(tagged_pair_const T1 T2 $S1$ S2 $t1$ t2,"")
end
fun mk_tagged_prod_dummy (t1, t2) =
let val T1 = fastype_of t1 and T2 = fastype_of t2 in
tagged_pair_const T1 T2 $dummy_str_opt$ dummy_str_opt $t1$ t2
end
fun Inl_const LT RT = if constr_names then Const (inl_constr_name, str_optT --> LT --> mk_tagged_sumT LT RT) else BNF_FP_Util.Inl_const LT RT
fun mk_tagged_Inl n RT t = Inl_const (fastype_of t) RT $n$ t
fun Inr_const LT RT = if constr_names then Const (inr_constr_name, str_optT --> RT --> mk_tagged_sumT LT RT) else BNF_FP_Util.Inr_const LT RT
fun mk_tagged_tuple _ [] = HOLogic.unit
| mk_tagged_tuple sels ts = fst (foldr1 mk_tagged_prod (ts ~~ sels))
fun mk_tagged_tuple_dummy [] = HOLogic.unit
| mk_tagged_tuple_dummy ts = foldr1 mk_tagged_prod_dummy ts
fun add_dummy_patterns (c $_) = c$ dummy_str_opt |

(* simple version for non-recursive types *)
fun generate_conversion_eqs lthy prefix ((tyco,ctrs),T) sels =
let
fun
mk_prod_listT [] = HOLogic.unitT |
mk_prod_listT [x] = x |
mk_prod_listT (x::xs) = mk_tagged_prodT (x, (mk_prod_listT xs))
fun generate_sum_prodT [] = HOLogic.unitT |
generate_sum_prodT [x] = mk_prod_listT x |
generate_sum_prodT (x::xs) =
let
val l = mk_prod_listT x
val r = generate_sum_prodT xs
in
mk_tagged_sumT l r
end

fun generate_conversion_eq lthy prefix (cN, Ts) sels tail_ctrs =
let
val c = Const (cN, Ts ---> T)
val sels = if null sels then replicate (length Ts) "" else sels
val cN_opt = Const (\<^const_name>‹Some›, \<^typ>‹string› --> str_optT) $HOLogic.mk_string (Long_Name.base_name cN) val xs = map Free (Name.invent_names (Variable.names_of lthy) "x" Ts) val conv_inner = case tail_ctrs of [] => if constr_names then mk_tagged_tuple sels xs else HOLogic.mk_tuple xs | _ => if constr_names then mk_tagged_Inl cN_opt (generate_sum_prodT tail_ctrs) (mk_tagged_tuple sels xs) else BNF_FP_Util.mk_Inl (generate_sum_prodT tail_ctrs) (HOLogic.mk_tuple xs) val prefix = case tail_ctrs of [] => if constr_names andalso (not (HOLogic.is_unit (hd prefix))) then (let val (butlast,last) = split_last prefix in butlast @ [last |> dest_comb |> fst |> (fn t => t$ cN_opt)]
end)
else prefix |
_ => prefix
val conv = if HOLogic.is_unit (hd prefix)
then conv_inner
else Library.foldr (op $) (prefix,conv_inner) val conv_dummy = if constr_names then (let val conv_inner = case tail_ctrs of [] => mk_tagged_tuple sels xs| _ => mk_tagged_Inl (Term.dummy_pattern \<^typ>‹string option›) (generate_sum_prodT tail_ctrs) (mk_tagged_tuple_dummy xs) in (if HOLogic.is_unit (hd prefix) then conv_inner else Library.foldr (op$) ((map add_dummy_patterns prefix), conv_inner))
end)
else conv
val lhs_from = (Free ("from_" ^ (Long_Name.base_name tyco), dummyT)) $list_comb (c, xs) val lhs_to = (Free ("to_" ^ (Long_Name.base_name tyco), dummyT))$ conv_dummy
in (abstract_over_vars xs (HOLogic.Trueprop $((HOLogic.eq_const dummyT)$ lhs_from $conv)), abstract_over_vars xs (HOLogic.Trueprop$ ((HOLogic.eq_const dummyT) $lhs_to$ list_comb(c,xs))))
end
in
case ctrs of
[] => ([],[]) |
(c::cs) =>
let
val s = if null sels then replicate (length (snd c)) "" else hd sels
val ss = if null sels then [] else tl sels
val new_prefix =
if HOLogic.is_unit (hd prefix)
then
if constr_names
then [(Inr_const (generate_sum_prodT [(snd c)]) (generate_sum_prodT (map snd cs))) $Const (\<^const_name>‹None›, str_optT)] else [(BNF_FP_Util.Inr_const (generate_sum_prodT [(snd c)]) (generate_sum_prodT (map snd cs)))] else prefix @ (if constr_names then [(Inr_const (generate_sum_prodT [(snd c)]) (generate_sum_prodT (map snd cs)))$ Const (\<^const_name>‹None›, str_optT)]
else [(BNF_FP_Util.Inr_const (generate_sum_prodT [(snd c)]) (generate_sum_prodT (map snd cs)))])
val (from_eq,to_eq) = generate_conversion_eq lthy prefix c s (map snd cs)
val (from_eqs,to_eqs) = generate_conversion_eqs lthy new_prefix ((tyco,cs),T) ss
in
(from_eq :: from_eqs, to_eq :: to_eqs)
end
end

(* version for recursive types *)
fun generate_conversion_eqs_rec lthy Ts comb_type (((tyco,ctrs),T),prefix) sels =
let
fun replace_type_with_comb T =
case List.find (fn x => x = T) Ts of
NONE => T |
_    => comb_type
fun mk_prod_listT [] = HOLogic.unitT |
mk_prod_listT [x] = replace_type_with_comb x |
mk_prod_listT (x::xs) = mk_tagged_prodT (replace_type_with_comb x, (mk_prod_listT xs))
fun generate_sum_prodT [] = HOLogic.unitT |
generate_sum_prodT [x] = mk_prod_listT x |
generate_sum_prodT (x::xs) =
let
val l = mk_prod_listT x
val r = generate_sum_prodT xs
in
mk_tagged_sumT l r
end

fun generate_conversion_eq lthy prefix (cN, Ts) sels tail_ctrs =
let
fun get_type_name T = case T of Type (s,_) => s | _ => ""
fun find_type T tycos = List.find (fn x => x = (get_type_name T) andalso x <> "") tycos
fun mk_comb_type v = Free (v |> dest_Free |> fst,comb_type)
fun from_const tyco = Free ("from_" ^ (Long_Name.base_name tyco), dummyT --> dummyT)
fun to_const tyco = Free ("to_" ^ (Long_Name.base_name tyco), dummyT --> dummyT)
val c = Const (cN, Ts ---> T)
val cN_opt = Const (\<^const_name>‹Some›, \<^typ>‹string› --> str_optT) $HOLogic.mk_string (Long_Name.base_name cN) val xs = map Free (Name.invent_names (Variable.names_of lthy) "x" Ts) val xs_from = map (fn (v,t) => case find_type t tnames of NONE => v | _ => (from_const (get_type_name t))$ v)
(xs ~~ Ts)
val xs_to   = map (fn (v,t) => case find_type t tnames
of NONE => v |
_    => (to_const (get_type_name t)) $mk_comb_type v) (xs ~~ Ts) val xs_to' = map (fn (v,t) => case find_type t tnames of NONE => v | _ => mk_comb_type v) (xs ~~ Ts) val prefix = case tail_ctrs of [] => if constr_names andalso (not (HOLogic.is_unit (hd prefix))) then (let val (butlast,last) = split_last prefix in butlast @ [last |> dest_comb |> fst |> (fn t => t$ cN_opt)]
end)
else prefix |
_ => prefix

val conv_inner_from = case tail_ctrs of
[] => if constr_names then mk_tagged_tuple sels xs_from else HOLogic.mk_tuple xs_from |
_ => if constr_names then mk_tagged_Inl cN_opt (generate_sum_prodT tail_ctrs) (mk_tagged_tuple sels xs_from)
else BNF_FP_Util.mk_Inl (generate_sum_prodT tail_ctrs) (HOLogic.mk_tuple xs_from)
val conv_inner_to = case tail_ctrs of
[] => if constr_names then mk_tagged_tuple_dummy xs_to' else HOLogic.mk_tuple xs_to' |
_ => if constr_names then mk_tagged_Inl dummy_str_opt (generate_sum_prodT tail_ctrs) (mk_tagged_tuple_dummy xs_to')
else BNF_FP_Util.mk_Inl (generate_sum_prodT tail_ctrs) (HOLogic.mk_tuple xs_to')
val conv_from = if HOLogic.is_unit (hd prefix)
then conv_inner_from
else Library.foldr (op $) (prefix,conv_inner_from) val conv_to = if HOLogic.is_unit (hd prefix) then conv_inner_to else Library.foldr (op$) (prefix,conv_inner_to)
val conv_dummy = if constr_names
then (let val conv_inner =
case tail_ctrs of
[] => mk_tagged_tuple_dummy xs_to' |
_ => mk_tagged_Inl dummy_str_opt (generate_sum_prodT tail_ctrs) (mk_tagged_tuple_dummy xs_to')
in (if HOLogic.is_unit (hd prefix) then conv_inner
else Library.foldr (op $) ((map add_dummy_patterns prefix), conv_inner)) end) else conv_to val lhs_from = (from_const tyco)$ list_comb (c, xs)
val lhs_to = (to_const tyco) $conv_dummy in (abstract_over_vars xs (HOLogic.Trueprop$ ((HOLogic.eq_const dummyT) $lhs_from$ conv_from)),
abstract_over_vars xs (HOLogic.Trueprop $((HOLogic.eq_const dummyT)$ lhs_to $list_comb (c,xs_to)))) end in case ctrs of [] => ([],[]) | (c::cs) => let val s = if null sels then replicate (length (snd c)) "" else hd sels val ss = if null sels then [] else tl sels val new_prefix = if HOLogic.is_unit (hd prefix) then if constr_names then [(Inr_const (generate_sum_prodT [(snd c)]) (generate_sum_prodT (map snd cs)))$ Const (\<^const_name>‹None›, str_optT)]
else [(BNF_FP_Util.Inr_const (generate_sum_prodT [(snd c)]) (generate_sum_prodT (map snd cs)))]
else prefix @
(if constr_names
then [(Inr_const (generate_sum_prodT [(snd c)]) (generate_sum_prodT (map snd cs))) $Const (\<^const_name>‹None›, str_optT)] else [(BNF_FP_Util.Inr_const (generate_sum_prodT [(snd c)]) (generate_sum_prodT (map snd cs)))]) val (from_eq,to_eq) = generate_conversion_eq lthy prefix c s (map snd cs) val (from_eqs,to_eqs) = generate_conversion_eqs_rec lthy Ts comb_type (((tyco,cs),T),new_prefix) ss in (from_eq :: from_eqs, to_eq :: to_eqs) end end fun generate_mutual_prefixes inConst Ts mutual_rep_types = let fun generate_sumT [] = HOLogic.unitT | generate_sumT [x] = x | generate_sumT (x::xs) = mk_tagged_sumT x (generate_sumT xs) fun generate_mutual_prefix rep_types index = let fun generate_mutual_prefix_aux rep_types index = case index of 0 => if constr_names then [(Inl_const (hd rep_types) (generate_sumT (tl rep_types)))$ none_str_opt]
else [Inl_const (hd rep_types) (generate_sumT (tl rep_types))] |
n => let
val inr = if constr_names
then (Inr_const (hd rep_types) (generate_sumT (tl rep_types))) \$ none_str_opt
else Inr_const (hd rep_types) (generate_sumT (tl rep_types))
in
if (length rep_types) > 2 then inr :: (generate_mutual_prefix_aux (tl rep_types) (n-1))
else [inr]
end
in
inConst :: (generate_mutual_prefix_aux rep_types index)
end

val indices = List.tabulate (length Ts, fn x => x)
in
(map (generate_mutual_prefix mutual_rep_types) indices)
end

let
val eqs =
if is_recursive
then
let
fun get_mutual_rep_types ty n =
if n = 1 then [ty] else
case ty of
Type (tname, [LT, RT]) => if tname = sum_type_name then LT :: (get_mutual_rep_types RT (n-1)) else [Type (tname, [LT, RT])] |
T => [T]
val comb_type = #comb_type (the (#comb_info ty_info))
val inConst_free = #inConst_free (the (#comb_info ty_info))
val rep_type_inst = #rep_type_instantiated (the (#comb_info ty_info))
val mutual_rep_types = if is_mutually_recursive then get_mutual_rep_types rep_type_inst (length Ts)
else [rep_type_inst]
val prefixes = if is_mutually_recursive then generate_mutual_prefixes inConst_free Ts mutual_rep_types
else replicate (length Ts) [inConst_free]
in
map2 (generate_conversion_eqs_rec lthy Ts comb_type)  ((ctrs ~~ Ts) ~~ prefixes) (map snd sels)
end
else
map2 (generate_conversion_eqs lthy [HOLogic.unit]) (ctrs ~~ Ts) (map snd sels)
val from_eqs = flat (map fst eqs)
val to_eqs = flat (map snd eqs)
val (from_info,lthy') =
(map (fn tname => (Binding.name ("from_" ^ Long_Name.base_name tname), NONE, NoSyn)) tnames)
(map (fn t => ((Binding.empty_atts, t), [], []))
from_eqs)
Function_Fun.fun_config
lthy
val (to_info,lthy'') =
(map (fn tname => (Binding.name ("to_" ^ Long_Name.base_name tname), NONE, NoSyn)) tnames)
(map (fn t => ((Binding.empty_atts, t), [], []))
to_eqs)
Function_Fun.fun_config
lthy'
in
(from_info,to_info,lthy'')
end

in
end

fun define_rep_type tnames ctrs constr_names lthy =
let
val sum_type_name = if constr_names then \<^type_name>‹Tagged_Prod_Sum.sum› else \<^type_name>‹Sum_Type.sum›
val prod_type_name = if constr_names then \<^type_name>‹Tagged_Prod_Sum.prod› else \<^type_name>‹Product_Type.prod›
fun collect_tfree_names ctrs =
fold Term.add_tfree_namesT (ctrs |> map (fn l => map snd (snd l)) |> flat |> flat) []

val tFree_renaming =
let
val used_tfrees = collect_tfree_names ctrs
val ctxt = Name.make_context used_tfrees
val ts = map Type ((map fst ctrs) ~~ (replicate (length ctrs) []))
val names = Name.invent_names ctxt "'a" ts |> map fst
in
ts ~~ (map TFree (names ~~ (replicate (length names) \<^sort>‹type›)))
end

fun replace_types_tvars recTs T  =
(case T of Type  (s,_) => perhaps (AList.lookup (op =) recTs) (Type (s, []))
| _ => T)

fun mk_tagged_prodT (T1, T2) = Type (prod_type_name, [T1, T2])
fun mk_tagged_sumT LT RT = Type (sum_type_name, [LT, RT])

val rep_type =
let
val ctrs' = (ctrs |> map (fn l => map snd (snd l)))
fun
mk_prodT [] = HOLogic.unitT |
mk_prodT [x] = replace_types_tvars tFree_renaming x |
mk_prodT (x::xs) = mk_tagged_prodT (replace_types_tvars tFree_renaming x, (mk_prodT xs))
fun
mk_sumT [] = HOLogic.unitT |
mk_sumT [x] = x |
mk_sumT (x::xs) = mk_tagged_sumT x (mk_sumT xs)
fun generate_rep_type_aux []  = HOLogic.unitT |
generate_rep_type_aux [x] = mk_prodT x |
generate_rep_type_aux (x::xs) =
let
val l = mk_prodT x
val r = generate_rep_type_aux xs
in
mk_tagged_sumT l r
end
in mk_sumT (map generate_rep_type_aux ctrs')
end

val rep_type_name = fold (curry (op ^)) (map Long_Name.base_name tnames) "" ^ "_rep"
val tfrees = (collect_tfree_names ctrs) @
(map (snd #> (dest_TFree #> fst)) tFree_renaming)
val _ = writeln ("Defining representation type " ^ rep_type_name)
val (full_rep_name,lthy') = Typedecl.abbrev (Binding.name rep_type_name, tfrees, NoSyn) rep_type lthy
in
({repname = full_rep_name,
rep_type = rep_type,
tFrees_mapping = tFree_renaming,
from_info = NONE,
to_info = NONE} : rep_type_info
, lthy')
end

fun get_combinator_info comb_type_name ctr_type lthy =
let
val inConst = Proof_Context.read_const {proper = true, strict = true} lthy "In"
val inConstType = inConst |> dest_Const |> snd |> Derive_Util.freeify_tvars
val inConst_free = Const (inConst |> dest_Const |> fst, inConstType)
val comb_type = inConstType |> body_type
val comb_type_name_full = comb_type |> dest_Type |> fst
val rep_type_instantiated = inConstType |> binder_types |> hd
in
{combname = comb_type_name,
combname_full = comb_type_name_full,
comb_type = comb_type,
ctr_type = ctr_type,
inConst = inConst,
inConst_free = inConst_free,
inConst_type = inConstType,
rep_type_instantiated = rep_type_instantiated} : comb_type_info
end

fun define_combinator_type tnames tfrees (rep_info : rep_type_info) lthy =
let
val comb_type_name = "mu" ^ (fold (curry (op ^)) (map Long_Name.base_name tnames) "") ^ "F"
val rec_tfrees = map (dest_TFree o snd) (#tFrees_mapping rep_info)
val rec_tfree_names = map fst rec_tfrees
val rep_tfree_names = (map (fst o dest_TFree o fst) tfrees)
val comb_type_name_tvars = add_tvars comb_type_name rep_tfree_names
val rec_type = (Type (comb_type_name,map fst tfrees))
val ctr_tfree_names = rep_tfree_names @ (replicate (length rec_tfree_names) comb_type_name_tvars)
val ctr_type_name = add_tvars (#repname rep_info) ctr_tfree_names
val ctr_type = map_type_tfree (fn (tfree,s) => case List.find (curry (op =) tfree) rec_tfree_names of
SOME _ => rec_type |
NONE   => TFree (tfree,s))
(#rep_type rep_info)
val ctr_typarams = ((replicate (length tfrees) (SOME Binding.empty)) ~~ (rep_tfree_names ~~ (replicate (length tfrees) NONE)))
val ctr_specs = [(((Binding.empty, Binding.name "In"), [(Binding.empty, ctr_type_name)]), NoSyn)]
val _ = writeln ("Defining combinator type " ^ comb_type_name)
val lthy' =
BNF_FP_Def_Sugar.co_datatype_cmd BNF_Util.Least_FP BNF_LFP.construct_lfp
((K Plugin_Name.default_filter, false),
[(((((ctr_typarams, Binding.name comb_type_name), NoSyn),
ctr_specs)
,(Binding.empty, Binding.empty, Binding.empty))
,[])]) lthy
val comb_info = get_combinator_info comb_type_name ctr_type lthy'
in
(SOME comb_info
, lthy')
end

fun generate_type_info tname constr_names lthy =
let
val (tnames, Ts) = Derive_Util.mutual_recursive_types tname lthy
(* get constructor and selector information from the BNF package *)
fun get_ctrs t = (t,map (apsnd (map Derive_Util.freeify_tvars o fst o strip_type) o dest_Const)
(Derive_Util.constr_terms lthy t))
fun get_sels t = (t,Ctr_Sugar.ctr_sugar_of lthy t |> the |> #selss
|> (map (map (dest_Const #> fst #> Long_Name.base_name))))
val ctrs = map get_ctrs tnames
val sels = map get_sels tnames
val tfrees = collect_tfrees ctrs
val is_mutually_rec = (length tnames) > 1
(* look for recursive constructor arguments *)
val is_rec = ctrs_arguments ctrs |> filter is_typeT |> map (dest_Type #> fst)
|> List.exists (fn n => is_some (List.find (curry (op =) n) tnames))
val (rep_info,lthy') = define_rep_type tnames ctrs constr_names lthy
val (comb_info,lthy'') = if is_rec then define_combinator_type tnames tfrees rep_info lthy'
else (NONE,lthy')
in
({tname = tname,
tfrees = tfrees,
mutual_tnames = tnames,
mutual_Ts = Ts,
mutual_ctrs = ctrs,
mutual_sels = sels,
is_rec = is_rec,
is_mutually_rec = is_mutually_rec,
rep_info = rep_info,
comb_info = comb_info,
iso_thm = NONE} : type_info
, lthy'')
end

fun generate_class_info class =
{classname = hd class,
class = class,
params = NONE,
class_law = NONE,
class_law_const = NONE,
ops = NONE,
transfer_law = NONE,
axioms = NONE,
axioms_def = NONE,
class_def = NONE,
equivalence_thm = NONE}

fun record_type_class_info ty_info cl_info inst_info thy =
let
Type_Data.put (Symreltab.update ((#tname info, Bool.toString (#uses_metadata info)),info) (Type_Data.get thy)) thy
fun add_inst_info classname tname thy =
Instance_Data.put (Symreltab.update ((classname, tname), inst_info) (Instance_Data.get thy)) thy
fun update_tname ({tname = _, uses_metadata = uses_metadata, tfrees = tfrees, mutual_tnames = mutual_tnames,
mutual_Ts = mutual_Ts, mutual_ctrs = mutual_ctrs, mutual_sels = mutual_sels, is_rec = is_rec,
is_mutually_rec = is_mutually_rec, rep_info = rep_info, comb_info = comb_info, iso_thm = iso_thm} : type_info)
tname =
{tname = tname, uses_metadata = uses_metadata, tfrees = tfrees, mutual_tnames = mutual_tnames,
mutual_Ts = mutual_Ts, mutual_ctrs = mutual_ctrs, mutual_sels = mutual_sels, is_rec = is_rec,
is_mutually_rec = is_mutually_rec, rep_info = rep_info, comb_info = comb_info, iso_thm = iso_thm} : type_info
val infos = map (update_tname ty_info) (#mutual_tnames ty_info)
in
|> Class_Data.put (Symtab.update ((#classname cl_info),cl_info) (Class_Data.get thy))
|> fold (add_inst_info (#classname cl_info)) (#mutual_tnames ty_info)
end

fun generate_instance tname class constr_names thy =
let
val (T,xs) = Derive_Util.typ_and_vs_of_typname thy tname class
val has_law = has_class_law (hd class) thy
val cl_info =
case get_class_info thy (hd class) of
NONE => if has_law then error ("Class " ^ (hd class) ^ "not set up for derivation, call derive_setup first")
else generate_class_info class |
SOME info => info
val raw_params = map snd (Class.these_params thy class)
val cl_info' = add_params_cl_info cl_info raw_params
val thy' = Class_Data.put (Symtab.update ((#classname cl_info'),cl_info') (Class_Data.get thy)) thy
val lthy = Named_Target.theory_init thy'
val (ty_info,lthy') =
case get_type_info thy tname constr_names of
SOME info => let
val _ = writeln ("Using existing type information for " ^ tname)
in (info,lthy)
end |
NONE => let
val (t_info,lthy) = generate_type_info tname constr_names lthy
val (from_info,to_info,lthy') = define_prod_sum_conv t_info constr_names lthy
val t_info' = add_conversion_info from_info to_info t_info
val (iso_thm,lthy'') = if has_law then Derive_Laws.prove_isomorphism t_info' lthy'
else (NONE,lthy')
val t_info'' = add_iso_info iso_thm t_info'
in
(t_info'',lthy'')
end

val tnames = #mutual_tnames ty_info
val Ts = (map (fn tn => Derive_Util.typ_and_vs_of_typname thy tn class) tnames) |> map fst
fun define_operations_all_Ts _ lthy =
let
val (thms,lthy') = fold_map (define_operations raw_params) (tnames ~~ Ts) lthy |> apfst flat
val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy')
val thms_export = Proof_Context.export lthy' ctxt_thy thms
val inst_info = mk_inst_info thms_export
in (inst_info,lthy')
end

fun instantiate_and_prove _ lthy =
Local_Theory.exit_global lthy
|> (Class.instantiation (tnames, xs, class))
|> define_operations_all_Ts cl_info'
|> (fn (inst_info,lthy) =>
(if has_law
then
let
val equivalence_thm = (Derive_Laws.prove_equivalence_law cl_info inst_info lthy)
val cl_info'' = add_equivalence_cl_info cl_info' equivalence_thm
in
(Class.prove_instantiation_exit (Derive_Laws.prove_instance_tac T cl_info'' inst_info ty_info) lthy) |> pair cl_info'' |> pair inst_info
end
else (Class.prove_instantiation_exit (fn ctxt => Class.intro_classes_tac ctxt []) lthy) |> pair cl_info |> pair inst_info))
|> (fn (inst_info,(cl_info,thy)) => record_type_class_info ty_info cl_info inst_info thy)
|> Proof_Context.init_global

val empty_goal = [[]]

(* Generate instance for Mu-combinator type if there is recursion *)
val thy' = if (#is_rec ty_info) then (instantiate_combinator_type ty_info cl_info' constr_names lthy'
|> (if is_some (#class_law cl_info')
then Derive_Laws.prove_combinator_instance instantiate_and_prove
else
Class.prove_instantiation_exit (fn ctxt => Class.intro_classes_tac ctxt [])
#> Named_Target.theory_init
#> Proof.theorem NONE instantiate_and_prove empty_goal))
else Proof.theorem NONE instantiate_and_prove empty_goal lthy'
in
thy'
end

fun generate_instance_cmd classname tyco constr_names thy =
let
val lthy = Proof_Context.init_global thy
val T = Syntax.parse_typ lthy tyco |> dest_Type |> fst
val class = Syntax.parse_sort (Proof_Context.init_global thy) classname
in
generate_instance T class constr_names thy
end

val parse_cmd =
Scan.optional (Args.parens (Parse.reserved "metadata")) ""  --
Parse.name --
Parse.type_const

val _ =
Outer_Syntax.command \<^command_keyword>‹derive_generic› "derives some sort"
(parse_cmd >> (fn ((s,c),t) =>
let val meta = s = "metadata" in Toplevel.theory_to_proof (generate_instance_cmd c t meta) end ))

fun add_inst_info classname tname thms thy =
Instance_Data.put (Symreltab.update ((classname, tname) ,{defs = thms}) (Instance_Data.get thy)) thy

end


# Theory Derive_Datatypes

chapter "Examples"

section "Example Datatypes"

theory Derive_Datatypes
imports Main
begin

(* Simple type without recursion or parameters *)
datatype simple = A (num: nat) | B (left:nat) (right:nat) | C

(* type with parameters *)
datatype ('a,'b) either = L 'a | R 'b

(* recursive type *)
datatype 'a tree = Leaf | Node 'a "'a tree" "'a tree"

(* mutually recursive types *)

datatype even_nat = Even_Zero | Even_Succ odd_nat
and   odd_nat  = Odd_Succ even_nat

datatype ('a,'b) exp = Term "('a,'b) trm" | Sum (left:"('a,'b) trm") (right:"('a,'b) exp")
and      ('a,'b) trm = Factor "('a,'b) fct " | Prod "('a,'b) fct " "('a,'b) trm "
and      ('a,'b) fct = Const 'a | Var (v:'b) | Expr "('a,'b) exp"

end

# Theory Derive_Eq

section "Equality"

theory Derive_Eq
imports Main "../Derive" Derive_Datatypes
begin

class eq =
fixes eq :: "'a ⇒ 'a ⇒ bool"

(* Manual instances for nat, unit, prod, and sum *)
instantiation nat and unit:: eq
begin
definition eq_nat : "eq (x::nat) y ⟷ x = y"
definition eq_unit_def: "eq (x::unit) y ⟷ True"
instance ..
end

instantiation prod and sum :: (eq, eq) eq
begin
definition eq_prod_def: "eq x y ⟷ (eq (fst x) (fst y)) ∧ (eq (snd x) (snd y))"
definition eq_sum_def: "eq x y = (case x of Inl a ⇒ (case y of Inl b ⇒ eq a b | Inr b ⇒ False)
| Inr a ⇒ (case y of Inl b ⇒ False | Inr b ⇒ eq a b))"

instance ..
end

(* nonrecursive test *)

derive_generic eq simple .

(* some tests *)
lemma "eq (A 4) (A 4)" by eval
lemma "eq (A 6) (A 4) ⟷ False" by eval
lemma "eq C C" by eval
lemma "eq (B 4 5) (B 4 5)" by eval
lemma "eq (B 4 4) (A 3) ⟷ False" by eval
lemma "eq C (A 4) ⟷ False" by eval

(* type with parameter *)

derive_generic eq either .

lemma "eq (L (3::nat)) (R 3) ⟷ False" by code_simp
lemma "eq (L (3::nat)) (L 3)" by code_simp
lemma "eq (L (3::nat)) (L 4) ⟷ False" by code_simp

(* recursive types *)
derive_generic eq list .

lemma "eq ([]::(nat list)) []" by eval
lemma "eq ([1,2,3]:: (nat list)) [1,2,3]" by eval
lemma "eq [(1::nat)] [1,2] ⟷ False" by eval

derive_generic eq tree .

lemma "eq Leaf Leaf" by code_simp
lemma "eq (Node (1::nat) Leaf Leaf) Leaf ⟷ False" by eval
lemma "eq (Node (1::nat) Leaf Leaf) (Node (1::nat) Leaf Leaf)" by eval
lemma "eq (Node (1::nat) (Node 2 Leaf Leaf) (Node 3 Leaf Leaf)) (Node (1::nat) (Node 2 Leaf Leaf) (Node 4 Leaf Leaf))
⟷ False" by eval

(* mutually recursive types *)

derive_generic eq even_nat .
derive_generic eq exp .

lemma "eq Even_Zero Even_Zero" by eval
lemma "eq Even_Zero (Even_Succ (Odd_Succ Even_Zero)) ⟷ False" by eval
lemma "eq (Odd_Succ (Even_Succ (Odd_Succ Even_Zero))) (Odd_Succ (Even_Succ (Odd_Succ Even_Zero)))" by eval
lemma "eq (Odd_Succ (Even_Succ (Odd_Succ Even_Zero))) (Odd_Succ (Even_Succ (Odd_Succ (Even_Succ (Odd_Succ Even_Zero)))))
⟷ False" by eval

lemma "eq (Const (1::nat)) (Const (1::nat))" by code_simp
lemma "eq (Const (1::nat)) (Var (1::nat)) ⟷ False" by eval
lemma "eq (Term (Prod (Const (1::nat)) (Factor (Const (2::nat))))) (Term (Prod (Const (1::nat)) (Factor (Const (2::nat)))))"
by code_simp
lemma "eq (Term (Prod (Const (1::nat)) (Factor (Const (2::nat))))) (Term (Prod (Const (1::nat)) (Factor (Const (3::nat)))))
⟷ False" by code_simp

end

# Theory Derive_Encode

section "Encoding"

theory Derive_Encode
imports Main "../Derive" Derive_Datatypes
begin

class encodeable =
fixes encode :: "'a ⇒ bool list"

(* Manual instances for nat, unit, prod, and sum *)
instantiation nat and unit:: encodeable
begin
fun encode_nat :: "nat ⇒ bool list" where
"encode_nat 0 = []" |
"encode_nat (Suc n) = True # (encode n)"

definition encode_unit: "encode (x::unit) = []"
instance ..
end

instantiation prod and sum :: (encodeable, encodeable) encodeable
begin
definition encode_prod_def: "encode x = append (encode (fst x)) (encode (snd x))"
definition encode_sum_def:  "encode x = (case x of Inl a ⇒ False # encode a
| Inr a ⇒ True # encode a)"
instance ..
end

derive_generic encodeable simple .
derive_generic encodeable either .

lemma "encode (B 3 4) = [True, False, True, True, True, True, True, True, True]" by eval
lemma "encode C = [True, True]" by eval
lemma "encode (R (3::nat)) = [True, True, True, True]" by code_simp

(* recursive types *)

derive_generic encodeable list .
derive_generic encodeable tree .

lemma "encode [1,2,3,4::nat]
= [True, True, True, True, True, True, True, True, True, True, True, True, True, True, False]" by eval
lemma "encode (Node (3::nat) (Node 1 Leaf Leaf) (Node 2 Leaf Leaf))
= [True, True, True, True, True, True, False, False, True, True, True, False, False]" by eval

(* mutually recursive types *)

derive_generic encodeable even_nat .
derive_generic encodeable exp .

lemma "encode (Odd_Succ (Even_Succ (Odd_Succ Even_Zero)))
= [True, False, True, True, False, False]" by eval
lemma "encode (Term (Prod (Const (1::nat)) (Factor (Const (2::nat)))))
= [False, False, True, False, True, True, True, False, True, True, False, False, True, True, False, True, True]"
by code_simp

end

# Theory Derive_Algebra

section "Algebraic Classes"

theory Derive_Algebra
imports Main "../Derive" Derive_Datatypes
begin

class semigroup =
fixes mult :: "'a ⇒ 'a ⇒ 'a" (infixl "⊗" 70)
(*  assumes assoc: "(x ⊗ y) ⊗ z = x ⊗ (y ⊗ z)" *)

class monoidl = semigroup +
fixes neutral :: 'a ("𝟭")
(* assumes neutl : "𝟭 ⊗ x = x" *)

class group = monoidl +
fixes inverse :: "'a ⇒ 'a"
(* assumes invl: "x÷ ⊗ x = 𝟭" *)

(* Manual instances for nat, unit, prod, and sum *)
instantiation nat and unit:: semigroup
begin
definition mult_nat : "mult (x::nat) y = x + y"
definition mult_unit_def: "mult (x::unit) y = x"
instance ..
end
instantiation nat and unit:: monoidl
begin
definition neutral_nat : "neutral = (0::nat)"
definition neutral_unit_def: "neutral = ()"
instance ..
end

instantiation nat and unit:: group
begin
definition inverse_nat : "inverse (i::nat) = 𝟭 - i"
definition inverse_unit_def: "inverse u = ()"
instance ..
end

instantiation prod and sum :: (semigroup, semigroup) semigroup
begin
definition mult_prod_def: "x ⊗ y = (fst x ⊗ fst y, snd x ⊗ snd y)"
definition mult_sum_def: "x ⊗ y = (case x of Inl a ⇒ (case y of Inl b ⇒ Inl (a ⊗ b) | Inr b ⇒ Inl a)
| Inr a ⇒ (case y of Inl b ⇒ Inr a | Inr b ⇒ Inr (a ⊗ b)))"
instance ..
end

instantiation prod and sum :: (monoidl, monoidl) monoidl
begin
definition neutral_prod_def: "neutral = (neutral,neutral)"
definition neutral_sum_def: "neutral = Inl neutral"
instance ..
end

instantiation prod and sum :: (group, group) group
begin
definition inverse_prod_def: "inverse p = (inverse (fst p), inverse (snd p))"
definition inverse_sum_def: "inverse x = (case x of Inl a ⇒ (Inl (inverse a))
| Inr b ⇒ Inr (inverse b))"
instance ..
end

(* Simple test *)

derive_generic semigroup simple .
derive_generic monoidl simple .
derive_generic group simple .

lemma "(B 𝟭 6) ⊗ (B 4 5) = B 4 11" by eval
lemma "(A 2) ⊗ (A 3) = A 5" by eval
lemma "(B 𝟭 6) ⊗ 𝟭 = B 0 6" by eval

(* type with parameter *)

derive_generic group either .

lemma "(L 3) ⊗ ((L 4)::(nat,nat) either) = L 7" by eval
lemma "(R (2::nat)) ⊗ (L (3::nat)) = R 2" by eval

(* recursive types *)

derive_generic semigroup list .
derive_generic monoidl list .
derive_generic group list .
derive_generic semigroup tree .
derive_generic monoidl tree .
derive_generic group tree .

lemma "[1,2,3,4::nat] ⊗ [1,2,3] = [2,4,6,4]" by eval
lemma "inverse [1,2,3::nat] = [0,0,0]" by eval

(* mutually recursive types *)

derive_generic semigroup even_nat .
derive_generic monoidl even_nat .
derive_generic group even_nat .
derive_generic semigroup exp .

(* instantiate monoidl manually *)
instantiation exp and trm and fct  :: (monoidl,monoidl) monoidl
begin
definition neutral_fct where "neutral_fct = Const neutral"
definition neutral_trm where "neutral_trm = Factor neutral"
definition neutral_exp where "neutral_exp = Term neutral"
instance ..
end

(* Manually defined instances need to be added to the theory context *)
setup ‹
(Derive.add_inst_info \<^class>‹monoidl› \<^type_name>‹fct› [@{thm neutral_fct_def}]) #>
(Derive.add_inst_info \<^class>‹monoidl› \<^type_name>‹trm› [@{thm neutral_trm_def}]) #>
›

derive_generic group exp .

lemma "(Odd_Succ (Even_Succ (Odd_Succ Even_Zero))) ⊗ (Odd_Succ Even_Zero)
= Odd_Succ (Even_Succ (Odd_Succ Even_Zero))" by eval
lemma "inverse (Odd_Succ Even_Zero) = Odd_Succ Even_Zero" by eval
lemma "(Term (Prod ((Const 1)::(nat, nat) fct) (Factor (Const (2::nat)))))
⊗ (Term (Prod (Const (2::nat)) (Factor ((Const 2)::(nat, nat) fct))))
= Term (Prod (Const 3) (Factor (Const 4)))" by eval

end

# Theory Derive_Show

section "Show"

theory Derive_Show
imports Main "../Derive" Derive_Datatypes
begin

class showable =
fixes print :: "'a ⇒ string"

fun string_of_nat :: "nat ⇒ string"
where
"string_of_nat n = (if n < 10 then [(char_of :: nat ⇒ char) (48 + n)] else
string_of_nat (n div 10) @ [(char_of :: nat ⇒ char) (48 + (n mod 10))])"

(* Manual instances for nat, unit, prod, and sum *)
instantiation nat and unit:: showable
begin
definition print_nat: "print (n::nat) = string_of_nat n"
definition print_unit: "print (x::unit) = ''''"
instance ..
end

instantiation Tagged_Prod_Sum.prod and Tagged_Prod_Sum.sum :: (showable, showable) showable
begin
definition print_prod_def:
"print (x::('a,'b) Tagged_Prod_Sum.prod) =
(case Tagged_Prod_Sum.sel_name_fst x of
None ⇒ (print (Tagged_Prod_Sum.fst x))
| Some s ⇒ ''('' @ s @ '': '' @ (print (Tagged_Prod_Sum.fst x)) @ '')'')
@
'' ''
@
(case Tagged_Prod_Sum.sel_name_snd x of
None ⇒ (print (Tagged_Prod_Sum.snd x))
| Some s ⇒ ''('' @ s @ '': '' @ (print (Tagged_Prod_Sum.snd x)) @ '')'')"

definition print_sum_def:  "print (x::('a,'b) Tagged_Prod_Sum.sum) =
(case x of (Tagged_Prod_Sum.Inl s a) ⇒ (case s of None ⇒ print a | Some c ⇒ ''('' @ c @ '' '' @ (print a) @ '')'')
| (Tagged_Prod_Sum.Inr s b) ⇒ (case s of None ⇒ print b | Some c ⇒ ''('' @ c @ '' '' @ (print b) @ '')''))"
instance ..
end

(* simple types *)

declare [[ML_print_depth=30]]

value "print (A 3)"
value "print (B 1 2)"
value [simp] "print (L (2::nat))"
value "print C"

(* recursive types *)

value "print [1,2::nat]"
value "print (Node (3::nat) (Node 1 Leaf Leaf) (Node 2 Leaf Leaf))"

(* mutually recursive types *)

value "print (Odd_Succ (Even_Succ (Odd_Succ Even_Zero)))"
value [simp] "print (Sum (Factor (Const (0::nat))) (Term (Prod (Const (1::nat)) (Factor (Const (2::nat))))))"

end

# Theory Derive_Eq_Laws

section "Classes with Laws"

subsection "Equality"

theory Derive_Eq_Laws
imports Main "../Derive" Derive_Datatypes
begin

class eq =
fixes eq :: "'a ⇒ 'a ⇒ bool"
assumes refl: "eq x x" and
sym: "eq x y ⟹ eq y x" and
trans: "eq x y ⟹ eq y z ⟹ eq x z"

derive_generic_setup eq
unfolding eq_class_law_def
by blast

lemma eq_law_eq: "eq_class_law eq"
unfolding eq_class_law_def
using eq_class.axioms unfolding class.eq_def .

(* Manual instances for nat, unit, prod, and sum *)
instantiation nat and unit :: eq
begin
definition eq_nat_def : "eq (x::nat) y ⟷ x = y"
definition eq_unit_def: "eq (x::unit) y ⟷ True"
instance proof
fix x y z :: nat
show "eq x x" unfolding eq_nat_def by simp
show "eq x y ⟹ eq y x" unfolding eq_nat_def by simp
show "eq x y ⟹ eq y z ⟹ eq x z" unfolding eq_nat_def by simp
next
fix x y z :: unit
show "eq x x" unfolding eq_unit_def by simp
show "eq x y ⟹ eq y x" unfolding eq_unit_def by simp
show "eq x y ⟹ eq y z ⟹ eq x z" unfolding eq_unit_def by simp
qed
end

instantiation prod and sum :: (eq, eq) eq
begin
definition eq_prod_def: "eq x y ⟷ (eq (fst x) (fst y)) ∧ (eq (snd x) (snd y))"
definition eq_sum_def: "eq x y = (case x of Inl a ⇒ (case y of Inl b ⇒ eq a b | Inr b ⇒ False)
| Inr a ⇒ (case y of Inl b ⇒ False | Inr b ⇒ eq a b))"
instance proof
fix x y z :: "('a::eq) × ('b::eq)"
show "eq x x" unfolding eq_prod_def by (simp add: eq_class.refl)
show "eq x y ⟹ eq y x" unfolding eq_prod_def by (simp add: eq_class.sym)
show "eq x y ⟹ eq y z ⟹ eq x z" unfolding eq_prod_def by (meson eq_class.trans)
next
fix x y z :: "('a::eq) + ('b::eq)"
show "eq x x" unfolding eq_sum_def by (simp add: sum.case_eq_if eq_class.refl)
show "eq x y ⟹ eq y x" unfolding eq_sum_def by (metis eq_class.sym sum.case_eq_if)
show "eq x y ⟹ eq y z ⟹ eq x z"
unfolding eq_sum_def
apply (simp only: sum.case_eq_if)
apply (cases "isl x"; cases "isl y"; cases "isl z")
qed
end

(* nonrecursive test *)
derive_generic eq simple .

(* some tests *)
lemma "eq (A 4) (A 4)" by eval
lemma "eq (A 6) (A 4) ⟷ False" by eval
lemma "eq C C" by eval
lemma "eq (B 4 5) (B 4 5)" by eval
lemma "eq (B 4 4) (A 3) ⟷ False" by eval
lemma "eq C (A 4) ⟷ False" by eval

(* type with parameter *)

derive_generic eq either .

lemma "eq (L (3::nat)) (R 3) ⟷ False" by code_simp
lemma "eq (L (3::nat)) (L 3)" by code_simp
lemma "eq (L (3::nat)) (L 4) ⟷ False" by code_simp

(* recursive types *)
derive_generic eq list
proof goal_cases
case (1 x)
then show ?case
proof (induction x)
case (In y)
then show ?case
apply(cases y)
by (auto simp add: Derive_Eq_Laws.eq_mulistF.simps eq_unit_def eq_class.refl)
qed
next
case (2 x y)
then show ?case
proof (induction y arbitrary: x)
case (In y)
then show ?case
apply(cases x; cases y; hypsubst_thin)
apply (simp add: Derive_Eq_Laws.eq_mulistF.simps sum.case_eq_if eq_unit_def)
apply(metis old.sum.simps(5))
unfolding sum_set_defs prod_set_defs
using eq_class.sym by fastforce
qed
next
case (3 x y z)
then show ?case
proof (induction x arbitrary: y z)
case (In x')
then show ?case
apply(cases x')
apply (cases y; cases z; hypsubst_thin)
apply (simp add: Derive_Eq_Laws.eq_mulistF.simps sum.case_eq_if eq_unit_def)
apply (metis sum.case_eq_if)
apply(cases y; cases z; hypsubst_thin)
unfolding sum_set_defs prod_set_defs
apply (simp add: Derive_Eq_Laws.eq_mulistF.simps eq_unit_def snds.intros)
apply (simp only: sum.case_eq_if)
by (meson eq_class.trans)
qed
qed

lemma "eq ([]::(nat list)) []" by eval
lemma "eq ([1,2,3]:: (nat list)) [1,2,3]" by eval
lemma "eq [(1::nat)] [1,2] ⟷ False" by eval

derive_generic eq tree
proof goal_cases
case (1 x)
then show ?case
proof (induction x)
case (In y)
then show ?case
apply(cases y)
by (auto simp add: Derive_Eq_Laws.eq_mutreeF.simps eq_unit_def eq_class.refl)
qed
next
case (2 x y)
then show ?case
proof (induction y arbitrary: x)
case (In y)
then show ?case
apply(cases x; cases y; hypsubst_thin)
apply (simp add: Derive_Eq_Laws.eq_mutreeF.simps sum.case_eq_if eq_unit_def)
apply(metis old.sum.simps(5))
unfolding sum_set_defs prod_set_defs
using eq_class.sym by fastforce
qed
next
case (3 x y z)
then show ?case
proof (induction x arbitrary: y z)
case (In x')
then show ?case
apply(cases x')
apply (cases y; cases z; hypsubst_thin)
apply (simp add: Derive_Eq_Laws.eq_mutreeF.simps sum.case_eq_if eq_unit_def)
apply (metis sum.case_eq_if)
apply(cases y; cases z; hypsubst_thin)
unfolding sum_set_defs prod_set_defs
apply (simp add: Derive_Eq_Laws.eq_mutreeF.simps eq_unit_def snds.intros)
apply (simp only: sum.case_eq_if)
by (meson eq_class.trans)
qed
qed

lemma "eq Leaf Leaf" by code_simp
lemma "eq (Node (1::nat) Leaf Leaf) Leaf ⟷ False" by eval
lemma "eq (Node (1::nat) Leaf Leaf) (Node (1::nat) Leaf Leaf)" by eval
lemma "eq (Node (1::nat) (Node 2 Leaf Leaf) (Node 3 Leaf Leaf)) (Node (1::nat) (Node 2 Leaf Leaf) (Node 4 Leaf Leaf))
⟷ False" by eval
end

# Theory Derive_Algebra_Laws

subsection "Algebraic Classes"

theory Derive_Algebra_Laws
imports Main "../Derive" Derive_Datatypes
begin

datatype simple_int = A int | B int int | C

class semigroup =
fixes mult :: "'a ⇒ 'a ⇒ 'a" (infixl "⊗" 70)
assumes assoc: "(x ⊗ y) ⊗ z = x ⊗ (y ⊗ z)"

class monoidl = semigroup +
fixes neutral :: 'a ("𝟭")
assumes neutl : "𝟭 ⊗ x = x"

class group = monoidl +
fixes inverse :: "'a ⇒ 'a"
assumes invl: "(inverse x) ⊗ x = 𝟭"

definition semigroup_law :: "('a ⇒ 'a ⇒ 'a) ⇒ bool" where
"semigroup_law MULT = (∀ x y z. MULT (MULT x y) z = MULT x (MULT y z))"
definition monoidl_law :: "'a ⇒ ('a ⇒ 'a ⇒ 'a) ⇒ bool" where
"monoidl_law NEUTRAL MULT = ((∀ x. MULT NEUTRAL x = x) ∧ semigroup_law MULT)"
definition group_law :: "('a ⇒ 'a) ⇒ 'a ⇒ ('a ⇒ 'a ⇒ 'a) ⇒ bool" where
"group_law INVERSE NEUTRAL MULT = ((∀ x. MULT (INVERSE x) x = NEUTRAL) ∧ monoidl_law NEUTRAL MULT)"

lemma transfer_semigroup:
assumes "Derive.iso f g"
shows "semigroup_law MULT ⟹ semigroup_law (λx y. g (MULT (f x) (f y)))"
unfolding semigroup_law_def
using assms unfolding Derive.iso_def by simp

lemma transfer_monoidl:
assumes "Derive.iso f g"
shows "monoidl_law NEUTRAL MULT ⟹ monoidl_law (g NEUTRAL) (λx y. g (MULT (f x) (f y)))"
unfolding monoidl_law_def semigroup_law_def
using assms unfolding Derive.iso_def by simp

lemma transfer_group:
assumes "Derive.iso f g"
shows "group_law INVERSE NEUTRAL MULT ⟹ group_law (λ x. g (INVERSE (f x))) (g NEUTRAL) (λx y. g (MULT (f x) (f y)))"
unfolding group_law_def monoidl_law_def semigroup_law_def
using assms unfolding Derive.iso_def by simp

lemma semigroup_law_semigroup: "semigroup_law mult"
unfolding semigroup_law_def
using semigroup_class.axioms unfolding class.semigroup_def .

lemma monoidl_law_monoidl: "monoidl_law neutral mult"
unfolding monoidl_law_def
using monoidl_class.axioms semigroup_law_semigroup
unfolding class.monoidl_axioms_def by simp

lemma group_law_group: "group_law inverse neutral mult"
unfolding group_law_def
using group_class.axioms monoidl_law_monoidl
unfolding class.group_axioms_def by simp

derive_generic_setup semigroup
unfolding semigroup_class_law_def
Derive.iso_def
by simp

derive_generic_setup monoidl
unfolding monoidl_class_law_def semigroup_class_law_def Derive.iso_def
by simp

derive_generic_setup group
unfolding group_class_law_def monoidl_class_law_def semigroup_class_law_def Derive.iso_def
by simp

(* Manual instances for int, unit, prod, and sum *)
instantiation int and unit:: semigroup
begin
definition mult_int_def : "mult (x::int) y = x + y"
definition mult_unit_def: "mult (x::unit) y = x"
instance proof
fix x y z :: int
show "x ⊗ y ⊗ z = x ⊗ (y ⊗ z)"
unfolding mult_int_def by simp
next
fix x y z :: unit
show "x ⊗ y ⊗ z = x ⊗ (y ⊗ z)"
unfolding mult_unit_def by simp
qed
end
instantiation int and unit:: monoidl
begin
definition neutral_int_def : "neutral = (0::int)"
definition neutral_unit_def: "neutral = ()"
instance proof
fix x :: int
show "𝟭 ⊗ x = x" unfolding neutral_int_def mult_int_def by simp
next
fix x :: unit
show "𝟭 ⊗ x = x" unfolding neutral_unit_def mult_unit_def by simp
qed
end

instantiation int and unit:: group
begin
definition inverse_int_def : "inverse (i::int) = 𝟭 - i"
definition inverse_unit_def: "inverse u = ()"
instance proof
fix x :: int
show "inverse x ⊗ x = 𝟭" unfolding inverse_int_def mult_int_def by simp
next
fix x :: unit
show "inverse x ⊗ x = 𝟭" unfolding inverse_unit_def mult_unit_def by simp
qed
end

instantiation prod and sum :: (semigroup, semigroup) semigroup
begin
definition mult_prod_def: "x ⊗ y = (fst x ⊗ fst y, snd x ⊗ snd y)"
definition mult_sum_def: "x ⊗ y = (case x of Inl a ⇒ (case y of Inl b ⇒ Inl (a ⊗ b) | Inr b ⇒ Inr b)
| Inr a ⇒ (case y of Inl b ⇒ Inr a | Inr b ⇒ Inr (a ⊗ b)))"
instance proof
fix x y z :: "('a::semigroup) × ('b::semigroup)"
show "x ⊗ y ⊗ z = x ⊗ (y ⊗ z)" unfolding mult_prod_def by (simp add: assoc)
next
fix x y z :: "('a::semigroup) + ('b::semigroup)"
show "x ⊗ y ⊗ z = x ⊗ (y ⊗ z)" unfolding mult_sum_def
qed
end

instantiation prod and sum :: (monoidl, monoidl) monoidl
begin
definition neutral_prod_def: "neutral = (neutral,neutral)"
definition neutral_sum_def: "neutral = Inl neutral"
instance proof
fix x :: "('a::monoidl) × ('b::monoidl)"
show "𝟭 ⊗ x = x" unfolding neutral_prod_def mult_prod_def by (simp add: neutl)
next
fix x :: "('a::monoidl) + ('b::monoidl)"
show "𝟭 ⊗ x = x" unfolding neutral_sum_def mult_sum_def
by (simp add: neutl sum.case_eq_if sum.exhaust_sel)
qed
end

instantiation prod :: (group, group) group
begin
definition inverse_prod_def: "inverse p = (inverse (fst p), inverse (snd p))"
instance proof
fix x :: "('a::group) × ('b::group)"
show "inverse x ⊗ x = 𝟭" unfolding inverse_prod_def mult_prod_def neutral_prod_def
qed
end

derive_generic semigroup simple_int .
derive_generic monoidl simple_int .

derive_generic semigroup either .
derive_generic monoidl either .

lemma "(B 𝟭 6) ⊗ (B 4 5) = B 4 11" by eval
lemma "(A 2) ⊗ (A 3) = A 5" by eval
lemma "(B 𝟭 6) ⊗ 𝟭 = B 0 6" by eval

lemma "(L 3) ⊗ ((L 4)::(int,int) either) = L 7" by eval
lemma "(R (2::int)) ⊗ (L (3::int)) = R 2" by eval

derive_generic semigroup list
proof goal_cases
case (1 x y z)
then show ?case
proof (induction x arbitrary: y z)
case (In x')
then show ?case
apply(cases x')
apply (cases y; cases z; hypsubst_thin)
apply (simp add: Derive_Algebra_Laws.mult_mulistF.simps sum.case_eq_if mult_unit_def)
apply(cases y; cases z; hypsubst_thin)
unfolding sum_set_defs prod_set_defs
qed
qed

derive_generic semigroup tree
proof goal_cases
case (1 x y z)
then show ?case
proof (induction x arbitrary: y z)
case (In x')
then show ?case
apply(cases x')
apply (cases y; cases z; hypsubst_thin)
apply (simp add: Derive_Algebra_Laws.mult_mutreeF.simps sum.case_eq_if mult_unit_def)
apply(cases y; cases z; hypsubst_thin)
unfolding sum_set_defs prod_set_defs
qed
qed

derive_generic monoidl list
proof goal_cases
case (1 x)
then show ?case
proof (induction x)
case (In x')
then show ?case
apply(cases x')
by (auto simp add: Derive_Algebra_Laws.neutral_mulistF_def sum.case_eq_if neutral_unit_def)
qed
qed

derive_generic monoidl tree
proof goal_cases
case (1 x)
then show ?case
proof (induction x)
case (In x')
then show ?case
apply(cases x')
by (auto simp add: Derive_Algebra_Laws.neutral_mutreeF_def sum.case_eq_if neutral_unit_def)
qed
qed

lemma "[1,2,3,4::int] ⊗ [1,2,3] = [2,4,6,4]" by eval
lemma "(Node (3::int) Leaf Leaf) ⊗ (Node (1::int) Leaf Leaf) = (Node 4 Leaf Leaf)" by eval

end`