Session Deriving

Theory Derive_Manager

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
section ‹Derive Manager›

theory Derive_Manager
imports Main
keywords "print_derives" :: diag and "derive" :: thy_decl
begin

text ‹
  The derive manager allows the user to register various derive-hooks, e.g., for orders,
  pretty-printers, hash-functions, etc. All registered hooks are accessible via the derive command.

  @{rail@'derive' ('(' param ')')? sort (datatype+)
  ›}

  \begin{description}
  \item derive (param) sort datatype› calls the hook for deriving sort› (that
  may depend on the optional param›) on datatype› (if such a hook is registered).

  E.g., derive compare_order list› will derive a comparator for datatype @{type list}
  which is also used to define a linear order on @{type list}s.
  \end{description}

  There is also the diagnostic command print_derives that shows the list of currently
  registered hooks.
›

ML_file ‹derive_manager.ML›

end

File ‹derive_manager.ML›

signature DERIVE_MANAGER =
sig
  (* identifier, description, (fn dtyp_name => param => derive-method) *)
  val register_derive : string -> string -> (string -> string -> theory -> theory) -> theory -> theory
  (* identifier, description, (fn dtyp_name => param => derive-method) *)
  val derive : string -> string -> string -> theory -> theory
  val derive_cmd : string -> string -> string -> theory -> theory
  (* print all registered deriving-methods  *)
  val print_info : theory -> unit
end

structure Derive_Manager : DERIVE_MANAGER =
struct

structure Derive_Data = Theory_Data
  (
    type T =
      (string * (string -> string -> theory -> theory)) Symtab.table  (* descr * derive-fun *)

    val empty = Symtab.empty
    val extend = I
    val merge = Symtab.merge (K true)
  )

val derive_options =
  Derive_Data.get #> Symtab.dest #> map (fn (id, (descr, _)) => (id, descr))

(* FIXME: possibly use Pretty function for presentation. *)
fun print_info thy =
  let
    val _ = writeln "The following sorts can be derived"
    val _ = derive_options thy |> sort_by fst |> map (fn (id,descr) => writeln (id ^ ": " ^ descr))
  in () end

fun register_derive id descr f thy =
  if Symtab.defined (Derive_Data.get thy) id then
    error ("Identifier " ^ quote id ^ " already in use for " ^ quote "deriving")
  else
    Derive_Data.map (Symtab.update_new (id, (descr, f))) thy

fun gen_derive prep id dtname param thy =
  (case Symtab.lookup (Derive_Data.get thy) id of
    NONE => error ("No handler to derive sort " ^ quote id ^
      " is registered. Try " ^ quote "print_derives" ^ " to see available sorts.")
  | SOME (_, f) => f (prep thy dtname) param thy)

val derive = gen_derive (K I)

fun derive_cmd id param dtname = gen_derive
  (fn thy => fst o dest_Type o Syntax.parse_typ (Proof_Context.init_global thy)) id dtname param
  
(* TODO: also check for alternative of  *)
(* NB: Proof_Context.read_type_name_proper ctxt false could be an alternative. *)

val _ =
  Outer_Syntax.command @{command_keyword print_derives} "lists all registered sorts which can be derived"
    (Scan.succeed (Toplevel.theory (tap print_info)))

val _ =
  Outer_Syntax.command @{command_keyword derive} "derives some sort"
    (Parse.parname -- Parse.name -- Scan.repeat1 Parse.type_const >> (fn ((param, s), tycons) =>
      Toplevel.theory (fold (derive_cmd s param) tycons)))

end

Theory Generator_Aux

section ‹Shared Utilities for all Generator›

text ‹In this theory we mainly provide some Isabelle/ML infrastructure
  that is used by several generators. It consists of a uniform interface
  to access all the theorems, terms, etc.\ from the BNF package, and 
  some auxiliary functions which provide recursors on datatypes, common tactics, etc.›

theory Generator_Aux
imports 
  Main
begin

ML_file ‹bnf_access.ML›
ML_file ‹generator_aux.ML›

lemma in_set_simps: 
  "x  set (y # z # ys) = (x = y  x  set (z # ys))"
  "x  set ([y]) = (x = y)"
  "x  set [] = False" 
  "Ball (set []) P = True" 
  "Ball (set [x]) P = P x" 
  "Ball (set (x # y # zs)) P = (P x  Ball (set (y # zs)) P)" 
  by auto
  
lemma conj_weak_cong: "a = b  c = d  (a  c) = (b  d)" by auto

lemma refl_True: "(x = x) = True" by simp

end

File ‹bnf_access.ML›

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
signature BNF_ACCESS =
sig
(* thms *)
val induct_thms : Proof.context -> string list -> thm list
val case_thms : Proof.context -> string list -> thm list
val distinct_thms : Proof.context -> string list -> thm list list
val inject_thms : Proof.context -> string list -> thm list list
val set_simps : Proof.context -> string list -> thm list list
val case_simps : Proof.context -> string list -> thm list list
val map_simps : Proof.context -> string list -> thm list list
val map_comps : Proof.context -> string list -> thm list

(* terms *)
val map_terms : Proof.context -> string list -> term list
val set_terms : Proof.context -> string list -> term list list
val case_consts : Proof.context -> string list -> term list
val constr_terms : Proof.context -> string -> term list

(* types *)
val constr_argument_types : Proof.context -> string list -> typ list list list
val bnf_types : Proof.context -> string list -> typ list

end

structure Bnf_Access : BNF_ACCESS =
struct

fun constr_terms lthy = BNF_FP_Def_Sugar.fp_sugar_of lthy
  #> the #> #fp_ctr_sugar #> #ctr_sugar #> #ctrs

fun induct_thms lthy =
  map (hd o #co_inducts o the o #fp_co_induct_sugar o the o BNF_FP_Def_Sugar.fp_sugar_of lthy)

fun case_thms lthy =
  map (#exhaust o #ctr_sugar o #fp_ctr_sugar o the o BNF_FP_Def_Sugar.fp_sugar_of lthy)

fun set_simps lthy = 
  map (#set_thms o #fp_bnf_sugar o the o BNF_FP_Def_Sugar.fp_sugar_of lthy)

fun distinct_thms lthy =
  map (#distincts o #ctr_sugar o #fp_ctr_sugar o the o BNF_FP_Def_Sugar.fp_sugar_of lthy)

fun inject_thms lthy =
  map (#injects o #ctr_sugar o #fp_ctr_sugar o the o BNF_FP_Def_Sugar.fp_sugar_of lthy)

fun case_simps lthy = 
  map (#case_thms o #ctr_sugar o #fp_ctr_sugar o the o BNF_FP_Def_Sugar.fp_sugar_of lthy)

fun map_simps lthy = 
  map (#map_thms o #fp_bnf_sugar o the o BNF_FP_Def_Sugar.fp_sugar_of lthy)

fun map_comps lthy tycos = hd tycos
  |> (#bnfs o #fp_res o the o BNF_FP_Def_Sugar.fp_sugar_of lthy)
  |> map (BNF_Def.map_comp_of_bnf)

fun constr_argument_types lthy = 
  map (#ctrXs_Tss o #fp_ctr_sugar o the o BNF_FP_Def_Sugar.fp_sugar_of lthy)

fun bnf_types lthy = 
  map (#X o the o BNF_FP_Def_Sugar.fp_sugar_of lthy)

fun map_terms lthy tycos = hd tycos
  |> (#bnfs o #fp_res o the o BNF_FP_Def_Sugar.fp_sugar_of lthy)
  |> map (BNF_Def.map_of_bnf)

fun set_terms lthy tycos = hd tycos
  |> (#bnfs o #fp_res o the o BNF_FP_Def_Sugar.fp_sugar_of lthy)
  |> map (BNF_Def.sets_of_bnf)

fun case_consts lthy = map (BNF_FP_Def_Sugar.fp_sugar_of lthy
  #> the #> #fp_ctr_sugar #> #ctr_sugar #> #casex)

end

File ‹generator_aux.ML›

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann
    License:     LGPL
*)
signature GENERATOR_AUX =
sig
  val alist_to_string : (string * 'a)list -> string

  (* put a string in sub-script *)
  val sub : string -> string

  (* put a type-name in sub-script *)
  val subT : string -> typ -> string

  (* [a,..,n] -> _{a_.._n} *)
  val ints_to_subscript : int list -> string

  (* drop last element of a non-empty list *)
  val drop_last : 'a list -> 'a list

  (* rename old types to new types in term, typ-lists (fst and snd components) are assumed to be distinct *)
  val rename_types : (typ * typ) list -> term -> term

  (* λ x1 ... xn.t *)
  val lambdas : term list -> term -> term

  (* thm[OF _(NONE) foo (SOME foo) ...] *)
  val OF_option : thm -> thm option list -> thm

  (* aim: treat a possible empty conjunction and handle each subcase
     by a specific tactic (indexed from 0 onwards) *)
  val conjI_tac : thm list ->
    Proof.context -> 'a list -> (Proof.context -> int -> tactic) -> tactic

  (* check whether a given type is of a given sort *)
  val is_class_instance : theory -> string -> sort -> bool

  val mk_case_tac :
    Proof.context ->
    term option list list ->
    thm ->
    (int -> Proof.context * thm list * (string * cterm) list -> tactic) ->
    tactic

  (*type-inference after replacing all schematic type-variables by dummyT*)
  val infer_type : Proof.context -> term -> term

  (* like old prove_multi, but in non-blocking future-verion *)
  val prove_multi_future : Proof.context -> string list -> term list -> term list ->
    ({prems: thm list, context: Proof.context} -> tactic) -> thm list

  (* determines all mutual recursive types of a given BNF-least-fixpoint-type *)
  val mutual_recursive_types : string -> Proof.context -> string list * typ list

  (* a fold over types, differentiating mutual recursive types and other type-constructors *)
  val recursor :
    (string -> 'info) * (* accessing the information of a datatype *)
    ('info -> bool list) * (* which arguments are used of a datatype *)
    string list -> (* information on used arguments and recursive types *)
    bool -> (*recursion over all types (or only used types)*)
    (typ -> 'a) -> (* how to treat TFrees *)
    (typ -> 'a) -> (* how to treat TVars *)
    (typ -> 'a) -> (* how to treat recursive case *)
    ((typ * 'a option) list * 'info -> 'a) -> (* how to treat non-rec.-case; NONE result for unused types, if all = false *)
    typ -> 'a

  (* split global list of induction hypotheses according to list of argument types *)
  val split_IHs : (string -> 'info) * ('info -> bool list) * string list -> typ list -> thm list -> thm list list

  (* a standard tactic to solve proof obligation with recursion on types:
     variables and rec.-cases are handled via IHs and preconditions,
     for non-rec.-type constructors  a partial soundness thm has to be generated from the info *)
  val std_recursor_tac : (string -> 'info) * ('info -> bool list) * string list ->
    typ list ->
    ('info -> thm) ->
    thm list -> typ -> thm list -> Proof.context -> tactic

  (* delivers a full type from a type name by instantiating the type-variables of that
   type with different variables of a given sort, also returns the chosen variables
   as second component *)
  val typ_and_vs_of_typname : theory -> string -> sort -> typ * (string * sort) list

  (* similar to typ_and_vs_of_typname, but only for used types the sort contraint will be enforced *)
  val typ_and_vs_of_used_typname : string -> bool list -> string list -> typ * (string * string list) list

  val freeify_tvars : typ -> typ

  val add_used_tycos : Proof.context -> string -> string list -> string list

  val type_parameters : typ -> Proof.context -> (string * sort) list * typ list

  val define_overloaded : (string * term) -> local_theory -> thm * local_theory

  val define_overloaded_generic : (Attrib.binding * term) -> local_theory -> thm * local_theory

  val mk_id : typ -> term

  val mk_def : typ -> string -> term -> term

  val mk_infer_const : string -> Proof.context -> term -> term

  val ind_case_to_idxs : 'a list list -> int -> int * int

  val create_partial :
    'a ->
    (typ -> bool) ->
    (local_theory -> string -> bool list) ->
    (local_theory -> string -> term) ->
    (local_theory -> string -> 'a -> term) ->
    string list ->
    (local_theory -> string -> 'a) ->
    typ ->
    local_theory -> term

  val create_map :
    (typ -> term) -> (*default operation*)
    (string * typ -> 'a -> term) -> (*recursive occurrences of operation*)
    'a -> (*initial state*)
    (typ -> bool) -> (*early abort*)
    (local_theory -> string -> bool list) -> (*used positions*)
    (local_theory -> string -> term) -> (*map function*)
    (local_theory -> string -> 'a -> term) -> (*partial operation*)
    string list -> (*mutually recursive types*)
    (local_theory -> string -> 'a) -> (*next state function*)
    typ -> (*type for which map should be created*)
    local_theory -> term

end

structure Generator_Aux : GENERATOR_AUX =
struct

fun alist_to_string al = map fst al |> commas |> enclose "(" ")"

fun typ_and_vs_of_used_typname typ_name used_pos sort =
  let
    val sorts = map (fn b => if b then sort else @{sort type}) used_pos
    val ty_vars = Name.invent_names (Name.make_context [typ_name]) "a" sorts
    val ty = Type (typ_name,map TFree ty_vars)
  in
    (ty,ty_vars)
  end

fun is_class_instance thy tname class =
  Sorts.has_instance (Sign.classes_of thy) tname class

fun conjI_tac conj_thms ctxt xs tac =
  if null xs then all_tac
  else
    (K (Method.try_intros_tac ctxt conj_thms [])
    THEN_ALL_NEW (fn k' =>
      Subgoal.SUBPROOF (fn {context = ctxt', ...} => tac ctxt' (k' - 1)) ctxt k'))
    1

fun mk_id T =
  let val x = Free ("x", T)
  in lambda x x end

local
  fun create_gen mk_comp dfun mk_p x early_abort up mfun pfun tycos read T lthy =
    let
      fun create x (T as (Type (tyco, Ts))) =
            if early_abort T then mk_id dummyT
            else if member (op =) tycos tyco then mk_p (tyco, T) x
            else
              let
                val x' = read lthy tyco
                val ts = (up lthy tyco ~~ Ts) |> map (fn (used, T) =>
                  if used then create x' T else dfun dummyT)
              in mk_comp (pfun lthy tyco x, list_comb (mfun lthy tyco, ts)) end
        | create _ _ = mk_id dummyT
    in create x T end
in
  fun create_partial x = create_gen HOLogic.mk_comp mk_id ((K o K) (mk_id dummyT)) x
  fun create_map dfun mk_p x = create_gen snd dfun mk_p x
end

fun drop_last [] = raise Empty
  | drop_last (x::xs) =
      let
        fun init _ [] = []
          | init x (y::ys) = x :: init y ys
      in init x xs end

fun rename_types [] t = t
  | rename_types ((t1, t2) :: ts) t =
      if t1 = t2 then rename_types ts t
      else
        let val swap = [(t1, t2), (t2, t1)]
        in
          rename_types
            (map (apfst (typ_subst_atomic swap)) ts)
            (subst_atomic_types swap t)
        end

fun sub s = Symbol.explode s |> map (fn c => "⇩" ^ c) |> implode

fun subT name T = name ^ sub
      (case T of
        TVar (xi, _) => Term.string_of_vname xi
      | TFree (x, _) => x
      | Type (tyco, _) => tyco)

val ints_to_subscript = sub o foldr1 (fn (x, y) => x ^ "_" ^ y) o map string_of_int

fun ind_case_to_idxs cTys=
  let
    fun number n (i, j) ((_ :: xs) :: ys) = (n, (i, j)) :: number (n+1) (i, j+1) (xs :: ys)
      | number n (i, _) ([] :: ys) = number n (i+1, 0) ys
      | number _ _ [] = []
  in AList.lookup (op =) (number 0 (0, 0) cTys) #> the end

(*recursively compute free type variables that are actually used by a given ((co)data)type, i.e.,
are not (indirect) phantom types.*)
fun add_used_tfrees ctxt =
      let
        val thy = Proof_Context.theory_of ctxt

        fun err_schematic T =
              error ("illegal schematic type variable " ^ quote (Syntax.string_of_typ ctxt T))

        fun add _ (T as TVar _) = err_schematic T
          | add _ (TFree (x, _)) = insert (op =) x
          | add skip (Type (tyco, Ts)) =
              if member (op =) skip tyco then I
              else
                (case BNF_FP_Def_Sugar.fp_sugar_of ctxt tyco of
                  NONE => fold (add skip) Ts
                | SOME _ => (*(co)datatype*)
                  BNF_LFP_Compat.the_spec thy tyco
                  |>> map TFree |>> (fn x => x ~~ Ts) ||> map snd ||> flat
                  |> uncurry (map o typ_subst_atomic)
                  |> fold (add (insert (op =) tyco skip)))
      in add [] end

(*collect all Type(constructor) names occurring in a given type*)
fun add_tycos (Type (tyco, Ts)) = insert (op =) tyco #> fold add_tycos Ts
  | add_tycos _ = I

(*starting from a (co)datatype "tyco", collect all Type(constructor) names that are
involved in its construction*)
fun add_used_tycos ctxt tyco =
      (case BNF_FP_Def_Sugar.fp_sugar_of ctxt tyco of
        NONE => I
      | SOME sugar => #fp_ctr_sugar sugar |> #ctrXs_Tss |> flat |> fold add_tycos)

fun infer_type ctxt =
      map_types (map_type_tvar (K dummyT))
      #> singleton (Type_Infer_Context.infer_types ctxt)

val lambdas = fold_rev lambda

fun mk_def T c rhs = Logic.mk_equals (Const (c, T), rhs)

fun OF_option thm thms = thm OF map (the_default @{lemma "P  P" by simp}) thms

fun typ_and_vs_of_typname thy typ_name sort =
  let
    val ar = Sign.arity_number thy typ_name
    val sorts = map (K sort) (1 upto ar)
    val ty_vars = Name.invent_names (Name.make_context [typ_name]) "a" sorts
    val ty = Type (typ_name,map TFree ty_vars)
  in (ty,ty_vars) end


(* code copied from HOL/SPARK/TOOLS *)
fun define_overloaded_generic (binding,eq) lthy =
  let
    val ((c, _), rhs) = eq |> Syntax.check_term lthy |>
      Logic.dest_equals |>> dest_Free;
    val ((_, (_, thm)), lthy') = Local_Theory.define
      ((Binding.name c, NoSyn), (binding, rhs)) lthy
    val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy');
    val thm' = singleton (Proof_Context.export lthy' ctxt_thy) thm
  in (thm', lthy')
end

fun define_overloaded (name,eq) = define_overloaded_generic ((Binding.name name, @{attributes [code]}),eq)


fun mk_case_tac ctxt insts thm sub_case_tac =
  (DETERM o Induct.cases_tac ctxt false insts (SOME thm) []
  THEN_ALL_NEW (fn i =>
    Subgoal.SUBPROOF (fn {context = ctxt, prems = hyps, params = params, ...} =>
      sub_case_tac (i-1) (ctxt, hyps, params)) ctxt i))
  1

val freeify_tvars = map_type_tvar (TFree o apfst fst)

fun mutual_recursive_types tyco lthy =
      (case BNF_FP_Def_Sugar.fp_sugar_of lthy tyco of
        SOME sugar =>
          if Sign.arity_number (Proof_Context.theory_of lthy) tyco -
            BNF_Def.live_of_bnf (#fp_bnf sugar) > 0
          then error "only datatypes without dead type parameters are supported"
          else if #fp sugar = BNF_Util.Least_FP then
            sugar |> #fp_res |> #Ts |> `(map (fst o dest_Type))
            ||> map freeify_tvars
          else error "only least fixpoints are supported"
      | NONE => error ("type " ^ quote tyco ^ " does not appear to be a new style datatype"))

fun type_parameters T lthy =
      let
        val tfrees = T |> dest_Type |> snd |> map dest_TFree
        val used_tfrees = (*type parameters of type tyco that are used (maintain original order)*)
          inter (op =) (add_used_tfrees lthy T []) (map fst tfrees)
          |> map (fn a => TFree (a, AList.lookup (op =) tfrees a |> the))
      in (tfrees, used_tfrees) end

fun sum_list xs = fold (curry (op +)) xs 0

fun mk_infer_const name ctxt c = infer_type ctxt (Const (name, dummyT) $ c)

fun prove_multi_future ctxt = Goal.prove_common ctxt (SOME ~1)

fun recursor rec_info all free tvar r typ (T as Type (tyco,Ts)) =
      if member (op =) (#3 rec_info) tyco then r T
      else
        let
          val (get_info,get_used,_) = rec_info
          val info = get_info tyco
          val up = get_used info
          val recs = (if all then map (pair true) Ts else up ~~ Ts) |> map (fn (b, T) =>
            if b then (T, SOME (recursor rec_info all free tvar r typ T)) else (T, NONE))
        in typ (recs, info) end
  | recursor _ _ free _ _ _ (T as TFree _) = free T
  | recursor _ _ _ tvar _ _ (T as TVar _) = tvar T

(* use the recursor to compute the number of IHs, in order to split them *)
fun num_IHs rec_info = recursor rec_info true (K 0) (K 0) (K 1)
      (fn (xs, _) => sum_list (map (the o snd) xs))

fun split_IHs rec_info (ty :: tys : typ list) (IHs : thm list) : thm list list =
      let
        val n = num_IHs rec_info ty
        val _ = if n > length IHs then error "split IH error: too few" else ()
      in
        take n IHs :: split_IHs rec_info tys (drop n IHs)
      end
  | split_IHs _ [] [] = []
  | split_IHs _ [] (_ :: _) = error "split IH error: too many"

fun std_recursor_tac rec_info used_tfrees info_to_pthm assms = recursor rec_info false
      (* TFrees via pre_condition and blast *)
      (fn T => fn IH => fn ctxt =>
          if null IH then
            (resolve_tac ctxt [nth assms (find_index (equal T) used_tfrees)] THEN_ALL_NEW blast_tac ctxt) 1
          else error "error 1 in distributing IHs in recursor_tac")
      (* TVars may not occur *)
      (fn ty => error ("error in recursor_tac for " ^ @{make_string} ty))
      (* recursive case via IH and blast *)
      (K (fn IHs => fn ctxt =>
        if length IHs = 1 then
          (resolve_tac ctxt [hd IHs] THEN_ALL_NEW blast_tac ctxt) 1
        else error "error 2 in distributing IHs in recursor_tac"))
      (* non-rec.-case distributed IHs and invokes partial soundness thm *)
      (fn (tys_tactics, info) => fn IH => fn ctxt =>
        let
          val IHs = split_IHs rec_info (map fst tys_tactics) IH
          val tactics = tys_tactics ~~ IHs |> map_filter (fn ((_, tac_opt), IH) =>
            Option.map (fn f => f IH) tac_opt)
          val pthm = info_to_pthm info
        in
          HEADGOAL (
            resolve_tac ctxt [pthm]
            THEN_ALL_NEW (fn k => Subgoal.SUBPROOF (fn {prems, context = ctxt', ...} =>
              Method.insert_tac ctxt' prems 1
              THEN (nth tactics (k - 1) ctxt')) ctxt k))
        end)

end

Theory Comparator

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
section ‹Comparisons›

subsection ‹Comparators and Linear Orders›

theory Comparator
imports Main
begin

text ‹Instead of having to define a strict and a weak linear order, @{term "((<), (≤))"},
 one can alternative use a comparator to define the linear order, which may deliver 
 three possible outcomes when comparing two values.›

datatype order = Eq | Lt | Gt

type_synonym 'a comparator = "'a  'a  order"

text ‹In the following, we provide the obvious definitions how to switch between 
  linear orders and comparators.›

definition lt_of_comp :: "'a comparator  'a  'a  bool" where
  "lt_of_comp acomp x y = (case acomp x y of Lt  True | _  False)"

definition le_of_comp :: "'a comparator  'a  'a  bool" where
  "le_of_comp acomp x y = (case acomp x y of Gt  False | _  True)"
  
definition comp_of_ords :: "('a  'a  bool)  ('a  'a  bool)  'a comparator" where
  "comp_of_ords le lt x y = (if lt x y then Lt else if le x y then Eq else Gt)"

lemma comp_of_ords_of_le_lt[simp]: "comp_of_ords (le_of_comp c) (lt_of_comp c) = c"
  by (intro ext, auto simp: comp_of_ords_def le_of_comp_def lt_of_comp_def split: order.split)

lemma lt_of_comp_of_ords: "lt_of_comp (comp_of_ords le lt) = lt"
  by (intro ext, auto simp: comp_of_ords_def le_of_comp_def lt_of_comp_def split: order.split)

lemma le_of_comp_of_ords_gen: "( x y. lt x y  le x y)  le_of_comp (comp_of_ords le lt) = le"
  by (intro ext, auto simp: comp_of_ords_def le_of_comp_def lt_of_comp_def split: order.split)

lemma le_of_comp_of_ords_linorder: assumes "class.linorder le lt"
  shows "le_of_comp (comp_of_ords le lt) = le"
proof -
  interpret linorder le lt by fact
  show ?thesis by (rule le_of_comp_of_ords_gen) simp
qed

fun invert_order:: "order  order" where
  "invert_order Lt = Gt" |
  "invert_order Gt = Lt" |
  "invert_order Eq = Eq"

locale comparator =
  fixes comp :: "'a comparator"
  assumes sym: "invert_order (comp x y) = comp y x"
    and weak_eq: "comp x y = Eq  x = y"
    and comp_trans: "comp x y = Lt  comp y z = Lt  comp x z = Lt"
begin 

lemma eq: "(comp x y = Eq) = (x = y)"
proof
  assume "x = y"
  with sym [of y y] show "comp x y = Eq" by (cases "comp x y") auto
qed (rule weak_eq)

lemma comp_same [simp]:
  "comp x x = Eq"
  by (simp add: eq)

abbreviation "lt  lt_of_comp comp"
abbreviation "le  le_of_comp comp"

sublocale ordering le lt
proof
  note [simp] = lt_of_comp_def le_of_comp_def
  fix x y z :: 'a
  show "le x x" using eq [of x x] by (simp)
  show "le x y  le y z  le x z"
    by (cases "comp x y" "comp y z" rule: order.exhaust [case_product order.exhaust])
       (auto dest: comp_trans simp: eq)
  show "le x y  le y x  x = y"
    using sym [of x y] by (cases "comp x y") (simp_all add: eq)
  show "lt x y  le x y  x  y"
    using eq [of x y] by (cases "comp x y") (simp_all)
qed

lemma linorder: "class.linorder le lt"
proof (rule class.linorder.intro)
  interpret order le lt
    using ordering_axioms by (rule ordering_orderI)
  show ‹class.order le lt›
    by (fact order_axioms)
  show ‹class.linorder_axioms le›
  proof
    note [simp] = lt_of_comp_def le_of_comp_def
    fix x y :: 'a
    show "le x y  le y x"
      using sym [of x y] by (cases "comp x y") (simp_all)
  qed
qed

sublocale linorder le lt
  by (rule linorder)

lemma Gt_lt_conv: "comp x y = Gt  lt y x" 
  unfolding lt_of_comp_def sym[of x y, symmetric] 
  by (cases "comp x y", auto)
lemma Lt_lt_conv: "comp x y = Lt  lt x y" 
  unfolding lt_of_comp_def by (cases "comp x y", auto)
lemma eq_Eq_conv: "comp x y = Eq  x = y" 
  by (rule eq)
lemma nGt_le_conv: "comp x y  Gt  le x y" 
  unfolding le_of_comp_def by (cases "comp x y", auto)
lemma nLt_le_conv: "comp x y  Lt  le y x" 
  unfolding le_of_comp_def sym[of x y, symmetric] by (cases "comp x y", auto)
lemma nEq_neq_conv: "comp x y  Eq  x  y" 
  using eq_Eq_conv[of x y] by simp

lemmas le_lt_convs =  nLt_le_conv nGt_le_conv Gt_lt_conv Lt_lt_conv eq_Eq_conv nEq_neq_conv

lemma two_comparisons_into_case_order: 
  "(if le x y then (if x = y then P else Q) else R) = (case_order P Q R (comp x y))" 
  "(if le x y then (if y = x then P else Q) else R) = (case_order P Q R (comp x y))" 
  "(if le x y then (if le y x then P else Q) else R) = (case_order P Q R (comp x y))" 
  "(if le x y then (if lt x y then Q else P) else R) = (case_order P Q R (comp x y))" 
  "(if lt x y then Q else (if le x y then P else R)) = (case_order P Q R (comp x y))" 
  "(if lt x y then Q else (if lt y x then R else P)) = (case_order P Q R (comp x y))" 
  "(if lt x y then Q else (if x = y then P else R)) = (case_order P Q R (comp x y))" 
  "(if lt x y then Q else (if y = x then P else R)) = (case_order P Q R (comp x y))" 
  "(if x = y then P else (if lt y x then R else Q)) = (case_order P Q R (comp x y))" 
  "(if x = y then P else (if lt x y then Q else R)) = (case_order P Q R (comp x y))" 
  "(if x = y then P else (if le y x then R else Q)) = (case_order P Q R (comp x y))" 
  "(if x = y then P else (if le x y then Q else R)) = (case_order P Q R (comp x y))"
  by (auto simp: le_lt_convs split: order.splits)

end

lemma comp_of_ords: assumes "class.linorder le lt"
  shows "comparator (comp_of_ords le lt)"
proof -
  interpret linorder le lt by fact
  show ?thesis
    by (unfold_locales, auto simp: comp_of_ords_def split: if_splits)
qed

definition (in linorder) comparator_of :: "'a comparator" where
  "comparator_of x y = (if x < y then Lt else if x = y then Eq else Gt)"

lemma comparator_of: "comparator comparator_of"
  by unfold_locales (auto split: if_splits simp: comparator_of_def)

end

Theory Compare

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
subsection ‹Compare›

theory Compare
imports Comparator
keywords "compare_code" :: thy_decl
begin

text ‹This introduces a type class for having a proper comparator, similar to @{class linorder}.
  Since most of the Isabelle/HOL algorithms work on the latter, we also provide a method which 
  turns linear-order based algorithms into comparator-based algorithms, where two consecutive 
  invocations of linear orders and equality are merged into one comparator invocation.
  We further define a class which both define a linear order and a comparator, and where the
  induces orders coincide.›

class compare =
  fixes compare :: "'a comparator"
  assumes comparator_compare: "comparator compare"
begin

lemma compare_Eq_is_eq [simp]:
  "compare x y = Eq  x = y"
  by (rule comparator.eq [OF comparator_compare])
  
lemma compare_refl [simp]:
  "compare x x = Eq"
  by simp

end

lemma (in linorder) le_lt_comparator_of:
  "le_of_comp comparator_of = (≤)" "lt_of_comp comparator_of = (<)"
  by (intro ext, auto simp: comparator_of_def le_of_comp_def lt_of_comp_def)+

class compare_order = ord + compare +
  assumes ord_defs: "le_of_comp compare = (≤) " "lt_of_comp compare = (<)"

text @{class compare_order} is @{class compare} and @{class linorder}, where comparator and orders 
  define the same ordering.›

subclass (in compare_order) linorder
  by (unfold ord_defs[symmetric], rule comparator.linorder, rule comparator_compare)

context compare_order
begin

lemma compare_is_comparator_of: 
  "compare = comparator_of" 
proof (intro ext)
  fix x y :: 'a
  show "compare x y = comparator_of x y"
    by  (unfold comparator_of_def, unfold ord_defs[symmetric] lt_of_comp_def, 
      cases "compare x y", auto)
qed

lemmas two_comparisons_into_compare = 
  comparator.two_comparisons_into_case_order[OF comparator_compare, unfolded ord_defs]
  
thm two_comparisons_into_compare
end

ML_file ‹compare_code.ML›

text Compare_Code.change_compare_code const ty-vars› changes the code equations of some constant such that
  two consecutive comparisons via @{term "(<=)"}, @{term "(<)"}", or @{term "(=)"} are turned into one
  invocation of @{const compare}. 
  The difference to a standard code_unfold› is that here we change the code-equations
  where an additional sort-constraint on @{class compare_order} can be added. Otherwise, there would
  be no @{const compare}-function.›

end

File ‹compare_code.ML›

signature COMPARE_CODE =
sig
  
  (* changes the code equations of some constant such that
     two consecutive comparisons via <=, <, or = are turned into one
     invocation of the comparator. 
     The difference to a standard code_unfold is that here we change the code-equations
     where an additional sort-constraint on compare_order can be added. Otherwise, there would
     be no compare-function. *)
  val change_compare_code : 
    term                      (* the constant *) 
    -> string list     (* the list of type parameters which should be constraint to @{sort compare_order} *) 
    -> local_theory -> local_theory

end

structure Compare_Code : COMPARE_CODE =
struct

fun drop_leading_qmark s = 
  if String.isPrefix "?" s then String.substring (s,1,String.size s - 1) else s

fun change_compare_code const inst_names lthy = 
  let
    val inst_names = map drop_leading_qmark inst_names
    val const_string = quote (Pretty.string_of (Syntax.pretty_term lthy const))
    val cname = (case const of Const (cname,_) => cname | _ => 
      error ("expected constant as input, but got " ^ const_string))
    val cert = Code.get_cert lthy [] cname
    val code_eqs = cert |> Code.equations_of_cert (Proof_Context.theory_of lthy)
    |> snd |> these |> map (fst o snd) |> map_filter I
    val _ = if null code_eqs then error "could not find code equations" else ()

    (* adding sort-constraint compare_order within code equations*)
    val const' = hd code_eqs |> Thm.concl_of |> Logic.dest_equals |> fst |> strip_comb |> fst
    val types = Term.add_tvars const' []
    val ctyp_of = TVar #> Thm.ctyp_of lthy
    fun filt s = not (member (op =) (@{sort ord} @ @{sort linorder} @ @{sort order}) s)
    val map_types = maps (fn ty => if List.exists (fn tn => tn = (fst o fst) ty) inst_names then 
      [(ty, ctyp_of (apsnd (fn ss => filter filt ss @ @{sort compare_order}) ty))] else 
      []) types
    val code_eqs = map (Thm.instantiate (map_types, [])) code_eqs 

    (* replace comparisons and register code eqns *)
    val new_code_eqs = map (Local_Defs.unfold lthy @{thms two_comparisons_into_compare}) code_eqs
    val _ = if map Thm.prop_of new_code_eqs = map Thm.prop_of code_eqs then
      warning ("Code equations for " ^ const_string ^ " did not change\n" ^
      "Perhaps you have to provide some type variables which should be restricted to compare_order\n" ^
      (@{make_string} (map TVar types ~~ map snd types, const', code_eqs))) 
      else ()
    val lthy = Local_Theory.note (
      (Binding.name (Long_Name.base_name cname ^ "_compare_code"), @{attributes [code]}), new_code_eqs) lthy 
      |> snd    
  in
    lthy
  end
  
fun change_compare_code_cmd const tnames_option lthy = 
  change_compare_code (Syntax.read_term lthy const) tnames_option lthy

val _ =
  Outer_Syntax.local_theory @{command_keyword compare_code} 
    "turn comparisons via <= and < into compare within code-equations"
    (Scan.optional (@{keyword "("} |-- (Parse.list Parse.string) --| @{keyword ")"}) [] --
      Parse.term >> (fn (inst, c) => change_compare_code_cmd c inst))

end

Theory RBT_Compare_Order_Impl

subsection ‹Example: Modifying the Code-Equations of Red-Black-Trees›

theory RBT_Compare_Order_Impl
imports
  Compare
  "HOL-Library.RBT_Impl"
begin

text ‹In the following, we modify all code-equations of the red-black-tree 
  implementation that perform comparisons. As a positive result, they now all require
  one invocation of comparator, where before two comparisons have been performed.
  The disadvantage of this simple solution is the additional class constraint on
  @{class compare_order}.›

compare_code ("'a") rbt_ins
compare_code ("'a") rbt_lookup
compare_code ("'a") rbt_del
compare_code ("'a") rbt_map_entry
compare_code ("'a") sunion_with
compare_code ("'a") sinter_with
compare_code ("'a") rbt_split

export_code rbt_ins rbt_lookup rbt_del rbt_map_entry rbt_union_with_key rbt_inter_with_key rbt_minus in Haskell

end

Theory RBT_Comparator_Impl

subsection ‹A Comparator-Interface to Red-Black-Trees›

theory RBT_Comparator_Impl
imports 
  "HOL-Library.RBT_Impl"
  Comparator
begin

text ‹For all of the main algorithms of red-black trees, we
  provide alternatives which are completely based on comparators,
  and which are provable equivalent. At the time of writing,
  this interface is used in the Container AFP-entry.
  
  It does not rely on the modifications of code-equations as in 
  the previous subsection.›

context 
  fixes c :: "'a comparator"
begin

primrec rbt_comp_lookup :: "('a, 'b) rbt  'a  'b"
where
  "rbt_comp_lookup RBT_Impl.Empty k = None"
| "rbt_comp_lookup (Branch _ l x y r) k = 
   (case c k x of Lt  rbt_comp_lookup l k 
     | Gt  rbt_comp_lookup r k 
     | Eq  Some y)"

fun
  rbt_comp_ins :: "('a  'b  'b  'b)  'a  'b  ('a,'b) rbt  ('a,'b) rbt"
where
  "rbt_comp_ins f k v RBT_Impl.Empty = Branch RBT_Impl.R RBT_Impl.Empty k v  RBT_Impl.Empty" |
  "rbt_comp_ins f k v (Branch RBT_Impl.B l x y r) = (case c k x of 
      Lt  balance (rbt_comp_ins f k v l) x y r
    | Gt  balance l x y (rbt_comp_ins f k v r)
    | Eq  Branch RBT_Impl.B l x (f k y v) r)" |
  "rbt_comp_ins f k v (Branch RBT_Impl.R l x y r) = (case c k x of 
      Lt  Branch RBT_Impl.R (rbt_comp_ins f k v l) x y r
    | Gt  Branch RBT_Impl.R l x y (rbt_comp_ins f k v r)
    | Eq  Branch RBT_Impl.R l x (f k y v) r)"

definition rbt_comp_insert_with_key :: "('a  'b  'b  'b)  'a  'b  ('a,'b) rbt  ('a,'b) rbt"
where "rbt_comp_insert_with_key f k v t = paint RBT_Impl.B (rbt_comp_ins f k v t)"

definition rbt_comp_insert :: "'a  'b  ('a, 'b) rbt  ('a, 'b) rbt" where
  "rbt_comp_insert = rbt_comp_insert_with_key (λ_ _ nv. nv)"

fun
  rbt_comp_del_from_left :: "'a  ('a,'b) rbt  'a  'b  ('a,'b) rbt  ('a,'b) rbt" and
  rbt_comp_del_from_right :: "'a  ('a,'b) rbt  'a  'b  ('a,'b) rbt  ('a,'b) rbt" and
  rbt_comp_del :: "'a ('a,'b) rbt  ('a,'b) rbt"
where
  "rbt_comp_del x RBT_Impl.Empty = RBT_Impl.Empty" |
  "rbt_comp_del x (Branch _ a y s b) = 
   (case c x y of 
        Lt  rbt_comp_del_from_left x a y s b 
      | Gt  rbt_comp_del_from_right x a y s b
      | Eq  combine a b)" |
  "rbt_comp_del_from_left x (Branch RBT_Impl.B lt z v rt) y s b = balance_left (rbt_comp_del x (Branch RBT_Impl.B lt z v rt)) y s b" |
  "rbt_comp_del_from_left x a y s b = Branch RBT_Impl.R (rbt_comp_del x a) y s b" |
  "rbt_comp_del_from_right x a y s (Branch RBT_Impl.B lt z v rt) = balance_right a y s (rbt_comp_del x (Branch RBT_Impl.B lt z v rt))" | 
  "rbt_comp_del_from_right x a y s b = Branch RBT_Impl.R a y s (rbt_comp_del x b)"

definition "rbt_comp_delete k t = paint RBT_Impl.B (rbt_comp_del k t)"

definition "rbt_comp_bulkload xs = foldr (λ(k, v). rbt_comp_insert k v) xs RBT_Impl.Empty"

primrec
  rbt_comp_map_entry :: "'a  ('b  'b)  ('a, 'b) rbt  ('a, 'b) rbt"
where
  "rbt_comp_map_entry k f RBT_Impl.Empty = RBT_Impl.Empty"
| "rbt_comp_map_entry k f (Branch cc lt x v rt) =
    (case c k x of 
        Lt  Branch cc (rbt_comp_map_entry k f lt) x v rt
      | Gt  Branch cc lt x v (rbt_comp_map_entry k f rt)
      | Eq  Branch cc lt x (f v) rt)"

function comp_sunion_with :: "('a  'b  'b  'b)  ('a × 'b) list  ('a × 'b) list  ('a × 'b) list" 
where
  "comp_sunion_with f ((k, v) # as) ((k', v') # bs) =
   (case c k' k of 
        Lt  (k', v') # comp_sunion_with f ((k, v) # as) bs
      | Gt  (k, v) # comp_sunion_with f as ((k', v') # bs)
      | Eq  (k, f k v v') # comp_sunion_with f as bs)"
| "comp_sunion_with f [] bs = bs"
| "comp_sunion_with f as [] = as"
by pat_completeness auto
termination by lexicographic_order

function comp_sinter_with :: "('a  'b  'b  'b)  ('a × 'b) list  ('a × 'b) list  ('a × 'b) list"
where
  "comp_sinter_with f ((k, v) # as) ((k', v') # bs) =
  (case c k' k of 
      Lt  comp_sinter_with f ((k, v) # as) bs
    | Gt  comp_sinter_with f as ((k', v') # bs)
    | Eq  (k, f k v v') # comp_sinter_with f as bs)"
| "comp_sinter_with f [] _ = []"
| "comp_sinter_with f _ [] = []"
by pat_completeness auto
termination by lexicographic_order

fun rbt_split_comp :: "('a, 'b) rbt  'a  ('a, 'b) rbt × 'b option × ('a, 'b) rbt" where
  "rbt_split_comp RBT_Impl.Empty k = (RBT_Impl.Empty, None, RBT_Impl.Empty)"
| "rbt_split_comp (RBT_Impl.Branch _ l a b r) x = (case c x a of
    Lt  (case rbt_split_comp l x of (l1, β, l2)  (l1, β, rbt_join l2 a b r))
  | Gt  (case rbt_split_comp r x of (r1, β, r2)  (rbt_join l a b r1, β, r2))
  | Eq  (l, Some b, r))"

lemma rbt_split_comp_size: "(l2, b, r2) = rbt_split_comp t2 a  size l2 + size r2  size t2"
  by (induction t2 a arbitrary: l2 b r2 rule: rbt_split_comp.induct)
     (auto split: order.splits if_splits prod.splits)

function rbt_comp_union_rec :: "('a  'b  'b  'b)  ('a, 'b) rbt  ('a, 'b) rbt  ('a, 'b) rbt" where
  "rbt_comp_union_rec f t1 t2 = (let (f, t2, t1) =
    if flip_rbt t2 t1 then (λk v v'. f k v' v, t1, t2) else (f, t2, t1) in
    if small_rbt t2 then RBT_Impl.fold (rbt_comp_insert_with_key f) t2 t1
    else (case t1 of RBT_Impl.Empty  t2
      | RBT_Impl.Branch _ l1 a b r1 
        case rbt_split_comp t2 a of (l2, β, r2) 
          rbt_join (rbt_comp_union_rec f l1 l2) a (case β of None  b | Some b'  f a b b') (rbt_comp_union_rec f r1 r2)))"
  by pat_completeness auto
termination
  using rbt_split_comp_size
  by (relation "measure (λ(f,t1,t2). size t1 + size t2)") (fastforce split: if_splits)+

declare rbt_comp_union_rec.simps[simp del]

function rbt_comp_union_swap_rec :: "('a  'b  'b  'b)  bool  ('a, 'b) rbt  ('a, 'b) rbt  ('a, 'b) rbt" where
  "rbt_comp_union_swap_rec f γ t1 t2 = (let (γ, t2, t1) =
    if flip_rbt t2 t1 then (¬γ, t1, t2) else (γ, t2, t1);
    f' = (if γ then (λk v v'. f k v' v) else f) in
    if small_rbt t2 then RBT_Impl.fold (rbt_comp_insert_with_key f') t2 t1
    else case t1 of rbt.Empty  t2
      | Branch x l1 a b r1 
        case rbt_split_comp t2 a of (l2, β, r2) 
          rbt_join (rbt_comp_union_swap_rec f γ l1 l2) a (case β of None  b | Some x  f' a b x) (rbt_comp_union_swap_rec f γ r1 r2))"
  by pat_completeness auto
termination
  using rbt_split_comp_size
  by (relation "measure (λ(f,γ,t1, t2). size t1 + size t2)") (fastforce split: if_splits)+

declare rbt_comp_union_swap_rec.simps[simp del]

lemma rbt_comp_union_swap_rec: "rbt_comp_union_swap_rec f γ t1 t2 =
  rbt_comp_union_rec (if γ then (λk v v'. f k v' v) else f) t1 t2"
proof (induction f γ t1 t2 rule: rbt_comp_union_swap_rec.induct)
  case (1 f γ t1 t2)
  show ?case
    using 1[OF refl _ refl refl _ refl _ refl]
    unfolding rbt_comp_union_swap_rec.simps[of _ _ t1] rbt_comp_union_rec.simps[of _ t1]
    by (auto simp: Let_def split: rbt.splits prod.splits option.splits) (* slow *)
qed

lemma rbt_comp_union_swap_rec_code[code]: "rbt_comp_union_swap_rec f γ t1 t2 = (
    let bh1 = bheight t1; bh2 = bheight t2; (γ, t2, bh2, t1, bh1) =
    if bh1 < bh2 then (¬γ, t1, bh1, t2, bh2) else (γ, t2, bh2, t1, bh1);
    f' = (if γ then (λk v v'. f k v' v) else f) in
    if bh2 < 4 then RBT_Impl.fold (rbt_comp_insert_with_key f') t2 t1
    else case t1 of rbt.Empty  t2
      | Branch x l1 a b r1 
        case rbt_split_comp t2 a of (l2, β, r2) 
          rbt_join (rbt_comp_union_swap_rec f γ l1 l2) a (case β of None  b | Some x  f' a b x) (rbt_comp_union_swap_rec f γ r1 r2))"
  by (auto simp: rbt_comp_union_swap_rec.simps flip_rbt_def small_rbt_def)

definition "rbt_comp_union_with_key f t1 t2 = paint RBT_Impl.B (rbt_comp_union_swap_rec f False t1 t2)"

definition "map_filter_comp_inter f t1 t2 = List.map_filter (λ(k, v).
  case rbt_comp_lookup t1 k of None  None
  | Some v'  Some (k, f k v' v)) (RBT_Impl.entries t2)"

function rbt_comp_inter_rec :: "('a  'b  'b  'b)  ('a, 'b) rbt  ('a, 'b) rbt  ('a, 'b) rbt" where
  "rbt_comp_inter_rec f t1 t2 = (let (f, t2, t1) =
    if flip_rbt t2 t1 then (λk v v'. f k v' v, t1, t2) else (f, t2, t1) in
    if small_rbt t2 then rbtreeify (map_filter_comp_inter f t1 t2)
    else case t1 of RBT_Impl.Empty  RBT_Impl.Empty
    | RBT_Impl.Branch _ l1 a b r1 
      case rbt_split_comp t2 a of (l2, β, r2)  let l' = rbt_comp_inter_rec f l1 l2; r' = rbt_comp_inter_rec f r1 r2 in
      (case β of None  rbt_join2 l' r' | Some b'  rbt_join l' a (f a b b') r'))"
  by pat_completeness auto
termination
  using rbt_split_comp_size
  by (relation "measure (λ(f,t1,t2). size t1 + size t2)") (fastforce split: if_splits)+

declare rbt_comp_inter_rec.simps[simp del]

function rbt_comp_inter_swap_rec :: "('a  'b  'b  'b)  bool  ('a, 'b) rbt  ('a, 'b) rbt  ('a, 'b) rbt" where
  "rbt_comp_inter_swap_rec f γ t1 t2 = (let (γ, t2, t1) =
    if flip_rbt t2 t1 then (¬γ, t1, t2) else (γ, t2, t1);
    f' = if γ then (λk v v'. f k v' v) else f in
    if small_rbt t2 then rbtreeify (map_filter_comp_inter f' t1 t2)
    else case t1 of rbt.Empty  rbt.Empty
    | Branch x l1 a b r1 
      (case rbt_split_comp t2 a of (l2, β, r2)  let l' = rbt_comp_inter_swap_rec f γ l1 l2; r' = rbt_comp_inter_swap_rec f γ r1 r2 in
      (case β of None  rbt_join2 l' r' | Some b'  rbt_join l' a (f' a b b') r')))"
  by pat_completeness auto
termination
  using rbt_split_comp_size
  by (relation "measure (λ(f,γ,t1,t2). size t1 + size t2)") (fastforce split: if_splits)+

declare rbt_comp_inter_swap_rec.simps[simp del]

lemma rbt_comp_inter_swap_rec: "rbt_comp_inter_swap_rec f γ t1 t2 =
  rbt_comp_inter_rec (if γ then (λk v v'. f k v' v) else f) t1 t2"
proof (induction f γ t1 t2 rule: rbt_comp_inter_swap_rec.induct)
  case (1 f γ t1 t2)
  show ?case
    using 1[OF refl _ refl refl _ refl _ refl]
    unfolding rbt_comp_inter_swap_rec.simps[of _ _ t1] rbt_comp_inter_rec.simps[of _ t1]
    by (auto simp: Let_def split: rbt.splits prod.splits option.splits)
qed

lemma comp_inter_with_key_code[code]: "rbt_comp_inter_swap_rec f γ t1 t2 = (
  let bh1 = bheight t1; bh2 = bheight t2; (γ, t2, bh2, t1, bh1) =
  if bh1 < bh2 then (¬γ, t1, bh1, t2, bh2) else (γ, t2, bh2, t1, bh1);
  f' = (if γ then (λk v v'. f k v' v) else f) in
  if bh2 < 4 then rbtreeify (map_filter_comp_inter f' t1 t2)
  else case t1 of rbt.Empty  rbt.Empty
    | Branch x l1 a b r1 
      (case rbt_split_comp t2 a of (l2, β, r2)  let l' = rbt_comp_inter_swap_rec f γ l1 l2; r' = rbt_comp_inter_swap_rec f γ r1 r2 in
      (case β of None  rbt_join2 l' r' | Some b'  rbt_join l' a (f' a b b') r')))"
  by (auto simp: rbt_comp_inter_swap_rec.simps flip_rbt_def small_rbt_def)

definition "rbt_comp_inter_with_key f t1 t2 = paint RBT_Impl.B (rbt_comp_inter_swap_rec f False t1 t2)"

definition "filter_comp_minus t1 t2 =
  filter (λ(k, _). rbt_comp_lookup t2 k = None) (RBT_Impl.entries t1)"

fun comp_minus :: "('a, 'b) rbt  ('a, 'b) rbt  ('a, 'b) rbt" where
  "comp_minus t1 t2 = (if small_rbt t2 then RBT_Impl.fold (λk _ t. rbt_comp_delete k t) t2 t1
    else if small_rbt t1 then rbtreeify (filter_comp_minus t1 t2)
    else case t2 of RBT_Impl.Empty  t1
      | RBT_Impl.Branch _ l2 a b r2 
        case rbt_split_comp t1 a of (l1, _, r1)  rbt_join2 (comp_minus l1 l2) (comp_minus r1 r2))"

declare comp_minus.simps[simp del]

definition "rbt_comp_minus t1 t2 = paint RBT_Impl.B (comp_minus t1 t2)"

context
  assumes c: "comparator c"
begin

lemma rbt_comp_lookup: "rbt_comp_lookup = ord.rbt_lookup (lt_of_comp c)" 
proof (intro ext)
  fix k and t :: "('a,'b)rbt"
  show "rbt_comp_lookup t k = ord.rbt_lookup (lt_of_comp c) t k"
    by (induct t, unfold rbt_comp_lookup.simps ord.rbt_lookup.simps
      comparator.two_comparisons_into_case_order[OF c]) 
      (auto split: order.splits)
qed  

lemma rbt_comp_ins: "rbt_comp_ins = ord.rbt_ins (lt_of_comp c)" 
proof (intro ext)
  fix f k v and t :: "('a,'b)rbt"
  show "rbt_comp_ins f k v t = ord.rbt_ins (lt_of_comp c) f k v t"
    by (induct f k v t rule: rbt_comp_ins.induct, unfold rbt_comp_ins.simps ord.rbt_ins.simps
      comparator.two_comparisons_into_case_order[OF c]) 
      (auto split: order.splits)
qed  

lemma rbt_comp_insert_with_key: "rbt_comp_insert_with_key = ord.rbt_insert_with_key (lt_of_comp c)"
  unfolding rbt_comp_insert_with_key_def[abs_def] ord.rbt_insert_with_key_def[abs_def]
  unfolding rbt_comp_ins ..

lemma rbt_comp_insert: "rbt_comp_insert = ord.rbt_insert (lt_of_comp c)"
  unfolding rbt_comp_insert_def[abs_def] ord.rbt_insert_def[abs_def]
  unfolding rbt_comp_insert_with_key ..

lemma rbt_comp_del: "rbt_comp_del = ord.rbt_del (lt_of_comp c)" 
proof - {
  fix k a b and s t :: "('a,'b)rbt"
  have 
    "rbt_comp_del_from_left k t a b s = ord.rbt_del_from_left (lt_of_comp c) k t a b s"
    "rbt_comp_del_from_right k t a b s = ord.rbt_del_from_right (lt_of_comp c) k t a b s"
    "rbt_comp_del k t = ord.rbt_del (lt_of_comp c) k t"
  by (induct k t a b s and k t a b s and k t rule: rbt_comp_del_from_left_rbt_comp_del_from_right_rbt_comp_del.induct,
    unfold 
      rbt_comp_del.simps ord.rbt_del.simps
      rbt_comp_del_from_left.simps ord.rbt_del_from_left.simps
      rbt_comp_del_from_right.simps ord.rbt_del_from_right.simps
      comparator.two_comparisons_into_case_order[OF c],
    auto split: order.split) 
  }
  thus ?thesis by (intro ext)
qed  

lemma rbt_comp_delete: "rbt_comp_delete = ord.rbt_delete (lt_of_comp c)"
  unfolding rbt_comp_delete_def[abs_def] ord.rbt_delete_def[abs_def]
  unfolding rbt_comp_del ..

lemma rbt_comp_bulkload: "rbt_comp_bulkload = ord.rbt_bulkload (lt_of_comp c)"
  unfolding rbt_comp_bulkload_def[abs_def] ord.rbt_bulkload_def[abs_def]
  unfolding rbt_comp_insert ..

lemma rbt_comp_map_entry: "rbt_comp_map_entry = ord.rbt_map_entry (lt_of_comp c)" 
proof (intro ext)
  fix f k and t :: "('a,'b)rbt"
  show "rbt_comp_map_entry f k t = ord.rbt_map_entry (lt_of_comp c) f k t"
    by (induct t, unfold rbt_comp_map_entry.simps ord.rbt_map_entry.simps
      comparator.two_comparisons_into_case_order[OF c]) 
      (auto split: order.splits)
qed  

lemma comp_sunion_with: "comp_sunion_with = ord.sunion_with (lt_of_comp c)"
proof (intro ext)
  fix f and as bs :: "('a × 'b)list"
  show "comp_sunion_with f as bs = ord.sunion_with (lt_of_comp c) f as bs"
    by (induct f as bs rule: comp_sunion_with.induct,
      unfold comp_sunion_with.simps ord.sunion_with.simps
      comparator.two_comparisons_into_case_order[OF c]) 
      (auto split: order.splits)
qed

lemma anti_sym: "lt_of_comp c a x  lt_of_comp c x a  False"
  by (metis c comparator.Gt_lt_conv comparator.Lt_lt_conv order.distinct(5))

lemma rbt_split_comp: "rbt_split_comp t x = ord.rbt_split (lt_of_comp c) t x"
  by (induction t x rule: rbt_split_comp.induct)
     (auto simp: ord.rbt_split.simps comparator.le_lt_convs[OF c]
      split: order.splits prod.splits dest: anti_sym)

lemma comp_union_with_key: "rbt_comp_union_rec f t1 t2 = ord.rbt_union_rec (lt_of_comp c) f t1 t2"
proof (induction f t1 t2 rule: rbt_comp_union_rec.induct)
  case (1 f t1 t2)
  obtain f' t1' t2' where flip: "(f', t2', t1') =
    (if flip_rbt t2 t1 then (λk v v'. f k v' v, t1, t2) else (f, t2, t1))"
    by fastforce
  show ?case
  proof (cases t1')
    case (Branch _ l1 a b r1)
    have t1_not_Empty: "t1'  RBT_Impl.Empty"
      by (auto simp: Branch)
    obtain l2 β r2 where split: "rbt_split_comp t2' a = (l2, β, r2)"
      by (cases "rbt_split_comp t2' a") auto
    show ?thesis
      using 1[OF flip refl _ _ Branch]
      unfolding rbt_comp_union_rec.simps[of _ t1] ord.rbt_union_rec.simps[of _ _ t1] flip[symmetric]
      by (auto simp: Branch split rbt_split_comp[symmetric] rbt_comp_insert_with_key
          split: prod.splits)
  qed (auto simp: rbt_comp_union_rec.simps[of _ t1] ord.rbt_union_rec.simps[of _ _ t1] flip[symmetric]
       rbt_comp_insert_with_key rbt_split_comp[symmetric])
qed

lemma comp_sinter_with: "comp_sinter_with = ord.sinter_with (lt_of_comp c)"
proof (intro ext)
  fix f and as bs :: "('a × 'b)list"
  show "comp_sinter_with f as bs = ord.sinter_with (lt_of_comp c) f as bs"
    by (induct f as bs rule: comp_sinter_with.induct,
      unfold comp_sinter_with.simps ord.sinter_with.simps
      comparator.two_comparisons_into_case_order[OF c]) 
      (auto split: order.splits)
qed

lemma rbt_comp_union_with_key: "rbt_comp_union_with_key = ord.rbt_union_with_key (lt_of_comp c)"
  by (rule ext)+
     (auto simp: rbt_comp_union_with_key_def rbt_comp_union_swap_rec ord.rbt_union_with_key_def
      ord.rbt_union_swap_rec comp_union_with_key)

lemma comp_inter_with_key: "rbt_comp_inter_rec f t1 t2 = ord.rbt_inter_rec (lt_of_comp c) f t1 t2"
proof (induction f t1 t2 rule: rbt_comp_inter_rec.induct)
  case (1 f t1 t2)
  obtain f' t1' t2' where flip: "(f', t2', t1') =
    (if flip_rbt t2 t1 then (λk v v'. f k v' v, t1, t2) else (f, t2, t1))"
    by fastforce
  show ?case
  proof (cases t1')
    case (Branch _ l1 a b r1)
    have t1_not_Empty: "t1'  RBT_Impl.Empty"
      by (auto simp: Branch)
    obtain l2 β r2 where split: "rbt_split_comp t2' a = (l2, β, r2)"
      by (cases "rbt_split_comp t2' a") auto
    show ?thesis
      using 1[OF flip refl _ _ Branch]
      unfolding rbt_comp_inter_rec.simps[of _ t1] ord.rbt_inter_rec.simps[of _ _ t1] flip[symmetric]
      by (auto simp: Branch split rbt_split_comp[symmetric] rbt_comp_lookup
          ord.map_filter_inter_def map_filter_comp_inter_def split: prod.splits)
  qed (auto simp: rbt_comp_inter_rec.simps[of _ t1] ord.rbt_inter_rec.simps[of _ _ t1] flip[symmetric]
       ord.map_filter_inter_def map_filter_comp_inter_def rbt_comp_lookup rbt_split_comp[symmetric])
qed

lemma rbt_comp_inter_with_key: "rbt_comp_inter_with_key = ord.rbt_inter_with_key (lt_of_comp c)"
  by (rule ext)+
     (auto simp: rbt_comp_inter_with_key_def rbt_comp_inter_swap_rec
      ord.rbt_inter_with_key_def ord.rbt_inter_swap_rec comp_inter_with_key)

lemma comp_minus: "comp_minus t1 t2 = ord.rbt_minus_rec (lt_of_comp c) t1 t2"
proof (induction t1 t2 rule: comp_minus.induct)
  case (1 t1 t2)
  show ?case
  proof (cases t2)
    case (Branch _ l2 a u r2)
    have t2_not_Empty: "t2  RBT_Impl.Empty"
      by (auto simp: Branch)
    obtain l1 β r1 where split: "rbt_split_comp t1 a = (l1, β, r1)"
      by (cases "rbt_split_comp t1 a") auto
    show ?thesis
      using 1[OF _ _ Branch]
      unfolding comp_minus.simps[of t1 t2] ord.rbt_minus_rec.simps[of _ t1 t2]
      by (auto simp: Branch split rbt_split_comp[symmetric] rbt_comp_delete rbt_comp_lookup
          filter_comp_minus_def ord.filter_minus_def split: prod.splits)
  qed (auto simp: comp_minus.simps[of t1] ord.rbt_minus_rec.simps[of _ t1]
       filter_comp_minus_def ord.filter_minus_def
       rbt_comp_delete rbt_comp_lookup rbt_split_comp[symmetric])
qed

lemma rbt_comp_minus: "rbt_comp_minus = ord.rbt_minus (lt_of_comp c)"
  by (rule ext)+ (auto simp: rbt_comp_minus_def ord.rbt_minus_def comp_minus)

lemmas rbt_comp_simps = 
  rbt_comp_insert
  rbt_comp_lookup
  rbt_comp_delete
  rbt_comp_bulkload
  rbt_comp_map_entry
  rbt_comp_union_with_key
  rbt_comp_inter_with_key
  rbt_comp_minus
end
end

end

Theory Comparator_Generator

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
section ‹Generating Comparators›

theory Comparator_Generator
imports
  "../Generator_Aux"
  "../Derive_Manager"
  Comparator
begin

typedecl ('a,'b,'c,'z)type

text ‹In the following, we define a generator which for a given datatype @{typ "('a,'b,'c,'z)type"}
  constructs a comparator of type 
  @{typ "'a comparator  'b comparator  'c comparator  'z comparator  ('a,'b,'c,'z)type"}.
  To this end, we first compare the index of the constructors, then for equal constructors, we
  compare the arguments recursively and combine the results lexicographically.›

hide_type "type"

subsection ‹Lexicographic combination of @{typ order}

fun comp_lex :: "order list  order"
where
  "comp_lex (c # cs) = (case c of Eq  comp_lex cs | _  c)" |
  "comp_lex [] = Eq"

subsection ‹Improved code for non-lazy languages›

text ‹The following equations will eliminate all occurrences of @{term comp_lex}
  in the generated code of the comparators.›

lemma comp_lex_unfolds: 
  "comp_lex [] = Eq"
  "comp_lex [c] = c"
  "comp_lex (c # d # cs) = (case c of Eq  comp_lex (d # cs) | z  z)"
  by (cases c, auto)+

subsection ‹Pointwise properties for equality, symmetry, and transitivity› 


text ‹The pointwise properties are important during inductive proofs of soundness of comparators.
  They are defined in a way that are combinable with @{const comp_lex}.›

lemma comp_lex_eq: "comp_lex os = Eq  ( ord  set os. ord = Eq)" 
  by (induct os) (auto split: order.splits)
  
definition trans_order :: "order  order  order  bool" where
  "trans_order x y z  x  Gt  y  Gt  z  Gt  ((x = Lt  y = Lt)  z = Lt)"

lemma trans_orderI:
  "(x  Gt  y  Gt  z  Gt  ((x = Lt  y = Lt)  z = Lt))  trans_order x y z"
  by (simp add: trans_order_def)

lemma trans_orderD:
  assumes "trans_order x y z" and "x  Gt" and "y  Gt"
  shows "z  Gt" and "x = Lt  y = Lt  z = Lt"
  using assms by (auto simp: trans_order_def)

lemma All_less_Suc:
  "(i < Suc x. P i)  P 0  (i < x. P (Suc i))"
  using less_Suc_eq_0_disj by force

lemma comp_lex_trans:
  assumes "length xs = length ys"
    and "length ys = length zs"
    and " i < length zs. trans_order (xs ! i) (ys ! i) (zs ! i)"
  shows "trans_order (comp_lex xs) (comp_lex ys) (comp_lex zs)"
using assms
proof (induct xs ys zs rule: list_induct3)
  case (Cons x xs y ys z zs)
  then show ?case
    by (intro trans_orderI)
       (cases x y z rule: order.exhaust [case_product order.exhaust order.exhaust],
        auto simp: All_less_Suc dest: trans_orderD)
qed (simp add: trans_order_def)

lemma comp_lex_sym:
  assumes "length xs = length ys"
    and " i < length ys. invert_order (xs ! i) = ys ! i"
  shows "invert_order (comp_lex xs) = comp_lex ys"
  using assms by (induct xs ys rule: list_induct2, simp, case_tac x) fastforce+

declare comp_lex.simps [simp del]

definition peq_comp :: "'a comparator  'a  bool"
where
  "peq_comp acomp x  ( y. acomp x y = Eq  x = y)"

lemma peq_compD: "peq_comp acomp x  acomp x y = Eq  x = y"
  unfolding peq_comp_def by auto

lemma peq_compI: "( y. acomp x y = Eq  x = y)  peq_comp acomp x"
  unfolding peq_comp_def by auto

definition psym_comp :: "'a comparator  'a  bool" where
  "psym_comp acomp x  ( y. invert_order (acomp x y) = (acomp y x))"

lemma psym_compD:
  assumes "psym_comp acomp x"
  shows "invert_order (acomp x y) = (acomp y x)"
  using assms unfolding psym_comp_def by blast+

lemma psym_compI:
  assumes " y. invert_order (acomp x y) = (acomp y x)"
  shows "psym_comp acomp x"
  using assms unfolding psym_comp_def by blast


definition ptrans_comp :: "'a comparator  'a  bool" where
  "ptrans_comp acomp x  ( y z. trans_order (acomp x y) (acomp y z) (acomp x z))"

lemma ptrans_compD:
  assumes "ptrans_comp acomp x"
  shows "trans_order (acomp x y) (acomp y z) (acomp x z)"
  using assms unfolding ptrans_comp_def by blast+

lemma ptrans_compI:
  assumes " y z. trans_order (acomp x y) (acomp y z) (acomp x z)"
  shows "ptrans_comp acomp x"
  using assms unfolding ptrans_comp_def by blast

subsection ‹Separate properties of comparators›

definition eq_comp :: "'a comparator  bool" where
  "eq_comp acomp  ( x. peq_comp acomp x)"

lemma eq_compD2: "eq_comp acomp  peq_comp acomp x"
  unfolding eq_comp_def by blast

lemma eq_compI2: "( x. peq_comp acomp x)  eq_comp acomp" 
  unfolding eq_comp_def by blast
    
definition trans_comp :: "'a comparator  bool" where
  "trans_comp acomp  ( x. ptrans_comp acomp x)"
  
lemma trans_compD2: "trans_comp acomp  ptrans_comp acomp x"
  unfolding trans_comp_def by blast

lemma trans_compI2: "( x. ptrans_comp acomp x)  trans_comp acomp" 
  unfolding trans_comp_def by blast

  
definition sym_comp :: "'a comparator  bool" where
  "sym_comp acomp  ( x. psym_comp acomp x)"

lemma sym_compD2:
  "sym_comp acomp  psym_comp acomp x"
  unfolding sym_comp_def by blast

lemma sym_compI2: "( x. psym_comp acomp x)  sym_comp acomp" 
  unfolding sym_comp_def by blast

lemma eq_compD: "eq_comp acomp  acomp x y = Eq  x = y"
  by (rule peq_compD[OF eq_compD2])

lemma eq_compI: "( x y. acomp x y = Eq  x = y)  eq_comp acomp"
  by (intro eq_compI2 peq_compI)

lemma trans_compD: "trans_comp acomp  trans_order (acomp x y) (acomp y z) (acomp x z)"
  by (rule ptrans_compD[OF trans_compD2])

lemma trans_compI: "( x y z. trans_order (acomp x y) (acomp y z) (acomp x z))  trans_comp acomp"
  by (intro trans_compI2 ptrans_compI)

lemma sym_compD:
  "sym_comp acomp  invert_order (acomp x y) = (acomp y x)" 
  by (rule psym_compD[OF sym_compD2])
  
lemma sym_compI: "( x y. invert_order (acomp x y) = (acomp y x))  sym_comp acomp"
  by (intro sym_compI2 psym_compI)

lemma eq_sym_trans_imp_comparator:
  assumes "eq_comp acomp" and "sym_comp acomp" and "trans_comp acomp"
  shows "comparator acomp"
proof
  fix x y z
  show "invert_order (acomp x y) = acomp y x"
    using sym_compD [OF ‹sym_comp acomp] .
  {
    assume "acomp x y = Eq"
    with eq_compD [OF ‹eq_comp acomp]
    show "x = y" by blast
  }
  {
    assume "acomp x y = Lt" and "acomp y z = Lt"
    with trans_orderD [OF trans_compD [OF ‹trans_comp acomp], of x y z]
    show "acomp x z = Lt" by auto
  }
qed

lemma comparator_imp_eq_sym_trans:
  assumes "comparator acomp"
  shows "eq_comp acomp" "sym_comp acomp" "trans_comp acomp" 
proof -
  interpret comparator acomp by fact
  show "eq_comp acomp" using eq by (intro eq_compI, auto)
  show "sym_comp acomp" using sym by (intro sym_compI, auto)
  show "trans_comp acomp"
  proof (intro trans_compI trans_orderI)
    fix x y z
    assume "acomp x y  Gt" "acomp y z  Gt"
    thus "acomp x z  Gt  (acomp x y = Lt  acomp y z = Lt  acomp x z = Lt)"
      using comp_trans [of x y z] and eq [of x y] and eq [of y z]
      by (cases "acomp x y" "acomp y z" rule: order.exhaust [case_product order.exhaust]) auto
  qed
qed

context
  fixes acomp :: "'a comparator"
  assumes c: "comparator acomp"
begin
lemma comp_to_psym_comp: "psym_comp acomp x"
  using comparator_imp_eq_sym_trans[OF c]
  by (intro sym_compD2)

lemma comp_to_peq_comp: "peq_comp acomp x" 
  using comparator_imp_eq_sym_trans [OF c] 
  by (intro eq_compD2)
  
lemma comp_to_ptrans_comp: "ptrans_comp acomp x" 
  using comparator_imp_eq_sym_trans [OF c] 
  by (intro trans_compD2)
end

subsection ‹Auxiliary Lemmas for Comparator Generator›

lemma forall_finite: "( i < (0 :: nat). P i) = True"
   "( i < Suc 0. P i) = P 0"
   "( i < Suc (Suc x). P i) = (P 0  ( i < Suc x. P (Suc i)))"
  by (auto, case_tac i, auto)
  
lemma trans_order_different:
  "trans_order a b Lt"
  "trans_order Gt b c"
  "trans_order a Gt c"
  by (intro trans_orderI, auto)+

lemma length_nth_simps: 
  "length [] = 0" "length (x # xs) = Suc (length xs)" 
  "(x # xs) ! 0 = x" "(x # xs) ! (Suc n) = xs ! n" by auto

subsection ‹The Comparator Generator›

ML_file ‹comparator_generator.ML›
                 
end

File ‹comparator_generator.ML›

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann
    License:     LGPL
*)
signature COMPARATOR_GENERATOR =
sig

  type info =
   {map : term,                    (* take % x. x, if there is no map *)
    pcomp : term,                  (* partial comparator *)
    comp : term,                   (* full comparator *)
    comp_def : thm option,         (* definition of comparator, important for nesting *)
    map_comp : thm option,         (* compositionality of map, important for nesting *)
    partial_comp_thms : thm list,  (* first eq, then sym, finally trans *)
    comp_thm : thm,               (* comparator acomp ⟹ … ⟹ comparator (full_comp acomp …) *)
    used_positions : bool list}

  (* registers @{term comparator_of :: "some_type :: linorder comparator"}
     where some_type must just be a type without type-arguments *)
  val register_comparator_of : string -> local_theory -> local_theory

  val register_foreign_comparator :
    typ -> (* type-constant without type-variables *)
    term -> (* comparator for type *)
    thm -> (* comparator thm for provided comparator *)
    local_theory -> local_theory

  val register_foreign_partial_and_full_comparator :
    string -> (* long type name *)
    term -> (* map function, should be λx. x, if there is no map *)
    term -> (* partial comparator of type ('a => order, 'b)ty => ('a,'b)ty => order,
      where 'a is used, 'b is unused *)
    term -> (* (full) comparator of type ('a ⇒ 'a ⇒ order) ⇒ ('a,'b)ty ⇒ ('a,'b)ty ⇒ order,
      where 'a is used, 'b is unused *)
    thm option -> (* comp_def, should be full_comp = pcomp o map acomp ..., important for nesting *)
    thm option -> (* map compositionality, important for nesting *)
    thm -> (* partial eq thm for full comparator *)
    thm -> (* partial sym thm for full comparator *)
    thm -> (* partial trans thm for full comparator *)
    thm -> (* full thm: comparator a-comp => comparator (full_comp a-comp) *)
    bool list -> (*used positions*)
    local_theory -> local_theory

  datatype comparator_type = Linorder | BNF

  val generate_comparators_from_bnf_fp :
    string ->                 (* name of type *)
    local_theory ->
    ((term * thm list) list * (* partial comparators + simp-rules *)
    (term * thm) list) *      (* non-partial comparator + def_rule *)
    local_theory

  val generate_comparator :
    comparator_type ->
    string -> (* name of type *)
    local_theory -> local_theory

  val get_info : Proof.context -> string -> info option

  (* ensures that the info will be available on later requests *)
  val ensure_info : comparator_type -> string -> local_theory -> local_theory

end

structure Comparator_Generator : COMPARATOR_GENERATOR =
struct

open Generator_Aux

datatype comparator_type = BNF | Linorder

val debug = false
fun debug_out s = if debug then writeln s else ()

val orderT = @{typ order}
fun compT T = T --> T --> orderT
val orderify = map_atyps (fn T => T --> orderT)
fun pcompT T = orderify T --> T --> orderT

type info =
 {map : term,
  pcomp : term,
  comp : term,
  comp_def : thm option,
  map_comp : thm option,
  partial_comp_thms : thm list,
  comp_thm : thm,
  used_positions : bool list};

structure Data = Generic_Data (
  type T = info Symtab.table;
  val empty = Symtab.empty;
  val extend = I;
  val merge = Symtab.merge (fn (info1 : info, info2 : info) => #comp info1 = #comp info2);
);

fun add_info T info = Data.map (Symtab.update_new (T, info))

val get_info = Context.Proof #> Data.get #> Symtab.lookup

fun the_info ctxt tyco =
     (case get_info ctxt tyco of
        SOME info => info
      | NONE => error ("no comparator information available for type " ^ quote tyco))

fun declare_info tyco m p c c_def m_comp p_thms c_thm used_pos =
  Local_Theory.declaration {syntax = false, pervasive = false} (fn phi =>
    add_info tyco
     {map = Morphism.term phi m,
      pcomp = Morphism.term phi p,
      comp = Morphism.term phi c,
      comp_def = Option.map (Morphism.thm phi) c_def,
      map_comp = Option.map (Morphism.thm phi) m_comp,
      partial_comp_thms = Morphism.fact phi p_thms,
      comp_thm = Morphism.thm phi c_thm,
      used_positions = used_pos})

val EQ = 0
val SYM = 1
val TRANS = 2

fun register_foreign_partial_and_full_comparator tyco m p c c_def m_comp eq_thm sym_thm
  trans_thm c_thm =
  declare_info tyco m p c c_def m_comp [eq_thm, sym_thm, trans_thm] c_thm

fun mk_infer_const name ctxt c = infer_type ctxt (Const (name, dummyT) $ c)
val mk_eq_comp = mk_infer_const @{const_name eq_comp}
val mk_peq_comp = mk_infer_const @{const_name peq_comp}
val mk_sym_comp = mk_infer_const @{const_name sym_comp}
val mk_psym_comp = mk_infer_const @{const_name psym_comp}
val mk_trans_comp = mk_infer_const @{const_name trans_comp}
val mk_ptrans_comp = mk_infer_const @{const_name ptrans_comp}
val mk_comp = mk_infer_const @{const_name comparator}
fun default_comp T = absdummy T (absdummy T @{term Eq}) (*%_ _. Eq*)

fun register_foreign_comparator T comp comp_thm lthy =
  let
    val tyco = (case T of Type (tyco, []) => tyco | _ => error "expected type constant")
    val eq = @{thm comp_to_peq_comp} OF [comp_thm]
    val sym = @{thm comp_to_psym_comp} OF [comp_thm]
    val trans = @{thm comp_to_ptrans_comp} OF [comp_thm]
  in
    register_foreign_partial_and_full_comparator
      tyco (HOLogic.id_const T) comp comp NONE NONE eq sym trans comp_thm [] lthy
  end

fun register_comparator_of tyco lthy =
  let
    val T = Type (tyco, [])
    val comp = Const (@{const_name comparator_of}, compT T)
    val comp_thm = Thm.instantiate' [SOME (Thm.ctyp_of lthy T)]
      [] @{thm comparator_of}
  in register_foreign_comparator T comp comp_thm lthy end

fun generate_comparators_from_bnf_fp tyco lthy =
  let
    val (tycos, Ts) = mutual_recursive_types tyco lthy
    val _ = map (fn tyco => "generating comparator for type " ^ quote tyco) tycos
      |> cat_lines |> writeln
    val (tfrees, used_tfrees) = type_parameters (hd Ts) lthy
    val used_positions = map (member (op =) used_tfrees o TFree) tfrees
    val cs = map (subT "comp") used_tfrees
    val comp_Ts = map compT used_tfrees
    val arg_comps = map Free (cs ~~ comp_Ts)
    val dep_tycos = fold (add_used_tycos lthy) tycos []

    val XTys = Bnf_Access.bnf_types lthy tycos
    val inst_types = typ_subst_atomic (XTys ~~ Ts)
    val cTys = map (map (map inst_types)) (Bnf_Access.constr_argument_types lthy tycos)

    val map_simps = Bnf_Access.map_simps lthy tycos
    val case_simps = Bnf_Access.case_simps lthy tycos
    val maps = Bnf_Access.map_terms lthy tycos
    val map_comp_thms = Bnf_Access.map_comps lthy tycos

    val t_ixs = 0 upto (length Ts - 1)

    val compNs =
      (*TODO: clashes in presence of same type names in different theories*)
      map (Long_Name.base_name) tycos
      |> map (fn s => "comparator_" ^ s)

    fun gen_vars prefix = map (fn (i, pty) => Free (prefix ^ ints_to_subscript [i], pty))
      (t_ixs ~~ Ts)

    (* primrec definitions of partial comparators *)

    fun mk_pcomp (tyco, T) = ("partial_comparator_" ^ Long_Name.base_name tyco, pcompT T)

    fun constr_terms lthy =
      Bnf_Access.constr_terms lthy
      #> map (apsnd (map freeify_tvars o fst o strip_type) o dest_Const)

    fun generate_pcomp_eqs lthy (tyco, T) =
      let
        val constrs = constr_terms lthy tyco

        fun comp_arg T x y =
          let
            val m = Generator_Aux.create_map default_comp (K o Free o mk_pcomp) () (K false)
              (#used_positions oo the_info) (#map oo the_info) (K o #pcomp oo the_info)
              tycos ((K o K) ()) T lthy
            val p = Generator_Aux.create_partial () (K false)
              (#used_positions oo the_info) (#map oo the_info) (K o #pcomp oo the_info)
              tycos ((K o K) ()) T lthy
          in p $ (m $ x) $ y |> infer_type lthy end

        fun generate_eq lthy (c_T as (cN, Ts)) =
          let
            val arg_Ts' = map orderify Ts
            val c = Const (cN, arg_Ts' ---> orderify T)
            val (y, (xs, ys)) = Name.variant "y" (Variable.names_of lthy) |>> Free o rpair T
              ||> (fn ctxt => Name.invent_names ctxt "x" (arg_Ts' @ Ts) |> map Free)
              ||> chop (length Ts)
            val k = find_index (curry (op =) c_T) constrs
            val cases = constrs |> map_index (fn (i, (_, Ts')) =>
              if i < k then fold_rev absdummy Ts' @{term Gt}
              else if k < i then fold_rev absdummy Ts' @{term Lt}
              else
                @{term comp_lex} $ HOLogic.mk_list orderT (@{map 3} comp_arg Ts xs ys)
                |> lambdas ys)
            val lhs = Free (mk_pcomp (tyco, T)) $ list_comb (c, xs) $ y
            val rhs = list_comb (singleton (Bnf_Access.case_consts lthy) tyco, cases) $ y
          in HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs)) |> infer_type lthy end
      in map (generate_eq lthy) constrs end

    val eqs = map (generate_pcomp_eqs lthy) (tycos ~~ Ts) |> flat
    val bindings = tycos ~~ Ts |> map mk_pcomp
      |> map (fn (name, T) => (Binding.name name, SOME T, NoSyn))

    val ((pcomps, pcomp_simps), lthy) =
      lthy
      |> Local_Theory.begin_nested
      |> snd
      |> (BNF_LFP_Rec_Sugar.primrec false [] bindings
          (map (fn t => ((Binding.empty_atts, t), [], [])) eqs))
      |> Local_Theory.end_nested_result (fn phi => fn (pcomps, _, pcomp_simps) => (map (Morphism.term phi) pcomps, map (Morphism.fact phi) pcomp_simps))

    (* definitions of comparators via partial comparators and maps *)

    fun generate_comp_def tyco lthy =
      let
        val cs = map (subT "comp") used_tfrees
        val arg_Ts = map compT used_tfrees
        val args = map Free (cs ~~ arg_Ts)
        val (pcomp, m) = AList.lookup (op =) (tycos ~~ (pcomps ~~ maps)) tyco |> the
        val ts = tfrees |> map TFree |> map (fn T =>
          AList.lookup (op =) (used_tfrees ~~ args) T |> the_default (default_comp T))
        val rhs = HOLogic.mk_comp (pcomp, list_comb (m, ts)) |> infer_type lthy
        val abs_def = lambdas args rhs
        val name = "comparator_" ^ Long_Name.base_name tyco
        val ((comp, (_, prethm)), lthy) =
          Local_Theory.define ((Binding.name name, NoSyn), (Binding.empty_atts, abs_def)) lthy
        val eq = Logic.mk_equals (list_comb (comp, args), rhs)
        val thm = Goal.prove lthy (map (fst o dest_Free) args) [] eq (K (unfold_tac lthy [prethm]))
      in
        Local_Theory.note ((Binding.name (name ^ "_def"), []), [thm]) lthy
        |>> the_single o snd
        |>> `(K comp)
      end

    val ((comps, comp_defs), lthy) =
      lthy
      |> Local_Theory.begin_nested
      |> snd
      |> fold_map generate_comp_def tycos
      |>> split_list
      |> Local_Theory.end_nested_result 
          (fn phi => fn (comps, comp_defs) => (map (Morphism.term phi) comps, map (Morphism.thm phi) comp_defs))

    (* alternative simp-rules for comparators *)

    val full_comps = map (list_comb o rpair arg_comps) comps

    fun generate_comp_simps (tyco, T) lthy =
      let
        val constrs = constr_terms lthy tyco

        fun comp_arg T x y =
          let
            fun create_comp (T as TFree _) =
                  AList.lookup (op =) (used_tfrees ~~ arg_comps) T
                  |> the_default (HOLogic.id_const dummyT)
              | create_comp (Type (tyco, Ts)) =
                  (case AList.lookup (op =) (tycos ~~ comps) tyco of
                    SOME c => list_comb (c, arg_comps)
                  | NONE =>
                      let
                        val {comp = c, used_positions = up, ...} = the_info lthy tyco
                        val ts = (up ~~ Ts) |> map_filter (fn (b, T) =>
                          if b then SOME (create_comp T) else NONE)
                      in list_comb (c, ts) end)
              | create_comp T =
                  error ("unexpected schematic variable " ^ quote (Syntax.string_of_typ lthy T))
            val comp = create_comp T
          in comp $ x $ y |> infer_type lthy end

        fun generate_eq_thm lthy (c_T as (_, Ts)) =
          let
            val (xs, ctxt) = Variable.names_of lthy
              |> fold_map (fn T => Name.variant "x" #>> Free o rpair T) Ts
            fun mk_const (c, Ts) = Const (c, Ts ---> T)
            val comp_const = AList.lookup (op =) (tycos ~~ comps) tyco |> the
            val lhs = list_comb (comp_const, arg_comps) $ list_comb (mk_const c_T, xs)
            val k = find_index (curry (op =) c_T) constrs

            fun mk_eq c ys rhs =
              let
                val y = list_comb (mk_const c, ys)
                val eq = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs $ y, rhs))
              in (ys, eq |> infer_type lthy) end

            val ((ys, eqs), _) = fold_map (fn (i, c as (_, Ts')) => fn ctxt =>
              let
                val (ys, ctxt) = fold_map (fn T => Name.variant "y" #>> Free o rpair T) Ts' ctxt
              in
                (if i < k then mk_eq c ys @{term Gt}
                else if k < i then mk_eq c ys @{term Lt}
                else
                  @{term comp_lex} $ HOLogic.mk_list orderT (@{map 3} comp_arg Ts xs ys)
                  |> mk_eq c ys,
                ctxt)
              end) (tag_list 0 constrs) ctxt
              |> apfst (apfst flat o split_list)

            val dep_comp_defs = map_filter (#comp_def o the_info lthy) dep_tycos
            val dep_map_comps = map_filter (#map_comp o the_info lthy) dep_tycos
            val thms = prove_multi_future lthy (map (fst o dest_Free) (xs @ ys) @ cs) [] eqs
              (fn {context = ctxt, ...} =>
                Goal.conjunction_tac 1
                THEN unfold_tac ctxt
                  (@{thms id_apply o_def} @
                    flat case_simps @
                    flat pcomp_simps @
                    dep_map_comps @ comp_defs @ dep_comp_defs @ flat map_simps))
          in thms end

        val thms = map (generate_eq_thm lthy) constrs |> flat
        val simp_thms = map (Local_Defs.unfold lthy @{thms comp_lex_unfolds}) thms
        val name = "comparator_" ^ Long_Name.base_name tyco
      in
        lthy
        |> Local_Theory.note ((Binding.name (name ^ "_simps"), @{attributes [simp, code]}), simp_thms)
        |> snd
        |> (fn lthy => (thms, lthy))
      end

    val (comp_simps, lthy) =
      lthy
      |> Local_Theory.begin_nested
      |> snd
      |> fold_map generate_comp_simps (tycos ~~ Ts)
      |> Local_Theory.end_nested_result (fn phi => map (Morphism.fact phi))

    (* partial theorems *)

    val set_funs = Bnf_Access.set_terms lthy tycos
    val x_vars = gen_vars "x"
    val free_names = map (fst o dest_Free) (x_vars @ arg_comps)
    val xi_vars = map_index (fn (i, _) =>
      map_index (fn (j, pty) => Free ("x" ^ ints_to_subscript [i, j], pty)) used_tfrees) Ts

    fun mk_eq_sym_trans_thm' mk_eq_sym_trans' = map_index (fn (i, ((set_funs, x), xis)) =>
      let
        fun create_cond ((set_t, xi), c) =
          let
            val rhs = mk_eq_sym_trans' lthy c $ xi |> HOLogic.mk_Trueprop
            val lhs = HOLogic.mk_mem (xi, set_t $ x) |> HOLogic.mk_Trueprop
          in Logic.all xi (Logic.mk_implies (lhs, rhs)) end

        val used_sets = map (the o AList.lookup (op =) (map TFree tfrees ~~ set_funs)) used_tfrees
        val conds = map create_cond (used_sets ~~ xis ~~ arg_comps)
        val concl = mk_eq_sym_trans' lthy (nth full_comps i) $ x |> HOLogic.mk_Trueprop
      in Logic.list_implies (conds, concl) |> infer_type lthy end) (set_funs ~~ x_vars ~~ xi_vars)

    val induct_thms = Bnf_Access.induct_thms lthy tycos
    val set_simps = Bnf_Access.set_simps lthy tycos
    val case_thms = Bnf_Access.case_thms lthy tycos
    val distinct_thms = Bnf_Access.distinct_thms lthy tycos
    val inject_thms = Bnf_Access.inject_thms lthy tycos

    val rec_info = (the_info lthy, #used_positions, tycos)
    val split_IHs = split_IHs rec_info

    val unknown_value = false (* effect of choosing false / true not yet visible *)

    fun induct_tac ctxt f =
      ((DETERM o Induction.induction_tac ctxt false
        (map (fn x => [SOME (NONE, (x, unknown_value))]) x_vars) [] [] (SOME induct_thms) [])
      THEN_ALL_NEW (fn i =>
        Subgoal.SUBPROOF (fn {context = ctxt, prems = prems, params = iparams, ...} =>
          f (i - 1) ctxt prems iparams) ctxt i)) 1

    fun recursor_tac kind = std_recursor_tac rec_info used_tfrees
      (fn info => nth (#partial_comp_thms info) kind)

    fun instantiate_IHs IHs pre_conds = map (fn IH =>
      OF_option IH (replicate (Thm.nprems_of IH - length pre_conds) NONE @ map SOME pre_conds)) IHs

    fun get_v_i vs k = nth vs k |> snd |> SOME

    (* partial eq-theorem *)
    val _ = debug_out "Partial equality"
    val eq_thms' = mk_eq_sym_trans_thm' mk_peq_comp

    fun eq_solve_tac i ctxt IH_prems xs =
      let
        val (i, j) = ind_case_to_idxs cTys i
        val k = length IH_prems - length arg_comps
        val pre_conds = drop k IH_prems
        val IH = take k IH_prems
        val comp_simps = nth comp_simps i
        val case_thm = nth case_thms i
        val distinct_thms = nth distinct_thms i
        val inject_thms = nth inject_thms i
        val set_thms = nth set_simps i
      in
        (* after induction *)
        resolve_tac ctxt @{thms peq_compI} 1
        THEN Subgoal.FOCUS (fn focus =>
          let
            val y = #params focus |> hd
            val yt = y |> snd |> Thm.term_of
            val ctxt = #context focus
            val pre_cond = map (fn pre_cond => Local_Defs.unfold ctxt set_thms pre_cond) pre_conds
            val IH = instantiate_IHs IH pre_cond
            val xs_tys = map (fastype_of o Thm.term_of o snd) xs
            val IHs = split_IHs xs_tys IH

            fun sub_case_tac j' (ctxt, y_simp, _) =
              if j = j' then
                unfold_tac ctxt (y_simp @ comp_simps)
                THEN unfold_tac ctxt @{thms comp_lex_eq}
                THEN unfold_tac ctxt (@{thms in_set_simps} @ inject_thms @ @{thms refl_True})
                THEN conjI_tac @{thms conj_weak_cong} ctxt xs (fn ctxt' => fn k =>
                  resolve_tac ctxt @{thms peq_compD} 1
                  THEN recursor_tac EQ pre_cond (nth xs_tys k) (nth IHs k) ctxt')
              else
                (* different constructors *)
                unfold_tac ctxt (y_simp @ distinct_thms @ comp_simps @ @{thms order.simps})
          in
            mk_case_tac ctxt [[SOME yt]] case_thm sub_case_tac
          end
        ) ctxt 1
      end

    val eq_thms' = prove_multi_future lthy free_names [] eq_thms' (fn {context = ctxt, ...} =>
      induct_tac ctxt eq_solve_tac)
    val _ = debug_out (@{make_string} eq_thms')

    (* partial symmetry-theorem *)
    val _ = debug_out "Partial symmetry"
    val sym_thms' = mk_eq_sym_trans_thm' mk_psym_comp

    fun sym_solve_tac i ctxt IH_prems xs =
      let
        val (i, j) = ind_case_to_idxs cTys i
        val k = length IH_prems - length arg_comps
        val pre_conds = drop k IH_prems
        val IH = take k IH_prems
        val comp_simps = nth comp_simps i
        val case_thm = nth case_thms i
        val set_thms = nth set_simps i
      in
        (* after induction *)
        resolve_tac ctxt @{thms psym_compI} 1
        THEN Subgoal.FOCUS (fn focus =>
          let
            val y = #params focus |> hd
            val yt = y |> snd |> Thm.term_of
            val ctxt = #context focus
            val pre_cond = map (fn pre_cond => Local_Defs.unfold ctxt set_thms pre_cond) pre_conds
            val IH = instantiate_IHs IH pre_cond
            val xs_tys = map (fastype_of o Thm.term_of o snd) xs
            val IHs = split_IHs xs_tys IH

            fun sub_case_tac j' (ctxt, y_simp, ys) =
              if j = j' then
                unfold_tac ctxt (y_simp @ comp_simps)
                THEN resolve_tac ctxt @{thms comp_lex_sym} 1
                THEN unfold_tac ctxt (@{thms length_nth_simps forall_finite})
                THEN conjI_tac @{thms conjI} ctxt xs (fn ctxt' => fn k =>
                  resolve_tac ctxt' [infer_instantiate' ctxt'
                    [NONE, get_v_i xs k, get_v_i ys k] @{thm psym_compD}] 1
                  THEN recursor_tac SYM pre_cond (nth xs_tys k) (nth IHs k) ctxt')
              else
                (* different constructors *)
                unfold_tac ctxt (y_simp @ comp_simps @ @{thms invert_order.simps})
          in
            mk_case_tac ctxt [[SOME yt]] case_thm sub_case_tac
          end
        ) ctxt 1
      end

    val sym_thms' = prove_multi_future lthy free_names [] sym_thms' (fn {context = ctxt, ...} =>
      induct_tac ctxt sym_solve_tac)
    val _ = debug_out (@{make_string} sym_thms')

    (* partial transitivity-theorem *)
    val _ = debug_out "Partial transitivity"

    val trans_thms' = mk_eq_sym_trans_thm' mk_ptrans_comp

    fun trans_solve_tac i ctxt IH_prems xs =
      let
        val (i, j) = ind_case_to_idxs cTys i
        val k = length IH_prems - length arg_comps
        val pre_conds = drop k IH_prems
        val IH = take k IH_prems
        val comp_simps = nth comp_simps i
        val case_thm = nth case_thms i
        val set_thms = nth set_simps i
      in
        (* after induction *)
        resolve_tac ctxt @{thms ptrans_compI} 1
        THEN Subgoal.FOCUS (fn focus =>
          let
            val y = nth (#params focus) 0
            val z = nth (#params focus) 1
            val yt = y |> snd |> Thm.term_of
            val zt = z |> snd |> Thm.term_of
            val ctxt = #context focus
            val pre_cond = map (fn pre_cond => Local_Defs.unfold ctxt set_thms pre_cond) pre_conds
            val IH = instantiate_IHs IH pre_cond
            val xs_tys = map (fastype_of o Thm.term_of o snd) xs
            val IHs = split_IHs xs_tys IH

            fun sub_case_tac j' (ctxt, y_simp, ys) =
              let
                fun sub_case_tac' j'' (ctxt, z_simp, zs) =
                      if j = j' andalso j = j'' then
                        unfold_tac ctxt (y_simp @ z_simp @ comp_simps)
                        THEN resolve_tac ctxt @{thms comp_lex_trans} 1
                        THEN unfold_tac ctxt (@{thms length_nth_simps forall_finite})
                        THEN conjI_tac @{thms conjI} ctxt xs (fn ctxt' => fn k =>
                          resolve_tac ctxt' [infer_instantiate' ctxt'
                            [NONE, get_v_i xs k, get_v_i ys k, get_v_i zs k] @{thm ptrans_compD}] 1
                          THEN recursor_tac TRANS pre_cond (nth xs_tys k) (nth IHs k) ctxt')
                      else
                        (* different constructors *)
                        unfold_tac ctxt
                          (y_simp @ z_simp @ comp_simps @ @{thms trans_order_different})
              in
                mk_case_tac ctxt [[SOME zt]] case_thm sub_case_tac'
              end
          in
            mk_case_tac ctxt [[SOME yt]] case_thm sub_case_tac
          end
        ) ctxt 1
      end

    val trans_thms' = prove_multi_future lthy free_names [] trans_thms' (fn {context = ctxt, ...} =>
      induct_tac ctxt trans_solve_tac)
    val _ = debug_out (@{make_string} trans_thms')

    (* total theorems *)
    fun mk_eq_sym_trans_thm mk_eq_sym_trans compI2 compE2 thms' =
      let
        val conds = map (fn c => mk_eq_sym_trans lthy c |> HOLogic.mk_Trueprop) arg_comps
        val thms = map (fn i =>
           mk_eq_sym_trans lthy (nth full_comps i)
           |> HOLogic.mk_Trueprop
           |> (fn concl => Logic.list_implies (conds,concl)))
           t_ixs
        val thms = prove_multi_future lthy free_names [] thms (fn {context = ctxt, ...} =>
          ALLGOALS Goal.conjunction_tac
          THEN Method.intros_tac ctxt (@{thm conjI} :: compI2 :: thms') []
          THEN ALLGOALS (eresolve_tac ctxt [compE2]))
      in thms end

    val eq_thms = mk_eq_sym_trans_thm mk_eq_comp @{thm eq_compI2} @{thm eq_compD2} eq_thms'
    val sym_thms = mk_eq_sym_trans_thm mk_sym_comp @{thm sym_compI2} @{thm sym_compD2} sym_thms'
    val trans_thms = mk_eq_sym_trans_thm mk_trans_comp @{thm trans_compI2} @{thm trans_compD2}
      trans_thms'

    val _ = debug_out "full comparator thms"
    fun mk_comp_thm (i, ((e, s), t)) =
      let
        val conds = map (fn c => mk_comp lthy c |> HOLogic.mk_Trueprop) arg_comps
        fun from_comp thm i = thm OF replicate (Thm.prems_of thm |> length)
          (nth @{thms comparator_imp_eq_sym_trans} i)
        val nearly_thm = @{thm eq_sym_trans_imp_comparator} OF
          [from_comp e EQ, from_comp s SYM, from_comp t TRANS]

        val thm =
           mk_comp lthy (nth full_comps i)
           |> HOLogic.mk_Trueprop
           |> (fn concl => Logic.list_implies (conds, concl))
      in
        Goal.prove_future lthy free_names [] thm
          (K (resolve_tac lthy [nearly_thm] 1 THEN ALLGOALS (assume_tac lthy)))
      end
    val comp_thms = map_index mk_comp_thm (eq_thms ~~ sym_thms ~~ trans_thms)

    val (_, lthy) = fold_map (fn (thm, cname) =>
      Local_Theory.note ((Binding.name cname, []), [thm])) (comp_thms ~~ compNs) lthy

    val _ = debug_out (@{make_string} comp_thms)

    val pcomp_thms = map (fn ((e, s), t) => [e, s, t]) (eq_thms' ~~ sym_thms' ~~ trans_thms')
    val (_, lthy) = fold_map (fn (thms, cname) =>
      Local_Theory.note ((Binding.name (cname ^ "_pointwise"), []), thms)) (pcomp_thms ~~ compNs) lthy

  in
    ((pcomps ~~ pcomp_simps, comps ~~ comp_defs), lthy)
    ||> fold (fn (((((((tyco, map), pcomp), comp), comp_def), map_comp), pcomp_thms), comp_thm) =>
          declare_info tyco map pcomp comp (SOME comp_def) (SOME map_comp)
            pcomp_thms comp_thm used_positions)
         (tycos ~~ maps ~~ pcomps ~~ comps ~~ comp_defs ~~ map_comp_thms ~~ pcomp_thms ~~ comp_thms)
  end

fun generate_comparator gen_type tyco lthy =
  let
    val _ = is_some (get_info lthy tyco)
      andalso error ("type " ^ quote tyco ^ " does already have a comparator")
  in
    case gen_type of
      BNF => generate_comparators_from_bnf_fp tyco lthy |> snd
    | Linorder => register_comparator_of tyco lthy
  end

fun ensure_info gen_type tyco lthy =
  (case get_info lthy tyco of
    SOME _ => lthy
  | NONE => generate_comparator gen_type tyco lthy)

fun generate_comparator_cmd tyco param = Named_Target.theory_map (
  if param = "linorder" then generate_comparator Linorder tyco
  else if param = "" then generate_comparator BNF tyco
  else error ("unknown parameter, expecting no parameter for BNF-datatypes, " ^
         "or \"linorder\" for types which are already in linorder"))

val _ =
  Theory.setup
    (Derive_Manager.register_derive
      "comparator"
      "generate comparators for given types, options: (linorder) or ()"
      generate_comparator_cmd)

end

Theory Compare_Generator

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
subsection ‹Compare Generator›

theory Compare_Generator
imports 
  Comparator_Generator
  Compare
begin

text ‹We provide a generator which takes the comparators of the comparator generator
  to synthesize suitable @{const compare}-functions from the @{class compare}-class.

One can further also use these comparison functions to derive an instance of the
@{class compare_order}-class, and therefore also for @{class linorder}. In total, we provide the three
derive›-methods where the example type @{type prod} can be replaced by any other datatype.

\begin{itemize}
\item derive compare prod› creates an instance @{type prod} :: (@{class compare}, @{class compare}) @{class compare}.
\item derive compare_order prod› creates an instance @{type prod} :: (@{class compare}, @{class compare}) @{class compare_order}.
\item derive linorder prod› creates an instance @{type prod} :: (@{class linorder}, @{class linorder}) @{class linorder}.
\end{itemize}

Usually, the use of derive linorder› is not recommended if there are comparators available:
Internally, the linear orders will directly be converted into comparators, so a direct use of the
comparators will result in more efficient generated code. This command is mainly provided as a convenience method
where comparators are not yet present. For example, at the time of writing, the Container Framework
has partly been adapted to internally use comparators, whereas in other AFP-entries, we did not
integrate comparators.
›

lemma linorder_axiomsD: assumes "class.linorder le lt"
  shows 
  "lt x y = (le x y  ¬ le y x)" (is ?a)
  "le x x" (is ?b)
  "le x y  le y z  le x z" (is "?c1  ?c2  ?c3") 
  "le x y  le y x  x = y" (is "?d1  ?d2  ?d3")
  "le x y  le y x" (is ?e)
proof -
  interpret linorder le lt by fact
  show ?a ?b "?c1  ?c2  ?c3" "?d1  ?d2  ?d3" ?e by auto
qed
 
named_theorems compare_simps "simp theorems to derive \"compare = comparator_of\""

ML_file ‹compare_generator.ML›

end

File ‹compare_generator.ML›

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
signature COMPARE_GENERATOR =
sig
(* derives a compare-instance for a given class. depending on the comparator_type, this will
   just be comparator_of (via linorder), or it will be a comparator constructed for BNF datatypes *)
val compare_instance : Comparator_Generator.comparator_type -> string -> theory -> theory

(* derives an instance for class compare_order via linorder, 
   where the main comparator will be comparator_of *)
val compare_order_instance_via_comparator_of : string -> theory -> theory

(* derives an instance for class compare_order via compare, where 
   the orders are defined via le_of_comp and lt_of_comp *)
val compare_order_instance_via_compare : string -> theory -> theory
end

structure Compare_Generator : COMPARE_GENERATOR =
struct

open Generator_Aux

val cmpS = @{sort compare};
val cmpoS = @{sort compare_order};
val cmpN = @{const_name compare};
val lessN = @{const_name less}
val less_eqN = @{const_name less_eq}
val linordS = @{sort linorder}
fun cmpT T = T --> T --> @{typ order};
fun ordT T = T --> T --> @{typ bool};
fun cmp_const T = Const (cmpN, T);
fun cmp_of_const T = Const (@{const_name comparator_of}, T);


fun dest_comp ctxt tname =
  (case Comparator_Generator.get_info ctxt tname of
    SOME {comp = c, comp_thm = c_thm, ...} =>
      let
        val Ts = fastype_of c |> strip_type |> fst |> `((fn x => x - 2) o length) |> uncurry take
      in (c, c_thm, Ts) end
  | NONE => error ("no order info for type " ^ quote tname))

fun all_tys comp free_types =
  let
    val Ts = fastype_of comp |> strip_type |> fst |> drop_last |> List.last |> dest_Type |> snd
  in rename_types (Ts ~~ free_types) end

fun mk_cmp_rhs c Ts =
  list_comb (c, map cmp_const Ts)

fun mk_cmp_rhs_comparator_of c Ts =
  list_comb (c, map cmp_of_const Ts)


fun mk_cmp_def T rhs =
  Logic.mk_equals (Const (@{const_name compare}, cmpT T), rhs)

fun mk_ord_def T strict rhs =
  Logic.mk_equals (
    Const (if strict then lessN else less_eqN, ordT T),
    Const (if strict then @{const_name lt_of_comp} else @{const_name le_of_comp}, cmpT T --> ordT T) $ rhs)
  
fun mk_binop_def binop T rhs =
  Logic.mk_equals (Const (binop, T --> T --> @{typ bool}), rhs)

fun comparator_tac ctxt tname =
  let val (_, c_thm, _) = dest_comp ctxt tname
  in (resolve_tac ctxt [c_thm] THEN_ALL_NEW resolve_tac ctxt @{thms comparator_compare}) 1 end

fun comparator_tac_comparator_of ctxt tname i =
  let val (_, c_thm, _) = dest_comp ctxt tname
  in (resolve_tac ctxt [c_thm] THEN_ALL_NEW resolve_tac ctxt @{thms comparator_of}) i end

fun compare_instance gen_type tname thy =
  let
    val _ = is_class_instance thy tname cmpS
      andalso error ("type " ^ quote tname ^ " is already an instance of class \"compare\"")
    val _ = writeln ("deriving \"compare\" instance for type " ^ quote tname)
    val thy = Named_Target.theory_map (Comparator_Generator.ensure_info gen_type tname) thy
    val {used_positions = us, ...} = the (Comparator_Generator.get_info 
        (Named_Target.theory_init thy) tname) 

    val (_, xs) = typ_and_vs_of_used_typname tname us cmpS
    val (cmp_thm, lthy) =
      Class.instantiation ([tname], xs, cmpS) thy
      |> (fn ctxt =>
        let
          val (c, _, Ts) = dest_comp ctxt tname
          val typ_mapping = all_tys c (map TFree xs)
          val cmp_def = mk_cmp_def dummyT (mk_cmp_rhs c Ts) |> typ_mapping
        in
          Generator_Aux.define_overloaded_generic
           ((Binding.name ("compare_" ^ Long_Name.base_name tname ^ "_def"),
            @{attributes [code, compare_simps]}),
            cmp_def) ctxt
        end)
  in
    Class.prove_instantiation_exit (fn ctxt =>
      Class.intro_classes_tac ctxt []
      THEN unfold_tac ctxt [cmp_thm]
      THEN comparator_tac ctxt tname) lthy
  end
  
fun linorder_instance gen_type tname thy =
  let
    val _ = is_class_instance thy tname linordS
      andalso error ("type " ^ quote tname ^ " is already an instance of class \"linorder\"")
    val _ = writeln ("deriving \"linorder\" instance for type " ^ quote tname)
    val thy = Named_Target.theory_map (Comparator_Generator.ensure_info gen_type tname) thy
    val {used_positions = us, ...} = the (Comparator_Generator.get_info 
        (Named_Target.theory_init thy) tname) 

    val (_, xs) = typ_and_vs_of_used_typname tname us linordS
    val ((less_thm, (less_eq_thm, lthy))) =
      Class.instantiation ([tname], xs, linordS) thy
      |> (fn ctxt =>
        let
          val (c, _, Ts) = dest_comp ctxt tname
          val typ_mapping = all_tys c (map TFree xs)
          val cmp = mk_cmp_rhs_comparator_of c Ts
          val less_def = mk_ord_def dummyT true cmp |> typ_mapping |> infer_type ctxt
          val less_eq_def = mk_ord_def dummyT false cmp |> typ_mapping
          val base_name = Long_Name.base_name tname
        in
          (ctxt 
          |> Generator_Aux.define_overloaded_generic
            ((Binding.name ("less_" ^ base_name ^ "_def"), @{attributes [code]}), less_def)            
          ||> Generator_Aux.define_overloaded_generic
            ((Binding.name ("less_eq_" ^ base_name ^ "_def"), @{attributes [code]}), less_eq_def))

        end)
    fun linear_tac ctxt i = 
      resolve_tac ctxt [nth @{thms linorder_axiomsD} (i - 1)] i
      THEN resolve_tac ctxt @{thms comparator.linorder} i
      THEN comparator_tac_comparator_of ctxt tname i
  in
    Class.prove_instantiation_exit ( fn ctxt => 
      Class.intro_classes_tac ctxt []
      THEN unfold_tac ctxt [less_thm, less_eq_thm]
      THEN linear_tac ctxt 5
      THEN linear_tac ctxt 4
      THEN linear_tac ctxt 3
      THEN linear_tac ctxt 2
      THEN linear_tac ctxt 1
      THEN auto_tac ctxt 
    ) lthy
  end


fun compare_instance_param tname param =  
  let
    val gen_type = if param = "" then Comparator_Generator.BNF
      else if param = "linorder" then Comparator_Generator.Linorder
      else error "unknown parameter for compare instance"
  in compare_instance gen_type tname end

fun linorder_instance_param tname param =  
  let
    val gen_type = if param = "" then Comparator_Generator.BNF
      else if param = "linorder" then Comparator_Generator.Linorder
      else error "unknown parameter for compare instance"
  in linorder_instance gen_type tname end


(*if "tname" not yet instance of "compare", instantiate*)
fun maybe_instantiate_compare gen_type tname thy =
  if is_class_instance thy tname cmpS then thy
  else compare_instance gen_type tname thy

fun compare_order_instance_via_compare tname thy =
  let
    val gen_type = Comparator_Generator.BNF
    val thy = maybe_instantiate_compare gen_type tname thy
    val {used_positions = us, ...} = the (Comparator_Generator.get_info 
      (Named_Target.theory_init thy) tname)
    val (T, xs) = typ_and_vs_of_used_typname tname us cmpS
    
    val cmp = cmp_const (cmpT T)
    val (le_thm, less_thm, lthy) =
      Class.instantiation ([tname], xs, cmpoS) thy
      |> (fn lthy =>
        let 
          val less_def = mk_binop_def @{const_name less} T 
            (Const (@{const_name lt_of_comp}, cmpT T --> T --> T --> @{typ bool}) $ cmp)
          val le_def = mk_binop_def @{const_name less_eq} T 
            (Const (@{const_name le_of_comp}, cmpT T --> T --> T --> @{typ bool}) $ cmp)
          val (less_thm, lthy) =  Generator_Aux.define_overloaded 
            ("less_" ^ Long_Name.base_name tname ^ "_def", less_def) lthy
          val (le_thm, lthy) = Generator_Aux.define_overloaded 
            ("less_eq_" ^ Long_Name.base_name tname ^ "_def", le_def) lthy         
        in (le_thm, less_thm, lthy) end)
  in
    Class.prove_instantiation_exit (fn ctxt =>
      Class.intro_classes_tac ctxt []
      THEN unfold_tac ctxt [le_thm]
      THEN unfold_tac ctxt [less_thm]) lthy
  end

fun compare_order_instance_via_comparator_of tname thy =
  let
    val gen_type = Comparator_Generator.Linorder
    val thy = maybe_instantiate_compare gen_type tname thy
    val xs = Generator_Aux.typ_and_vs_of_typname thy tname cmpS |> snd
    val lthy = Class.instantiation ([tname], xs, cmpoS) thy
  in
    Class.prove_instantiation_exit (fn ctxt =>
      Class.intro_classes_tac ctxt []
      THEN unfold_tac ctxt (Named_Theorems.get ctxt @{named_theorems compare_simps})
      THEN resolve_tac ctxt @{thms le_lt_comparator_of(1)} 1
      THEN resolve_tac ctxt @{thms le_lt_comparator_of(2)} 1) lthy
  end
  
fun compare_order_instance tname param thy =
  let
    val _ = is_class_instance thy tname cmpoS
      andalso error ("type " ^ quote tname ^ " is already an instance of class \"compare_order\"")
    val _ = writeln ("deriving \"compare_order\" instance for type " ^ quote tname)
  in
    if param = "" then compare_order_instance_via_compare tname thy
    else if param = "linorder" then compare_order_instance_via_comparator_of tname thy
    else error "unknown parameter, supported are (no parameter) and \"linorder\""
  end

val _ =
  Theory.setup
    (Derive_Manager.register_derive 
      "compare" 
      "register types in class compare, options: (linorder) or ()" 
      compare_instance_param
    #> Derive_Manager.register_derive 
      "compare_order" 
      "register types in class compare_order, options: (linorder) or ()" 
      compare_order_instance
    #> Derive_Manager.register_derive 
      "linorder" 
      "register types in class linorder, options: (linorder) or ()" 
      linorder_instance_param)

end

Theory Compare_Instances

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
subsection ‹Defining Comparators and Compare-Instances for Common Types›

theory Compare_Instances
imports
  Compare_Generator
  "HOL-Library.Char_ord"
begin


text ‹For all of the following types, we define comparators and register them in the class 
  @{class compare}:
  @{type int}, @{type integer}, @{type nat}, @{type char}, @{type bool}, @{type unit}, @{type sum}, @{type option}, @{type list},
  and @{type prod}. We do not register those classes in @{class compare_order} where
  so far no linear order is defined, in particular if there are conflicting orders, like pair-wise or
  lexicographic comparison on pairs.›

text ‹For @{type int}, @{type nat}, @{type integer} and @{type char} we just use their linear orders as comparators.›
derive (linorder) compare_order int integer nat char

text ‹For @{type sum}, @{type list}, @{type prod}, and @{type option} we generate comparators 
  which are however are not used to instantiate @{class linorder}.›
derive compare sum list prod option

text ‹We do not use the linear order to define the comparator for @{typ bool} and @{typ unit}, 
  but implement more efficient ones.›

fun comparator_unit :: "unit comparator" where
  "comparator_unit x y = Eq"

fun comparator_bool :: "bool comparator" where
  "comparator_bool False False = Eq"
| "comparator_bool False True = Lt"
| "comparator_bool True True = Eq"
| "comparator_bool True False = Gt"

lemma comparator_unit: "comparator comparator_unit"
  by (unfold_locales, auto)

lemma comparator_bool: "comparator comparator_bool"
proof
  fix x y z :: bool
  show "invert_order (comparator_bool x y) = comparator_bool y x" by (cases x, (cases y, auto)+)
  show "comparator_bool x y = Eq  x = y" by (cases x, (cases y, auto)+)
  show "comparator_bool x y = Lt  comparator_bool y z = Lt  comparator_bool x z = Lt"
    by (cases x, (cases y, auto), cases y, (cases z, auto)+)
qed


local_setup Comparator_Generator.register_foreign_comparator @{typ unit}
    @{term comparator_unit}
    @{thm comparator_unit}

local_setup Comparator_Generator.register_foreign_comparator @{typ bool}
    @{term comparator_bool}
    @{thm comparator_bool}

derive compare bool unit

text ‹It is not directly possible to derive (linorder) bool unit›, since 
  @{term "compare :: bool comparator"}
  was not defined as @{term "comparator_of :: bool comparator"}, but as
  @{const comparator_bool}.
  However, we can manually prove this equivalence
  and then use this knowledge to prove the instance of @{class compare_order}.›

lemma comparator_bool_comparator_of [compare_simps]:
  "comparator_bool = comparator_of"
proof (intro ext)
  fix a b 
  show "comparator_bool a b = comparator_of a b"
    unfolding comparator_of_def
    by (cases a, (cases b, auto))
qed

lemma comparator_unit_comparator_of [compare_simps]:
  "comparator_unit = comparator_of"
proof (intro ext)
  fix a b 
  show "comparator_unit a b = comparator_of a b"
    unfolding comparator_of_def by auto
qed

derive (linorder) compare_order bool unit
end

Theory Compare_Order_Instances

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
subsection ‹Defining Compare-Order-Instances for Common Types›

theory Compare_Order_Instances
imports
  Compare_Instances
  "HOL-Library.List_Lexorder"
  "HOL-Library.Product_Lexorder"
  "HOL-Library.Option_ord"
begin

text ‹We now also instantiate class @{class compare_order} and not only @{class compare}.
  Here, we also prove that our definitions do not clash with existing orders on
  @{type list}, @{type option}, and @{type prod}.
  
  For @{type sum} we just define the linear orders via their comparator.›

derive compare_order sum

instance list :: (compare_order)compare_order
proof
  note [simp] = le_of_comp_def lt_of_comp_def comparator_of_def
  show "le_of_comp (compare :: 'a list comparator) = (≤)" 
    unfolding compare_list_def compare_is_comparator_of 
  proof (intro ext)
    fix xs ys :: "'a list"
    show "le_of_comp (comparator_list comparator_of) xs ys = (xs  ys)"
    proof (induct xs arbitrary: ys)
      case (Nil ys)
      show ?case
        by (cases ys, simp_all)
    next
      case (Cons x xs yys) note IH = this
      thus ?case
      proof (cases yys)
        case Nil
        thus ?thesis by auto
      next
        case (Cons y ys)
        show ?thesis unfolding Cons
          using IH[of ys]
          by (cases x y rule: linorder_cases, auto)
      qed
    qed
  qed
  show "lt_of_comp (compare :: 'a list comparator) = (<)" 
    unfolding compare_list_def compare_is_comparator_of 
  proof (intro ext)
    fix xs ys :: "'a list"
    show "lt_of_comp (comparator_list comparator_of) xs ys = (xs < ys)"
    proof (induct xs arbitrary: ys)
      case (Nil ys)
      show ?case
        by (cases ys, simp_all)
    next
      case (Cons x xs yys) note IH = this
      thus ?case
      proof (cases yys)
        case Nil
        thus ?thesis by auto
      next
        case (Cons y ys)
        show ?thesis unfolding Cons
          using IH[of ys]
          by (cases x y rule: linorder_cases, auto)
      qed
    qed
  qed
qed

instance prod :: (compare_order, compare_order)compare_order
proof
  note [simp] = le_of_comp_def lt_of_comp_def comparator_of_def
  show "le_of_comp (compare :: ('a,'b)prod comparator) = (≤)" 
    unfolding compare_prod_def compare_is_comparator_of 
  proof (intro ext)
    fix xy1 xy2 :: "('a,'b)prod"
    show "le_of_comp (comparator_prod comparator_of comparator_of) xy1 xy2 = (xy1  xy2)"
      by (cases xy1, cases xy2, auto)
  qed
  show "lt_of_comp (compare :: ('a,'b)prod comparator) = (<)" 
    unfolding compare_prod_def compare_is_comparator_of 
  proof (intro ext)
    fix xy1 xy2 :: "('a,'b)prod"
    show "lt_of_comp (comparator_prod comparator_of comparator_of) xy1 xy2 = (xy1 < xy2)"
      by (cases xy1, cases xy2, auto)
  qed
qed

instance option :: (compare_order)compare_order
proof
  note [simp] = le_of_comp_def lt_of_comp_def comparator_of_def
  show "le_of_comp (compare :: 'a option comparator) = (≤)" 
    unfolding compare_option_def compare_is_comparator_of 
  proof (intro ext)
    fix xy1 xy2 :: "'a option"
    show "le_of_comp (comparator_option comparator_of) xy1 xy2 = (xy1  xy2)"
      by (cases xy1, (cases xy2, auto split: if_splits)+)
  qed
  show "lt_of_comp (compare :: 'a option comparator) = (<)" 
    unfolding compare_option_def compare_is_comparator_of 
  proof (intro ext)
    fix xy1 xy2 :: "'a option"
    show "lt_of_comp (comparator_option comparator_of) xy1 xy2 = (xy1 < xy2)"
      by (cases xy1, (cases xy2, auto split: if_splits)+)
  qed
qed

end

Theory Compare_Rat

(*  
    Author:      René Thiemann 
    License:     LGPL
*)
subsection ‹Compare Instance for Rational Numbers›

theory Compare_Rat
imports
  Compare_Generator
  HOL.Rat
begin
  
derive (linorder) compare_order rat

end

Theory Compare_Real

(*
    Author:      René Thiemann
    License:     LGPL
*)
subsection ‹Compare Instance for Real Numbers›

theory Compare_Real
imports
  Compare_Generator
  HOL.Real
begin
  
derive (linorder) compare_order real

lemma invert_order_compare_real[simp]: " x y :: real. invert_order (compare x y) = compare y x" 
  by (simp add: comparator_of_def compare_is_comparator_of)


end

Theory Equality_Generator

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
section ‹Checking Equality Without "="›

theory Equality_Generator
imports
  "../Generator_Aux"
  "../Derive_Manager"
begin

typedecl ('a,'b,'c,'z)type

text ‹In the following, we define a generator which for a given datatype @{typ "('a,'b,'c,'z)type"}
  constructs an equality-test function of type 
  @{typ "('a  'a  bool)  ('b  'b  bool)  ('c  'c  bool)  ('z  'z  bool)  
    (('a,'b,'c,'z)type  ('a,'b,'c,'z)type  bool)"}.
  These functions are essential to synthesize conditional equality functions in the container framework,
  where a strict membership in the @{class equal}-class must not be enforced.
›

hide_type "type"

text ‹Just a constant to define conjunction on lists of booleans, which will
  be used to merge the results when having compared the arguments of identical
  constructors.›

definition list_all_eq :: "bool list  bool" where
  "list_all_eq = list_all id "

subsection ‹Improved Code for Non-Lazy Languages›

text ‹The following equations will eliminate all occurrences of @{term list_all_eq}
  in the generated code of the equality functions.›

lemma list_all_eq_unfold: 
  "list_all_eq [] = True"
  "list_all_eq [b] = b"
  "list_all_eq (b1 # b2 # bs) = (b1  list_all_eq (b2 # bs))"
  unfolding list_all_eq_def
  by auto

lemma list_all_eq: "list_all_eq bs  ( b  set bs. b)" 
  unfolding list_all_eq_def list_all_iff by auto  

subsection ‹Partial Equality Property›

text ‹We require a partial property which can be used in inductive proofs.›

type_synonym 'a equality = "'a  'a  bool"

definition pequality :: "'a equality  'a  bool"
where
  "pequality aeq x  ( y. aeq x y  x = y)"

lemma pequalityD: "pequality aeq x  aeq x y  x = y"
  unfolding pequality_def by auto

lemma pequalityI: "( y. aeq x y  x = y)  pequality aeq x"
  unfolding pequality_def by auto


subsection ‹Global equality property›

definition equality :: "'a equality  bool" where
  "equality aeq  ( x. pequality aeq x)"

lemma equalityD2: "equality aeq  pequality aeq x"
  unfolding equality_def by blast

lemma equalityI2: "( x. pequality aeq x)  equality aeq" 
  unfolding equality_def by blast
    
lemma equalityD: "equality aeq  aeq x y  x = y"
  by (rule pequalityD[OF equalityD2])

lemma equalityI: "( x y. aeq x y  x = y)  equality aeq"
  by (intro equalityI2 pequalityI)

lemma equality_imp_eq:
  "equality aeq  aeq = (=)" 
  by (intro ext, auto dest: equalityD)

lemma eq_equality: "equality (=)"
  by (rule equalityI, simp)

lemma equality_def': "equality f = (f = (=))" 
  using equality_imp_eq eq_equality by blast


subsection ‹The Generator›

ML_file ‹equality_generator.ML›

hide_fact (open) equalityI

end

File ‹equality_generator.ML›

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann
    License:     LGPL
*)
signature EQUALITY_GENERATOR =
sig

  type info =
   {map : term,                    (* take % x. x, if there is no map *)
    pequality : term,                  (* partial equality *)
    equality : term,                   (* full equality *)
    equality_def : thm option,         (* definition of equality, important for nesting *)
    map_comp : thm option,         (* compositionality of map, important for nesting *)
    partial_equality_thm : thm,       (* partial version of equality thm *)
    equality_thm : thm,               (* equality acomp ⟹ … ⟹ equality (full_comp acomp …) *)
    used_positions : bool list}

  (* registers @{term equality_of :: "some_type :: linorder equality"}
     where some_type must just be a type without type-arguments *)
  val register_equality_of : string -> local_theory -> local_theory

  val register_foreign_equality :
    typ -> (* type-constant without type-variables *)
    term -> (* equality for type *)
    thm -> (* equality thm for provided equality *)
    local_theory -> local_theory



  val register_foreign_partial_and_full_equality :
    string -> (* long type name *)
    term -> (* map function, should be λx. x, if there is no map *)
    term -> (* partial equality of type ('a => order, 'b)ty => ('a,'b)ty => order,
      where 'a is used, 'b is unused *)
    term -> (* (full) equality of type ('a ⇒ 'a ⇒ order) ⇒ ('a,'b)ty ⇒ ('a,'b)ty ⇒ order,
      where 'a is used, 'b is unused *)
    thm option -> (* comp_def, should be full_comp = pcomp o map acomp ..., important for nesting *)
    thm option -> (* map compositionality, important for nesting *)
    thm -> (* partial eq thm for full equality *)
    thm -> (* full thm: equality a-comp => equality (full_comp a-comp) *)
    bool list -> (*used positions*)
    local_theory -> local_theory

  datatype equality_type = EQ | BNF

  val generate_equalitys_from_bnf_fp :
    string ->                 (* name of type *)
    local_theory ->
    ((term * thm list) list * (* partial equalitys + simp-rules *)
    (term * thm) list) *      (* non-partial equality + def_rule *)
    local_theory

  val generate_equality :
    equality_type ->
    string -> (* name of type *)
    local_theory -> local_theory

  val get_info : Proof.context -> string -> info option

  (* ensures that the info will be available on later requests *)
  val ensure_info : equality_type -> string -> local_theory -> local_theory

end

structure Equality_Generator : EQUALITY_GENERATOR =
struct

open Generator_Aux

datatype equality_type = BNF | EQ

val debug = false
fun debug_out s = if debug then writeln s else ()

val boolT = @{typ bool}
fun compT T = T --> T --> boolT
val orderify = map_atyps (fn T => T --> boolT)
fun pcompT T = orderify T --> T --> boolT

type info =
 {map : term,
  pequality : term,
  equality : term,
  equality_def : thm option,
  map_comp : thm option,
  partial_equality_thm : thm,
  equality_thm : thm,
  used_positions : bool list};

structure Data = Generic_Data (
  type T = info Symtab.table;
  val empty = Symtab.empty;
  val extend = I;
  val merge = Symtab.merge (fn (info1 : info, info2 : info) => #equality info1 = #equality info2);
);

fun add_info T info = Data.map (Symtab.update_new (T, info))

val get_info = Context.Proof #> Data.get #> Symtab.lookup

fun the_info ctxt tyco =
     (case get_info ctxt tyco of
        SOME info => info
      | NONE => error ("no equality information available for type " ^ quote tyco))

fun declare_info tyco m p c c_def m_comp p_thm c_thm used_pos =
  Local_Theory.declaration {syntax = false, pervasive = false} (fn phi =>
    add_info tyco
     {map = Morphism.term phi m,
      pequality = Morphism.term phi p,
      equality = Morphism.term phi c,
      equality_def = Option.map (Morphism.thm phi) c_def,
      map_comp = Option.map (Morphism.thm phi) m_comp,
      partial_equality_thm = Morphism.thm phi p_thm,
      equality_thm = Morphism.thm phi c_thm,
      used_positions = used_pos})

fun register_foreign_partial_and_full_equality tyco m p c c_def m_comp eq_thm c_thm =
  declare_info tyco m p c c_def m_comp eq_thm c_thm

val mk_equality = mk_infer_const @{const_name equality}
val mk_pequality = mk_infer_const @{const_name pequality}

fun default_comp T = absdummy T (absdummy T @{term True}) (*%_ _. True*)

fun register_foreign_equality T comp comp_thm lthy =
  let
    val tyco = (case T of Type (tyco, []) => tyco | _ => error "expected type constant with no arguments")
    val eq = @{thm equalityD2} OF [comp_thm]
  in
    register_foreign_partial_and_full_equality
      tyco (HOLogic.id_const T) comp comp NONE NONE eq comp_thm [] lthy
  end

fun register_equality_of tyco lthy =
  let
    val (T,_) = typ_and_vs_of_typname (Proof_Context.theory_of lthy) tyco @{sort type}
    val comp = HOLogic.eq_const T
    val comp_thm = Thm.instantiate' [SOME (Thm.ctyp_of lthy T)]
      [] @{thm eq_equality}
  in register_foreign_equality T comp comp_thm lthy end


fun generate_equalitys_from_bnf_fp tyco lthy =
  let
    val (tycos, Ts) = mutual_recursive_types tyco lthy
    val _ = map (fn tyco => "generating equality for type " ^ quote tyco) tycos
      |> cat_lines |> writeln
    val (tfrees, used_tfrees) = type_parameters (hd Ts) lthy
    val used_positions = map (member (op =) used_tfrees o TFree) tfrees
    val cs = map (subT "eq") used_tfrees
    val comp_Ts = map compT used_tfrees
    val arg_comps = map Free (cs ~~ comp_Ts)
    val dep_tycos = fold (add_used_tycos lthy) tycos []

    val XTys = Bnf_Access.bnf_types lthy tycos
    val inst_types = typ_subst_atomic (XTys ~~ Ts)
    val cTys = map (map (map inst_types)) (Bnf_Access.constr_argument_types lthy tycos)

    val map_simps = Bnf_Access.map_simps lthy tycos
    val case_simps = Bnf_Access.case_simps lthy tycos
    val maps = Bnf_Access.map_terms lthy tycos
    val map_comp_thms = Bnf_Access.map_comps lthy tycos

    val t_ixs = 0 upto (length Ts - 1)

    val compNs =
      (*TODO: clashes in presence of same type names in different theories*)
      map (Long_Name.base_name) tycos
      |> map (fn s => "equality_" ^ s)

    fun gen_vars prefix = map (fn (i, pty) => Free (prefix ^ ints_to_subscript [i], pty))
      (t_ixs ~~ Ts)

    (* primrec definitions of partial equalitys *)

    fun mk_pcomp (tyco, T) = ("partial_equality_" ^ Long_Name.base_name tyco, pcompT T)

    fun constr_terms lthy =
      Bnf_Access.constr_terms lthy
      #> map (apsnd (map freeify_tvars o fst o strip_type) o dest_Const)

    fun generate_pcomp_eqs lthy (tyco, T) =
      let
        val constrs = constr_terms lthy tyco

        fun comp_arg T x y =
          let
            val m = Generator_Aux.create_map default_comp (K o Free o mk_pcomp) () (K false)
              (#used_positions oo the_info) (#map oo the_info) (K o #pequality oo the_info)
              tycos ((K o K) ()) T lthy
            val p = Generator_Aux.create_partial () (K false)
              (#used_positions oo the_info) (#map oo the_info) (K o #pequality oo the_info)
              tycos ((K o K) ()) T lthy
          in p $ (m $ x) $ y |> infer_type lthy end

        fun generate_eq lthy (c_T as (cN, Ts)) =
          let
            val arg_Ts' = map orderify Ts
            val c = Const (cN, arg_Ts' ---> orderify T)
            val (y, (xs, ys)) = Name.variant "y" (Variable.names_of lthy) |>> Free o rpair T
              ||> (fn ctxt => Name.invent_names ctxt "x" (arg_Ts' @ Ts) |> map Free)
              ||> chop (length Ts)
            val k = find_index (curry (op =) c_T) constrs
            val cases = constrs |> map_index (fn (i, (_, Ts')) =>
              if i < k then fold_rev absdummy Ts' @{term False}
              else if k < i then fold_rev absdummy Ts' @{term False}
              else
                @{term list_all_eq} $ HOLogic.mk_list boolT (@{map 3} comp_arg Ts xs ys)
                |> lambdas ys)
            val lhs = Free (mk_pcomp (tyco, T)) $ list_comb (c, xs) $ y
            val rhs = list_comb (singleton (Bnf_Access.case_consts lthy) tyco, cases) $ y
          in HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs)) |> infer_type lthy end
      in map (generate_eq lthy) constrs end

    val eqs = map (generate_pcomp_eqs lthy) (tycos ~~ Ts) |> flat
    val bindings = tycos ~~ Ts |> map mk_pcomp
      |> map (fn (name, T) => (Binding.name name, SOME T, NoSyn))

    val ((pcomps, pcomp_simps), lthy) =
      lthy
      |> Local_Theory.begin_nested
      |> snd
      |> (BNF_LFP_Rec_Sugar.primrec false [] bindings
          (map (fn t => ((Binding.empty_atts, t), [], [])) eqs))
      |> Local_Theory.end_nested_result
          (fn phi => fn (pcomps, _, pcomp_simps) => (map (Morphism.term phi) pcomps, map (Morphism.fact phi) pcomp_simps))

    (* definitions of equalitys via partial equalitys and maps *)

    fun generate_comp_def tyco lthy =
      let
        val cs = map (subT "eq") used_tfrees
        val arg_Ts = map compT used_tfrees
        val args = map Free (cs ~~ arg_Ts)
        val (pcomp, m) = AList.lookup (op =) (tycos ~~ (pcomps ~~ maps)) tyco |> the
        val ts = tfrees |> map TFree |> map (fn T =>
          AList.lookup (op =) (used_tfrees ~~ args) T |> the_default (default_comp T))
        val rhs = HOLogic.mk_comp (pcomp, list_comb (m, ts)) |> infer_type lthy
        val abs_def = lambdas args rhs
        val name = "equality_" ^ Long_Name.base_name tyco
        val ((comp, (_, prethm)), lthy) =
          Local_Theory.define ((Binding.name name, NoSyn), (Binding.empty_atts, abs_def)) lthy
        val eq = Logic.mk_equals (list_comb (comp, args), rhs)
        val thm = Goal.prove lthy (map (fst o dest_Free) args) [] eq (K (unfold_tac lthy [prethm]))
      in
        Local_Theory.note ((Binding.name (name ^ "_def"), []), [thm]) lthy
        |>> the_single o snd
        |>> `(K comp)
      end

    val ((comps, comp_defs), lthy) =
      lthy
      |> Local_Theory.begin_nested
      |> snd
      |> fold_map generate_comp_def tycos
      |>> split_list
      |> Local_Theory.end_nested_result
        (fn phi => fn (comps, comp_defs) => (map (Morphism.term phi) comps, map (Morphism.thm phi) comp_defs))

    (* alternative simp-rules for equalitys *)

    val full_comps = map (list_comb o rpair arg_comps) comps

    fun generate_comp_simps (tyco, T) lthy =
      let
        val constrs = constr_terms lthy tyco

        fun comp_arg T x y =
          let
            fun create_comp (T as TFree _) =
                  AList.lookup (op =) (used_tfrees ~~ arg_comps) T
                  |> the_default (HOLogic.id_const dummyT)
              | create_comp (Type (tyco, Ts)) =
                  (case AList.lookup (op =) (tycos ~~ comps) tyco of
                    SOME c => list_comb (c, arg_comps)
                  | NONE =>
                      let
                        val {equality = c, used_positions = up, ...} = the_info lthy tyco
                        val ts = (up ~~ Ts) |> map_filter (fn (b, T) =>
                          if b then SOME (create_comp T) else NONE)
                      in list_comb (c, ts) end)
              | create_comp T =
                  error ("unexpected schematic variable " ^ quote (Syntax.string_of_typ lthy T))
            val comp = create_comp T
          in comp $ x $ y |> infer_type lthy end

        fun generate_eq_thm lthy (c_T as (_, Ts)) =
          let
            val (xs, ctxt) = Variable.names_of lthy
              |> fold_map (fn T => Name.variant "x" #>> Free o rpair T) Ts
            fun mk_const (c, Ts) = Const (c, Ts ---> T)
            val comp_const = AList.lookup (op =) (tycos ~~ comps) tyco |> the
            val lhs = list_comb (comp_const, arg_comps) $ list_comb (mk_const c_T, xs)
            val k = find_index (curry (op =) c_T) constrs

            fun mk_eq c ys rhs =
              let
                val y = list_comb (mk_const c, ys)
                val eq = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs $ y, rhs))
              in (ys, eq |> infer_type lthy) end

            val ((ys, eqs), _) = fold_map (fn (i, c as (_, Ts')) => fn ctxt =>
              let
                val (ys, ctxt) = fold_map (fn T => Name.variant "y" #>> Free o rpair T) Ts' ctxt
              in
                (if i < k then mk_eq c ys @{term False}
                else if k < i then mk_eq c ys @{term False}
                else
                  @{term list_all_eq} $ HOLogic.mk_list boolT (@{map 3} comp_arg Ts xs ys)
                  |> mk_eq c ys,
                ctxt)
              end) (tag_list 0 constrs) ctxt
              |> apfst (apfst flat o split_list)

            val dep_comp_defs = map_filter (#equality_def o the_info lthy) dep_tycos
            val dep_map_comps = map_filter (#map_comp o the_info lthy) dep_tycos
            val thms = prove_multi_future lthy (map (fst o dest_Free) (xs @ ys) @ cs) [] eqs
              (fn {context = ctxt, ...} =>
                Goal.conjunction_tac 1
                THEN unfold_tac ctxt
                  (@{thms id_apply o_def} @
                    flat case_simps @
                    flat pcomp_simps @
                    dep_map_comps @ comp_defs @ dep_comp_defs @ flat map_simps))
          in thms end

        val thms = map (generate_eq_thm lthy) constrs |> flat
        val simp_thms = map (Local_Defs.unfold lthy @{thms list_all_eq_unfold}) thms

        val name = "equality_" ^ Long_Name.base_name tyco
      in
        lthy
        |> Local_Theory.note ((Binding.name (name ^ "_simps"), @{attributes [simp, code]}), simp_thms)
        |> snd
        |> (fn lthy => (thms, lthy))
      end

    val (comp_simps, lthy) =
      lthy
      |> Local_Theory.begin_nested
      |> snd
      |> fold_map generate_comp_simps (tycos ~~ Ts)
      |> Local_Theory.end_nested_result (fn phi => map (Morphism.fact phi))

    (* partial theorems *)

    val set_funs = Bnf_Access.set_terms lthy tycos
    val x_vars = gen_vars "x"
    val free_names = map (fst o dest_Free) (x_vars @ arg_comps)
    val xi_vars = map_index (fn (i, _) =>
      map_index (fn (j, pty) => Free ("x" ^ ints_to_subscript [i, j], pty)) used_tfrees) Ts

    fun mk_eq_thm' mk_eq' = map_index (fn (i, ((set_funs, x), xis)) =>
      let
        fun create_cond ((set_t, xi), c) =
          let
            val rhs = mk_eq' lthy c $ xi |> HOLogic.mk_Trueprop
            val lhs = HOLogic.mk_mem (xi, set_t $ x) |> HOLogic.mk_Trueprop
          in Logic.all xi (Logic.mk_implies (lhs, rhs)) end

        val used_sets = map (the o AList.lookup (op =) (map TFree tfrees ~~ set_funs)) used_tfrees
        val conds = map create_cond (used_sets ~~ xis ~~ arg_comps)
        val concl = mk_eq' lthy (nth full_comps i) $ x |> HOLogic.mk_Trueprop
      in Logic.list_implies (conds, concl) |> infer_type lthy end) (set_funs ~~ x_vars ~~ xi_vars)

    val induct_thms = Bnf_Access.induct_thms lthy tycos
    val set_simps = Bnf_Access.set_simps lthy tycos
    val case_thms = Bnf_Access.case_thms lthy tycos
    val distinct_thms = Bnf_Access.distinct_thms lthy tycos
    val inject_thms = Bnf_Access.inject_thms lthy tycos

    val rec_info = (the_info lthy, #used_positions, tycos)
    val split_IHs = split_IHs rec_info

    val unknown_value = false (* effect of choosing false / true not yet visible *)

    fun induct_tac ctxt f =
      ((DETERM o Induction.induction_tac ctxt false
        (map (fn x => [SOME (NONE, (x, unknown_value))]) x_vars) [] [] (SOME induct_thms) [])
      THEN_ALL_NEW (fn i =>
        Subgoal.SUBPROOF (fn {context = ctxt, prems = prems, params = iparams, ...} =>
          f (i - 1) ctxt prems iparams) ctxt i)) 1

    val recursor_tac = std_recursor_tac rec_info used_tfrees
      (fn info => #partial_equality_thm info)

    fun instantiate_IHs IHs pre_conds = map (fn IH =>
      OF_option IH (replicate (Thm.nprems_of IH - length pre_conds) NONE @ map SOME pre_conds)) IHs


    (* partial eq-theorem *)
    val _ = debug_out "Partial equality"
    val eq_thms' = mk_eq_thm' mk_pequality

    fun eq_solve_tac i ctxt IH_prems xs =
      let
        val (i, j) = ind_case_to_idxs cTys i
        val k = length IH_prems - length arg_comps
        val pre_conds = drop k IH_prems
        val IH = take k IH_prems
        val comp_simps = nth comp_simps i
        val case_thm = nth case_thms i
        val distinct_thms = nth distinct_thms i
        val inject_thms = nth inject_thms i
        val set_thms = nth set_simps i
      in
        (* after induction *)
        resolve_tac ctxt @{thms pequalityI} 1
        THEN Subgoal.FOCUS (fn focus =>
          let
            val y = #params focus |> hd
            val yt = y |> snd |> Thm.term_of
            val ctxt = #context focus
            val pre_cond = map (fn pre_cond => Local_Defs.unfold ctxt set_thms pre_cond) pre_conds
            val IH = instantiate_IHs IH pre_cond
            val xs_tys = map (fastype_of o Thm.term_of o snd) xs
            val IHs = split_IHs xs_tys IH

            fun sub_case_tac j' (ctxt, y_simp, _) =
              if j = j' then
                unfold_tac ctxt (y_simp @ comp_simps)
                THEN unfold_tac ctxt @{thms list_all_eq}
                THEN unfold_tac ctxt (@{thms in_set_simps} @ inject_thms @ @{thms refl_True})
                THEN conjI_tac @{thms conj_weak_cong} ctxt xs (fn ctxt' => fn k =>
                  resolve_tac ctxt @{thms pequalityD} 1
                  THEN recursor_tac pre_cond (nth xs_tys k) (nth IHs k) ctxt')
              else
                (* different constructors *)
                unfold_tac ctxt (y_simp @ distinct_thms @ comp_simps @ @{thms bool.simps})
          in
            mk_case_tac ctxt [[SOME yt]] case_thm sub_case_tac
          end
        ) ctxt 1
      end

    val eq_thms' = prove_multi_future lthy free_names [] eq_thms' (fn {context = ctxt, ...} =>
      induct_tac ctxt eq_solve_tac)
    val _ = debug_out (@{make_string} eq_thms')

    (* total theorems *)
    fun mk_eq_sym_trans_thm mk_eq_sym_trans compI2 compE2 thms' =
      let
        val conds = map (fn c => mk_eq_sym_trans lthy c |> HOLogic.mk_Trueprop) arg_comps
        val thms = map (fn i =>
           mk_eq_sym_trans lthy (nth full_comps i)
           |> HOLogic.mk_Trueprop
           |> (fn concl => Logic.list_implies (conds,concl)))
           t_ixs
        val thms = prove_multi_future lthy free_names [] thms (fn {context = ctxt, ...} =>
          ALLGOALS Goal.conjunction_tac
          THEN Method.intros_tac ctxt (@{thm conjI} :: compI2 :: thms') []
          THEN ALLGOALS (eresolve_tac ctxt [compE2]))
      in thms end

    val eq_thms = mk_eq_sym_trans_thm mk_equality @{thm equalityI2} @{thm equalityD2} eq_thms'


    val _ = debug_out "full equality thms"
    fun mk_comp_thm (i, e) =
      let
        val conds = map (fn c => mk_equality lthy c |> HOLogic.mk_Trueprop) arg_comps
        val nearly_thm = e

        val thm =
           mk_equality lthy (nth full_comps i)
           |> HOLogic.mk_Trueprop
           |> (fn concl => Logic.list_implies (conds, concl))
      in
        Goal.prove_future lthy free_names [] thm
          (K (resolve_tac lthy [nearly_thm] 1 THEN ALLGOALS (assume_tac lthy)))
      end
    val comp_thms = map_index mk_comp_thm eq_thms

    val (_, lthy) = fold_map (fn (thm, cname) =>
      Local_Theory.note ((Binding.name cname, []), [thm])) (comp_thms ~~ compNs) lthy

    val _ = debug_out (@{make_string} comp_thms)

    val (_, lthy) = fold_map (fn (thm, cname) =>
      Local_Theory.note ((Binding.name (cname ^ "_pointwise"), []), [thm])) (eq_thms' ~~ compNs) lthy

  in
    ((pcomps ~~ pcomp_simps, comps ~~ comp_defs), lthy)
    ||> fold (fn (((((((tyco, map), pcomp), comp), comp_def), map_comp) , peq_thm), comp_thm) =>
          declare_info tyco map pcomp comp (SOME comp_def) (SOME map_comp)
            peq_thm comp_thm used_positions)
         (tycos ~~ maps ~~ pcomps ~~ comps ~~ comp_defs ~~ map_comp_thms ~~ eq_thms' ~~ comp_thms)
  end

fun generate_equality gen_type tyco lthy =
  let
    val _ = is_some (get_info lthy tyco)
      andalso error ("type " ^ quote tyco ^ " does already have a equality")
  in
    case gen_type of
      BNF => generate_equalitys_from_bnf_fp tyco lthy |> snd
    | EQ => register_equality_of tyco lthy
  end

fun ensure_info gen_type tyco lthy =
  (case get_info lthy tyco of
    SOME _ => lthy
  | NONE => generate_equality gen_type tyco lthy)

fun generate_equality_cmd tyco param = Named_Target.theory_map (
  if param = "eq" then generate_equality EQ tyco
  else if param = "" then generate_equality BNF tyco
  else error ("unknown parameter, expecting no parameter for BNF-datatypes, " ^
         "or \"eq\" for types where the built-in equality \"=\" should be used."))

val _ =
  Theory.setup
    (Derive_Manager.register_derive "equality" "generate an equality function, options are () and (eq)" generate_equality_cmd)

end

Theory Equality_Instances

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
subsection ‹Defining Equality-Functions for Common Types›

theory Equality_Instances
imports
  Equality_Generator
begin

text ‹For all of the following types, we register equality-functions.
  @{type int}, @{type integer}, @{type nat}, @{type char}, @{type bool}, @{type unit}, @{type sum}, @{type option}, @{type list},
  and @{type prod}. For types without type parameters, we use plain @{term "(=)"}, and for the 
  others we use generated ones. These functions will be essential, when the generator is later on
  invoked on types, which in their definition use one these types.›

derive (eq) equality int integer nat char bool unit
derive equality sum list prod option

end

Theory Hash_Generator

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
section ‹Generating Hash-Functions›

theory Hash_Generator
imports
  "../Generator_Aux"
  "../Derive_Manager"
  Collections.HashCode
begin

text ‹As usual, in the generator we use a dedicated function to combine the results
  from evaluating the hash-function of the arguments of a constructor, to deliver
  the global hash-value.›

fun hash_combine :: "hashcode list  hashcode list  hashcode" where
  "hash_combine [] [x] = x"
| "hash_combine (y # ys) (z # zs) = y * z + hash_combine ys zs"
| "hash_combine _ _ = 0"

text ‹The first argument of @{const hash_combine} originates from evaluating the hash-function 
  on the arguments of a constructor, and the second argument of @{const hash_combine} will be static \emph{magic} numbers
  which are generated within the generator.›

subsection ‹Improved Code for Non-Lazy Languages›

lemma hash_combine_unfold: 
  "hash_combine [] [x] = x"
  "hash_combine (y # ys) (z # zs) = y * z + hash_combine ys zs" 
  by auto

subsection ‹The Generator›

ML_file ‹hash_generator.ML›

end

File ‹hash_generator.ML›

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
signature HASHCODE_GENERATOR =
sig

  type info =
   {map : term,                    (* take % x. x, if there is no map *)
    phash : term,                  (* partial hash *)
    hash : term,                   (* full hash *)
    hash_def : thm option,         (* definition of hash, important for nesting *) 
    map_comp : thm option,         (* hashositionality of map, important for nesting *)
    used_positions : bool list}

  (* registers some type which is already instance of hashcode class in hash generator
     where some type must just be a type without type-arguments *)
  val register_hash_of : string -> local_theory -> local_theory

  val register_foreign_hash :
    typ -> (* type-constant without type-variables *)
    term -> (* hash-function for type *)
    local_theory -> local_theory


  val register_foreign_partial_and_full_hash :
    string -> (* long type name *)
    term -> (* map function, should be λx. x, if there is no map *)
    term -> (* partial hash-function of type (hashcode, 'b)ty => hashcode, 
      where 'a is used, 'b is unused *)
    term -> (* (full) hash-function of type ('a ⇒ hashcode) ⇒ ('a,'b)ty ⇒ hashcode,
      where 'a is used, 'b is unused *)
    thm option -> (* hash_def, should be full_hash = phash o map ahash ..., important for nesting *)
    thm option -> (* map compositionality, important for nesting *)
    bool list -> (*used positions*)
    local_theory -> local_theory

  datatype hash_type = HASHCODE | BNF

  val generate_hashs_from_bnf_fp : 
    string ->                 (* name of type *)
    local_theory -> 
    ((term * thm list) list * (* partial hashs + simp-rules *)
    (term * thm) list) *      (* non-partial hash + def_rule *)
    local_theory

  val generate_hash : 
    hash_type -> 
    string -> (* name of type *)
    local_theory -> local_theory

  (* construct hashcode instance for datatype *)
  val hashable_instance : string -> theory -> theory

  val get_info : Proof.context -> string -> info option

  (* ensures that the info will be available on later requests *)
  val ensure_info : hash_type -> string -> local_theory -> local_theory
    
end

structure Hashcode_Generator : HASHCODE_GENERATOR =
struct

open Generator_Aux

datatype hash_type = BNF | HASHCODE

val hash_name = @{const_name "hashcode"}

val hashS = @{sort hashable}
val hashT = @{typ hashcode}
fun hashfunT T = T --> hashT
val hashify = map_atyps (fn _ => hashT)
fun phashfunT T = hashify T --> hashT

val max_int = 2147483648 (* 2 ^^ 31 *)

fun int_of_string s = fold
  (fn c => fn i => (1792318057 * i + Char.ord c) mod max_int)
  (String.explode s)
  0

(* all numbers in int_of_string and create_factors are primes (31-bit) *)

fun create_factor ty_name con_name i =
  (1444315237 * int_of_string ty_name +
  1336760419 * int_of_string con_name +
  2044890737 * (i + 1) 
  ) mod max_int

fun create_hashes ty_name con_name Ts = map (fn i =>
  HOLogic.mk_number hashT (create_factor ty_name con_name i)) (0 upto length Ts)
  |> HOLogic.mk_list hashT

fun create_def_size _ = 10

type info =
 {map : term,
  phash : term,
  hash : term,
  hash_def : thm option,
  map_comp : thm option,
  used_positions : bool list};

structure Data = Generic_Data (
  type T = info Symtab.table;
  val empty = Symtab.empty;
  val extend = I;
  val merge = Symtab.merge (fn (info1 : info, info2 : info) => #hash info1 = #hash info2);
);

fun add_info T info = Data.map (Symtab.update_new (T, info))

val get_info = Context.Proof #> Data.get #> Symtab.lookup

fun the_info ctxt tyco =
     (case get_info ctxt tyco of
        SOME info => info
      | NONE => error ("no hash_code information available for type " ^ quote tyco))

fun declare_info tyco m p c c_def m_hash used_pos =
  Local_Theory.declaration {syntax = false, pervasive = false} (fn phi =>
    add_info tyco
     {map = Morphism.term phi m,
      phash = Morphism.term phi p,
      hash = Morphism.term phi c,
      hash_def = Option.map (Morphism.thm phi) c_def,
      map_comp = Option.map (Morphism.thm phi) m_hash,
      used_positions = used_pos})

fun register_foreign_partial_and_full_hash tyco m p c c_def m_hash eq_thm c_thm =
  declare_info tyco m p c c_def m_hash eq_thm c_thm

fun default_hash T = absdummy T @{term "0 :: hashcode"} (*%_. 0*)

fun register_foreign_hash T hash lthy =
  let 
    val tyco = (case T of Type (tyco, []) => tyco | _ => error "expected type constant with no arguments")
  in
    register_foreign_partial_and_full_hash 
      tyco (HOLogic.id_const T) hash hash NONE NONE [] lthy
  end

fun register_hash_of tyco lthy = 
  let 
    val _ = is_class_instance (Proof_Context.theory_of lthy) tyco hashS
      orelse error ("type " ^ quote tyco ^ " is not an instance of class \"hashable\"")
    val (T,_) = typ_and_vs_of_typname (Proof_Context.theory_of lthy) tyco @{sort type}
    val hash = Const (hash_name, hashfunT T)
  in register_foreign_hash T hash lthy end
                       

fun generate_hashs_from_bnf_fp tyco lthy =
  let
    val (tycos, Ts) = mutual_recursive_types tyco lthy
    val _ = map (fn tyco => "generating hash-function for type " ^ quote tyco) tycos
      |> cat_lines |> writeln
    val (tfrees, used_tfrees) = type_parameters (hd Ts) lthy
    val used_positions = map (member (op =) used_tfrees o TFree) tfrees
    val cs = map (subT "h") used_tfrees
    val hash_Ts = map hashfunT used_tfrees
    val arg_hashs = map Free (cs ~~ hash_Ts)
    val dep_tycos = fold (add_used_tycos lthy) tycos []

    val map_simps = Bnf_Access.map_simps lthy tycos
    val case_simps = Bnf_Access.case_simps lthy tycos
    val maps = Bnf_Access.map_terms lthy tycos
    val map_comp_thms = Bnf_Access.map_comps lthy tycos
    

    (* primrec definitions of partial hashs *)

    fun mk_phash (tyco, T) = ("partial_hash_code_" ^ Long_Name.base_name tyco, phashfunT T)

    fun constr_terms lthy =  
      Bnf_Access.constr_terms lthy 
      #> map (apsnd (map freeify_tvars o fst o strip_type) o dest_Const)

    fun generate_phash_eqs lthy (tyco, T) =
      let
        val constrs = constr_terms lthy tyco 

        fun hash_arg T x =
          let
            val m = Generator_Aux.create_map default_hash (K o Free o mk_phash) () (K false)
              (#used_positions oo the_info) (#map oo the_info) (K o #phash oo the_info)
              tycos ((K o K) ()) T lthy
            val p = Generator_Aux.create_partial () (K false)
              (#used_positions oo the_info) (#map oo the_info) (K o #phash oo the_info)
              tycos ((K o K) ()) T lthy
          in p $ (m $ x) |> infer_type lthy end

        fun generate_eq lthy (cN, Ts) =
          let
            val arg_Ts' = map hashify Ts
            val c = Const (cN, arg_Ts' ---> hashify T)
            val xs = Name.invent_names (Variable.names_of lthy) "x" (arg_Ts') |> map Free
            val lhs = Free (mk_phash (tyco, T)) $ list_comb (c, xs)
            val rhs = @{term hash_combine} $ HOLogic.mk_list hashT (@{map 2} hash_arg Ts xs) $ create_hashes tyco cN Ts
          in HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs)) |> infer_type lthy end
      in map (generate_eq lthy) constrs end

    val eqs = map (generate_phash_eqs lthy) (tycos ~~ Ts) |> flat
    val bindings = tycos ~~ Ts |> map mk_phash
      |> map (fn (name, T) => (Binding.name name, SOME T, NoSyn))
    val ((phashs, phash_simps), lthy) =
      lthy
      |> Local_Theory.begin_nested
      |> snd
      |> BNF_LFP_Rec_Sugar.primrec false [] bindings
          (map (fn t => ((Binding.empty_atts, t), [], [])) eqs)
      |> Local_Theory.end_nested_result
          (fn phi => fn (phashs, _, phash_simps) => (map (Morphism.term phi) phashs, map (Morphism.fact phi) phash_simps))

    (* definitions of hashs via partial hashs and maps *)

    fun generate_hash_def tyco lthy =
      let
        val cs = map (subT "h") used_tfrees
        val arg_Ts = map hashfunT used_tfrees
        val args = map Free (cs ~~ arg_Ts)
        val (phash, m) = AList.lookup (op =) (tycos ~~ (phashs ~~ maps)) tyco |> the
        val ts = tfrees |> map TFree |> map (fn T =>
          AList.lookup (op =) (used_tfrees ~~ args) T |> the_default (default_hash T))
        val rhs = HOLogic.mk_comp (phash, list_comb (m, ts)) |> infer_type lthy
        val abs_def = lambdas args rhs
        val name = "hash_code_" ^ Long_Name.base_name tyco
        val ((hash, (_, prethm)), lthy) =
          Local_Theory.define ((Binding.name name, NoSyn), (Binding.empty_atts, abs_def)) lthy
        val eq = Logic.mk_equals (list_comb (hash, args), rhs)
        val thm = Goal.prove lthy (map (fst o dest_Free) args) [] eq (K (unfold_tac lthy [prethm]))
      in
        Local_Theory.note ((Binding.name (name ^ "_def"), []), [thm]) lthy
        |>> the_single o snd
        |>> `(K hash)
      end
    val ((hashs, hash_defs), lthy) =
      lthy
      |> Local_Theory.begin_nested
      |> snd
      |> fold_map generate_hash_def tycos
      |>> split_list
      |> Local_Theory.end_nested_result
          (fn phi => fn (hashs, hash_defs) => (map (Morphism.term phi) hashs, map (Morphism.thm phi) hash_defs))

    (* alternative simp-rules for hashs *)

    fun generate_hash_simps (tyco, T) lthy =
      let
        val constrs = constr_terms lthy tyco

        fun hash_arg T x =
          let
            fun create_hash (T as TFree _) =
                  AList.lookup (op =) (used_tfrees ~~ arg_hashs) T
                  |> the_default (HOLogic.id_const dummyT)
              | create_hash (Type (tyco, Ts)) =
                  (case AList.lookup (op =) (tycos ~~ hashs) tyco of
                    SOME c => list_comb (c, arg_hashs)
                  | NONE =>
                      let
                        val {hash = c, used_positions = up, ...} = the_info lthy tyco
                        val ts = (up ~~ Ts) |> map_filter (fn (b, T) =>
                          if b then SOME (create_hash T) else NONE)
                      in list_comb (c, ts) end)
              | create_hash T =
                  error ("unexpected schematic variable " ^ quote (Syntax.string_of_typ lthy T))
            val hash = create_hash T
          in hash $ x |> infer_type lthy end

        fun generate_eq_thm lthy (c_T as (cN, Ts)) =
          let
            val xs = Variable.names_of lthy
              |> fold_map (fn T => Name.variant "x" #>> Free o rpair T) Ts |> fst
            fun mk_const (c, Ts) = Const (c, Ts ---> T)
            val hash_const = AList.lookup (op =) (tycos ~~ hashs) tyco |> the
            val lhs = list_comb (hash_const, arg_hashs) $ list_comb (mk_const c_T, xs)
            val rhs = @{term hash_combine} $ HOLogic.mk_list hashT (@{map 2} hash_arg Ts xs) $ create_hashes tyco cN Ts
            val eq = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs)) |> infer_type lthy

            val dep_hash_defs = map_filter (#hash_def o the_info lthy) dep_tycos
            val dep_map_comps = map_filter (#map_comp o the_info lthy) dep_tycos
            val thms = prove_multi_future lthy (map (fst o dest_Free) xs @ cs) [] [eq]
              (fn {context = ctxt, ...} =>
                Goal.conjunction_tac 1
                THEN unfold_tac ctxt
                  (@{thms id_apply o_def} @
                    flat case_simps @
                    flat phash_simps @
                    dep_map_comps @ hash_defs @ dep_hash_defs @ flat map_simps))
          in thms end

        val thms = map (generate_eq_thm lthy) constrs |> flat
        val simp_thms = map (Local_Defs.unfold lthy @{thms hash_combine_unfold}) thms
        
        val name = "hash_code_" ^ Long_Name.base_name tyco
      in
        lthy
        |> Local_Theory.note ((Binding.name (name ^ "_simps"), @{attributes [simp, code]}), simp_thms)
        |> snd
        |> (fn lthy => (thms, lthy))
      end

    val lthy =
      lthy
      |> Local_Theory.begin_nested
      |> snd
      |> fold_map generate_hash_simps (tycos ~~ Ts)
      |> snd
      |> Local_Theory.end_nested

  in
    ((phashs ~~ phash_simps, hashs ~~ hash_defs), lthy)
    ||> fold (fn (((((tyco, map), phash), hash), hash_def), map_comp) =>
          declare_info tyco map phash hash (SOME hash_def) (SOME map_comp) 
            used_positions)
         (tycos ~~ maps ~~ phashs ~~ hashs ~~ hash_defs ~~ map_comp_thms)
  end

fun generate_hash gen_type tyco lthy =
  let 
    val _ = is_some (get_info lthy tyco)
      andalso error ("type " ^ quote tyco ^ " does already have a hash")
  in
    case gen_type of 
      BNF => generate_hashs_from_bnf_fp tyco lthy |> snd
    | HASHCODE => register_hash_of tyco lthy
  end
  
fun ensure_info gen_type tyco lthy =
  (case get_info lthy tyco of
    SOME _ => lthy
  | NONE => generate_hash gen_type tyco lthy)

fun dest_hash ctxt tname =
  (case get_info ctxt tname of
    SOME {hash = c, ...} =>
      let
        val Ts = fastype_of c |> strip_type |> fst |> `((fn x => x - 1) o length) |> uncurry take
      in (c, Ts) end
  | NONE => error ("no hash info for type " ^ quote tname))

fun all_tys hash free_types =
  let
    val Ts = fastype_of hash |> strip_type |> fst |> List.last |> dest_Type |> snd
  in rename_types (Ts ~~ free_types) end

fun mk_hash_rhs c Ts =
  list_comb (c, map (fn T => Const (hash_name, T)) Ts)

fun mk_hash_def T rhs =
  Logic.mk_equals (Const (hash_name, hashfunT T), rhs)

fun hashable_instance tname thy =
  let
    val _ = is_class_instance thy tname hashS
      andalso error ("type " ^ quote tname ^ " is already an instance of class \"hashcode\"")
    val _ = writeln ("deriving \"hashable\" instance for type " ^ quote tname)
    val thy = Named_Target.theory_map (ensure_info BNF tname) thy
    val {used_positions = us, ...} = the (get_info 
        (Named_Target.theory_init thy) tname) 

    val (_, xs) = typ_and_vs_of_used_typname tname us hashS
    val (_, (hashs_thm,lthy)) =
      Class.instantiation ([tname], xs, hashS) thy
      |> (fn ctxt =>
        let
          val (c, Ts) = dest_hash ctxt tname
          val typ_mapping = all_tys c (map TFree xs)
          val hash_rhs = mk_hash_rhs c Ts
          val hash_def = mk_hash_def dummyT hash_rhs |> typ_mapping |> infer_type ctxt

          val ty = Term.fastype_of (snd (Logic.dest_equals hash_def)) |> Term.dest_Type |> snd |> hd
          val ty_it = Type (@{type_name itself}, [ty])
          val hashs_rhs = lambda (Free ("x",ty_it)) (HOLogic.mk_number @{typ nat} (create_def_size tname))
          val hashs_def = mk_def (ty_it --> @{typ nat}) @{const_name def_hashmap_size} hashs_rhs

          val basename = Long_Name.base_name tname
        in
          Generator_Aux.define_overloaded_generic
           ((Binding.name ("hashcode_" ^ basename ^ "_def"),
            @{attributes [code]}),
            hash_def) ctxt
          ||> define_overloaded ("def_hashmap_size_" ^ basename ^ "_def", hashs_def)
        end)
  in
    Class.prove_instantiation_exit (fn ctxt =>
      Class.intro_classes_tac ctxt []
      THEN unfold_tac ctxt [hashs_thm]
      THEN simp_tac ctxt 1
      ) lthy
  end

fun generate_hash_cmd tyco param = Named_Target.theory_map (
  if param = "hashcode" then generate_hash HASHCODE tyco
  else if param = "" then generate_hash BNF tyco
  else error ("unknown parameter, expecting no parameter for BNF-datatypes, " ^
         "or \"hashcode\" for types where the class-instance hashcode should be used."))

val _ =
  Theory.setup
    (Derive_Manager.register_derive 
      "hash_code" "generate a hash function, options are () and (hashcode)" generate_hash_cmd
    #> Derive_Manager.register_derive 
      "hashable" 
      "register types in class hashable" 
      (fn tyname => K (hashable_instance tyname)))

end

Theory Hash_Instances

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
subsection ‹Defining Hash-Functions for Common Types›

theory Hash_Instances
imports
  Hash_Generator
begin

text ‹For all of the following types, we register hashcode-functions.
  @{type int}, @{type integer}, @{type nat}, @{type char}, @{type bool}, @{type unit}, @{type sum}, @{type option}, @{type list},
  and @{type prod}. For types without type parameters, we use plain @{const "hashcode"}, and for the 
  others we use generated ones.›

derive (hashcode) hash_code int integer bool char unit nat

derive hash_code prod sum option list 

text ‹There is no need to derive hashable prod sum option list› since all of these types 
  are already instances of class @{class hashable}. Still the above command is necessary to register
  these types in the generator.›

end

Theory Countable_Generator

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
section ‹Countable Datatypes›

theory Countable_Generator
imports 
  "HOL-Library.Countable"
  "../Derive_Manager"
begin

text ‹
Brian Huffman and Alexander Krauss (old datatype), and Jasmin Blanchette (BNF datatype) 
have developed tactics which automatically can prove that a datatype is countable.
We just make this tactic available in the derive-manager so that
one can conveniently write \texttt{derive countable some-datatype}.
›

subsection "Installing the tactic"

text ‹
There is nothing more to do, then to write some boiler-plate ML-code
for class-instantiation.
›

setup let 
    fun derive dtyp_name _ thy = 
      let
        val base_name = Long_Name.base_name dtyp_name
        val _ = writeln ("proving that datatype " ^ base_name ^ " is countable")
        val sort = @{sort countable}
        val vs = 
          let val i = BNF_LFP_Compat.the_spec thy dtyp_name |> #1 
          in map (fn (n,_) => (n, sort)) i end
        val thy' = Class.instantiation ([dtyp_name],vs,sort) thy
          |> Class.prove_instantiation_exit (fn ctxt => countable_tac ctxt 1)
        val _ = writeln ("registered " ^ base_name ^ " in class countable")
      in thy' end
  in 
    Derive_Manager.register_derive "countable" "register datatypes is class countable" derive
  end

end

Theory Derive

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
section ‹Loading Existing Derive-Commands›
theory Derive
imports 
  "Comparator_Generator/Compare_Instances"
  "Equality_Generator/Equality_Instances"
  "Hash_Generator/Hash_Instances"
  "Countable_Generator/Countable_Generator"
begin

text‹
We just load the commands to derive comparators, equality-functions, hash-functions, and the
command to show that a datatype is countable, so that now all of them are available.
There are further generators available in the AFP entries Containers and Show.
›

print_derives

end

Theory Derive_Examples

(*  Title:       Deriving class instances for datatypes
    Author:      Christian Sternagel and René Thiemann  <christian.sternagel|rene.thiemann@uibk.ac.at>
    Maintainer:  Christian Sternagel and René Thiemann 
    License:     LGPL
*)
section Examples

theory Derive_Examples
imports 
  Derive
  "Comparator_Generator/Compare_Order_Instances"
  "Equality_Generator/Equality_Instances"
  HOL.Rat
begin

subsection "Rational Numbers"

text ‹The rational numbers are not a datatype, so it will not be possible to derive 
  corresponding instances of comparators, hashcodes, etc. via the generators. But we can and should
  still register the existing instances, so that later datatypes are supported 
  which use rational numbers.› 

text ‹Use the linear order on rationals to define the @{class compare_order}-instance.›
derive (linorder) compare_order rat

text ‹Use @{term "(=) :: rat => rat => bool"} as equality function.›
derive (eq) equality rat

text ‹First manually define a hashcode function.›

instantiation rat :: hashable
begin
definition "def_hashmap_size = (λ_ :: rat itself. 10)"
definition "hashcode (r :: rat) = hashcode (quotient_of r)"
instance
  by (intro_classes)(simp_all add: def_hashmap_size_rat_def)
end

text ‹And then register it at the generator.›

derive (hashcode) hash_code rat

subsection "A Datatype Without Nested Recursion"

datatype 'a bintree = BEmpty | BNode "'a bintree" 'a "'a bintree"

derive compare_order bintree
derive countable bintree
derive equality bintree
derive hashable bintree

subsection "Using Other datatypes"

datatype nat_list_list = NNil | CCons "nat list × rat option" nat_list_list

derive compare_order nat_list_list
derive countable nat_list_list
derive (eq) equality nat_list_list
derive hashable nat_list_list

subsection "Mutual Recursion"

datatype
  'a mtree = MEmpty | MNode 'a "'a mtree_list" and
  'a mtree_list = MNil | MCons "'a mtree" "'a mtree_list"

derive compare_order mtree mtree_list
derive countable mtree mtree_list
derive hashable mtree mtree_list

text ‹For derive (equality|comparator|hash_code) mutual_recursive_type› 
  there is the speciality that only one of the mutual recursive types has to be mentioned in
  order to register all of them. So one of @{type mtree} and @{type mtree_list} suffices.›

derive equality mtree 
 
subsection "Nested recursion"

datatype 'a tree = Empty | Node 'a "'a tree list"
datatype 'a ttree = TEmpty | TNode 'a "'a ttree list tree"

derive compare_order tree ttree
derive countable tree ttree
derive equality tree ttree
derive hashable tree ttree

subsection ‹Examples from \isafor›

datatype ('f,'v) "term" = Var 'v | Fun 'f "('f,'v) term list"
datatype ('f, 'l) lab =
  Lab "('f, 'l) lab" 'l
| FunLab "('f, 'l) lab" "('f, 'l) lab list"
| UnLab 'f
| Sharp "('f, 'l) lab"

derive compare_order "term" lab
derive countable "term" lab
derive equality "term" lab
derive hashable "term" lab

subsection "A Complex Datatype"
text ‹
The following datatype has nested and mutual recursion, and
uses other datatypes.
›

datatype ('a, 'b) complex = 
  C1 nat "'a ttree × rat + ('a,'b) complex list" |
  C2 "('a, 'b) complex list tree tree" 'b "('a, 'b) complex" "('a, 'b) complex2 ttree list"
and ('a, 'b) complex2 = D1 "('a, 'b) complex ttree"

text ‹On this last example type we illustrate the difference of the various comparator- and order-generators.

  For @{type complex} we create an instance of @{class compare_order} which also defines
  a linear order. Note however that the instance will 
  be @{type complex} :: (@{class compare}, @{class compare}) @{class compare_order}, i.e., the 
  argument types have to be in class @{class compare}. 

  For @{type complex2} we only derive @{class compare} which is not a subclass of @{class linorder}.
  The instance will be @{type complex2} :: (@{class compare}, @{class compare}) @{class compare}, i.e., 
  again the argument types have to be in class @{class compare}.

  To avoid the dependence on @{class compare}, we can also instruct derive› to be based on 
  @{class linorder}. Here, the command derive linorder complex2› will create the instance
  @{type complex2} :: (@{class linorder}, @{class linorder}) @{class linorder}, i.e., 
  here the argument types have to be in class @{class linorder}.
  ›
derive compare_order complex 
derive compare complex2
derive linorder complex2
derive countable complex complex2
derive equality complex
derive hashable complex complex2

end