Session Datatype_Order_Generator

Theory Derive_Aux

theory Derive_Aux
imports 
  Deriving.Derive_Manager
begin

ML_file ‹derive_aux.ML›

end

File ‹derive_aux.ML›

(* auxiliary functions which might be useful when deriving something for
   datatypes via recursors *)

signature DERIVE_AUX =
sig
  (* split_last [a,...,z] -> ([a,...,y],z) *)
  val split_last : 'a list -> 'a list * 'a

  (* p1 ⟶ p2 ⟶ … ⟶ p_n ⟶ r *)
  val HOLogic_list_implies : term list * term -> term

  (* p1 /\ ... /\ pn *)
  val HOLogic_list_conj : term list -> term

  (* ∀ x1 ... xn . p *)
  val HOLogic_list_all : term list * term -> term

  (* rulifys P1 .. Pn in a thm P1 ==> ... ==> Pn ==> Q *)
  val rulify_only_asm : Proof.context -> thm -> thm

(* takes a list of implications, an induction theorem, and a tactic to handle each case,
   and delivers the major implication.
   Example: if imp_list is [([p1 x y, q1 x y], [x,y]), ([p2 x' y', q2 x' y'], [x',y'])]
            and ind_thm is ... ==> P1 x /\ P2 x'
            then it encodes p1 x y ==> q1 x y &&& p2 x' y' ==> q2 x' y'
            which is converted to HOL (! y. p1 x y --> q1 x y) /\ (! y'. p2 x' y' --> q2 x' y')
            which is proven by the instantiated ind_thm where, e.g. P1 = % x. (! y. p1 x y --> q1 x y)
            and where in the IH all HOL-constructs are rulified again.
            As a result, only the first implication is returned: "p1 x y ==> q1 x y"
   Purpose: encountered problems with induct-tac
     - P1 and P2 are two large for internal unification, so they must be provided
     - if P1 and P2 are provided, then one has to use HOL-constructs (for arbitrary choice of y and y')
     - in induct_tac, then the IH are not converted to nice meta-implications/quantifications
*)
  val inductive_thm : theory -> (term list * term list) list -> thm -> sort ->
      (*                idx    ih_hyps     ih_prems    case_vars    arbi_vars *)
      (Proof.context -> int -> thm list -> thm list -> term list -> term list -> tactic) -> thm

(* delivers a typ substitution which constrains all free type variables in datatype by sort *)
  val typ_subst_for_sort : theory -> Old_Datatype_Aux.info -> sort -> typ -> typ

(* delivers a full type from a type name by instantiating the type-variables of that
   type with different variables of a given sort, also returns the chosen variables
   as second component *)
  val typ_and_vs_of_typname : theory -> string -> sort -> typ * (string * sort) list

(* identity and number recursive occurrences of datatypes *)
  val dt_number_recs : Old_Datatype_Aux.dtyp list -> int * (int * int) list

(* like print_tac, but is turned off by default to not exceed tracing limit *)
  val my_print_tac : Proof.context -> string -> tactic

(* generates a theorem over two variables, where induction over the first one performed,
   and then in every case one performs immediately a case analysis on the second variable.
   It is assumed that if the constructors are different, then the goal is proven by some
   standard tactic, whereas for same constructors, one has to provide a tactic *)
  val mk_binary_thm :
    (theory -> Old_Datatype_Aux.info -> sort -> 'a -> (term list * term list) list) -> (* mk_prop_trm *)
    (theory -> Old_Datatype_Aux.info -> sort -> (int -> term) * ('b * int * int) list) -> (* mk_bin_idx *)
    string -> (* bin_const_name *)
    theory ->
    Old_Datatype_Aux.info ->
    'a -> (* property_generator *)
    sort ->
    (* same_constructor_tac *)
    (Proof.context -> thm list -> thm list -> thm -> (Proof.context -> thm list -> tactic) -> int -> term list ->
     term list -> string * Old_Datatype_Aux.dtyp list -> (Old_Datatype_Aux.dtyp -> term -> term -> term) -> tactic)
    -> thm

  val mk_case_tac : Proof.context ->
    term option list list -> (* usually [[SOME term_to_perform_the_case]] *)
    thm -> (* exhaust theorem *)
    (*               i-th case, prems, newly obtained arguments *)
    (Proof.context * int * thm list * (string * cterm) list -> tactic)
    -> tactic

  val prop_trm_to_major_imp : (term list * 'a) list -> term * 'a

(* delivers "x_i" of corresponding datatype of idx-th type for datatype *)
(*                                                 idx    i *)
  val mk_xs : theory -> Old_Datatype_Aux.info -> sort -> int -> int -> term

(* create Some t *)
  val mk_Some : term -> term

(* my_simp_set should be HOL_ss + the other simplification stuff for orders like simprocs, ... *)
  val my_simp_set : simpset

  val mk_solve_with_tac : Proof.context -> thm list -> tactic -> tactic

  val define_overloaded : (string * term) -> local_theory -> thm * local_theory

  val define_overloaded_generic : (Attrib.binding * term) -> local_theory -> thm * local_theory

  val mk_def : typ -> string -> term -> term

end


structure Derive_Aux : DERIVE_AUX =
struct

val printing = false
fun my_print_tac ctxt = if printing then print_tac ctxt else (fn _ => Seq.single)

fun split_last xs = (take (length xs - 1) xs, List.last xs)

(* FIXME: reconsolidate with similar functions in the Isabelle repository and move to HOLogic *)
fun HOLogic_list_implies (prems,r) = fold_rev (fn r => fn p => HOLogic.mk_imp (r,p)) prems r
fun HOLogic_list_conj [] = @{term true}
  | HOLogic_list_conj [x] = x
  | HOLogic_list_conj (x :: xs) = HOLogic.mk_conj (x, HOLogic_list_conj xs)
fun HOLogic_list_all (xs,p) = fold_rev (fn (x,ty) => fn p => HOLogic.mk_all (x,ty,p)) (map dest_Free xs) p

fun mk_Some t = let
    val ty = fastype_of t
  in
    Const (@{const_name Some}, ty --> Type (@{type_name option}, [ty])) $ t
  end

fun rulify_only_asm ctxt thm =
  (@{thm conjI[OF TrueI]} OF [thm]) (* add conj to prohibit rulify in conclusion *)
  |> Object_Logic.rulify ctxt (* rulify everything, i.e., by preprocessing only the assms *)
  |> (fn conj => (@{thm conjunct2} OF [conj])) (* drop conjunction again *)

fun permute_for_ind_thm ps xs ind_thm =
  let
    val n = length ps
    val vs_p = Thm.prop_of ind_thm |> Term.add_vars |> (fn f => f [] |> rev)
    fun number_ih_vars _ [] = []
      | number_ih_vars i (P :: x :: pxs) = ((P,i) :: (x,i+n) :: number_ih_vars (i+1) pxs)
      | number_ih_vars _ _ = error "odd number of vars in ind-thm"
    val vs_c = Thm.concl_of ind_thm |> Term.add_vars |> (fn f => f [] |> rev) |> number_ih_vars 0
    val permutation = map (AList.lookup (op =) vs_c #> the) vs_p
  in
    map (nth (ps @ xs)) permutation
  end


fun inductive_thm thy (imp_list : (term list * term list) list) ind_thm sort ind_tac =
  let
    val imps = map
      (fn (imps,xs) => HOLogic_list_all (tl xs, HOLogic_list_implies (split_last imps)))
      imp_list
    val ind_term =
      HOLogic_list_conj imps
      |> HOLogic.mk_Trueprop
    val nr_prems = length (hd imp_list |> fst) - 1
    val nr_arbi = length (hd imp_list |> snd) - 1
    val xs = map (snd #> hd) imp_list
    val ps = xs ~~ imps
      |> map (fn (x,imp) => lambda x imp)
    val insts = permute_for_ind_thm ps xs ind_thm
    val xs_strings = map (dest_Free #> fst) xs
    val conjunctive_thm = Goal.prove_global_future thy xs_strings [] ind_term
      (fn {context = ctxt, ...} =>
        let
          val ind_thm_inst = infer_instantiate' ctxt (map (SOME o Thm.cterm_of ctxt) insts) ind_thm
          val ind_thm' = rulify_only_asm (Proof_Context.init_global thy) ind_thm_inst
        in
          (DETERM o Induct.induct_tac ctxt false [] [] [] (SOME [ind_thm']) [])
          THEN_ALL_NEW
          (fn i => Subgoal.SUBPROOF
            (fn {context = ctxt, prems = prems, params = iparams, ...} =>
              let
                val m = length prems - nr_prems
                val ih_prems = drop m prems
                val ih_hyps = take m prems
                val tparams = map (snd #> Thm.term_of) iparams
                val m' = length tparams - nr_arbi
                val arbi_vars = drop m' tparams
                val case_vars = take m' tparams
              in
                ind_tac ctxt (i-1) ih_hyps ih_prems case_vars arbi_vars
              end
            ) ctxt i
          )
        end 1
      )
    (* extract first conjunct *)
    val first_conj = if length imp_list > 1 then @{thm conjunct1} OF [conjunctive_thm] else conjunctive_thm
    (* and replace ⟶ and ∀ by meta-logic (for those ⟶ and ∀ which have been constructed) *)
    val elim_spec = funpow nr_arbi (fn thm => @{thm spec} OF [thm]) first_conj
    val elim_imp = funpow nr_prems (fn thm => @{thm mp} OF [thm]) elim_spec
  in elim_imp end

fun typ_subst_for_sort thy info sort =
  let
    val spec = BNF_LFP_Compat.the_spec thy (#descr info |> hd |> (fn (_,(dty_name,_,_)) => dty_name))
    val typ_subst = Term.typ_subst_atomic (spec |> fst |> map (fn (n,s) => (TFree (n,s), TFree (n,sort))))
  in typ_subst end

fun typ_and_vs_of_typname thy typ_name sort =
  let
    val ar = Sign.arity_number thy typ_name
    val sorts = map (K sort) (1 upto ar)
    val ty_vars = Name.invent_names (Name.make_context [typ_name]) "a" sorts
    val ty = Type (typ_name,map TFree ty_vars)
  in (ty,ty_vars)
end


fun dt_number_recs dtys =
  let
    fun dtrecs [] j = (j,[])
      | dtrecs (Old_Datatype_Aux.DtTFree _ :: dtys) j = dtrecs dtys j
      | dtrecs (Old_Datatype_Aux.DtRec i :: dtys) j =
          let
            val (j',ijs) = dtrecs dtys (j+1)
          in (j',(i,j) :: ijs) end
      | dtrecs (Old_Datatype_Aux.DtType (_,dtys1) :: dtys2) j =
          let
            val (j',ijs) = dtrecs dtys1 j
            val (j'',ijs') = dtrecs dtys2 j'
          in (j'',ijs @ ijs') end
  in dtrecs dtys 0
end

(* code copied from HOL/SPARK/TOOLS *)
fun define_overloaded_generic (binding,eq) lthy =
  let
    val ((c, _), rhs) = eq |> Syntax.check_term lthy |>
      Logic.dest_equals |>> dest_Free;
    val ((_, (_, thm)), lthy') = Local_Theory.define
      ((Binding.name c, NoSyn), (binding, rhs)) lthy
    val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy');
    val thm' = singleton (Proof_Context.export lthy' ctxt_thy) thm
  in (thm', lthy')
end

fun define_overloaded (name,eq) = define_overloaded_generic ((Binding.name name, @{attributes [code]}),eq)


fun mk_def T c rhs = Logic.mk_equals (Const (c, T), rhs)

(* construct free variable x_i *)
fun mk_free_tysubst_i typ_subst x i ty = Free (x ^ string_of_int i, ty |> typ_subst)


fun mk_xs thy info sort idx i =
  let
    val descr = #descr info
    val typ_subst = typ_subst_for_sort thy info sort
    fun typ_of dty = Old_Datatype_Aux.typ_of_dtyp descr dty |> typ_subst
    val mk_free_i = mk_free_tysubst_i typ_subst
  in
    mk_free_i ("x_" ^ string_of_int idx ^ "_") i (typ_of (Old_Datatype_Aux.DtRec idx))
  end

fun prop_trm_to_major_imp prop =
  hd prop
  |> (fn (p,v) => (
    map (HOLogic.mk_Trueprop) p
    |> split_last
    |> Logic.list_implies,
    v))


fun mk_case_tac (ctxt : Proof.context)
  (insts : term option list list)
  (thm : thm)
  (sub_case_tac : Proof.context * int * thm list * (string * cterm) list -> tactic) =
    (
      DETERM o Induct.cases_tac ctxt false insts (SOME thm) []
      THEN_ALL_NEW (fn i => Subgoal.SUBPROOF (fn {context = ctxt, prems = hyps, params = params, ...}
        => sub_case_tac (ctxt, i-1, hyps, params)) ctxt i)
    )
    1

fun mk_solve_with_tac ctxt thms solver_tac =
  SOLVE (Method.insert_tac ctxt thms 1 THEN solver_tac)

fun simps_of_info info = #case_rewrites info @ #rec_rewrites info @ #inject info @ #distinct info

val my_simp_set =
  simpset_of (@{context}
    delsimps (simpset_of @{context} |> dest_ss |> #simps |> map snd)
    addsimps @{thms HOL.simp_thms})

fun mk_binary_thm mk_prop_trm mk_bin_idx bin_const_name thy (info : Old_Datatype_Aux.info) prop_gen sort same_constructor_tac =
  let
    fun bin_const ty = Const (bin_const_name, ty --> ty --> @{typ bool})
    val prop_props = mk_prop_trm thy info sort prop_gen
    val (mk_rec,nrec_args) = mk_bin_idx thy info sort
    val typ_subst = typ_subst_for_sort thy info sort
    val descr = #descr info
    fun typ_of dty = Old_Datatype_Aux.typ_of_dtyp descr dty |> typ_subst
    fun mk_binary_term (Old_Datatype_Aux.DtRec i) = mk_rec i
      | mk_binary_term dty =
          let
            val ty = typ_of dty
          in bin_const ty end;
    fun mk_binary dty x y = mk_binary_term dty $ x $ y;
    val ind_thm = #induct info
    val prop_thm_of_tac = inductive_thm thy prop_props ind_thm sort
    fun ind_case_tac ctxt i hyps ihprems params_x ys =
      let
        val y = hd ys
        val (j,idx) = nth nrec_args i |> (fn (_,j,idx) => (j,idx))
        val linfo = nth descr idx |> (fn (_,(ty_name,_,_)) => ty_name)
          |> BNF_LFP_Compat.the_info thy []
        fun solve_with_tac ctxt thms =
          let val simp_ctxt =
            (ctxt
              |> Context_Position.set_visible false
              |> put_simpset my_simp_set)
              addsimps (simps_of_info info @ simps_of_info linfo)
          in mk_solve_with_tac ctxt thms (force_tac simp_ctxt 1) end
        fun case_tac ctxt = mk_case_tac ctxt [[SOME y]] (#exhaust linfo)
        fun sub_case_tac (ctxt,k,prems,iparams_y) =
          let
            val case_hyp_y = hd prems
          in
            if not (j = k)
            then my_print_tac ctxt ("different constructors ") THEN solve_with_tac ctxt (case_hyp_y :: ihprems) (* different constructor *)
            else
              let
                val params_y = map (snd #> Thm.term_of) iparams_y
                val c_info = nth descr idx |> snd |> (fn (_,_,info) => nth info j)
              in
                my_print_tac ctxt ("consider constructor " ^ string_of_int k)
                THEN same_constructor_tac ctxt hyps ihprems case_hyp_y solve_with_tac j params_x params_y c_info mk_binary
              end
          end
      in my_print_tac ctxt ("start induct " ^ string_of_int i) THEN case_tac ctxt sub_case_tac end
    val prop_thm = prop_thm_of_tac ind_case_tac
  in prop_thm end

fun mk_case_tac ctxt
  (insts : term option list list)
  (thm : thm)
  (sub_case_tac : Proof.context * int * thm list * (string * cterm) list -> tactic) =
    (
      DETERM o Induct.cases_tac ctxt false insts (SOME thm) []
      THEN_ALL_NEW (fn i => Subgoal.SUBPROOF (fn {context = ctxt, prems = hyps, params = params, ...}
        => sub_case_tac (ctxt, i-1, hyps, params)) ctxt i)
    )
    1

end

Theory Order_Generator

(*  Title:       Deriving class instances for datatypes
    Author:      René Thiemann       <rene.thiemann@uibk.ac.at>
    Maintainer:  René Thiemann
    License:     LGPL
*)

(*
Copyright 2013 René Thiemann

This file is part of IsaFoR/CeTA.

IsaFoR/CeTA is free software: you can redistribute it and/or modify it under the
terms of the GNU Lesser General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

IsaFoR/CeTA is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License along
with IsaFoR/CeTA. If not, see <http://www.gnu.org/licenses/>.
*)

section ‹Generating linear orders for datatypes›

theory Order_Generator
imports 
  Derive_Aux
begin

subsection Introduction

text ‹

The order generator registers itself at the derive-manager for the classes @{class ord},
@{class order}, and @{class linorder}.
To be more precise,
it automatically generates the two functions @{term "(≤)"} and @{term "(<)"} for some datatype 
\texttt{dtype} and
proves the following instantiations.

\begin{itemize}
\item \texttt{instantiation dtype :: (ord,\ldots,ord) ord}
\item \texttt{instantiation dtype :: (order,\ldots,order) order}
\item \texttt{instantiation dtype :: (linorder,\ldots,linorder) linorder}
\end{itemize}

All the non-recursive types that are used in the datatype must have similar instantiations.
For recursive type-dependencies this is automatically generated.

For example, for the \texttt{datatype tree = Leaf nat | Node "tree list"} we require that
@{type nat} is already in @{class linorder}, whereas for @{type list} nothing is required, since for the 
\texttt{tree}
datatype the @{type list} is only used recursively.

However, if we define \texttt{datatype tree = Leaf "nat list" | Node tree tree} then 
@{type list} must
provide the above instantiations.

Note that when calling the generator for @{class linorder}, it will automatically also derive the instantiations 
for @{class order}, which in turn invokes the generator for @{class ord}. 
A later invokation of @{class linorder}
after @{class order} or @{class ord} is not possible.
›

subsection "Implementation Notes"

text ‹
The generator uses the recursors from the datatype package to define a lexicographic order.
E.g., for a declaration 
\texttt{datatype 'a tree = Empty | Node "'a tree" 'a "'a tree"}
this will semantically result in
\begin{verbatim}
(Empty < Node _ _ _) = True
(Node l1 l2 l3 < Node r1 r2 r3) = 
  (l1 < r1 || l1 = r1 && (l2 < r2 || l2 = r2 && l3 < r3))
(_ < _) = False
(l <= r) = (l < r || l = r)
\end{verbatim}

The desired properties (like @{term "x < y  y < z  x < z"}) 
of the orders are all proven using induction (with the induction theorem from the datatype on @{term x}),
and afterwards there is a case distinction on the remaining variables, i.e., here @{term y} and @{term z}.
If the constructors of @{term x}, @{term y}, and @{term z} are different always some basic tactic is invoked. 
In the other case (identical constructors) for each property a dedicated tactic was designed.
›

subsection "Features and Limitations"

text ‹
The order generator has been developed mainly for datatypes without explicit mutual recursion. 
For mutual recursive datatypes---like
\texttt{datatype a = C b and b = D a a}---only
for the first mentioned datatype---here \texttt{a}---the instantiations of the order-classes are
derived.

Indirect recursion like in \texttt{datatype tree = Leaf nat | Node "tree list"} should work 
without problems.
›

subsection "Installing the generator"

lemma linear_cases: "(x :: 'a :: linorder) = y  x < y  y < x" by auto

ML_file ‹order_generator.ML› 

end

File ‹order_generator.ML›

signature ORDER_GENERATOR =
sig
  (* 1. pair of result creates the rhs for the < operator, idx 0 results in < *)
  (* 2. pair of result contains list of arguments for recursor, each indexed by
        first the constructor number
        second the index number
  *)
  (*                          dtyp_info            order   idx        *)
  val mk_less_idx : theory -> Old_Datatype_Aux.info -> sort -> (int -> term) * (term * int * int) list;

  (* given an idx, x, and y, it creates x <= y *)
  (*                             dtyp_info            order   idx     x       y   *)
  val mk_less_eq_idx : theory -> Old_Datatype_Aux.info -> sort -> (int -> term -> term -> term);

  (* proves the transitivity theorems (for ≤ and < ) *)
  val mk_transitivity_thms : theory -> Old_Datatype_Aux.info -> thm * thm;

  (* proves the theorem (x < y) = (x ≤ y ∧ ¬ y ≤ x) *)
  val mk_less_le_not_le_thm : theory -> Old_Datatype_Aux.info -> thm

  (* proves the theorem (x ≤ x) *)
  val mk_le_refl_thm : theory -> Old_Datatype_Aux.info -> thm

  (* proves the theorem (x ≤ y ⟹ y ≤ x ⟹ x = y) *)
  (* takes as input the transitivity thm for < and the less_le_not_le thm *)
  val mk_antisym_thm : theory -> Old_Datatype_Aux.info -> thm -> thm -> thm

  (* proves the theorem (x ≤ y ∨ y ≤ x) *)
  val mk_linear_thm : theory -> Old_Datatype_Aux.info -> thm

  (* proves all four theorems which are required for orders: trans, refl, antisym, less_le_not_le *)
  val mk_order_thms : theory -> Old_Datatype_Aux.info -> thm list

  (* creates and registers (0 = ord, 1 = order, 2 = linear-order) for datatype *)
  val derive : int -> string -> string -> theory -> theory
end


structure Order_Generator : ORDER_GENERATOR =
struct

open Derive_Aux

val less_name = @{const_name "Orderings.less"}

(* construct free variable x_i *)
fun mk_free_tysubst_i typ_subst x i ty = Free (x ^ string_of_int i, ty |> typ_subst)

fun mk_less_idx thy info sort =
  let
    val typ_subst = typ_subst_for_sort thy info sort
    val descr = #descr info
    fun typ_of dty = Old_Datatype_Aux.typ_of_dtyp descr dty |> typ_subst
    val rec_names = #rec_names info
    val mk_free_i = mk_free_tysubst_i typ_subst
    fun rec_idx i dtys = dt_number_recs (take i dtys) |> fst
    fun mk_rhss (idx,(ty_name,_,cons)) =
      let
        val ty = typ_of (Old_Datatype_Aux.DtRec idx)
        val linfo = BNF_LFP_Compat.the_info thy [] ty_name
        val case_name = #case_name linfo
        fun mk_rhs (i,(_,dtysi)) =
          let
            val lvars = map_index (fn (i,dty) => mk_free_i "x_" i (typ_of dty)) dtysi
            fun res_var (i,oc) = mk_free_i "res_" oc (typ_of (Old_Datatype_Aux.DtRec i) --> @{typ bool});
            val res_vars = dt_number_recs dtysi
                     |> snd
                     |> map res_var
            fun mk_case (j,(_,dtysj)) =
              let
                val rvars = map_index (fn (i,dty) => mk_free_i "y_" i (typ_of dty)) dtysj
                val x = nth lvars
                val y = nth rvars
                fun combine_dts [] = @{term False}
                  | combine_dts ((_,c) :: []) = c
                  | combine_dts ((i,c) :: ics) = HOLogic.mk_disj (c, HOLogic.mk_conj (HOLogic.mk_eq (x i, y i), combine_dts ics))
                fun less_of_dty (i,Old_Datatype_Aux.DtRec j) = res_var (j,rec_idx i dtysj) $ y i
                  | less_of_dty (i,_) =
                      let
                        val xi = x i
                        val ty = Term.type_of xi
                        val less = Const (less_name, ty --> ty --> @{typ bool})
                      in less $ xi $ y i end
                val rhs =
                  if i < j then @{term True}
                  else if i > j then @{term False}
                  else map_index less_of_dty dtysi
                    |> map_index I
                    |> combine_dts
                val lam_rhs = fold lambda (rev rvars) rhs
              in lam_rhs end
            val cases = map_index mk_case cons
            val case_ty = (map type_of cases @ [ty]) ---> @{typ bool}
            val rhs_case = list_comb (Const (case_name, case_ty), cases)
            val rhs = fold lambda (rev (lvars @ res_vars)) rhs_case
          in rhs end
        val rec_args = map_index (fn (i,c) => (mk_rhs (i,c),i,idx)) cons
      in rec_args end
    val nrec_args = maps mk_rhss descr
    val rec_args = map #1 nrec_args
    fun mk_rec i =
      let
        val ty = typ_of (Old_Datatype_Aux.DtRec i)
        val rec_ty = map type_of rec_args @ [ty,ty] ---> @{typ bool}
        val rec_name = nth rec_names i
        val rhs = list_comb (Const (rec_name, rec_ty), rec_args)
      in rhs end
  in (mk_rec,nrec_args) end

fun mk_less_eq_idx thy info sort idx x y =
  mk_less_idx thy info sort
  |> fst
  |> (fn less => HOLogic.mk_disj (less idx $ x $ y, HOLogic.mk_eq (x,y)))

fun mk_prop_trm thy info sort
  (gen : (int -> term) -> (term -> term -> term)list -> term list * term list) =
  let
    fun main idx =
      let
        val xs = mk_xs thy info sort idx
        fun less a b = (mk_less_idx thy info sort |> fst) idx $ a $ b
        val less_eq = mk_less_eq_idx thy info sort idx
      in gen xs [less, less_eq] end
  in #descr info
    |> map (fst #> main)
  end

fun mk_prop_major_trm thy info sort gen =
  mk_prop_trm thy info sort gen |> prop_trm_to_major_imp


fun mk_trans_thm_trm thy info =
  mk_prop_trm thy info @{sort "order"}
  (fn xs => fn [less,_] =>
    let val (x,y,z) = (xs 1, xs 2, xs 3)
    in ([less x y, less y z, less x z], [x,y,z]) end)

fun mk_trans_eq_thm_trm thy info =
  mk_prop_major_trm thy info @{sort "order"}
  (fn xs => fn [_, lesseq] =>
    let val (x,y,z) = (xs 1, xs 2, xs 3)
    in ([lesseq x y, lesseq y z, lesseq x z], [x,y,z]) end)

fun mk_less_disj mk_less px py dtys =
  let
    fun build_disj [] _ _ = @{term False}
      | build_disj (px :: xs) (py :: ys) (dty :: dtys) =
          HOLogic.mk_disj (mk_less dty px py, HOLogic.mk_conj (HOLogic.mk_eq (px,py),build_disj xs ys dtys))
  in
    HOLogic.mk_Trueprop (build_disj px py dtys)
  end;

fun simps_of_info info = #case_rewrites info @ #rec_rewrites info @ #inject info @ #distinct info

fun mk_transitivity_thms thy (info : Old_Datatype_Aux.info) =
  let
    val ctxt = Proof_Context.init_global thy
    (* first prove transitivity of < *)
    val trans_props = mk_trans_thm_trm thy info
    val sort = @{sort "order"}
    val (mk_rec,nrec_args) = mk_less_idx thy info sort
    val typ_subst = typ_subst_for_sort thy info sort
    val descr = #descr info
    fun typ_of dty = Old_Datatype_Aux.typ_of_dtyp descr dty |> typ_subst
    fun mk_less_term (Old_Datatype_Aux.DtRec i) = mk_rec i
      | mk_less_term dty =
          let
            val ty = typ_of dty
          in Const (less_name, ty --> ty --> @{typ bool}) end;
    fun mk_less dty x y = mk_less_term dty $ x $ y;
    val ind_thm = #induct info
    val trans_thm_of_tac = inductive_thm thy trans_props ind_thm sort
    fun ind_case_tac ctxt i hyps [xy,yz] params_x [y,z] =
      let
        val (j,idx) = nth nrec_args i |> (fn (_,j,idx) => (j,idx))
        val linfo = nth descr idx |> (fn (_,(ty_name,_,_)) => ty_name)
          |> BNF_LFP_Compat.the_info thy []
        fun solve_with_tac ctxt thms =
          let
            val simp_ctxt =
              (ctxt
                |> Context_Position.set_visible false
                |> put_simpset my_simp_set)
                addsimps (simps_of_info info @ simps_of_info linfo @ thms)
          in mk_solve_with_tac simp_ctxt thms (asm_full_simp_tac simp_ctxt 1) end

        fun case_tac ctxt y_z = mk_case_tac ctxt [[SOME y_z]] (#exhaust linfo)
        fun sub_case_tac (ctxt,k,prems,iparams_y) =
          let
            val case_hyp_y = hd prems
            fun sub_sub_case_tac (ctxt,l,prems,iparams_z) =
              let
                val case_hyp_z = hd prems
                val comp_eq = [case_hyp_z, case_hyp_y, xy, yz]
              in
                (if not (j = l andalso l = k)
                then
                  K (solve_with_tac ctxt comp_eq)
                else
                  let
                    val params_y = map (snd #> Thm.term_of) iparams_y
                    val params_z = map (snd #> Thm.term_of) iparams_z
                    val c_info = nth descr idx |> snd |> (fn (_,_,info) => nth info j)
                    val pdtys = snd c_info
                    val build_disj = mk_less_disj mk_less
                    val xy' = build_disj params_x params_y pdtys
                    val yz' = build_disj params_y params_z pdtys
                    fun disj_thm t = Goal.prove_future ctxt [] [] t (K (solve_with_tac ctxt comp_eq))
                    val xy_disj = disj_thm xy'
                    val yz_disj = disj_thm yz'
                    fun solve_tac xy _ [] _ _ _ _ _ = K (solve_with_tac ctxt [xy])
                      | solve_tac xy yz (px :: pxs) (py :: pys) (pz :: pzs) (dty :: dtys) eqs ihyps =
                          let
                            fun case_tac_disj ctxt disj tac =
                              mk_case_tac ctxt [] (@{thm disjE} OF [disj]) (fn (ctxt,ii,hyps,_) => tac ii ctxt (List.last hyps))
                            fun yz_case_tac ctxt = case_tac_disj ctxt yz
                            val rec_type = (fn Old_Datatype_Aux.DtRec _ => true | _ => false) dty
                            fun xy_tac ii ctxt hyp_xy =
                              if ii = 1 (* right branch, px = py and pxs < pys *)
                              then
                                let
                                  val eq_term = HOLogic.mk_eq (px,py) |> HOLogic.mk_Trueprop
                                  val eq_xy_thm = Goal.prove_future ctxt [] [] eq_term (K (solve_with_tac ctxt [hyp_xy]))
                                  val xy'_thm = Goal.prove_future ctxt [] [] (build_disj pxs pys dtys) (K (solve_with_tac ctxt [hyp_xy]))
                                  fun yz_tac jj ctxt hyp_yz =
                                  if jj = 1 (* right branch, py = pz and pys < pzs *)
                                    (* = and = *)
                                  then
                                    let
                                      val eq_term = HOLogic.mk_eq (px,pz) |> HOLogic.mk_Trueprop
                                      val eq_thm  = Goal.prove_future ctxt [] [] eq_term (K (solve_with_tac ctxt [eq_xy_thm,hyp_yz]))
                                      val yz'_thm = Goal.prove_future ctxt [] [] (build_disj pys pzs dtys) (K (solve_with_tac ctxt [hyp_yz]))
                                      val drop_hyps = if rec_type then tl else I
                                    in
                                      solve_tac xy'_thm yz'_thm pxs pys pzs dtys (eq_thm :: eqs) (drop_hyps ihyps) 1
                                    end
                                  else (* left branch, py < pz *)
                                    (* = and < *)
                                    let
                                      val xz_term = mk_less dty px pz |> HOLogic.mk_Trueprop
                                      val xz_thm  = Goal.prove_future ctxt [] [] xz_term (K (solve_with_tac ctxt [hyp_xy,hyp_yz]))
                                    in
                                      solve_with_tac ctxt (xz_thm :: case_hyp_z :: eqs)
                                    end
                                in
                                  yz_case_tac ctxt yz_tac
                                end
                              else (* left branch, px < py *)
                                let
                                  val xz_term = mk_less dty px pz |> HOLogic.mk_Trueprop
                                  fun yz_tac jj ctxt hyp_yz =
                                    if jj = 1 (* right branch, py = pz *)
                                      (* < and = *)
                                    then
                                      let
                                        val xz_thm = Goal.prove_future ctxt [] [] xz_term (K (solve_with_tac ctxt [hyp_xy,hyp_yz]))
                                      in
                                        solve_with_tac ctxt (xz_thm :: case_hyp_z :: eqs)
                                      end
                                    else (* left branch, py < pz *)
                                      (* < and < *)
                                      let
                                        val trans_thm = if rec_type then hd ihyps else @{thm less_trans}
                                        val tac = resolve_tac ctxt [trans_thm OF [hyp_xy,hyp_yz]] 1
                                        val xz_thm = Goal.prove_future ctxt [] [] xz_term (K tac)
                                      in
                                        solve_with_tac ctxt (xz_thm :: case_hyp_z :: eqs)
                                      end
                                in
                                  yz_case_tac ctxt yz_tac
                                end;
                            val xy_case_tac = case_tac_disj ctxt xy xy_tac
                          in
                            K (my_print_tac ctxt ("another case: ") THEN xy_case_tac)
                          end
                  in
                    K (my_print_tac ctxt "recursive case: ")
                    THEN' solve_tac xy_disj yz_disj params_x params_y params_z pdtys [] hyps
                  end
                ) 1
              end
          in
            my_print_tac ctxt ("consider constructor " ^ string_of_int k) THEN
            (if k >= j then case_tac ctxt z sub_sub_case_tac else
               solve_with_tac ctxt [case_hyp_y,xy])
          end (* end sub_case tac *)
      in
        my_print_tac ctxt ("start induct " ^ string_of_int i) THEN case_tac ctxt y sub_case_tac
      end (* end ind_case tac *)
    val trans_thm =  trans_thm_of_tac ind_case_tac
    val (trans_eq_trm,vars) = mk_trans_eq_thm_trm thy info
    val inst_trans = infer_instantiate' ctxt (map (SOME o Thm.cterm_of ctxt) vars) trans_thm
    val trans_eq_vars_string = map (dest_Free #> fst) vars
    fun tac_to_eq_thm tac = Goal.prove_global_future thy trans_eq_vars_string [] trans_eq_trm (K tac)
    val eq_tac = mk_solve_with_tac ctxt [inst_trans] (blast_tac ctxt 1)
    val trans_eq_thm = tac_to_eq_thm eq_tac
  in
    (trans_thm,trans_eq_thm)
  end

val mk_binary_less_thm = mk_binary_thm mk_prop_trm mk_less_idx less_name

fun mk_less_le_not_le_thm thy info =
  let
    val sort = @{sort "order"}
    (* main property: x < y ⟹ ¬ y ≤ x *)
    fun prop_gen xs [less,lesseq] =
      let
        val (x,y) = (xs 1, xs 2)
      in ([less x y, lesseq y x |> HOLogic.mk_not], [x,y]) end
    fun main_tac ctxt ih_hyps ih_prems y_prem solve_with_tac _ params_x params_y c_info mk_less =
      let
        val pdtys = snd c_info
        val comp_eq = y_prem :: ih_prems
        val build_disj = mk_less_disj mk_less
        val xy' = build_disj params_x params_y pdtys
        val xy_disj = Goal.prove_future ctxt [] [] xy' (K (solve_with_tac ctxt comp_eq))
        fun solve_tac ctxt xy [] _ _ _ _ = solve_with_tac ctxt [xy]
          | solve_tac ctxt xy (px :: pxs) (py :: pys) (dty :: dtys) eqs ihyps =
              let
                val xs_ys = build_disj pxs pys dtys
                val x_eq_y = HOLogic.mk_eq (px,py)
                val x_less_y = mk_less dty px py
                val disj2 =
                  HOLogic.mk_disj (x_less_y, HOLogic.mk_conj (HOLogic.mk_not x_less_y, HOLogic.mk_conj( x_eq_y, HOLogic.dest_Trueprop xs_ys)))
                  |> HOLogic.mk_Trueprop
                val disj2_thm = Goal.prove_future ctxt [] [] disj2 (K (Method.insert_tac ctxt [xy] 1 THEN blast_tac ctxt 1))
                fun case_tac_disj disj tac =
                  mk_case_tac ctxt [] (@{thm disjE} OF [disj]) (fn (ctxt,ii,hyps,_) => tac ii ctxt (List.last hyps))
                val rec_type = (fn Old_Datatype_Aux.DtRec _ => true | _ => false) dty
                fun xy_tac ii ctxt hyp_xy =
                  if ii = 1
                  then (* right branch, px = py and ¬ px < py and pxs < pys *)
                    let
                      val eq_term = x_eq_y |> HOLogic.mk_Trueprop
                      val eq_xy_thm = Goal.prove_future ctxt [] [] eq_term (K (solve_with_tac ctxt [hyp_xy]))
                      val xy'_thm = Goal.prove_future ctxt [] [] xs_ys (K (solve_with_tac ctxt [hyp_xy]))
                      val yx_thm =
                        Goal.prove_future ctxt [] [] (mk_less dty py px |> HOLogic.mk_not |> HOLogic.mk_Trueprop)
                        (K (solve_with_tac ctxt [hyp_xy]))
                      val ihyps' = if rec_type then tl ihyps else ihyps
                      val solve_rec = solve_tac ctxt xy'_thm pxs pys dtys (eq_xy_thm :: yx_thm :: eqs) ihyps'
                    in
                      solve_rec
                    end
                  else (* left branch, px < py *)
                       (* hence ¬ py ≤ px (yx_thm) *)
                    let
                      val yx = HOLogic.mk_disj (mk_less dty py px, HOLogic.mk_eq (py,px)) |> HOLogic.mk_not |> HOLogic.mk_Trueprop
                      val tac = if rec_type then solve_with_tac ctxt [hd ihyps OF [hyp_xy]] else solve_with_tac ctxt [hyp_xy]
                      val yx_thm = Goal.prove_future ctxt [] [] yx (K tac)
                    in
                      solve_with_tac ctxt (yx_thm :: y_prem :: eqs)
                    end;
              in
                case_tac_disj disj2_thm xy_tac
              end (* end solve tac *)
      in
        (solve_tac ctxt xy_disj params_x params_y pdtys [] ih_hyps : tactic)
      end (* end main_tac *)
    val main_thm = mk_binary_less_thm thy info prop_gen sort main_tac
    val ctxt = Proof_Context.init_global thy
    val (thm_trm,vars) = mk_prop_major_trm thy info sort (fn xs => fn [less,lesseq] =>
      let val (x,y) = (xs 1, xs 2)
      in
        ([HOLogic.mk_eq (less x y,HOLogic.mk_conj (lesseq x y, lesseq y x |> HOLogic.mk_not))], [x,y])
      end)
    val inst_thm = infer_instantiate' ctxt (map (SOME o Thm.cterm_of ctxt) vars) main_thm
    val vars_strings = map (dest_Free #> fst) vars
    val thm =
      Goal.prove_future ctxt vars_strings [] thm_trm
        (K (Method.insert_tac ctxt [inst_thm] 1 THEN blast_tac ctxt 1))
  in
    thm
  end

fun mk_le_refl_thm thy info =
  let
    val sort = @{sort "order"}
    (* x ≤ x *)
    val ctxt = Proof_Context.init_global thy
    val (thm_trm,vars) = mk_prop_major_trm thy info sort (fn xs => fn [_,lesseq] =>
      let val x = xs 1
      in
        ([lesseq x x],[x])
      end)
    val vars_strings = map (dest_Free #> fst) vars
  in
    Goal.prove_future ctxt vars_strings [] thm_trm (K (blast_tac ctxt 1))
  end

fun mk_antisym_thm thy info trans_thm less_thm =
  let
    val sort = @{sort "order"}
    (* x ≤ y ⟹ y ≤ x ⟹ x = y *)
    val ctxt = Proof_Context.init_global thy
    val (thm_trm,vars) = mk_prop_major_trm thy info sort (fn xs => fn [_,lesseq] =>
      let val (x,y) = (xs 1, xs 2)
      in
        ([lesseq x y, lesseq y x, HOLogic.mk_eq (x,y)],[x,y])
      end)
    val vars_strings = map (dest_Free #> fst) vars
    val tvars = vars @ [hd vars]
    val lvars = [hd vars,hd vars]
    fun inst_thm vars thm = infer_instantiate' ctxt (map (SOME o Thm.cterm_of ctxt) vars) thm
    val inst_trans = inst_thm tvars trans_thm
    val inst_less = inst_thm lvars less_thm
    val res =
      Goal.prove_future ctxt vars_strings [] thm_trm
        (K (Method.insert_tac ctxt [inst_trans,inst_less] 1 THEN blast_tac ctxt 1))
  in
    res
  end

fun mk_order_thms thy info =
  let
    val (trans,trans_eq) = mk_transitivity_thms thy info
    val less = mk_less_le_not_le_thm thy info
    val refl = mk_le_refl_thm thy info
    val antisym = mk_antisym_thm thy info trans less
  in
    [trans_eq,less,refl,antisym]
  end


fun mk_linear_thm thy info =
  let
    val sort = @{sort "linorder"}
    (* main property: x = y ∨ x < y ∨ y < x *)
    fun prop_gen xs [less,_] =
      let val (x,y) = (xs 1, xs 2)
      in
        ([HOLogic.mk_disj (HOLogic.mk_eq (x,y),HOLogic.mk_disj(less x y, less y x))],[x,y])
      end
    fun main_tac ctxt ih_hyps _ y_prem solve_with_tac _ params_x params_y c_info mk_less =
      let
        val pdtys = snd c_info
        fun solve_tac [] _ _ eqs _ = solve_with_tac ctxt eqs
          | solve_tac (px :: pxs) (py :: pys) (dty :: dtys) eqs ihyps =
              let
                val less = mk_less dty
                val x_eq_y = HOLogic.mk_eq (px,py)
                val disj_trm = HOLogic.mk_disj (x_eq_y,HOLogic.mk_disj(less px py, less py px)) |> HOLogic.mk_Trueprop
                val rec_type = (fn Old_Datatype_Aux.DtRec _ => true | _ => false) dty
                val disj_thm' = if rec_type then hd ihyps else @{thm linear_cases}
                val disj_tac = resolve_tac ctxt [disj_thm'] 1
                val disj_thm = Goal.prove_future ctxt [] [] disj_trm (K (disj_tac))
                fun case_tac_disj disj tac =
                  mk_case_tac ctxt [] (@{thm disjE} OF [disj]) (fn (ctxt,ii,hyps,_) => tac ii ctxt (List.last hyps))
                fun eq_less_less_tac ii _ eq_less =
                  if ii = 0
                  then (* left branch, px = py *)
                    let
                      val ihyps' = if rec_type then tl ihyps else ihyps
                      val solve_rec = solve_tac pxs pys dtys (eq_less :: eqs) ihyps'
                    in
                      solve_rec
                    end
                  else (* right branch, px < py ∨ py < px *)
                    let
                      fun less_tac _ _ less = solve_with_tac ctxt (less :: eqs)
                    in
                      case_tac_disj eq_less less_tac
                    end;
              in
                case_tac_disj disj_thm eq_less_less_tac
              end (* end solve tac *)
      in
        (solve_tac params_x params_y pdtys [y_prem] ih_hyps : tactic)
      end (* end main tac *)
    val main_thm = mk_binary_less_thm thy info prop_gen sort main_tac
    val ctxt = Proof_Context.init_global thy
    (* x ≤ y ∨ y ≤ x *)
    val (thm_trm,vars) = mk_prop_major_trm thy info sort (fn xs => fn [_,lesseq] =>
      let val (x,y) = (xs 1, xs 2)
      in
        ([HOLogic.mk_disj (lesseq x y,lesseq y x)],[x,y])
      end)
    val inst_thm = infer_instantiate' ctxt (map (SOME o Thm.cterm_of ctxt) vars) main_thm
    val vars_strings = map (dest_Free #> fst) vars
    val thm =
      Goal.prove_future ctxt vars_strings [] thm_trm
        (K (Method.insert_tac ctxt [inst_thm] 1 THEN blast_tac ctxt 1))
  in
    thm
  end

fun derive kind dtyp_name _ thy =
  let
    val tyco = dtyp_name

    (* first register in class ord *)
    val base_name = Long_Name.base_name tyco
    val _ = writeln ("creating orders for datatype " ^ base_name)
    val sort = @{sort ord}
    val info = BNF_LFP_Compat.the_info thy [] tyco
    val vs_of_sort =
      let val i = BNF_LFP_Compat.the_spec thy tyco |> #1
      in fn sort => map (fn (n,_) => (n, sort)) i end
    val vs = vs_of_sort sort
    val less_rhs = mk_less_idx thy info sort |> fst |> (fn x => x 0)
    val ty = Term.fastype_of less_rhs |> Term.dest_Type |> snd |> hd
    fun mk_binrel_def T = mk_def (T --> T --> HOLogic.boolT)
    val less_def = mk_binrel_def ty @{const_name less} less_rhs
    val x = Free ("x",ty)
    val y = Free ("y",ty)
    val less_eq_rhs = lambda x (lambda y (HOLogic.mk_disj (less_rhs $ x $ y, HOLogic.mk_eq (x,y))))
    val less_eq_def = mk_binrel_def ty @{const_name less_eq} less_eq_rhs
    val ((less_thm,less_eq_thm),lthy) = Class.instantiation ([tyco],vs,sort) thy
      |> define_overloaded ("less_" ^ base_name ^ "_def", less_def)
      ||>> define_overloaded ("less_eq_" ^ base_name ^ "_def", less_eq_def)
    val less_thms = [less_thm, less_eq_thm]

    val thy' = Class.prove_instantiation_exit (fn ctxt => Class.intro_classes_tac ctxt []) lthy
    val _ = writeln ("registered " ^ base_name ^ " in class ord")

    (* next register in class order *)
    val thy'' =
      if kind < 1 then thy'
      else
        let
          val sort = @{sort order}
          val vs = vs_of_sort sort
          val [trans_eq,less,refl,antisym] = mk_order_thms thy info
          val lthy = Class.instantiation ([tyco],vs,sort) thy'

          fun order_tac ctxt =
            my_print_tac ctxt "enter order" THEN
            unfold_tac ctxt less_thms THEN
            my_print_tac ctxt "after unfolding" THEN
            resolve_tac ctxt [less] 1 THEN
            my_print_tac ctxt "after less" THEN
            resolve_tac ctxt [refl] 1 THEN
            my_print_tac ctxt "after refl" THEN
            resolve_tac ctxt [trans_eq] 1 THEN assume_tac ctxt 1 THEN assume_tac ctxt 1 THEN
            my_print_tac ctxt "after trans" THEN
            resolve_tac ctxt [antisym] 1 THEN assume_tac ctxt 1 THEN assume_tac ctxt 1
          val thy'' =
            Class.prove_instantiation_exit (fn ctxt => Class.intro_classes_tac ctxt [] THEN order_tac ctxt) lthy
          val _ = writeln ("registered " ^ base_name ^ " in class order")
        in thy'' end

    (* next register in class linorder *)
    val thy''' =
      if kind < 2 then thy''
      else
        let
          val sort = @{sort linorder}
          val vs = vs_of_sort sort
          val lthy = Class.instantiation ([tyco],vs,sort) thy''
          val linear = mk_linear_thm thy info
          fun order_tac ctxt =
            unfold_tac ctxt less_thms THEN
            resolve_tac ctxt [linear] 1
          val thy''' = Class.prove_instantiation_exit (fn ctxt => Class.intro_classes_tac ctxt [] THEN order_tac ctxt) lthy
          val _ = writeln ("registered " ^ base_name ^ " in class linorder")
        in thy''' end

  in thy''' end

val _ =
  Theory.setup
   (Derive_Manager.register_derive "ord" "derives ord for a datatype" (derive 0) #>
    Derive_Manager.register_derive "order" "derives an order for a datatype" (derive 1) #>
    Derive_Manager.register_derive "linorder" "derives a linear order for a datatype" (derive 2))

end

Theory Hash_Generator

(*  Title:       Deriving class instances for datatypes
    Author:      René Thiemann       <rene.thiemann@uibk.ac.at>
    Maintainer:  René Thiemann
    License:     LGPL
*)

(*
Copyright 2013 René Thiemann

This file is part of IsaFoR/CeTA.

IsaFoR/CeTA is free software: you can redistribute it and/or modify it under the
terms of the GNU Lesser General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

IsaFoR/CeTA is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License along
with IsaFoR/CeTA. If not, see <http://www.gnu.org/licenses/>.
*)

section ‹Hash functions›

theory Hash_Generator
imports 
  Collections.HashCode
  Derive_Aux
begin

subsection "Introduction"

text ‹
The interface for hash-functions is defined in the class @{class hashable} which has been developed
as part of the Isabelle Collection Framework \cite{rbt}. It requires a hash-function
(@{const hashcode}), a bounded hash-function (@{const bounded_hashcode}),
and a default hash-table size (@{const def_hashmap_size}).

The @{const hashcode} function for each datatype are created by instantiating the recursors of that 
datatype appropriately. E.g., for \texttt{datatype 'a test = C1 'a 'a | C2 "'a test list"} 
we get a hash-function which is equivalent to 
\begin{verbatim}
hashcode (C1 a b) = c1 * hashcode a + c2 * hashcode b
hashcode (C2 Nil) = c3
hashcode (C2 (a # as)) = c4 * hashcode a + c5 * hashcode as
\end{verbatim}
where each \texttt{c$_{i}$} is a non-negative 32-bit number which is dependent on the
datatype name, the constructor name, and the occurrence of the argument (i.e., 
in the example \texttt{c1} and \texttt{c2} will usually be different numbers.)
These parameters are used in linear combination with prime numbers to hopefully
get some useful hash-function.

The @{const bounded_hashcode} functions are constructed in the same way, except that after each
arithmetic operation a modulo operation is performed.

Finally, the default hash-table size is just set to 10, following Java's default
hash-table constructor.
›

subsection "Features and Limitations"

text ‹
We get same limitation as for the order generator. 
For mutual recursive datatypes, only
for the first mentioned datatype the instantiations of the @{class hashable}-class are
derived.
›

subsection "Installing the generator"

lemma hash_mod_lemma: "1 < (n :: nat)  x mod n < n" by auto

ML_file ‹hash_generator.ML›

end

File ‹hash_generator.ML›

signature HASH_GENERATOR =
sig
  (* creates the hash function (possible bounded by some parameter) *)
  (*                          dtyp_info                *)
  val mk_hash : theory -> Old_Datatype_Aux.info -> term;

  (* creates and registers hash-functions for datatype *)
  val derive : string -> string -> theory -> theory
end

structure Hash_Generator : HASH_GENERATOR =
struct

open Derive_Aux

val max_int = 2147483648 (* 2 ^^ 31 *)

fun int_of_string s = fold
  (fn c => fn i => (1792318057 * i + Char.ord c) mod max_int)
  (String.explode s)
  0

(* all numbers in int_of_string and create_factors are primes (31-bit) *)

fun create_factor ty_name con_name idx i j =
  (1444315237 * int_of_string ty_name +
  1336760419 * int_of_string con_name +
  2044890737 * (idx + 1) +
  1622892797 * (i+1) +
  2140823281 * (j+1)) mod max_int

fun create_def_size _ = 10

val hash_name = @{const_name "hashcode"}

(* construct free variable x_i *)
fun mk_free_tysubst_i typ_subst x i ty = Free (x ^ string_of_int i, ty |> typ_subst)

fun mk_hash thy info =
  let
    val sort = @{sort hashable}
    val typ_subst = typ_subst_for_sort thy info sort
    val descr = #descr info
    val ty_name = info |> #descr |> hd |> snd |> #1
    val cons_hash = create_factor ty_name
    val mk_num = HOLogic.mk_number @{typ hashcode}
    fun typ_of dty = Old_Datatype_Aux.typ_of_dtyp descr dty |> typ_subst
    val rec_names = #rec_names info
    val mk_free_i = mk_free_tysubst_i typ_subst
    fun rec_idx i dtys = dt_number_recs (take i dtys) |> fst
    fun mk_rhss (idx,(_,_,cons)) =
      let
        fun mk_rhs (i,(cname,dtysi)) =
          let
            val lvars = map_index (fn (i,dty) => mk_free_i "x_" i (typ_of dty)) dtysi
            fun res_var (_,oc) = mk_free_i "res_" oc (@{typ hashcode});
            val res_vars = dt_number_recs dtysi
              |> snd
              |> map res_var
            val x = nth lvars
            fun combine_dts [] = mk_num (cons_hash cname idx i 0)
              | combine_dts ((_,c) :: ics) = @{term "(+) :: hashcode => hashcode => hashcode"} $ c $ combine_dts ics
            fun multiply j t =
              let
                val mult = mk_num (cons_hash cname idx i (j+1))
              in @{term "(*) :: hashcode => hashcode => hashcode"} $ mult $ t end
            fun hash_of_dty (i,Old_Datatype_Aux.DtRec j) = res_var (j,rec_idx i dtysi) |> multiply i
              | hash_of_dty (i,_) =
                  let
                    val xi = x i
                    val ty = Term.type_of xi
                    val hash = Const (hash_name, ty --> @{typ hashcode}) $ xi
                  in hash |> multiply i end
            val pre_rhs = map_index hash_of_dty dtysi
              |> map_index I
              |> combine_dts
            val rhs = fold lambda (rev (lvars @ res_vars)) pre_rhs
          in rhs end
        val rec_args = map_index (fn (i,c) => (mk_rhs (i,c),i,idx)) cons
      in rec_args end
    val nrec_args = maps mk_rhss descr
    val rec_args = map #1 nrec_args
    fun mk_rec i =
      let
        val ty = typ_of (Old_Datatype_Aux.DtRec i)
        val rec_ty = map type_of rec_args @ [ty] ---> @{typ hashcode}
        val rec_name = nth rec_names i
        val rhs = list_comb (Const (rec_name, rec_ty), rec_args)
      in rhs end
  in mk_rec 0 end


fun derive dtyp_name _ thy =
  let
    val tyco = dtyp_name

    val base_name = Long_Name.base_name tyco
    val _ = writeln ("creating hashcode for datatype " ^ base_name)
    val sort = @{sort hashable}
    val info = BNF_LFP_Compat.the_info thy [] tyco
    val vs_of_sort =
      let val i = BNF_LFP_Compat.the_spec thy tyco |> #1
      in fn sort => map (fn (n,_) => (n, sort)) i end
    val vs = vs_of_sort sort
    val hash_rhs = mk_hash thy info
    val ty = Term.fastype_of hash_rhs |> Term.dest_Type |> snd |> hd
    val ty_it = Type (@{type_name itself}, [ty])
    val hashs_rhs = lambda (Free ("x",ty_it)) (HOLogic.mk_number @{typ nat} (create_def_size ty))

    val hash_def = mk_def (ty --> @{typ hashcode}) @{const_name hashcode} hash_rhs
    val hashs_def = mk_def (ty_it --> @{typ nat}) @{const_name def_hashmap_size} hashs_rhs

    val ((hash_thm , hashs_thm),lthy) = Class.instantiation ([tyco],vs,sort) thy
      |> define_overloaded ("hashcode_" ^ base_name ^ "_def", hash_def)
      ||>> define_overloaded ("def_hashmap_size_" ^ base_name ^ "_def", hashs_def)
    val hash_thms = [hash_thm, hashs_thm]

    fun hash_tac ctxt =
      my_print_tac ctxt "enter hash " THEN
      unfold_tac ctxt hash_thms THEN
      my_print_tac ctxt "after unfolding" THEN
      simp_tac ctxt 1
    val thy' = Class.prove_instantiation_exit (fn ctxt => Class.intro_classes_tac ctxt [] THEN hash_tac ctxt) lthy
    val _ = writeln ("registered " ^ base_name ^ " in class hashable")

  in thy' end

val _ =
  Theory.setup
    (Derive_Manager.register_derive "hashable" "derives a hash function for a datatype" derive)

end

Theory Derive

(*  Title:       Deriving class instances for datatypes
    Author:      René Thiemann       <rene.thiemann@uibk.ac.at>
    Maintainer:  René Thiemann
    License:     LGPL
*)

(*
Copyright 2013 René Thiemann

This file is part of IsaFoR/CeTA.

IsaFoR/CeTA is free software: you can redistribute it and/or modify it under the
terms of the GNU Lesser General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

IsaFoR/CeTA is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License along
with IsaFoR/CeTA. If not, see <http://www.gnu.org/licenses/>.
*)

section ‹Loading derive-commands›
theory Derive
imports 
  Order_Generator
  Hash_Generator
  Deriving.Countable_Generator
begin

text‹
We just load the commands to derive (linear) orders, hash-functions, and the
command to show that a datatype is countable, so that now all of them are available.
There are further generators available in the AFP entries of lightweight containers and Show.
›

print_derives

end

Theory Derive_Examples

(*  Title:       Deriving class instances for datatypes
    Author:      René Thiemann       <rene.thiemann@uibk.ac.at>
    Maintainer:  René Thiemann
    License:     LGPL
*)

(*
Copyright 2013 René Thiemann

This file is part of IsaFoR/CeTA.

IsaFoR/CeTA is free software: you can redistribute it and/or modify it under the
terms of the GNU Lesser General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

IsaFoR/CeTA is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License along
with IsaFoR/CeTA. If not, see <http://www.gnu.org/licenses/>.
*)

section Examples

theory Derive_Examples
imports 
  Derive
  HOL.Rat
begin

subsection "Register standard existing types"

derive linorder list sum prod

subsection "Without nested recursion"

datatype 'a bintree = BEmpty | BNode "'a bintree" 'a "'a bintree"

derive linorder bintree
derive hashable bintree
derive countable bintree

subsection "Using other datatypes"

datatype nat_list_list = NNil | CCons "nat list" nat_list_list

derive linorder nat_list_list
derive hashable nat_list_list
derive countable nat_list_list

subsection "Explicit mutual recursion"

datatype
  'a mtree = MEmpty | MNode 'a "'a mtree_list" and
  'a mtree_list = MNil | MCons "'a mtree" "'a mtree_list"

derive linorder mtree
derive hashable mtree
derive countable mtree

subsection "Implicit mutual recursion"

datatype 'a tree = Empty | Node 'a "'a tree list"

datatype_compat tree

derive linorder tree
derive hashable tree
derive countable tree

datatype 'a ttree = TEmpty | TNode 'a "'a ttree list tree"

datatype_compat ttree

derive linorder ttree
derive hashable ttree
derive countable ttree

subsection "Examples from IsaFoR"

datatype ('f,'v) "term" = Var 'v | Fun 'f "('f,'v) term list"

datatype_compat "term"

datatype ('f, 'l) lab =
  Lab "('f, 'l) lab" 'l
| FunLab "('f, 'l) lab" "('f, 'l) lab list"
| UnLab 'f
| Sharp "('f, 'l) lab"

datatype_compat lab

derive linorder "term" lab
derive countable "term" lab
derive hashable "term" lab

subsection "A complex datatype"
text ‹
The following datatype has nested indirect recursion, mutual recursion and
uses other datatypes.
›

datatype ('a, 'b) complex = 
  C1 nat "'a ttree" |
  C2 "('a, 'b) complex list tree tree" 'b "('a, 'b) complex" "('a, 'b) complex2 ttree list"
and ('a, 'b) complex2 = D1 "('a, 'b) complex ttree"

datatype_compat complex complex2

derive linorder complex
derive hashable complex
derive countable complex

end