Session Tycon

Theory TypeApp

section ‹Type Application›

theory TypeApp
imports HOLCF
begin

subsection ‹Class of type constructors›

text ‹In HOLCF, the type @{typ "udom defl"} consists of deflations
over the universal domain---each value of type @{typ "udom defl"}
represents a bifinite domain. In turn, values of the continuous
function type @{typ "udom defl  udom defl"} represent functions from
domains to domains, i.e.~type constructors.›

text ‹Class tycon›, defined below, will be populated with
dummy types: For example, if the type foo› is an instance of
class tycon›, then users will never deal with any values x::foo› in practice. Such types are only used with the overloaded
constant tc›, which associates each type 'a::tycon›
with a value of type @{typ "udom defl  udom defl"}. \medskip›

class tycon =
  fixes tc :: "('a::type) itself  udom defl  udom defl"

text ‹Type @{typ "'a itself"} is defined in Isabelle's meta-logic;
it is inhabited by a single value, written @{term "TYPE('a)"}. We
define the syntax TC('a)› to abbreviate tc
TYPE('a)›. \medskip›

syntax  "_TC" :: "type  logic"  ("(1TC/(1'(_')))")

translations "TC('a)"  "CONST tc TYPE('a)"


subsection ‹Type constructor for type application›

text ‹We now define a binary type constructor that models type
application: Type ('a, 't) app› is the result of applying the
type constructor 't› (from class tycon›) to the type
argument 'a› (from class domain›).›

text ‹We define type ('a, 't) app› using domaindef›,
a low-level type-definition command provided by HOLCF (similar to
typedef› in Isabelle/HOL) that defines a new domain type
represented by the given deflation. Note that in HOLCF, DEFL('a)› is an abbreviation for defl TYPE('a)›, where
defl :: ('a::domain) itself ⇒ udom defl› is an overloaded
function from the domain› type class that yields the deflation
representing the given type. \medskip›

domaindef ('a,'t) app = "TC('t::tycon)DEFL('a::domain)"

text ‹We define the infix syntax 'a⋅'t› for the type ('a,'t) app›. Note that for consistency with Isabelle's existing
type syntax, we have used postfix order for type application: type
argument on the left, type constructor on the right. \medskip›

type_notation app ("(__)" [999,1000] 999)

text ‹The domaindef› command generates the theorem DEFL_app›: @{thm DEFL_app [where 'a="'a::domain" and 't="'t::tycon"]},
which we can use to derive other useful lemmas. \medskip›

lemma TC_DEFL: "TC('t::tycon)DEFL('a) = DEFL('a't)"
by (rule DEFL_app [symmetric])

lemma DEFL_app_mono [simp, intro]:
  "DEFL('a)  DEFL('b)  DEFL('a't::tycon)  DEFL('b't)"
 apply (simp add: DEFL_app)
 apply (erule monofun_cfun_arg)
done

end

Theory Coerce

section ‹Coercion Operator›

theory Coerce
imports HOLCF
begin

subsection ‹Coerce›

text ‹The domain› type class, which is the default type class
in HOLCF, fixes two overloaded functions: emb::'a → udom› and
prj::udom → 'a›. By composing the prj› and emb›
functions together, we can coerce values between any two types in
class domain›. \medskip›

definition coerce :: "'a  'b"
  where "coerce  prj oo emb"

text ‹When working with proofs involving emb›, prj›,
and coerce›, it is often difficult to tell at which types those
constants are being used. To alleviate this problem, we define special
input and output syntax to indicate the types. \medskip›

syntax
  "_emb" :: "type  logic" ("(1EMB/(1'(_')))")
  "_prj" :: "type  logic" ("(1PRJ/(1'(_')))")
  "_coerce" :: "type  type  logic" ("(1COERCE/(1'(_,/ _')))")

translations
  "EMB('a)"  "CONST emb :: 'a  udom"
  "PRJ('a)"  "CONST prj :: udom  'a"
  "COERCE('a,'b)"  "CONST coerce :: 'a  'b"

typed_print_translation let
  fun emb_tr' (ctxt : Proof.context) (Type(_, [T, _])) [] =
    Syntax.const @{syntax_const "_emb"} $ Syntax_Phases.term_of_typ ctxt T
  fun prj_tr' ctxt (Type(_, [_, T])) [] =
    Syntax.const @{syntax_const "_prj"} $ Syntax_Phases.term_of_typ ctxt T
  fun coerce_tr' ctxt (Type(_, [T, U])) [] =
    Syntax.const @{syntax_const "_coerce"} $
      Syntax_Phases.term_of_typ ctxt T $ Syntax_Phases.term_of_typ ctxt U
in
  [(@{const_syntax emb}, emb_tr'),
   (@{const_syntax prj}, prj_tr'),
   (@{const_syntax coerce}, coerce_tr')]
end

lemma beta_coerce: "coercex = prj(embx)"
by (simp add: coerce_def)

lemma prj_emb: "prj(embx) = coercex"
by (simp add: coerce_def)

lemma coerce_strict [simp]: "coerce = "
by (simp add: coerce_def)

text ‹Certain type instances of coerce› may reduce to the
identity function, emb›, or prj›. \medskip›

lemma coerce_eq_ID [simp]: "COERCE('a, 'a) = ID"
by (rule cfun_eqI, simp add: beta_coerce)

lemma coerce_eq_emb [simp]: "COERCE('a, udom) = EMB('a)"
by (rule cfun_eqI, simp add: beta_coerce)

lemma coerce_eq_prj [simp]: "COERCE(udom, 'a) = PRJ('a)"
by (rule cfun_eqI, simp add: beta_coerce)

text "Cancellation rules"

lemma emb_coerce:
  "DEFL('a)  DEFL('b)
    EMB('b)(COERCE('a,'b)x) = EMB('a)x"
by (simp add: beta_coerce emb_prj_emb)

lemma coerce_prj:
  "DEFL('a)  DEFL('b)
    COERCE('b,'a)(PRJ('b)x) = PRJ('a)x"
by (simp add: beta_coerce prj_emb_prj)

lemma coerce_idem [simp]:
  "DEFL('a)  DEFL('b)
    COERCE('b,'c)(COERCE('a,'b)x) = COERCE('a,'c)x"
by (simp add: beta_coerce emb_prj_emb)

subsection ‹More lemmas about emb and prj›

lemma prj_cast_DEFL [simp]: "PRJ('a)(castDEFL('a)x) = PRJ('a)x"
by (simp add: cast_DEFL)

lemma cast_DEFL_emb [simp]: "castDEFL('a)(EMB('a)x) = EMB('a)x"
by (simp add: cast_DEFL)

text @{term "DEFL(udom)"}

lemma below_DEFL_udom [simp]: "A  DEFL(udom)"
apply (rule cast_below_imp_below)
apply (rule cast.belowI)
apply (simp add: cast_DEFL)
done

subsection ‹Coercing various datatypes›

text ‹Coercing from the strict product type @{typ "'a  'b"} to
another strict product type @{typ "'c  'd"} is equivalent to mapping
the coerce› function separately over each component using
sprod_map :: ('a → 'c) → ('b → 'd) → 'a ⊗ 'b → 'c ⊗ 'd›. Each
of the several type constructors defined in HOLCF satisfies a similar
property, with respect to its own map combinator. \medskip›

lemma coerce_u: "coerce = u_mapcoerce"
apply (rule cfun_eqI, simp add: coerce_def)
apply (simp add: emb_u_def prj_u_def liftemb_eq liftprj_eq)
apply (subst ep_pair.e_inverse [OF ep_pair_u])
apply (simp add: u_map_map cfcomp1)
done

lemma coerce_sfun: "coerce = sfun_mapcoercecoerce"
apply (rule cfun_eqI, simp add: coerce_def)
apply (simp add: emb_sfun_def prj_sfun_def)
apply (subst ep_pair.e_inverse [OF ep_pair_sfun])
apply (simp add: sfun_map_map cfcomp1)
done

lemma coerce_cfun': "coerce = cfun_mapcoercecoerce"
apply (rule cfun_eqI, simp add: prj_emb [symmetric])
apply (simp add: emb_cfun_def prj_cfun_def)
apply (simp add: prj_emb coerce_sfun coerce_u)
apply (simp add: encode_cfun_map [symmetric])
done

lemma coerce_ssum: "coerce = ssum_mapcoercecoerce"
apply (rule cfun_eqI, simp add: coerce_def)
apply (simp add: emb_ssum_def prj_ssum_def)
apply (subst ep_pair.e_inverse [OF ep_pair_ssum])
apply (simp add: ssum_map_map cfcomp1)
done

lemma coerce_sprod: "coerce = sprod_mapcoercecoerce"
apply (rule cfun_eqI, simp add: coerce_def)
apply (simp add: emb_sprod_def prj_sprod_def)
apply (subst ep_pair.e_inverse [OF ep_pair_sprod])
apply (simp add: sprod_map_map cfcomp1)
done

lemma coerce_prod: "coerce = prod_mapcoercecoerce"
apply (rule cfun_eqI, simp add: coerce_def)
apply (simp add: emb_prod_def prj_prod_def)
apply (subst ep_pair.e_inverse [OF ep_pair_prod])
apply (simp add: prod_map_map cfcomp1)
done

subsection ‹Simplifying coercions›

text ‹When simplifying applications of the coerce› function,
rewrite rules are always oriented to replace coerce› at complex
types with other applications of coerce› at simpler types.›

text ‹The safest rewrite rules for coerce› are given the
[simp]› attribute. For other rules that do not belong in the
global simpset, we use dynamic theorem list called coerce_simp›,
which will collect additional rules for simplifying coercions. \medskip›

named_theorems coerce_simp "rule for simplifying coercions"

text ‹The coerce› function commutes with data constructors
for various HOLCF datatypes. \medskip›

lemma coerce_up [simp]: "coerce(upx) = up(coercex)"
by (simp add: coerce_u)

lemma coerce_sinl [simp]: "coerce(sinlx) = sinl(coercex)"
by (simp add: coerce_ssum ssum_map_sinl')

lemma coerce_sinr [simp]: "coerce(sinrx) = sinr(coercex)"
by (simp add: coerce_ssum ssum_map_sinr')

lemma coerce_spair [simp]: "coerce(:x, y:) = (:coercex, coercey:)"
by (simp add: coerce_sprod sprod_map_spair')

lemma coerce_Pair [simp]: "coerce(x, y) = (coercex, coercey)"
by (simp add: coerce_prod)

lemma beta_coerce_cfun [simp]: "coercefx = coerce(f(coercex))"
by (simp add: coerce_cfun')

lemma coerce_cfun: "coercef = coerce oo f oo coerce"
by (simp add: cfun_eqI)

lemma coerce_cfun_app [coerce_simp]:
  "coercef = (Λ x. coerce(f(coercex)))"
by (simp add: cfun_eqI)

end

Theory Functor

section ‹Functor Class›

theory Functor
imports TypeApp Coerce
keywords "tycondef" :: thy_defn and "⋅"
begin

subsection ‹Class definition›

text ‹Here we define the functor› class, which models the
Haskell class \texttt{Functor}. For technical reasons, we split the
definition of functor› into two separate classes: First, we
introduce prefunctor›, which only requires fmap› to
preserve the identity function, and not function composition.›

text ‹The Haskell class \texttt{Functor f} fixes a polymorphic
function \texttt{fmap :: (a -> b) -> f a -> f b}. Since functions in
Isabelle type classes can only mention one type variable, we have the
prefunctor› class fix a function fmapU› that fixes both
of the polymorphic types to be the universal domain. We will use the
coercion operator to recover a polymorphic fmap›.›

text ‹The single axiom of the prefunctor› class is stated in
terms of the HOLCF constant isodefl›, which relates a function
f :: 'a → 'a› with a deflation t :: udom defl›:
@{thm isodefl_def [of f t, no_vars]}.›

class prefunctor = "tycon" +
  fixes fmapU :: "(udom  udom)  udom'a  udom'a::tycon"
  assumes isodefl_fmapU:
    "isodefl (fmapU(castt)) (TC('a::tycon)t)"

text ‹The functor› class extends prefunctor› with an
axiom stating that fmapU› preserves composition.›

class "functor" = prefunctor +
  assumes fmapU_fmapU [coerce_simp]:
    "f g (xs::udom'a::tycon).
      fmapUf(fmapUgxs) = fmapU(Λ x. f(gx))xs"

text ‹We define the polymorphic fmap› by coercion from fmapU›, then we proceed to derive the polymorphic versions of the
functor laws.›

definition fmap :: "('a  'b)  'a'f  'b'f::functor"
  where "fmap = coerce(fmapU :: _  udom'f  udom'f)"

subsection ‹Polymorphic functor laws›

lemma fmapU_eq_fmap: "fmapU = fmap"
by (simp add: fmap_def eta_cfun)

lemma fmap_eq_fmapU: "fmap = fmapU"
  by (simp only: fmapU_eq_fmap)

lemma cast_TC:
  "cast(TC('f)t) = emb oo fmapU(castt) oo PRJ(udom'f::prefunctor)"
by (rule isodefl_fmapU [unfolded isodefl_def])

lemma isodefl_cast: "isodefl (castt) t"
by (simp add: isodefl_def)

lemma cast_cast_below1: "A  B  castA(castBx) = castAx"
by (intro deflation_below_comp1 deflation_cast monofun_cfun_arg)

lemma cast_cast_below2: "A  B  castB(castAx) = castAx"
by (intro deflation_below_comp2 deflation_cast monofun_cfun_arg)

lemma isodefl_fmap:
  assumes "isodefl d t"
  shows "isodefl (fmapd :: 'a'f  _) (TC('f::functor)t)"
proof -
  have deflation_d: "deflation d"
    using assms by (rule isodefl_imp_deflation)
  have cast_t: "castt = emb oo d oo prj"
    using assms unfolding isodefl_def .
  have t_below: "t  DEFL('a)"
    apply (rule cast_below_imp_below)
    apply (simp only: cast_t cast_DEFL)
    apply (simp add: cfun_below_iff deflation.below [OF deflation_d])
    done
  have fmap_eq: "fmapd = PRJ('a'f) oo cast(TC('f)t) oo emb"
    by (simp add: fmap_def coerce_cfun cast_TC cast_t prj_emb cfcomp1)
  show ?thesis
    apply (simp add: fmap_eq isodefl_def cfun_eq_iff emb_prj)
    apply (simp add: DEFL_app)
    apply (simp add: cast_cast_below1 monofun_cfun t_below)
    apply (simp add: cast_cast_below2 monofun_cfun t_below)
    done
qed

lemma fmap_ID [simp]: "fmapID = ID"
apply (rule isodefl_DEFL_imp_ID)
apply (subst DEFL_app)
apply (rule isodefl_fmap)
apply (rule isodefl_ID_DEFL)
done

lemma fmap_ident [simp]: "fmap(Λ x. x) = ID"
by (simp add: ID_def [symmetric])

lemma coerce_coerce_eq_fmapU_cast [coerce_simp]:
  fixes xs :: "udom'f::functor"
  shows "COERCE('a'f, udom'f)(COERCE(udom'f, 'a'f)xs) =
    fmapU(castDEFL('a))xs"
by (simp add: coerce_def emb_prj DEFL_app cast_TC)

lemma fmap_fmap:
  fixes xs :: "'a'f::functor" and g :: "'a  'b" and f :: "'b  'c"
  shows "fmapf(fmapgxs) = fmap(Λ x. f(gx))xs"
unfolding fmap_def
by (simp add: coerce_simp)

lemma fmap_cfcomp: "fmap(f oo g) = fmapf oo fmapg"
by (simp add: cfcomp1 fmap_fmap eta_cfun)

subsection ‹Derived properties of fmap›

text ‹Other theorems about fmap› can be derived using only
the abstract functor laws.›

lemma deflation_fmap:
  "deflation d  deflation (fmapd)"
 apply (rule deflation.intro)
  apply (simp add: fmap_fmap deflation.idem eta_cfun)
 apply (subgoal_tac "fmapdx  fmapIDx", simp)
 apply (rule monofun_cfun_fun, rule monofun_cfun_arg)
 apply (erule deflation.below_ID)
done

lemma ep_pair_fmap:
  "ep_pair e p  ep_pair (fmape) (fmapp)"
 apply (rule ep_pair.intro)
  apply (simp add: fmap_fmap ep_pair.e_inverse)
 apply (simp add: fmap_fmap)
 apply (rule_tac y="fmapIDy" in below_trans)
  apply (rule monofun_cfun_fun)
  apply (rule monofun_cfun_arg)
  apply (rule cfun_belowI, simp)
  apply (erule ep_pair.e_p_below)
 apply simp
done

lemma fmap_strict:
  fixes f :: "'a  'b"
  assumes "f = " shows "fmapf = (::'b'f::functor)"
proof (rule bottomI)
  have "fmapf(::'a'f)  fmapf(fmap(::'b'f))"
    by (simp add: monofun_cfun)
  also have "... = fmap(Λ x. f(x))(::'b'f)"
    by (simp add: fmap_fmap)
  also have "...  fmapID"
    by (simp add: monofun_cfun assms del: fmap_ID)
  also have "... = "
    by simp
  finally show "fmapf  (::'b'f::functor)" .
qed

subsection ‹Proving that fmap⋅coerce = coerce›

lemma fmapU_cast_eq:
  "fmapU(castA) =
    PRJ(udom'f) oo cast(TC('f::functor)A) oo emb"
by (subst cast_TC, rule cfun_eqI, simp)

lemma fmapU_cast_DEFL:
  "fmapU(castDEFL('a)) =
    PRJ(udom'f) oo castDEFL('a'f::functor) oo emb"
by (simp add: fmapU_cast_eq DEFL_app)

lemma coerce_functor: "COERCE('a'f, 'b'f::functor) = fmapcoerce"
apply (rule cfun_eqI, rename_tac xs)
apply (simp add: fmap_def coerce_cfun)
apply (simp add: coerce_def)
apply (simp add: cfcomp1)
apply (simp only: emb_prj)
apply (subst fmapU_fmapU [symmetric])
apply (simp add: fmapU_cast_DEFL)
apply (simp add: emb_prj)
apply (simp add: cast_cast_below1 cast_cast_below2)
done

subsection ‹Lemmas for reasoning about coercion›

lemma fmapU_cast_coerce [coerce_simp]:
  fixes m :: "'a'f::functor"
  shows "fmapU(castDEFL('a))(COERCE('a'f, udom'f)m) =
    COERCE('a'f, udom'f)m"
by (simp add: coerce_functor cast_DEFL fmapU_eq_fmap fmap_fmap eta_cfun)

lemma coerce_fmap [coerce_simp]:
  fixes xs :: "'a'f::functor" and f :: "'a  'b"
  shows "COERCE('b'f, 'c'f)(fmapfxs) = fmap(Λ x. COERCE('b,'c)(fx))xs"
by (simp add: coerce_functor fmap_fmap)

lemma fmap_coerce [coerce_simp]:
  fixes xs :: "'a'f::functor" and f :: "'b  'c"
  shows "fmapf(COERCE('a'f, 'b'f)xs) = fmap(Λ x. f(COERCE('a,'b)x))xs"
by (simp add: coerce_functor fmap_fmap)

subsection ‹Configuration of Domain package›

text ‹We make various theorem declarations to enable Domain
  package definitions that involve tycon› application.›

setup Domain_Take_Proofs.add_rec_type (@{type_name app}, [true, false])

declare DEFL_app [domain_defl_simps]
declare fmap_ID [domain_map_ID]
declare deflation_fmap [domain_deflation]
declare isodefl_fmap [domain_isodefl]

subsection ‹Configuration of the Tycon package›

text ‹We now set up a new type definition command, which is used for
  defining new tycon› instances. The tycondef› command
  is implemented using much of the same code as the Domain package,
  and supports a similar input syntax. It automatically generates a
  prefunctor› instance for each new type. (The user must
  provide a proof of the composition law to obtain a functor›
  class instance.)›

ML_file ‹tycondef.ML›

end

File ‹tycondef.ML›

(* Version: Isabelle2012 *)

signature TYCON =
sig
  val add_tycon_cmd:
      (string * (string * string option) list * binding * mixfix *
       (binding * (bool * binding option * string) list * mixfix) list) list
      -> theory -> theory

  val add_tycon:
      (string * (string * sort) list * binding * mixfix *
       (binding * (bool * binding option * typ) list * mixfix) list) list
      -> theory -> theory
end

structure Tycon : TYCON =
struct

val TC_simp =
  @{lemma "f  g  f TYPE('a::{}) = g TYPE('a)" by simp}

fun mk_appT T U = Type (@{type_name app}, [T, U])

fun dest_appT (Type (@{type_name app}, [T, U])) = (T, U)
  | dest_appT T = raise TYPE ("dest_appT", [T], [])

open HOLCF_Library

fun first  (x,_,_) = x
fun second (_,x,_) = x

val beta_ss =
  simpset_of (put_simpset HOL_basic_ss @{context}
    addsimps @{thms simp_thms} addsimprocs [@{simproc beta_cfun_proc}])

fun is_cpo thy T = Sign.of_sort thy (T, @{sort cpo})

(******************************************************************************)
(****************************** defining tycons *******************************)
(******************************************************************************)

fun define_singleton_type
    (typ : binding * (string * sort) list * mixfix)
    (thy : theory) : (string * Typedef.info) * theory =
  let
    val set = @{term "UNIV :: unit set"}
    val opt_morphs = NONE
    fun tac ctxt = resolve_tac ctxt [UNIV_witness] 1
    val _ = writeln ("Defining type " ^ Binding.print (first typ))
  in
    thy
    |> Named_Target.theory_map_result (apsnd o Typedef.transform_info)
      (Typedef.add_typedef {overloaded = false} typ set opt_morphs tac)
  end    

(******************************************************************************)
(************************** building types and terms **************************)
(******************************************************************************)

infixr 6 ->>
infixr -->>

val udomT = @{typ udom}
val deflT = @{typ "udom defl"}
val udeflT = @{typ "udom u defl"}

fun mk_DEFL T =
  Const (@{const_name defl}, Term.itselfT T --> deflT) $ Logic.mk_type T

fun dest_DEFL (Const (@{const_name defl}, _) $ t) = Logic.dest_type t
  | dest_DEFL t = raise TERM ("dest_DEFL", [t])

fun mk_LIFTDEFL T =
  Const (@{const_name liftdefl}, Term.itselfT T --> udeflT) $ Logic.mk_type T

fun dest_LIFTDEFL (Const (@{const_name liftdefl}, _) $ t) = Logic.dest_type t
  | dest_LIFTDEFL t = raise TERM ("dest_LIFTDEFL", [t])

fun tc_const T =
  Const (@{const_name tc}, Term.itselfT T --> deflT ->> deflT)

fun mk_TC T = tc_const T $ Logic.mk_type T

fun argumentTs (Type (@{type_name app}, [T, Type (_, Ts)])) = Ts @ [T]
  | argumentTs (Type (_, Ts)) = Ts
  | argumentTs T = []

fun mk_u_defl t = mk_capply (@{const "u_defl"}, t)

fun emb_const T = Const (@{const_name emb}, T ->> udomT)
fun prj_const T = Const (@{const_name prj}, udomT ->> T)
fun coerce_const (T, U) = mk_cfcomp (prj_const U, emb_const T)

fun isodefl_const T =
  Const (@{const_name isodefl}, (T ->> T) --> deflT --> HOLogic.boolT)

fun isodefl'_const T =
  Const (@{const_name isodefl'}, (T ->> T) --> udeflT --> HOLogic.boolT)

fun mk_deflation t =
  Const (@{const_name deflation}, Term.fastype_of t --> boolT) $ t

(* splits a cterm into the right and lefthand sides of equality *)
fun dest_eqs t = HOLogic.dest_eq (HOLogic.dest_Trueprop t)

fun mk_eqs (t, u) = HOLogic.mk_Trueprop (HOLogic.mk_eq (t, u))

(******************************************************************************)
(****************************** isomorphism info ******************************)
(******************************************************************************)

fun deflation_abs_rep (info : Domain_Take_Proofs.iso_info) : thm =
  let
    val abs_iso = #abs_inverse info
    val rep_iso = #rep_inverse info
    val thm = @{thm deflation_abs_rep} OF [abs_iso, rep_iso]
  in
    Drule.zero_var_indexes thm
  end

(******************************************************************************)
(*************** fixed-point definitions and unfolding theorems ***************)
(******************************************************************************)

fun mk_projs []      _ = []
  | mk_projs (x::[]) t = [(x, t)]
  | mk_projs (x::xs) t = (x, mk_fst t) :: mk_projs xs (mk_snd t)

fun add_fixdefs
    (spec : (binding * term) list)
    (thy : theory) : (thm list * thm list * thm) * theory =
  let
    val binds = map fst spec
    val (lhss, rhss) = ListPair.unzip (map (dest_eqs o snd) spec)
    val functional = lambda_tuple lhss (mk_tuple rhss)
    val fixpoint = mk_fix (mk_cabs functional)

    (* project components of fixpoint *)
    val projs = mk_projs lhss fixpoint

    (* convert parameters to lambda abstractions *)
    fun mk_eqn (lhs, rhs) =
        case lhs of
          Const (@{const_name Rep_cfun}, _) $ f $ (x as Free _) =>
            mk_eqn (f, big_lambda x rhs)
        | f $ Const (@{const_name Pure.type}, T) =>
            mk_eqn (f, Abs ("t", T, rhs))
        | Const _ => Logic.mk_equals (lhs, rhs)
        | _ => raise TERM ("lhs not of correct form", [lhs, rhs])
    val eqns = map mk_eqn projs

    (* register constant definitions *)
    val (fixdef_thms, thy) =
      (Global_Theory.add_defs false o map Thm.no_attributes)
        (map (Binding.suffix_name "_def") binds ~~ eqns) thy

    (* prove applied version of definitions *)
    fun prove_proj (lhs, rhs) =
      let
        fun tac ctxt = rewrite_goals_tac ctxt fixdef_thms THEN
          simp_tac (put_simpset beta_ss ctxt) 1
        val goal = Logic.mk_equals (lhs, rhs)
      in Goal.prove_global thy [] [] goal (tac o #context) end
    val proj_thms = map prove_proj projs

    (* mk_tuple lhss == fixpoint *)
    fun pair_equalI (thm1, thm2) = @{thm Pair_equalI} OF [thm1, thm2]
    val tuple_fixdef_thm = foldr1 pair_equalI proj_thms

    val cont_thm =
      let
        val ctxt = Proof_Context.init_global thy
        val prop = mk_trp (mk_cont functional)
        val rules = Named_Theorems.get ctxt @{named_theorems cont2cont}
        val tac = REPEAT_ALL_NEW (match_tac ctxt (rev rules)) 1
      in
        Goal.prove_global thy [] [] prop (K tac)
      end

    val tuple_unfold_thm =
      (@{thm def_cont_fix_eq} OF [tuple_fixdef_thm, cont_thm])
      |> Local_Defs.unfold (Proof_Context.init_global thy) @{thms split_conv}

    fun mk_unfold_thms [] _ = []
      | mk_unfold_thms (n::[]) thm = [(n, thm)]
      | mk_unfold_thms (n::ns) thm = let
          val thmL = thm RS @{thm Pair_eqD1}
          val thmR = thm RS @{thm Pair_eqD2}
        in (n, thmL) :: mk_unfold_thms ns thmR end
    val unfold_binds = map (Binding.suffix_name "_unfold") binds

    (* register unfold theorems *)
    val (unfold_thms, thy) =
      (Global_Theory.add_thms o map (Thm.no_attributes o apsnd Drule.zero_var_indexes))
        (mk_unfold_thms unfold_binds tuple_unfold_thm) thy
  in
    ((proj_thms, unfold_thms, cont_thm), thy)
  end

(******************************************************************************)
(****************** deflation combinators and map functions *******************)
(******************************************************************************)

fun defl_of_typ
    (thy : theory)
    (rules' : (term * term) list)
    (T : typ) : term =
  let
    val defl_simps = Global_Theory.get_thms thy "domain_defl_simps"
    val rules = map (Thm.concl_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq) defl_simps
    fun proc1 t =
      (case dest_DEFL t of
        TFree (a, _) => SOME (Free ("d" ^ Library.unprefix "'" a, deflT))
      | _ => NONE) handle TERM _ => NONE
    fun proc2 t =
      (case dest_LIFTDEFL t of
        TFree (a, _) => SOME (Free ("p" ^ Library.unprefix "'" a, udeflT))
      | _ => NONE) handle TERM _ => NONE
  in
    Pattern.rewrite_term thy (rules @ rules') [proc1, proc2] (mk_DEFL T)
  end

(******************************************************************************)
(********************* declaring definitions and theorems *********************)
(******************************************************************************)

fun define_const
    (bind : binding, rhs : term)
    (thy : theory)
    : (term * thm) * theory =
  let
    val typ = Term.fastype_of rhs
    val (const, thy) = Sign.declare_const_global ((bind, typ), NoSyn) thy
    val eqn = Logic.mk_equals (const, rhs)
    val def = Thm.no_attributes (Binding.suffix_name "_def" bind, eqn)
    val (def_thm, thy) = yield_singleton (Global_Theory.add_defs false) def thy
  in
    ((const, def_thm), thy)
  end

fun add_qualified_thm name (dbind, thm) =
    yield_singleton Global_Theory.add_thms
      ((Binding.qualify_name true dbind name, thm), [])

(******************************************************************************)
(*************************** defining map functions ***************************)
(******************************************************************************)

fun define_map_functions
    (spec : (binding * Domain_Take_Proofs.iso_info) list)
    (thy : theory) =
  let

    (* retrieve components of spec *)
    val dbinds = map fst spec
    val iso_infos = map snd spec
    val dom_eqns = map (fn x => (#absT x, #repT x)) iso_infos
    val rep_abs_consts = map (fn x => (#rep_const x, #abs_const x)) iso_infos

    fun mapT T =
      map (fn T => T ->> T) (filter (is_cpo thy) (argumentTs T)) -->> (T ->> T)

    (* declare map functions *)
    fun declare_map_const (tbind, (lhsT, _)) thy =
      let
        val map_type = mapT lhsT
        val map_bind = Binding.suffix_name "_map" tbind
      in
        Sign.declare_const_global ((map_bind, map_type), NoSyn) thy
      end
    val (map_consts, thy) = thy |>
      fold_map declare_map_const (dbinds ~~ dom_eqns)

    (* defining equations for map functions *)
    local
      fun unprime a = Library.unprefix "'" a
      fun mapvar T = Free (unprime (fst (dest_TFree T)), T ->> T)
      fun map_lhs (map_const, lhsT) =
        (lhsT, list_ccomb (map_const, map mapvar (filter (is_cpo thy) (argumentTs lhsT))))
      val tab1 = map map_lhs (map_consts ~~ map fst dom_eqns)
      val Ts = (argumentTs o fst o hd) dom_eqns
      val tab = (Ts ~~ map mapvar Ts) @ tab1
      fun mk_map_spec (((rep_const, abs_const), _), (lhsT, rhsT)) =
        let
          val lhs = Domain_Take_Proofs.map_of_typ thy tab lhsT
          val body = Domain_Take_Proofs.map_of_typ thy tab rhsT
          val rhs = mk_cfcomp (abs_const, mk_cfcomp (body, rep_const))
        in mk_eqs (lhs, rhs) end
    in
      val map_specs =
          map mk_map_spec (rep_abs_consts ~~ map_consts ~~ dom_eqns)
    end

    (* register recursive definition of map functions *)
    val map_binds = map (Binding.suffix_name "_map") dbinds
    val ((map_apply_thms, map_unfold_thms, map_cont_thm), thy) =
      add_fixdefs (map_binds ~~ map_specs) thy

    (* prove deflation theorems for map functions *)
    val deflation_abs_rep_thms = map deflation_abs_rep iso_infos
    val deflation_map_thm =
      let
        fun unprime a = Library.unprefix "'" a
        fun mk_f T = Free (unprime (fst (dest_TFree T)), T ->> T)
        fun mk_assm T = mk_trp (mk_deflation (mk_f T))
        fun mk_goal (map_const, (lhsT, _)) =
          let
            val Ts = argumentTs lhsT
            val map_term = list_ccomb (map_const, map mk_f (filter (is_cpo thy) Ts))
          in mk_deflation map_term end
        val assms = (map mk_assm o filter (is_cpo thy) o argumentTs o fst o hd) dom_eqns
        val goals = map mk_goal (map_consts ~~ dom_eqns)
        val goal = mk_trp (foldr1 HOLogic.mk_conj goals)
        val adm_rules =
          @{thms adm_conj adm_subst [OF _ adm_deflation]
                 cont2cont_fst cont2cont_snd cont_id}
        val bottom_rules =
          @{thms fst_strict snd_strict deflation_bottom simp_thms}
        val tuple_rules =
          @{thms split_def fst_conv snd_conv}
        val deflation_rules =
          @{thms conjI deflation_ID}
          @ deflation_abs_rep_thms
          @ Domain_Take_Proofs.get_deflation_thms thy
      in
        Goal.prove_global thy [] assms goal (fn {prems, context = ctxt, ...} =>
         EVERY
          [rewrite_goals_tac ctxt map_apply_thms,
           resolve_tac ctxt [map_cont_thm RS @{thm cont_fix_ind}] 1,
           REPEAT (resolve_tac ctxt adm_rules 1),
           simp_tac (put_simpset HOL_basic_ss ctxt addsimps bottom_rules) 1,
           simp_tac (put_simpset HOL_basic_ss ctxt addsimps tuple_rules) 1,
           REPEAT (eresolve_tac ctxt @{thms conjE} 1),
           REPEAT (resolve_tac ctxt (deflation_rules @ prems) 1 ORELSE assume_tac ctxt 1)])
      end
    fun conjuncts [] _ = []
      | conjuncts (n::[]) thm = [(n, thm)]
      | conjuncts (n::ns) thm = let
          val thmL = thm RS @{thm conjunct1}
          val thmR = thm RS @{thm conjunct2}
        in (n, thmL):: conjuncts ns thmR end
    val deflation_map_binds = dbinds |>
        map (Binding.prefix_name "deflation_" o Binding.suffix_name "_map")
    val (deflation_map_thms, thy) = thy |>
      (Global_Theory.add_thms o map (Thm.no_attributes o apsnd Drule.zero_var_indexes))
        (conjuncts deflation_map_binds deflation_map_thm)

(*
    (* register indirect recursion in theory data *)
    local
      fun register_map (dname, args) =
        Domain_Take_Proofs.add_rec_type (dname, args)
      val dnames = map (fst o dest_Type o fst) dom_eqns
      fun args (T, _) = case T of Type (_, Ts) => map (is_cpo thy) Ts | _ => []
      val argss = map args dom_eqns
    in
      val thy =
          fold register_map (dnames ~~ argss) thy
    end
*)

    (* register deflation theorems *)
    val thy = fold Domain_Take_Proofs.add_deflation_thm deflation_map_thms thy

    val result =
      {
        map_consts = map_consts,
        map_apply_thms = map_apply_thms,
        map_unfold_thms = map_unfold_thms,
        map_cont_thm = map_cont_thm,
        deflation_map_thms = deflation_map_thms
      }
  in
    (result, thy)
  end

(******************************************************************************)
(******************************* main function ********************************)
(******************************************************************************)

fun domain_isomorphism
    (param : string)
    (doms : ((string * sort) list * binding * mixfix * typ *
             (binding * binding) option) list)
    (thy: theory)
    : (Domain_Take_Proofs.iso_info list
       * Domain_Take_Proofs.take_induct_info) * theory =
  let
    (* this theory is used just for parsing *)
    val tmp_thy = thy |>
      Sign.add_types_global (map (fn (tvs, tbind, mx, _, _) =>
        (tbind, length tvs, mx)) doms)

    (* declare arities in temporary theory *)
    val tmp_thy =
      let
        fun arity (vs, tbind, _, _, _) =
          (Sign.full_name thy tbind, map snd vs, @{sort "tycon"})
      in
        fold Axclass.arity_axiomatization (map arity doms) tmp_thy
      end

    (* check bifiniteness of right-hand sides *)
    fun check_rhs (_, _, _, rhs, _) =
      if Sign.of_sort tmp_thy (rhs, @{sort "domain"}) then ()
      else error ("Type not of sort domain: " ^
        quote (Syntax.string_of_typ_global tmp_thy rhs))
    val _ = map check_rhs doms

    (* domain equations *)
    val paramT = TFree (param, @{sort domain})
    fun mk_tc_eqn (vs, tbind, _, rhs, _) =
      (Type (Sign.full_name tmp_thy tbind, map TFree vs), rhs)
    val tc_eqns = map mk_tc_eqn doms
    fun mk_dom_eqn (vs, tbind, _, rhs, _) =
      (mk_appT paramT (Type (Sign.full_name tmp_thy tbind, map TFree vs)), rhs)
    val dom_eqns = map mk_dom_eqn doms

    (* check for valid type parameters *)
    val (tyvars, _, _, _, _) = hd doms
    val _ = map (fn (tvs, tname, _, _, _) =>
      let val full_tname = Sign.full_name tmp_thy tname
      in
        (case duplicates (op =) (map fst tvs) of
          [] =>
            if eq_set (op =) (tyvars, tvs) then (full_tname, tvs)
            else error ("Mutually recursive domains must have same type parameters")
        | dups => error ("Duplicate parameter(s) for domain " ^ Binding.print tname ^
            " : " ^ commas dups))
      end) doms
    val dbinds = map (fn (_, dbind, _, _, _) => dbind) doms
    val morphs = map (fn (_, _, _, _, morphs) => morphs) doms

    (* determine deflation combinator arguments *)
    val tcs : typ list = map fst tc_eqns
    val param_defl = Free ("d" ^ Library.unprefix "'" param, deflT)
    val lhsTs : typ list = map fst dom_eqns
    val defl_rec = Free ("t", mk_tupleT (map (K deflT) lhsTs))
    val defl_rews = mk_projs (map (fn T => mk_capply (mk_TC T, param_defl)) tcs) defl_rec

    fun defl_body (_, _, _, rhsT, _) = defl_of_typ tmp_thy defl_rews rhsT
    val functional = Term.lambda defl_rec (mk_tuple (map defl_body doms))

    val tfrees = Term.add_tfrees functional []
    val frees = map fst (Term.add_frees functional [])
    fun get_defl_flags (vs, _, _, _, _) =
      let
        fun argT v = TFree v
        fun mk_d v = "d" ^ Library.unprefix "'" (fst v)
        fun mk_p v = "p" ^ Library.unprefix "'" (fst v)
        val args = maps (fn v => [(mk_d v, mk_DEFL (argT v)), (mk_p v, mk_LIFTDEFL (argT v))]) vs
        val typeTs = map argT (filter (member (op =) tfrees) vs)
        val defl_args = map snd (filter (member (op =) frees o fst) args)
      in
        (typeTs, defl_args)
      end
    val defl_flagss = map get_defl_flags doms

    (* declare deflation combinator constants *)
    fun declare_defl_const ((typeTs, defl_args), (_, tbind, _, _, _)) thy =
      let
        val defl_bind = Binding.suffix_name "_defl" tbind
        val defl_type =
          map Term.itselfT typeTs ---> map fastype_of defl_args -->> deflT ->> deflT
      in
        Sign.declare_const_global ((defl_bind, defl_type), NoSyn) thy
      end
    val (defl_consts, thy) =
      fold_map declare_defl_const (defl_flagss ~~ doms) thy
    val (_, tmp_thy) =
      fold_map declare_defl_const (defl_flagss ~~ doms) tmp_thy

    (* defining equations for type combinators *)
    fun mk_defl_term (defl_const, (typeTs, defl_args)) =
      let
        val type_args = map Logic.mk_type typeTs
      in
        list_ccomb (list_comb (defl_const, type_args), defl_args @ [param_defl])
      end
    val defl_terms = map mk_defl_term (defl_consts ~~ defl_flagss)
    val defl_rews = map fst defl_rews ~~ defl_terms
    fun mk_defl_spec (lhsT, rhsT) =
      mk_eqs (defl_of_typ tmp_thy defl_rews lhsT,
              defl_of_typ tmp_thy defl_rews rhsT)
    val defl_specs = map mk_defl_spec dom_eqns

    (* register recursive definition of deflation combinators *)
    val defl_binds = map (Binding.suffix_name "_defl") dbinds
    val ((defl_apply_thms, defl_unfold_thms, defl_cont_thm), thy) =
      add_fixdefs (defl_binds ~~ defl_specs) thy

    (* define tycons using deflation combinators *)
    fun make_tycondef ((vs, tbind, mx, _, _), tc) thy =
      let
        val spec = (tbind, vs, mx)
        val ((full_tname, info), thy) = define_singleton_type spec thy
        val newT = #abs_type (#1 info)
        val lhs_tfrees = map dest_TFree (snd (dest_Type newT))
        val tc_eqn = Logic.mk_equals (tc_const newT, Abs ("", Term.itselfT newT, tc))
        val ((_, (_, tc_ldef)), lthy) = thy
          |> Class.instantiation ([full_tname], lhs_tfrees, @{sort tycon})
          |> Specification.definition NONE [] []
              ((Binding.prefix_name "tc_" (Binding.suffix_name "_def" tbind), []), tc_eqn)
        val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy)
        val tc_def = singleton (Proof_Context.export lthy ctxt_thy) tc_ldef
        val thy = lthy
          |> Class.prove_instantiation_exit (fn ctxt => Class.intro_classes_tac ctxt [])
        (* declare domain_defl_simps *)
        val attr = @{attributes [domain_defl_simps]}
          |> map (Attrib.attribute_global thy)
        val tc_thm = Drule.zero_var_indexes (tc_def RS TC_simp)
        val tc_thm_bind = Binding.prefix_name "TC_" tbind
        val (_, thy) = thy |> Global_Theory.add_thm ((tc_thm_bind, tc_thm), attr)
      in
        (tc_def, thy)
      end
    fun mk_defl_term (defl_const, (typeTs, defl_args)) =
      let
        val type_args = map Logic.mk_type typeTs
      in
        list_ccomb (list_comb (defl_const, type_args), defl_args)
      end
    val defl_terms = map mk_defl_term (defl_consts ~~ defl_flagss)
    val (tc_defs, thy) = fold_map make_tycondef (doms ~~ defl_terms) thy

    (* prove DEFL equations *)
    fun mk_DEFL_eq_thm (lhsT, rhsT) =
      let
        val goal = mk_eqs (mk_DEFL lhsT, mk_DEFL rhsT)
        val DEFL_simps = Global_Theory.get_thms thy "domain_defl_simps"
        fun tac ctxt =
          rewrite_goals_tac ctxt (map mk_meta_eq DEFL_simps @ tc_defs)
          THEN TRY (resolve_tac ctxt defl_unfold_thms 1)
      in
        Goal.prove_global thy [] [] goal (tac o #context)
      end
    val DEFL_eq_thms = map mk_DEFL_eq_thm dom_eqns

    (* register DEFL equations *)
    val DEFL_eq_binds = map (Binding.prefix_name "DEFL_eq_") dbinds
    val (_, thy) = thy |>
      (Global_Theory.add_thms o map Thm.no_attributes)
        (DEFL_eq_binds ~~ DEFL_eq_thms)

    (* define rep/abs functions *)
    fun mk_rep_abs ((tbind, _), (lhsT, rhsT)) thy =
      let
        val rep_bind = Binding.suffix_name "_rep" tbind
        val abs_bind = Binding.suffix_name "_abs" tbind
        val ((rep_const, rep_def), thy) =
            define_const (rep_bind, coerce_const (lhsT, rhsT)) thy
        val ((abs_const, abs_def), thy) =
            define_const (abs_bind, coerce_const (rhsT, lhsT)) thy
      in
        (((rep_const, abs_const), (rep_def, abs_def)), thy)
      end
    val ((rep_abs_consts, rep_abs_defs), thy) = thy
      |> fold_map mk_rep_abs (dbinds ~~ morphs ~~ dom_eqns)
      |>> ListPair.unzip

    (* prove isomorphism and isodefl rules *)
    fun mk_iso_thms ((tbind, DEFL_eq), (rep_def, abs_def)) thy =
      let
        fun make thm =
            Drule.zero_var_indexes (thm OF [DEFL_eq, abs_def, rep_def])
        val rep_iso_thm = make @{thm domain_rep_iso}
        val abs_iso_thm = make @{thm domain_abs_iso}
        val isodefl_thm = make @{thm isodefl_abs_rep}
        val thy = thy
          |> snd o add_qualified_thm "rep_iso" (tbind, rep_iso_thm)
          |> snd o add_qualified_thm "abs_iso" (tbind, abs_iso_thm)
          |> snd o add_qualified_thm "isodefl_abs_rep" (tbind, isodefl_thm)
      in
        (((rep_iso_thm, abs_iso_thm), isodefl_thm), thy)
      end
    val ((iso_thms, isodefl_abs_rep_thms), thy) =
      thy
      |> fold_map mk_iso_thms (dbinds ~~ DEFL_eq_thms ~~ rep_abs_defs)
      |>> ListPair.unzip

    (* collect info about rep/abs *)
    val iso_infos : Domain_Take_Proofs.iso_info list =
      let
        fun mk_info (((lhsT, rhsT), (repC, absC)), (rep_iso, abs_iso)) =
          {
            repT = rhsT,
            absT = lhsT,
            rep_const = repC,
            abs_const = absC,
            rep_inverse = rep_iso,
            abs_inverse = abs_iso
          }
      in
        map mk_info (dom_eqns ~~ rep_abs_consts ~~ iso_thms)
      end

    (* definitions and proofs related to map functions *)
    val (map_info, thy) =
        define_map_functions (dbinds ~~ iso_infos) thy
    val { map_consts, map_apply_thms, map_cont_thm, ...} = map_info

    (* prove isodefl rules for map functions *)
    val isodefl_thm =
      let
        fun unprime a = Library.unprefix "'" a
        fun mk_d T = Free ("d" ^ unprime (fst (dest_TFree T)), deflT)
        fun mk_p T = Free ("p" ^ unprime (fst (dest_TFree T)), udeflT)
        fun mk_f T = Free ("f" ^ unprime (fst (dest_TFree T)), T ->> T)
        fun mk_assm t =
          case try dest_LIFTDEFL t of
            SOME T => mk_trp (isodefl'_const T $ mk_f T $ mk_p T)
          | NONE =>
            let val T = dest_DEFL t
            in mk_trp (isodefl_const T $ mk_f T $ mk_d T) end
        fun mk_goal (map_const, ((T, _), defl_term)) =
          let
            val Ts = argumentTs T
            val map_term = list_ccomb (map_const, map mk_f (filter (is_cpo thy) Ts))
            val rews1 = map mk_DEFL Ts ~~ map mk_d Ts
            val rews2 = map mk_LIFTDEFL Ts ~~ map mk_p Ts
            val rews3 = [((mk_TC o snd o dest_appT) T, defl_term)]
            val rews = rews1 @ rews2 @ rews3
            val defl_term = defl_of_typ thy rews T
          in isodefl_const T $ map_term $ defl_term end
        val assms = (map mk_assm o snd o hd) defl_flagss
          @ [mk_trp (isodefl_const paramT $ mk_f paramT $ mk_d paramT)]
        val goals = map mk_goal (map_consts ~~ (dom_eqns ~~ defl_terms))
        val goal = mk_trp (foldr1 HOLogic.mk_conj goals)
        val adm_rules =
          @{thms adm_conj adm_isodefl cont2cont_fst cont2cont_snd cont_id}
        val bottom_rules =
          @{thms fst_strict snd_strict isodefl_bottom simp_thms}
        val tuple_rules =
          @{thms split_def fst_conv snd_conv}
        val map_ID_thms = Domain_Take_Proofs.get_map_ID_thms thy
        val map_ID_simps = map (fn th => th RS sym) map_ID_thms
        val isodefl_rules =
          @{thms conjI isodefl_ID_DEFL isodefl_LIFTDEFL}
          @ isodefl_abs_rep_thms
          @ Global_Theory.get_thms thy "domain_isodefl"
      in
        Goal.prove_global thy [] assms goal (fn {prems, context = ctxt, ...} =>
         EVERY
          [rewrite_goals_tac ctxt (defl_apply_thms @ map_apply_thms),
           resolve_tac ctxt [@{thm cont_parallel_fix_ind} OF [defl_cont_thm, map_cont_thm]] 1,
           REPEAT (resolve_tac ctxt adm_rules 1),
           simp_tac (put_simpset HOL_basic_ss ctxt addsimps bottom_rules) 1,
           simp_tac (put_simpset HOL_basic_ss ctxt addsimps tuple_rules) 1,
           simp_tac (put_simpset HOL_basic_ss ctxt addsimps map_ID_simps) 1,
           REPEAT (eresolve_tac ctxt @{thms conjE} 1),
           REPEAT (resolve_tac ctxt (isodefl_rules @ prems) 1 ORELSE assume_tac ctxt 1)])
      end
    val isodefl_binds = map (Binding.prefix_name "isodefl_") dbinds
    val isodefl_attr = @{attributes [domain_isodefl]}
      |> map (Attrib.attribute_global thy)
    fun conjuncts [] _ = []
      | conjuncts (n::[]) thm = [(n, thm)]
      | conjuncts (n::ns) thm = let
          val thmL = thm RS @{thm conjunct1}
          val thmR = thm RS @{thm conjunct2}
        in (n, thmL):: conjuncts ns thmR end
    val (isodefl_thms, thy) = thy |>
      (Global_Theory.add_thms o map (rpair isodefl_attr o apsnd Drule.zero_var_indexes))
        (conjuncts isodefl_binds isodefl_thm)

    (* prove map_ID theorems *)
    fun prove_map_ID_thm
        (((map_const, (lhsT, _)), DEFL_thm), isodefl_thm) =
      let
        val Ts = argumentTs lhsT
        fun is_cpo T = Sign.of_sort thy (T, @{sort cpo})
        val lhs = list_ccomb (map_const, map mk_ID (filter is_cpo Ts))
        val goal = mk_eqs (lhs, mk_ID lhsT)
        fun tac ctxt = EVERY
          [resolve_tac ctxt @{thms isodefl_DEFL_imp_ID} 1,
           stac ctxt @{thm DEFL_app} 1,
           stac ctxt DEFL_thm 1,
           resolve_tac ctxt [isodefl_thm] 1,
           REPEAT (resolve_tac ctxt @{thms isodefl_ID_DEFL isodefl_LIFTDEFL} 1)]
      in
        Goal.prove_global thy [] [] goal (tac o #context)
      end
    val map_ID_binds = map (Binding.suffix_name "_map_ID") dbinds
    val map_ID_thms =
      map prove_map_ID_thm
        (map_consts ~~ dom_eqns ~~ tc_defs ~~ isodefl_thms)
    val (_, thy) = thy |>
      (Global_Theory.add_thms o map (rpair [Domain_Take_Proofs.map_ID_add]))
        (map_ID_binds ~~ map_ID_thms)

    (* definitions and proofs related to take functions *)
    val (take_info, thy) =
        Domain_Take_Proofs.define_take_functions
          (dbinds ~~ iso_infos) thy
    val { take_consts, chain_take_thms, take_0_thms, take_Suc_thms, ...} =
        take_info

    (* least-upper-bound lemma for take functions *)
    val lub_take_lemma =
      let
        val lhs = mk_tuple (map mk_lub take_consts)
        fun is_cpo T = Sign.of_sort thy (T, @{sort cpo})
        fun mk_map_ID (map_const, (lhsT, _)) =
          list_ccomb (map_const, map mk_ID (filter is_cpo (argumentTs lhsT)))
        val rhs = mk_tuple (map mk_map_ID (map_consts ~~ dom_eqns))
        val goal = mk_trp (mk_eq (lhs, rhs))
        val map_ID_thms = Domain_Take_Proofs.get_map_ID_thms thy
        val start_rules =
            @{thms lub_Pair [symmetric] ch2ch_Pair} @ chain_take_thms
            @ @{thms prod.collapse split_def}
            @ map_apply_thms @ map_ID_thms
        val rules0 =
            @{thms iterate_0 Pair_strict} @ take_0_thms
        val rules1 =
            @{thms iterate_Suc prod_eq_iff fst_conv snd_conv}
            @ take_Suc_thms
        fun tac ctxt =
            EVERY
            [simp_tac (put_simpset HOL_basic_ss ctxt addsimps start_rules) 1,
             simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms fix_def2}) 1,
             resolve_tac ctxt @{thms lub_eq} 1,
             resolve_tac ctxt @{thms nat.induct} 1,
             simp_tac (put_simpset HOL_basic_ss ctxt addsimps rules0) 1,
             asm_full_simp_tac (put_simpset beta_ss ctxt addsimps rules1) 1]
      in
        Goal.prove_global thy [] [] goal (tac o #context)
      end

    (* prove lub of take equals ID *)
    fun prove_lub_take (((dbind, take_const), map_ID_thm), (lhsT, _)) thy =
      let
        val n = Free ("n", natT)
        val goal = mk_eqs (mk_lub (lambda n (take_const $ n)), mk_ID lhsT)
        fun tac ctxt =
            EVERY
            [resolve_tac ctxt @{thms trans} 1,
             resolve_tac ctxt [map_ID_thm] 2,
             cut_facts_tac [lub_take_lemma] 1,
             REPEAT (eresolve_tac ctxt @{thms Pair_inject} 1), assume_tac ctxt 1]
        val lub_take_thm = Goal.prove_global thy [] [] goal (tac o #context)
      in
        add_qualified_thm "lub_take" (dbind, lub_take_thm) thy
      end
    val (lub_take_thms, thy) =
        fold_map prove_lub_take
          (dbinds ~~ take_consts ~~ map_ID_thms ~~ dom_eqns) thy

    (* prove additional take theorems *)
    val (take_info2, thy) =
        Domain_Take_Proofs.add_lub_take_theorems
          (dbinds ~~ iso_infos) take_info lub_take_thms thy

    fun fmapU_const T =
      let val U = mk_appT udomT T
      in Const (@{const_name fmapU}, (udomT ->> udomT) ->> (U ->> U)) end

    (* instantiate prefunctor class *)
    fun inst_prefunctor (map_const, ((lhsT, _), tbind)) thy =
      let
        val (_, tyconT) = dest_appT lhsT
        val (full_tname, Ts) = dest_Type tyconT
        val lhs_tfrees = map dest_TFree Ts
        val argTs = filter (is_cpo thy) Ts
        val U = mk_appT udomT tyconT
        val mapT = map (fn T => T ->> T) argTs -->> (udomT ->> udomT) ->> (U ->> U)
        val mapC = Const (fst (dest_Const map_const), mapT)
        val fmap_rhs = list_ccomb (mapC, map mk_ID argTs)
        val fmap_eqn = Logic.mk_equals (fmapU_const tyconT, fmap_rhs)
        val fmap_def_bind = tbind
          |> Binding.suffix_name "_def"
          |> Binding.prefix_name "fmapU_"
        val ((_, (_, fmap_ldef)), lthy) = thy
          |> Class.instantiation ([full_tname], lhs_tfrees, @{sort prefunctor})
          |> Specification.definition NONE [] [] ((fmap_def_bind, []), fmap_eqn)
        val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy)
        val fmap_def = singleton (Proof_Context.export lthy ctxt_thy) fmap_ldef
        fun tacf ctxt = EVERY
          [Class.intro_classes_tac ctxt [],
           rewrite_goals_tac ctxt (fmap_def :: tc_defs),
           resolve_tac ctxt isodefl_thms 1,
           REPEAT (resolve_tac ctxt @{thms isodefl_ID_DEFL isodefl_LIFTDEFL isodefl_cast} 1)]
        val thy = lthy
          |> Class.prove_instantiation_exit tacf
      in
        (fmap_def, thy)
      end
    val (fmap_defs, thy) = fold_map inst_prefunctor
      (map_consts ~~ (dom_eqns ~~ dbinds)) thy
  in
    ((iso_infos, take_info2), thy)
  end

(******************************************************************************)
(****************************** top-level command *****************************)
(******************************************************************************)

(* ----- calls for building new thy and thms -------------------------------- *)

type info =
     Domain_Take_Proofs.iso_info list * Domain_Take_Proofs.take_induct_info

fun add_arity ((b, sorts, mx), sort) thy : theory =
  thy
  |> Sign.add_types_global [(b, length sorts, mx)]
  |> Axclass.arity_axiomatization (Sign.full_name thy b, sorts, sort)

fun gen_add_tycon
    (prep_sort : theory -> 'a -> sort)
    (prep_typ : theory -> (string * sort) list -> 'b -> typ)
    (arg_sort : bool -> sort)
    (raw_specs : (string * (string * 'a) list * binding * mixfix *
               (binding * (bool * binding option * 'b) list * mixfix) list) list)
    (thy : theory) =
  let
    val dtnvs0 : (binding * (string * sort) list * mixfix) list =
        map (fn (a, vs, dbind, mx, _) =>
            (dbind, map (apsnd (prep_sort thy)) vs, mx)) raw_specs

    val dtnvs : (binding * typ list * mixfix) list =
      let
        fun prep_tvar (a, s) = TFree (a, prep_sort thy s)
      in
        map (fn (a, vs, dbind, mx, _) =>
                (dbind, map prep_tvar vs, mx)) raw_specs
      end

    fun thy_arity (dbind, tvars, mx) =
      ((dbind, map (snd o dest_TFree) tvars, mx), @{sort tycon})

    (* this theory is used just for parsing and error checking *)
    val tmp_thy = thy
      |> fold (add_arity o thy_arity) dtnvs

    val dbinds : binding list =
        map (fn (_,_,dbind,_,_) => dbind) raw_specs
    val raw_rhss :
        (binding * (bool * binding option * 'b) list * mixfix) list list =
        map (fn (_,_,_,_,cons) => cons) raw_specs
    val dtnvs' : (string * typ list) list =
        map (fn (dbind, vs, _) => (Sign.full_name thy dbind, vs)) dtnvs

    val all_cons = map (Binding.name_of o first) (flat raw_rhss)
    val _ =
      case duplicates (op =) all_cons of 
        [] => false | dups => error ("Duplicate constructors: " 
                                      ^ commas_quote dups)
    val all_sels =
      (map Binding.name_of o map_filter second o maps second) (flat raw_rhss)
    val _ =
      case duplicates (op =) all_sels of
        [] => false | dups => error("Duplicate selectors: "^commas_quote dups)

    fun test_dupl_tvars s =
      case duplicates (op =) (map(fst o dest_TFree)s) of
        [] => false | dups => error("Duplicate type arguments: " 
                                    ^commas_quote dups)
    val _ = exists test_dupl_tvars (map snd dtnvs')

    val param : string * sort =
      let val all_params = map #1 raw_specs
      in
        case distinct (op =) all_params of
          [param] => (param, @{sort domain})
        | _ => error "Mutually recursive domains must have same type parameter"
      end

    val sorts : (string * sort) list =
      let val all_sorts = map (map dest_TFree o snd) dtnvs'
      in
        case distinct (eq_set (op =)) all_sorts of
          [sorts] => sorts
        | _ => error "Mutually recursive domains must have same type parameters"
      end

    val sorts' : (string * sort) list = param :: sorts

    (* a lazy argument may have an unpointed type *)
    (* unless the argument has a selector function *)
    fun check_pcpo (lazy, sel, T) =
      let val sort = arg_sort (lazy andalso is_none sel) in
        if Sign.of_sort tmp_thy (T, sort) then ()
        else error ("Constructor argument type is not of sort " ^
                    Syntax.string_of_sort_global tmp_thy sort ^ ": " ^
                    Syntax.string_of_typ_global tmp_thy T)
      end

    (* test for free type variables, illegal sort constraints on rhs,
       non-pcpo-types and invalid use of recursive type
       replace sorts in type variables on rhs *)
    val rec_tab = Domain_Take_Proofs.get_rec_tab thy
    fun check_rec _ (T as TFree (v,_))  =
        if AList.defined (op =) sorts' v then T
        else error ("Free type variable " ^ quote v ^ " on rhs.")
      | check_rec no_rec (T as Type (s, Ts)) =
        (case AList.lookup (op =) dtnvs' s of
          NONE =>
            let val no_rec' =
                  if no_rec = NONE then
                    if Symtab.defined rec_tab s then NONE else SOME s
                  else no_rec
            in Type (s, map (check_rec no_rec') Ts) end
        | SOME typevars =>
          if typevars <> Ts
          then error ("Recursion of type " ^ 
                      quote (Syntax.string_of_typ_global tmp_thy T) ^ 
                      " with different arguments")
          else (case no_rec of
                  NONE => T
                | SOME c =>
                  error ("Illegal indirect recursion of type " ^ 
                         quote (Syntax.string_of_typ_global tmp_thy T) ^
                         " under type constructor " ^ quote c)))
      | check_rec _ (TVar _) = error "extender:check_rec"

    fun prep_arg (lazy, sel, raw_T) =
      let
        val T = prep_typ tmp_thy sorts raw_T
(*
        val _ = check_rec NONE T
*)
        val _ = check_pcpo (lazy, sel, T)
      in (lazy, sel, T) end
    fun prep_con (b, args, mx) = (b, map prep_arg args, mx)
    fun prep_rhs cons = map prep_con cons
    val rhss : (binding * (bool * binding option * typ) list * mixfix) list list =
        map prep_rhs raw_rhss

    fun mk_arg_typ (lazy, _, T) = if lazy then mk_upT T else T
    fun mk_con_typ (_, args, _) =
        if null args then oneT else foldr1 mk_sprodT (map mk_arg_typ args)
    fun mk_rhs_typ cons = foldr1 mk_ssumT (map mk_con_typ cons)

    val repTs : typ list = map mk_rhs_typ rhss

    val doms : ((string * sort) list * binding * mixfix * typ * (binding * binding) option) list =
      map (fn ((tbind, vs, mx), repT) => (vs, tbind, mx, repT, NONE)) (dtnvs0 ~~ repTs)
    val ((iso_infos, take_info), thy) = domain_isomorphism (fst param) doms thy

    val (constr_infos, thy) =
        thy
          |> fold_map (fn ((dbind, cons), info) =>
                Domain_Constructors.add_domain_constructors dbind cons info)
             (dbinds ~~ rhss ~~ iso_infos)

    val (_, thy) =
        Domain_Induction.comp_theorems
          dbinds take_info constr_infos thy
  in
    thy
  end

fun rep_arg lazy = if lazy then @{sort predomain} else @{sort "domain"}

fun read_sort thy (SOME s) = Syntax.read_sort_global thy s
  | read_sort thy NONE = Sign.defaultS thy

(* Adapted from src/HOL/Tools/Datatype/datatype_data.ML *)
fun read_typ thy sorts str =
  let
    val ctxt = Proof_Context.init_global thy
      |> fold (Variable.declare_typ o TFree) sorts
  in Syntax.read_typ ctxt str end

fun cert_typ sign sorts raw_T =
  let
    val T = Type.no_tvars (Sign.certify_typ sign raw_T)
      handle TYPE (msg, _, _) => error msg
    val sorts' = Term.add_tfreesT T sorts
    val _ =
      case duplicates (op =) (map fst sorts') of
        [] => ()
      | dups => error ("Inconsistent sort constraints for " ^ commas dups)
  in T end

val add_tycon =
    gen_add_tycon (K I) cert_typ rep_arg

val add_tycon_cmd =
    gen_add_tycon read_sort read_typ rep_arg


(** outer syntax **)

val dest_decl : (bool * binding option * string) parser =
  @{keyword "("} |-- Scan.optional (@{keyword "lazy"} >> K true) false --
    (Parse.binding >> SOME) -- (@{keyword "::"} |-- Parse.typ)  --| @{keyword ")"} >> Scan.triple1
    || @{keyword "("} |-- @{keyword "lazy"} |-- Parse.typ --| @{keyword ")"}
    >> (fn t => (true,NONE,t))
    || Parse.typ >> (fn t => (false,NONE,t))

val cons_decl =
  Parse.binding -- Scan.repeat dest_decl -- Parse.opt_mixfix

val tycon_decl =
  (Parse.type_ident --| @{keyword ""} -- Parse.type_args_constrained --
    Parse.binding -- Parse.opt_mixfix) --
    (@{keyword "="} |-- Parse.enum1 "|" cons_decl)

val tycons_decl =
  Parse.and_list1 tycon_decl

fun mk_tycon
    (doms : ((((string * (string * string option) list) * binding) * mixfix) *
             ((binding * (bool * binding option * string) list) * mixfix) list) list ) =
  let
    val specs : (string * (string * string option) list * binding * mixfix *
                 (binding * (bool * binding option * string) list * mixfix) list) list =
        map (fn ((((a, vs), t), mx), cons) =>
                (a, vs, t, mx, map (fn ((c, ds), mx) => (c, ds, mx)) cons)) doms
  in
    add_tycon_cmd specs
  end

val _ =
  Outer_Syntax.command @{command_keyword tycondef}
    "define recursive type constructors (HOLCF)"
    (tycons_decl >> (Toplevel.theory o mk_tycon))

end

Theory Monad

section ‹Monad Class›

theory Monad
imports Functor
begin

subsection ‹Class definition›

text ‹In Haskell, class \emph{Monad} is defined as follows:›

text_raw ‹
\begin{verbatim}
class Monad m where
  return :: a -> m a
  (>>=) :: m a -> (a -> m b) -> m b
\end{verbatim}
›

text ‹We formalize class monad› in a manner similar to the
functor› class: We fix monomorphic versions of the class
constants, replacing type variables with udom›, and assume
monomorphic versions of the class axioms.›

text ‹Because the monad laws imply the composition rule for fmap›, we declare prefunctor› as the superclass, and separately
prove a subclass relationship with functor›.›

class monad = prefunctor +
  fixes returnU :: "udom  udom'a::tycon"
  fixes bindU :: "udom'a  (udom  udom'a)  udom'a"
  assumes fmapU_eq_bindU:
    "f xs. fmapUfxs = bindUxs(Λ x. returnU(fx))"
  assumes bindU_returnU:
    "f x. bindU(returnUx)f = fx"
  assumes bindU_bindU:
    "xs f g. bindU(bindUxsf)g = bindUxs(Λ x. bindU(fx)g)"

instance monad  "functor"
proof
  fix f g :: "udom  udom" and xs :: "udom'a"
  show "fmapUf(fmapUgxs) = fmapU(Λ x. f(gx))xs"
    by (simp add: fmapU_eq_bindU bindU_bindU bindU_returnU)
qed

text ‹As with fmap›, we define the polymorphic return›
and bind› by coercion from the monomorphic returnU› and
bindU›.›

definition return :: "'a  'a'm::monad"
  where "return = coerce(returnU :: udom  udom'm)"

definition bind :: "'a'm::monad  ('a  'b'm)  'b'm"
  where "bind = coerce(bindU :: udom'm  _)"

abbreviation bind_syn :: "'a'm::monad  ('a  'b'm)  'b'm" (infixl "" 55)
  where "m  f  bindmf"

subsection ‹Naturality of bind and return›

text ‹The three class axioms imply naturality properties of returnU› and bindU›, i.e., that both commute with fmapU›.›

lemma fmapU_returnU [coerce_simp]:
  "fmapUf(returnUx) = returnU(fx)"
by (simp add: fmapU_eq_bindU bindU_returnU)

lemma fmapU_bindU [coerce_simp]:
  "fmapUf(bindUmk) = bindUm(Λ x. fmapUf(kx))"
by (simp add: fmapU_eq_bindU bindU_bindU)

lemma bindU_fmapU:
  "bindU(fmapUfxs)k = bindUxs(Λ x. k(fx))"
by (simp add: fmapU_eq_bindU bindU_returnU bindU_bindU)

subsection ‹Polymorphic versions of class assumptions›

lemma monad_fmap:
  fixes xs :: "'a'm::monad" and f :: "'a  'b"
  shows "fmapfxs = xs  (Λ x. return(fx))"
unfolding bind_def return_def fmap_def
by (simp add: coerce_simp fmapU_eq_bindU bindU_returnU)

lemma monad_left_unit [simp]: "(returnx  f) = (fx)"
unfolding bind_def return_def
by (simp add: coerce_simp bindU_returnU)

lemma bind_bind:
  fixes m :: "'a'm::monad"
  shows "((m  f)  g) = (m  (Λ x. fx  g))"
unfolding bind_def
by (simp add: coerce_simp bindU_bindU)

subsection ‹Derived rules›

text ‹The following properties can be derived using only the
abstract monad laws.›

lemma monad_right_unit [simp]: "(m  return) = m"
 apply (subgoal_tac "fmapIDm = m")
  apply (simp only: monad_fmap)
  apply (simp add: eta_cfun)
 apply simp
done

lemma fmap_return: "fmapf(returnx) = return(fx)"
by (simp add: monad_fmap)

lemma fmap_bind: "fmapf(bindxsk) = bindxs(Λ x. fmapf(kx))"
by (simp add: monad_fmap bind_bind)

lemma bind_fmap: "bind(fmapfxs)k = bindxs(Λ x. k(fx))"
by (simp add: monad_fmap bind_bind)

text ‹Bind is strict in its first argument, if its second argument
is a strict function.›

lemma bind_strict:
  assumes "k = " shows "  k = "
proof -
  have "  k  return  k"
    by (intro monofun_cfun below_refl minimal)
  thus "  k = "
    by (simp add: assms)
qed

lemma congruent_bind:
  "(m. m  k1 = m  k2) = (k1 = k2)"
 apply (safe, rule cfun_eqI)
 apply (drule_tac x="returnx" in spec, simp)
done

subsection ‹Laws for join›

definition join :: "('a'm)'m  'a'm::monad"
  where "join  Λ m. m  (Λ x. x)"

lemma join_fmap_fmap: "join(fmap(fmapf)xss) = fmapf(joinxss)"
by (simp add: join_def monad_fmap bind_bind)

lemma join_return: "join(returnxs) = xs"
by (simp add: join_def)

lemma join_fmap_return: "join(fmapreturnxs) = xs"
by (simp add: join_def monad_fmap eta_cfun bind_bind)

lemma join_fmap_join: "join(fmapjoinxsss) = join(joinxsss)"
by (simp add: join_def monad_fmap bind_bind)

lemma bind_def2: "m  k = join(fmapkm)"
by (simp add: join_def monad_fmap eta_cfun bind_bind)

subsection ‹Equivalence of monad laws and fmap/join laws›

lemma "(returnx  f) = (fx)"
by (simp only: bind_def2 fmap_return join_return)

lemma "(m  return) = m"
by (simp only: bind_def2 join_fmap_return)

lemma "((m  f)  g) = (m  (Λ x. fx  g))"
 apply (simp only: bind_def2)
 apply (subgoal_tac "join(fmapg(join(fmapfm))) =
    join(fmapjoin(fmap(fmapg)(fmapfm)))")
  apply (simp add: fmap_fmap)
 apply (simp add: join_fmap_join join_fmap_fmap)
done

subsection ‹Simplification of coercions›

text ‹We configure rewrite rules that push coercions inwards, and
reduce them to coercions on simpler types.›

lemma coerce_return [coerce_simp]:
  "COERCE('a'm,'b'm::monad)(returnx) = return(COERCE('a,'b)x)"
by (simp add: coerce_functor fmap_return)

lemma coerce_bind [coerce_simp]:
  fixes m :: "'a'm::monad" and k :: "'a  'b'm"
  shows "COERCE('b'm,'c'm)(m  k) = m  (Λ x. COERCE('b'm,'c'm)(kx))"
by (simp add: coerce_functor fmap_bind)

lemma bind_coerce [coerce_simp]:
  fixes m :: "'a'm::monad" and k :: "'b  'c'm"
  shows "COERCE('a'm,'b'm)m  k = m  (Λ x. k(COERCE('a,'b)x))"
by (simp add: coerce_functor bind_fmap)

end

Theory Monad_Zero

section ‹Monad-Zero Class›

theory Monad_Zero
imports Monad
begin

class zeroU = tycon +
  fixes zeroU :: "udom'a::tycon"

class functor_zero = zeroU + "functor" +
  assumes fmapU_zeroU [coerce_simp]:
    "fmapUfzeroU = zeroU"

class monad_zero = zeroU + monad +
  assumes bindU_zeroU:
    "bindUzeroUf = zeroU"

instance monad_zero  functor_zero
proof
  fix f show "fmapUfzeroU = (zeroU :: udom'a)"
    unfolding fmapU_eq_bindU
    by (rule bindU_zeroU)
qed

definition fzero :: "'a'f::functor_zero"
  where "fzero = coerce(zeroU :: udom'f)"

lemma fmap_fzero:
  "fmapf(fzero :: 'a'f::functor_zero) = (fzero :: 'b'f)"
unfolding fmap_def fzero_def
by (simp add: coerce_simp)

abbreviation mzero :: "'a'm::monad_zero"
  where "mzero  fzero"

lemmas mzero_def = fzero_def [where 'f="'m::monad_zero"] for f
lemmas fmap_mzero = fmap_fzero [where 'f="'m::monad_zero"] for f

lemma bindU_eq_bind: "bindU = bind"
unfolding bind_def by simp

lemma bind_mzero:
  "bind(fzero :: 'a'm::monad_zero)k = (mzero :: 'b'm)"
unfolding bind_def mzero_def
by (simp add: coerce_simp bindU_zeroU)

end

Theory Monad_Plus

section ‹Monad-Plus Class›

theory Monad_Plus
imports Monad
begin

hide_const (open) Fixrec.mplus

class plusU = tycon +
  fixes plusU :: "udom'a  udom'a  udom'a::tycon"

class functor_plus = plusU + "functor" +
  assumes fmapU_plusU [coerce_simp]:
    "fmapUf(plusUab) = plusU(fmapUfa)(fmapUfb)"
  assumes plusU_assoc:
    "plusU(plusUab)c = plusUa(plusUbc)"

class monad_plus = plusU + monad +
  assumes bindU_plusU:
    "bindU(plusUxsys)k = plusU(bindUxsk)(bindUysk)"
  assumes plusU_assoc':
    "plusU(plusUab)c = plusUa(plusUbc)"

instance monad_plus  functor_plus
by standard (simp_all only: fmapU_eq_bindU bindU_plusU plusU_assoc')

definition fplus :: "'a'f::functor_plus  'a'f  'a'f"
  where "fplus = coerce(plusU :: udom'f  _)"

lemma fmap_fplus:
  fixes f :: "'a  'b" and a b :: "'a'f::functor_plus"
  shows "fmapf(fplusab) = fplus(fmapfa)(fmapfb)"
unfolding fmap_def fplus_def
by (simp add: coerce_simp)

lemma fplus_assoc:
  fixes a b c :: "'a'f::functor_plus"
  shows "fplus(fplusab)c = fplusa(fplusbc)"
unfolding fplus_def
by (simp add: coerce_simp plusU_assoc)

abbreviation mplus :: "'a'm::monad_plus  'a'm  'a'm"
  where "mplus  fplus"

lemmas mplus_def = fplus_def [where 'f="'m::monad_plus" for f]
lemmas fmap_mplus = fmap_fplus [where 'f="'m::monad_plus" for f]
lemmas mplus_assoc = fplus_assoc [where 'f="'m::monad_plus" for f]

lemma bind_mplus:
  fixes a b :: "'a'm::monad_plus"
  shows "bind(mplusab)k = mplus(bindak)(bindbk)"
unfolding bind_def mplus_def
by (simp add: coerce_simp bindU_plusU)

lemma join_mplus:
  fixes xss yss :: "('a'm)'m::monad_plus"
  shows "join(mplusxssyss) = mplus(joinxss)(joinyss)"
by (simp add: join_def bind_mplus)

end

Theory Monad_Zero_Plus

section ‹Monad-Zero-Plus Class›

theory Monad_Zero_Plus
imports Monad_Zero Monad_Plus
begin

hide_const (open) Fixrec.mplus

class functor_zero_plus = functor_zero + functor_plus +
  assumes plusU_zeroU_left:
    "plusUzeroUm = m"
  assumes plusU_zeroU_right:
    "plusUmzeroU = m"

class monad_zero_plus = monad_zero + monad_plus + functor_zero_plus

lemma fplus_fzero_left:
  fixes m :: "'a'f::functor_zero_plus"
  shows "fplusfzerom = m"
unfolding fplus_def fzero_def
by (simp add: coerce_simp plusU_zeroU_left)

lemma fplus_fzero_right:
  fixes m :: "'a'f::functor_zero_plus"
  shows "fplusmfzero = m"
unfolding fplus_def fzero_def
by (simp add: coerce_simp plusU_zeroU_right)

lemmas mplus_mzero_left =
  fplus_fzero_left [where 'f="'m::monad_zero_plus"] for f

lemmas mplus_mzero_right =
  fplus_fzero_right [where 'f="'m::monad_zero_plus"] for f

end

Theory Lazy_List_Monad

section ‹Lazy list monad›

theory Lazy_List_Monad
imports Monad_Zero_Plus
begin

text ‹To illustrate the general process of defining a new type
constructor, we formalize the datatype of lazy lists. Below are the
Haskell datatype definition and class instances.›

text_raw ‹
\begin{verbatim}
data List a = Nil | Cons a (List a)

instance Functor List where
  fmap f Nil = Nil
  fmap f (Cons x xs) = Cons (f x) (fmap f xs)

instance Monad List where
  return x        = Cons x Nil
  Nil       >>= k = Nil
  Cons x xs >>= k = mplus (k x) (xs >>= k)

instance MonadZero List where
  mzero = Nil

instance MonadPlus List where
  mplus Nil         ys = ys
  mplus (Cons x xs) ys = Cons x (mplus xs ys)
\end{verbatim}
›

subsection ‹Type definition›

text ‹The first step is to register the datatype definition with
tycondef›.›

tycondef 'allist = LNil | LCons (lazy "'a") (lazy "'allist")

text ‹The tycondef› command generates lots of theorems
automatically, but there are a few more involving coerce› and
fmapU› that we still need to prove manually. These proofs could
be automated in a later version of tycondef›.›

lemma coerce_llist_abs [simp]: "coerce(llist_absx) = llist_abs(coercex)"
apply (simp add: llist_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_llist)
done

lemma coerce_LNil [simp]: "coerceLNil = LNil"
unfolding LNil_def by simp

lemma coerce_LCons [simp]: "coerce(LConsxxs) = LCons(coercex)(coercexs)"
unfolding LCons_def by simp

lemma fmapU_llist_simps [simp]:
  "fmapUf(::udomllist) = "
  "fmapUfLNil = LNil"
  "fmapUf(LConsxxs) = LCons(fx)(fmapUfxs)"
unfolding fmapU_llist_def llist_map_def
apply (subst fix_eq, simp)
apply (subst fix_eq, simp add: LNil_def)
apply (subst fix_eq, simp add: LCons_def)
done

subsection ‹Class instances›

text ‹The tycondef› command defines fmapU› for us and
proves a prefunctor› class instance automatically. For the
functor› instance we only need to prove the composition law,
which we can do by induction.›

instance llist :: "functor"
proof
  fix f g and xs :: "udomllist"
  show "fmapUf(fmapUgxs) = fmapU(Λ x. f(gx))xs"
    by (induct xs rule: llist.induct) simp_all
qed

text ‹For the other class instances, we need to provide definitions
for a few constants: returnU›, bindU› zeroU›, and
plusU›. We can use ordinary commands like definition›
and fixrec› for this purpose. Finally we prove the class
axioms, along with a few helper lemmas, using ordinary proof
procedures like induction.›

instantiation llist :: monad_zero_plus
begin

fixrec plusU_llist :: "udomllist  udomllist  udomllist"
  where "plusU_llistLNilys = ys"
  | "plusU_llist(LConsxxs)ys = LConsx(plusU_llistxsys)"

lemma plusU_llist_strict [simp]: "plusUys = (::udomllist)"
by fixrec_simp

fixrec bindU_llist :: "udomllist  (udom  udomllist)  udomllist"
  where "bindU_llistLNilk = LNil"
  | "bindU_llist(LConsxxs)k = plusU(kx)(bindU_llistxsk)"

lemma bindU_llist_strict [simp]: "bindUk = (::udomllist)"
by fixrec_simp

definition zeroU_llist_def:
  "zeroU = LNil"

definition returnU_llist_def:
  "returnU = (Λ x. LConsxLNil)"

lemma plusU_LNil_right: "plusUxsLNil = xs"
by (induct xs rule: llist.induct) simp_all

lemma plusU_llist_assoc:
  fixes xs ys zs :: "udomllist"
  shows "plusU(plusUxsys)zs = plusUxs(plusUyszs)"
by (induct xs rule: llist.induct) simp_all

lemma bindU_plusU_llist:
  fixes xs ys :: "udomllist" shows
  "bindU(plusUxsys)f = plusU(bindUxsf)(bindUysf)"
by (induct xs rule: llist.induct) (simp_all add: plusU_llist_assoc)

instance proof
  fix x :: "udom"
  fix f :: "udom  udom"
  fix h k :: "udom  udomllist"
  fix xs ys zs :: "udomllist"
  show "fmapUfxs = bindUxs(Λ x. returnU(fx))"
    by (induct xs rule: llist.induct, simp_all add: returnU_llist_def)
  show "bindU(returnUx)k = kx"
    by (simp add: returnU_llist_def plusU_LNil_right)
  show "bindU(bindUxsh)k = bindUxs(Λ x. bindU(hx)k)"
    by (induct xs rule: llist.induct)
       (simp_all add: bindU_plusU_llist)
  show "bindU(plusUxsys)k = plusU(bindUxsk)(bindUysk)"
    by (induct xs rule: llist.induct)
       (simp_all add: plusU_llist_assoc)
  show "plusU(plusUxsys)zs = plusUxs(plusUyszs)"
    by (rule plusU_llist_assoc)
  show "bindUzeroUk = zeroU"
    by (simp add: zeroU_llist_def)
  show "fmapUf(plusUxsys) = plusU(fmapUfxs)(fmapUfys)"
    by (induct xs rule: llist.induct) simp_all
  show "fmapUfzeroU = (zeroU :: udomllist)"
    by (simp add: zeroU_llist_def)
  show "plusUzeroUxs = xs"
    by (simp add: zeroU_llist_def)
  show "plusUxszeroU = xs"
    by (simp add: zeroU_llist_def plusU_LNil_right)
qed

end

subsection ‹Transfer properties to polymorphic versions›

text ‹After proving the class instances, there is still one more
step: We must transfer all the list-specific lemmas about the
monomorphic constants (e.g., fmapU› and bindU›) to the
corresponding polymorphic constants (fmap› and bind›).
These lemmas primarily consist of the defining equations for each
constant. The polymorphic constants are defined using coerce›,
so the proofs proceed by unfolding the definitions and simplifying
with the coerce_simp› rules.›

lemma fmap_llist_simps [simp]:
  "fmapf(::'allist) = "
  "fmapfLNil = LNil"
  "fmapf(LConsxxs) = LCons(fx)(fmapfxs)"
unfolding fmap_def by simp_all

lemma mplus_llist_simps [simp]:
  "mplus(::'allist)ys = "
  "mplusLNilys = ys"
  "mplus(LConsxxs)ys = LConsx(mplusxsys)"
unfolding mplus_def by simp_all

lemma bind_llist_simps [simp]:
  "bind(::'allist)f = "
  "bindLNilf = LNil"
  "bind(LConsxxs)f = mplus(fx)(bindxsf)"
unfolding bind_def mplus_def
by (simp_all add: coerce_simp)

lemma return_llist_def:
  "return = (Λ x. LConsxLNil)"
unfolding return_def returnU_llist_def
by (simp add: coerce_simp)

lemma mzero_llist_def:
  "mzero = LNil"
unfolding mzero_def zeroU_llist_def
by simp

lemma join_llist_simps [simp]:
  "join(::'allistllist) = "
  "joinLNil = LNil"
  "join(LConsxsxss) = mplusxs(joinxss)"
unfolding join_def by simp_all

end

Theory Maybe_Monad

section ‹Maybe monad›

theory Maybe_Monad
imports Monad_Zero_Plus
begin

subsection ‹Type definition›

tycondef 'amaybe = Nothing | Just (lazy "'a")

lemma coerce_maybe_abs [simp]: "coerce(maybe_absx) = maybe_abs(coercex)"
apply (simp add: maybe_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_maybe)
done

lemma coerce_Nothing [simp]: "coerceNothing = Nothing"
unfolding Nothing_def by simp

lemma coerce_Just [simp]: "coerce(Justx) = Just(coercex)"
unfolding Just_def by simp

lemma fmapU_maybe_simps [simp]:
  "fmapUf(::udommaybe) = "
  "fmapUfNothing = Nothing"
  "fmapUf(Justx) = Just(fx)"
unfolding fmapU_maybe_def maybe_map_def fix_const
apply simp
apply (simp add: Nothing_def)
apply (simp add: Just_def)
done

subsection ‹Class instance proofs›

instance maybe :: "functor"
apply standard
apply (induct_tac xs rule: maybe.induct, simp_all)
done

instantiation maybe :: "{functor_zero_plus, monad_zero}"
begin

fixrec plusU_maybe :: "udommaybe  udommaybe  udommaybe"
  where "plusU_maybeNothingys = ys"
  | "plusU_maybe(Justx)ys = Justx"

lemma plusU_maybe_strict [simp]: "plusUys = (::udommaybe)"
by fixrec_simp

fixrec bindU_maybe :: "udommaybe  (udom  udommaybe)  udommaybe"
  where "bindU_maybeNothingk = Nothing"
  | "bindU_maybe(Justx)k = kx"

lemma bindU_maybe_strict [simp]: "bindUk = (::udommaybe)"
by fixrec_simp

definition zeroU_maybe_def:
  "zeroU = Nothing"

definition returnU_maybe_def:
  "returnU = Just"

lemma plusU_Nothing_right: "plusUxsNothing = xs"
by (induct xs rule: maybe.induct) simp_all

lemma bindU_plusU_maybe:
  fixes xs ys :: "udommaybe" shows
  "bindU(plusUxsys)f = plusU(bindUxsf)(bindUysf)"
apply (induct xs rule: maybe.induct)
apply simp_all
oops

instance proof
  fix x :: "udom"
  fix f :: "udom  udom"
  fix h k :: "udom  udommaybe"
  fix xs ys zs :: "udommaybe"
  show "fmapUfxs = bindUxs(Λ x. returnU(fx))"
    by (induct xs rule: maybe.induct, simp_all add: returnU_maybe_def)
  show "bindU(returnUx)k = kx"
    by (simp add: returnU_maybe_def plusU_Nothing_right)
  show "bindU(bindUxsh)k = bindUxs(Λ x. bindU(hx)k)"
    by (induct xs rule: maybe.induct) simp_all
  show "plusU(plusUxsys)zs = plusUxs(plusUyszs)"
    by (induct xs rule: maybe.induct) simp_all
  show "bindUzeroUk = zeroU"
    by (simp add: zeroU_maybe_def)
  show "fmapUf(plusUxsys) = plusU(fmapUfxs)(fmapUfys)"
    by (induct xs rule: maybe.induct) simp_all
  show "fmapUfzeroU = (zeroU :: udommaybe)"
    by (simp add: zeroU_maybe_def)
  show "plusUzeroUxs = xs"
    by (simp add: zeroU_maybe_def)
  show "plusUxszeroU = xs"
    by (simp add: zeroU_maybe_def plusU_Nothing_right)
qed

end

subsection ‹Transfer properties to polymorphic versions›

lemma fmap_maybe_simps [simp]:
  "fmapf(::'amaybe) = "
  "fmapfNothing = Nothing"
  "fmapf(Justx) = Just(fx)"
unfolding fmap_def by simp_all

lemma fplus_maybe_simps [simp]:
  "fplus(::'amaybe)ys = "
  "fplusNothingys = ys"
  "fplus(Justx)ys = Justx"
unfolding fplus_def by simp_all

lemma fplus_Nothing_right [simp]:
  "fplusmNothing = m"
by (simp add: fplus_def plusU_Nothing_right)

lemma bind_maybe_simps [simp]:
  "bind(::'amaybe)f = "
  "bindNothingf = Nothing"
  "bind(Justx)f = fx"
unfolding bind_def fplus_def by simp_all

lemma return_maybe_def: "return = Just"
unfolding return_def returnU_maybe_def
by (simp add: coerce_cfun cfcomp1 eta_cfun)

lemma mzero_maybe_def: "mzero = Nothing"
unfolding mzero_def zeroU_maybe_def
by simp

lemma join_maybe_simps [simp]:
  "join(::'amaybemaybe) = "
  "joinNothing = Nothing"
  "join(Justxs) = xs"
unfolding join_def by simp_all

subsection ‹Maybe is not in monad_plus›

text ‹
  The maybe› type does not satisfy the law bind_mplus›.
›

lemma maybe_counterexample1:
  "a = Justx; b = ; kx = Nothing
     fplusab  k  fplus(a  k)(b  k)"
by simp

lemma maybe_counterexample2:
  "a = Justx; b = Justy; kx = Nothing; ky = Justz
     fplusab  k  fplus(a  k)(b  k)"
by simp

end

Theory Error_Monad

section ‹Error monad›

theory Error_Monad
imports Monad_Plus
begin

subsection ‹Type definition›

tycondef 'a'e error = Err (lazy "'e") | Ok (lazy "'a")

lemma coerce_error_abs [simp]: "coerce(error_absx) = error_abs(coercex)"
apply (simp add: error_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_error)
done

lemma coerce_Err [simp]: "coerce(Errx) = Err(coercex)"
unfolding Err_def by simp

lemma coerce_Ok [simp]: "coerce(Okm) = Ok(coercem)"
unfolding Ok_def by simp

lemma fmapU_error_simps [simp]:
  "fmapUf(::udom'a error) = "
  "fmapUf(Erre) = Erre"
  "fmapUf(Okx) = Ok(fx)"
unfolding fmapU_error_def error_map_def fix_const
apply simp
apply (simp add: Err_def)
apply (simp add: Ok_def)
done

subsection ‹Monad class instance›

instantiation error :: ("domain") "{monad, functor_plus}"
begin

definition
  "returnU = Ok"

fixrec bindU_error :: "udom'a error  (udom  udom'a error)  udom'a error"
  where "bindU_error(Erre)f = Erre"
  | "bindU_error(Okx)f = fx"

lemma bindU_error_strict [simp]: "bindUk = (::udom'a error)"
by fixrec_simp

fixrec plusU_error :: "udom'a error  udom'a error  udom'a error"
  where "plusU_error(Erre)f = f"
  | "plusU_error(Okx)f = Okx"

lemma plusU_error_strict [simp]: "plusU( :: udom'a error) = "
by fixrec_simp

instance proof
  fix f g :: "udom  udom" and r :: "udom'a error"
  show "fmapUf(fmapUgr) = fmapU(Λ x. f(gx))r"
    by (induct r rule: error.induct) simp_all
next
  fix f :: "udom  udom" and r :: "udom'a error"
  show "fmapUfr = bindUr(Λ x. returnU(fx))"
    by (induct r rule: error.induct)
       (simp_all add: returnU_error_def)
next
  fix f :: "udom  udom'a error" and x :: "udom"
  show "bindU(returnUx)f = fx"
    by (simp add: returnU_error_def)
next
  fix r :: "udom'a error" and f g :: "udom  udom'a error"
  show "bindU(bindUrf)g = bindUr(Λ x. bindU(fx)g)"
    by (induct r rule: error.induct)
       simp_all
next
  fix f :: "udom  udom" and a b :: "udom'a error"
  show "fmapUf(plusUab) = plusU(fmapUfa)(fmapUfb)"
    by (induct a rule: error.induct) simp_all
next
  fix a b c :: "udom'a error"
  show "plusU(plusUab)c = plusUa(plusUbc)"
    by (induct a rule: error.induct) simp_all
qed

end

subsection ‹Transfer properties to polymorphic versions›

lemma fmap_error_simps [simp]:
  "fmapf(::'a'e error) = "
  "fmapf(Erre :: 'a'e error) = Erre"
  "fmapf(Okx :: 'a'e error) = Ok(fx)"
unfolding fmap_def [where 'f="'e error"]
by (simp_all add: coerce_simp)

lemma return_error_def: "return = Ok"
unfolding return_def returnU_error_def
by (simp add: coerce_simp eta_cfun)

lemma bind_error_simps [simp]:
  "bind( :: 'a'e error)f = "
  "bind(Erre :: 'a'e error)f = Erre"
  "bind(Okx :: 'a'e error)f = fx"
unfolding bind_def
by (simp_all add: coerce_simp)

lemma join_error_simps [simp]:
  "join = ( :: 'a'e error)"
  "join(Erre) = Erre"
  "join(Okx) = x"
unfolding join_def by simp_all

lemma fplus_error_simps [simp]:
  "fplusr = ( :: 'a'e error)"
  "fplus(Erre)r = r"
  "fplus(Okx)r = Okx"
unfolding fplus_def
by (simp_all add: coerce_simp)

end

Theory Writer_Monad

section ‹Writer monad›

theory Writer_Monad
imports Monad
begin

subsection ‹Monoid class›

class monoid = "domain" +
  fixes mempty :: "'a"
  fixes mappend :: "'a  'a  'a"
  assumes mempty_left: "ys. mappendmemptyys = ys"
  assumes mempty_right: "xs. mappendxsmempty = xs"
  assumes mappend_assoc:
    "xs ys zs. mappend(mappendxsys)zs = mappendxs(mappendyszs)"

subsection ‹Writer monad type›

text ‹Below is the standard Haskell definition of a writer monad
type; it is an isomorphic copy of the lazy pair type \texttt{(a, w)}.
›

text_raw ‹
\begin{verbatim}
newtype Writer w a = Writer { runWriter :: (a, w) }
\end{verbatim}
›

text ‹Since HOLCF does not have a pre-defined lazy pair type, we
will base this formalization on an equivalent, more direct definition:
›

text_raw ‹
\begin{verbatim}
data Writer w a = Writer w a
\end{verbatim}
›

text ‹We can directly translate the above Haskell type definition
using tycondef›. \medskip›

tycondef 'a'w writer = Writer (lazy "'w") (lazy "'a")

lemma coerce_writer_abs [simp]: "coerce(writer_absx) = writer_abs(coercex)"
apply (simp add: writer_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_writer)
done

lemma coerce_Writer [simp]:
  "coerce(Writerwx) = Writer(coercew)(coercex)"
unfolding Writer_def by simp

lemma fmapU_writer_simps [simp]:
  "fmapUf(::udom'w writer) = "
  "fmapUf(Writerwx) = Writerw(fx)"
unfolding fmapU_writer_def writer_map_def fix_const
apply simp
apply (simp add: Writer_def)
done

subsection ‹Class instance proofs›

instance writer :: ("domain") "functor"
proof
  fix f g :: "udom  udom" and xs :: "udom'a writer"
  show "fmapUf(fmapUgxs) = fmapU(Λ x. f(gx))xs"
    by (induct xs rule: writer.induct) simp_all
qed

instantiation writer :: (monoid) monad
begin

fixrec bindU_writer ::
    "udom'a writer  (udom  udom'a writer)  udom'a writer"
  where "bindU_writer(Writerwx)f =
    (case fx of Writerw'y  Writer(mappendww')y)"

lemma bindU_writer_strict [simp]: "bindUk = (::udom'a writer)"
by fixrec_simp

definition
  "returnU = Writermempty"

instance proof
  fix f :: "udom  udom" and m :: "udom'a writer"
  show "fmapUfm = bindUm(Λ x. returnU(fx))"
    by (induct m rule: writer.induct)
       (simp_all add: returnU_writer_def mempty_right)
next
  fix f :: "udom  udom'a writer" and x :: "udom"
  show "bindU(returnUx)f = fx"
    by (cases "fx" rule: writer.exhaust)
       (simp_all add: returnU_writer_def mempty_left)
next
  fix m :: "udom'a writer" and f g :: "udom  udom'a writer"
  show "bindU(bindUmf)g = bindUm(Λ x. bindU(fx)g)"
    apply (induct m rule: writer.induct, simp)
    apply (case_tac "fa" rule: writer.exhaust, simp)
    apply (case_tac "gaa" rule: writer.exhaust, simp)
    apply (simp add: mappend_assoc)
    done
qed

end

subsection ‹Transfer properties to polymorphic versions›

lemma fmap_writer_simps [simp]:
  "fmapf(::'a'w writer) = "
  "fmapf(Writerwx :: 'a'w writer) = Writerw(fx)"
unfolding fmap_def [where 'f="'w writer"]
by (simp_all add: coerce_simp)

lemma return_writer_def: "return = Writermempty"
unfolding return_def returnU_writer_def
by (simp add: coerce_simp eta_cfun)

lemma bind_writer_simps [simp]:
  "bind( :: 'a'w::monoid writer)f = "
  "bind(Writerwx :: 'a'w::monoid writer)k =
    (case kx of Writerw'y  Writer(mappendww')y)"
unfolding bind_def
apply (simp add: coerce_simp)
apply (cases "kx" rule: writer.exhaust)
apply (simp_all add: coerce_simp)
done

lemma join_writer_simps [simp]:
  "join = ( :: 'a'w::monoid writer)"
  "join(Writerw(Writerw'x)) = Writer(mappendww')x"
unfolding join_def by simp_all

subsection ‹Extra operations›

definition tell :: "'w  unit('w::monoid writer)"
  where "tell = (Λ w. Writerw())"

end

Theory Binary_Tree_Monad

section ‹Binary tree monad›

theory Binary_Tree_Monad
imports Monad
begin

subsection ‹Type definition›

tycondef 'abtree =
  Leaf (lazy "'a") | Node (lazy "'abtree") (lazy "'abtree")

lemma coerce_btree_abs [simp]: "coerce(btree_absx) = btree_abs(coercex)"
apply (simp add: btree_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_btree)
done

lemma coerce_Leaf [simp]: "coerce(Leafx) = Leaf(coercex)"
unfolding Leaf_def by simp

lemma coerce_Node [simp]: "coerce(Nodexsys) = Node(coercexs)(coerceys)"
unfolding Node_def by simp

lemma fmapU_btree_simps [simp]:
  "fmapUf(::udombtree) = "
  "fmapUf(Leafx) = Leaf(fx)"
  "fmapUf(Nodexsys) = Node(fmapUfxs)(fmapUfys)"
unfolding fmapU_btree_def btree_map_def
apply (subst fix_eq, simp)
apply (subst fix_eq, simp add: Leaf_def)
apply (subst fix_eq, simp add: Node_def)
done

subsection ‹Class instance proofs›

instance btree :: "functor"
apply standard
apply (induct_tac xs rule: btree.induct, simp_all)
done

instantiation btree :: monad
begin

definition
  "returnU = Leaf"

fixrec bindU_btree :: "udombtree  (udom  udombtree)  udombtree"
  where "bindU_btree(Leafx)k = kx"
  | "bindU_btree(Nodexsys)k =
      Node(bindU_btreexsk)(bindU_btreeysk)"

lemma bindU_btree_strict [simp]: "bindUk = (::udombtree)"
by fixrec_simp

instance proof
  fix x :: "udom"
  fix f :: "udom  udom"
  fix h k :: "udom  udombtree"
  fix xs :: "udombtree"
  show "fmapUfxs = bindUxs(Λ x. returnU(fx))"
    by (induct xs rule: btree.induct, simp_all add: returnU_btree_def)
  show "bindU(returnUx)k = kx"
    by (simp add: returnU_btree_def)
  show "bindU(bindUxsh)k = bindUxs(Λ x. bindU(hx)k)"
    by (induct xs rule: btree.induct) simp_all
qed

end

subsection ‹Transfer properties to polymorphic versions›

lemma fmap_btree_simps [simp]:
  "fmapf(::'abtree) = "
  "fmapf(Leafx) = Leaf(fx)"
  "fmapf(Nodexsys) = Node(fmapfxs)(fmapfys)"
unfolding fmap_def by simp_all

lemma bind_btree_simps [simp]:
  "bind(::'abtree)k = "
  "bind(Leafx)k = kx"
  "bind(Nodexsys)k = Node(bindxsk)(bindysk)"
unfolding bind_def
by (simp_all add: coerce_simp)

lemma return_btree_def:
  "return = Leaf"
unfolding return_def returnU_btree_def
by (simp add: coerce_simp eta_cfun)

lemma join_btree_simps [simp]:
  "join(::'abtreebtree) = "
  "join(Leafxs) = xs"
  "join(Nodexssyss) = Node(joinxss)(joinyss)"
unfolding join_def by simp_all

end

Theory Lift_Monad

section ‹Lift monad›

theory Lift_Monad
imports Monad
begin

subsection ‹Type definition›

tycondef 'alifted = Lifted (lazy "'a")

lemma coerce_lifted_abs [simp]: "coerce(lifted_absx) = lifted_abs(coercex)"
apply (simp add: lifted_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_lifted)
done

lemma coerce_Lifted [simp]: "coerce(Liftedx) = Lifted(coercex)"
unfolding Lifted_def by simp

lemma fmapU_lifted_simps [simp]:
  "fmapUf(::udomlifted) = "
  "fmapUf(Liftedx) = Lifted(fx)"
unfolding fmapU_lifted_def lifted_map_def fix_const
apply simp
apply (simp add: Lifted_def)
done

subsection ‹Class instance proofs›

instance lifted :: "functor"
  by standard (induct_tac xs rule: lifted.induct, simp_all)

instantiation lifted :: monad
begin

fixrec bindU_lifted :: "udomlifted  (udom  udomlifted)  udomlifted"
  where "bindU_lifted(Liftedx)k = kx"

lemma bindU_lifted_strict [simp]: "bindUk = (::udomlifted)"
by fixrec_simp

definition returnU_lifted_def:
  "returnU = Lifted"

instance proof
  fix x :: "udom"
  fix f :: "udom  udom"
  fix h k :: "udom  udomlifted"
  fix xs :: "udomlifted"
  show "fmapUfxs = bindUxs(Λ x. returnU(fx))"
    by (induct xs rule: lifted.induct, simp_all add: returnU_lifted_def)
  show "bindU(returnUx)k = kx"
    by (simp add: returnU_lifted_def)
  show "bindU(bindUxsh)k = bindUxs(Λ x. bindU(hx)k)"
    by (induct xs rule: lifted.induct) simp_all
qed

end

subsection ‹Transfer properties to polymorphic versions›

lemma fmap_lifted_simps [simp]:
  "fmapf(::'alifted) = "
  "fmapf(Liftedx) = Lifted(fx)"
unfolding fmap_def by simp_all

lemma bind_lifted_simps [simp]:
  "bind(::'alifted)f = "
  "bind(Liftedx)f = fx"
unfolding bind_def by simp_all

lemma return_lifted_def: "return = Lifted"
unfolding return_def returnU_lifted_def
by (simp add: coerce_cfun cfcomp1 eta_cfun)

lemma join_lifted_simps [simp]:
  "join(::'aliftedlifted) = "
  "join(Liftedxs) = xs"
unfolding join_def by simp_all

end

Theory Resumption_Transformer

section ‹Resumption monad transformer›

theory Resumption_Transformer
imports Monad_Plus
begin

subsection ‹Type definition›

text ‹The standard Haskell libraries do not include a resumption
monad transformer type; below is the Haskell definition for the one we
will use here.›

text_raw ‹
\begin{verbatim}
data ResT m a = Done a | More (m (ResT m a))
\end{verbatim}
›

text ‹The above datatype definition can be translated directly into
HOLCF using tycondef›. \medskip›

tycondef 'a('f::"functor") resT =
  Done (lazy "'a") | More (lazy "('a'f resT)'f")

lemma coerce_resT_abs [simp]: "coerce(resT_absx) = resT_abs(coercex)"
apply (simp add: resT_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_resT)
done

lemma coerce_Done [simp]: "coerce(Donex) = Done(coercex)"
unfolding Done_def by simp

lemma coerce_More [simp]: "coerce(Morem) = More(coercem)"
unfolding More_def by simp

lemma resT_induct [case_names adm bottom Done More]:
  fixes P :: "'a'f::functor resT  bool"
  assumes adm: "adm P"
  assumes bottom: "P "
  assumes Done: "x. P (Donex)"
  assumes More: "m f. ((r::'a'f resT). P (fr))  P (More(fmapfm))"
  shows "P r"
proof (induct r rule: resT.take_induct [OF adm])
  fix n show "P (resT_take nr)"
    apply (induct n arbitrary: r)
    apply (simp add: bottom)
    apply (case_tac r rule: resT.exhaust)
    apply (simp add: bottom)
    apply (simp add: Done)
    apply (simp add: More)
    done
qed

subsection ‹Class instance proofs›

lemma fmapU_resT_simps [simp]:
  "fmapUf(::udom'f::functor resT) = "
  "fmapUf(Donex) = Done(fx)"
  "fmapUf(Morem) = More(fmap(fmapUf)m)"
unfolding fmapU_resT_def resT_map_def
apply (subst fix_eq, simp)
apply (subst fix_eq, simp add: Done_def)
apply (subst fix_eq, simp add: More_def)
done

instance resT :: ("functor") "functor"
proof
  fix f g :: "udom  udom" and xs :: "udom'a resT"
  show "fmapUf(fmapUgxs) = fmapU(Λ x. f(gx))xs"
    by (induct xs rule: resT_induct, simp_all add: fmap_fmap)
qed

instantiation resT :: ("functor") monad
begin

fixrec bindU_resT :: "udom'a resT  (udom  udom'a resT)  udom'a resT"
  where "bindU_resT(Donex)f = fx"
  | "bindU_resT(Morem)f = More(fmap(Λ r. bindU_resTrf)m)"

lemma bindU_resT_strict [simp]: "bindUk = (::udom'a resT)"
by fixrec_simp

definition
  "returnU = Done"

instance proof
  fix f :: "udom  udom" and xs :: "udom'a resT"
  show "fmapUfxs = bindUxs(Λ x. returnU(fx))"
    by (induct xs rule: resT_induct)
       (simp_all add: fmap_fmap returnU_resT_def)
next
  fix f :: "udom  udom'a resT" and x :: "udom"
  show "bindU(returnUx)f = fx"
    by (simp add: returnU_resT_def)
next
  fix xs :: "udom'a resT" and h k :: "udom  udom'a resT"
  show "bindU(bindUxsh)k = bindUxs(Λ x. bindU(hx)k)"
    by (induct xs rule: resT_induct)
       (simp_all add: fmap_fmap)
qed

end

subsection ‹Transfer properties to polymorphic versions›

lemma fmap_resT_simps [simp]:
  "fmapf(::'a'f::functor resT) = "
  "fmapf(Donex :: 'a'f::functor resT) = Done(fx)"
  "fmapf(Morem :: 'a'f::functor resT) = More(fmap(fmapf)m)"
unfolding fmap_def [where 'f="'f resT"]
by (simp_all add: coerce_simp)

lemma return_resT_def: "return = Done"
unfolding return_def returnU_resT_def
by (simp add: coerce_simp eta_cfun)

lemma bind_resT_simps [simp]:
  "bind( :: 'a'f::functor resT)f = "
  "bind(Donex :: 'a'f::functor resT)f = fx"
  "bind(Morem :: 'a'f::functor resT)f = More(fmap(Λ r. bindrf)m)"
unfolding bind_def
by (simp_all add: coerce_simp)

lemma join_resT_simps [simp]:
  "join = ( :: 'a'f::functor resT)"
  "join(Donex) = x"
  "join(Morem) = More(fmapjoinm)"
unfolding join_def by simp_all

subsection ‹Nondeterministic interleaving›

text ‹In this section we present a more general formalization of the
nondeterministic interleaving operation presented in Chapter 7 of the
author's PhD thesis \cite{holcf11}. If both arguments are Done›, then zipRT› combines the results with the function
f› and terminates. While either argument is More›,
zipRT› nondeterministically chooses one such argument, runs
it for one step, and then calls itself recursively.›

fixrec zipRT ::
  "('a  'b  'c)  'a('m::functor_plus) resT  'b'm resT  'c'm resT"
  where zipRT_Done_Done:
    "zipRTf(Donex)(Doney) = Done(fxy)"
  | zipRT_Done_More:
    "zipRTf(Donex)(Moreb) =
      More(fmap(Λ r. zipRTf(Donex)r)b)"
  | zipRT_More_Done:
    "zipRTf(Morea)(Doney) =
      More(fmap(Λ r. zipRTfr(Doney))a)"
  | zipRT_More_More:
    "zipRTf(Morea)(Moreb) =
      More(fplus(fmap(Λ r. zipRTf(Morea)r)b)
                 (fmap(Λ r. zipRTfr(Moreb))a))"

lemma zipRT_strict1 [simp]: "zipRTfr = "
by fixrec_simp

lemma zipRT_strict2 [simp]: "zipRTfr = "
by (fixrec_simp, cases r, simp_all)

abbreviation apR (infixl "" 70)
  where "a  b  zipRTIDab"

text ‹Proofs that zipRT› satisfies the applicative functor laws:›

lemma zipRT_homomorphism: "Donef  Donex = Done(fx)"
  by simp

lemma zipRT_identity: "DoneID  r = r"
  by (induct r rule: resT_induct, simp_all add: fmap_fmap eta_cfun)

lemma zipRT_interchange: "r  Donex = Done(Λ f. fx)  r"
  by (induct r rule: resT_induct, simp_all add: fmap_fmap)

text ‹The associativity rule is the hard one!›

lemma zipRT_associativity: "Donecfcomp  r1  r2  r3 = r1  (r2  r3)"
proof (induct r1 arbitrary: r2 r3 rule: resT_induct)
  case (Done x1) thus ?case
  proof (induct r2 arbitrary: r3 rule: resT_induct)
    case (Done x2) thus ?case
    proof (induct r3 rule: resT_induct)
      case (More p3 c3) thus ?case (* Done/Done/More *)
        by (simp add: fmap_fmap)
    qed simp_all
  next
    case (More p2 c2) thus ?case
    proof (induct r3 rule: resT_induct)
      case (Done x3) thus ?case (* Done/More/Done *)
        by (simp add: fmap_fmap)
    next
      case (More p3 c3) thus ?case (* Done/More/More *)
        by (simp add: fmap_fmap fmap_fplus)
    qed simp_all
  qed simp_all
next
  case (More p1 c1) thus ?case
  proof (induct r2 arbitrary: r3 rule: resT_induct)
    case (Done y) thus ?case
    proof (induct r3 rule: resT_induct)
      case (Done x3) thus ?case
        by (simp add: fmap_fmap)
    next
      case (More p3 c3) thus ?case
        by (simp add: fmap_fmap)
    qed simp_all
  next
    case (More p2 c2) thus ?case
    proof (induct r3 rule: resT_induct)
      case (Done x3) thus ?case
        by (simp add: fmap_fmap fmap_fplus)
    next
      case (More p3 c3) thus ?case
        by (simp add: fmap_fmap fmap_fplus fplus_assoc)
    qed simp_all
  qed simp_all
qed simp_all

end

Theory State_Transformer

section ‹State monad transformer›

theory State_Transformer
imports Monad_Zero_Plus
begin

text ‹
  This version has non-lifted product, and a non-lifted function space.
›

tycondef 'a('f::"functor", 's) stateT =
  StateT (runStateT :: "'s  ('a × 's)'f")

lemma coerce_stateT_abs [simp]: "coerce(stateT_absx) = stateT_abs(coercex)"
apply (simp add: stateT_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_stateT)
done

lemma coerce_StateT [simp]: "coerce(StateTk) = StateT(coercek)"
unfolding StateT_def by simp

lemma stateT_cases [case_names StateT]:
  obtains k where "y = StateTk"
proof
  show "y = StateT(runStateTy)"
    by (cases y, simp_all)
qed

lemma stateT_induct [case_names StateT]:
  fixes P :: "'a('f::functor,'s) stateT  bool"
  assumes "k. P (StateTk)"
  shows "P y"
by (cases y rule: stateT_cases, simp add: assms)

lemma stateT_eqI:
  "(s. runStateTas = runStateTbs)  a = b"
apply (cases a rule: stateT_cases)
apply (cases b rule: stateT_cases)
apply (simp add: cfun_eq_iff)
done

lemma runStateT_coerce [simp]:
  "runStateT(coercek)s = coerce(runStateTks)"
by (induct k rule: stateT_induct, simp)

subsection ‹Functor class instance›

lemma fmapU_StateT [simp]:
  "fmapUf(StateTk) =
    StateT(Λ s. fmap(Λ(x, s'). (fx, s'))(ks))"
unfolding fmapU_stateT_def stateT_map_def StateT_def
by (subst fix_eq, simp add: cfun_map_def csplit_def prod_map_def)

lemma runStateT_fmapU [simp]:
  "runStateT(fmapUfm)s =
    fmap(Λ(x, s'). (fx, s'))(runStateTms)"
by (cases m rule: stateT_cases, simp)

instantiation stateT :: ("functor", "domain") "functor"
begin

instance
apply standard
apply (induct_tac xs rule: stateT_induct)
apply (simp_all add: fmap_fmap ID_def csplit_def)
done

end

subsection ‹Monad class instance›

instantiation stateT :: (monad, "domain") monad
begin

definition returnU_stateT_def:
  "returnU = (Λ x. StateT(Λ s. return(x, s)))"

definition bindU_stateT_def:
  "bindU = (Λ m k. StateT(Λ s. runStateTms  (Λ (x, s'). runStateT(kx)s')))"

lemma bindU_stateT_StateT [simp]:
  "bindU(StateTf)k =
    StateT(Λ s. fs  (Λ (x, s'). runStateT(kx)s'))"
unfolding bindU_stateT_def by simp

lemma runStateT_bindU [simp]:
  "runStateT(bindUmk)s = runStateTms  (Λ (x, s'). runStateT(kx)s')"
unfolding bindU_stateT_def by simp

instance proof
  fix f :: "udom  udom" and r :: "udom('a,'b) stateT"
  show "fmapUfr = bindUr(Λ x. returnU(fx))"
    by (rule stateT_eqI)
       (simp add: returnU_stateT_def monad_fmap prod_map_def csplit_def)
next
  fix f :: "udom  udom('a,'b) stateT" and x :: "udom"
  show "bindU(returnUx)f = fx"
    by (rule stateT_eqI)
       (simp add: returnU_stateT_def eta_cfun)
next
  fix r :: "udom('a,'b) stateT" and f g :: "udom  udom('a,'b) stateT"
  show "bindU(bindUrf)g = bindUr(Λ x. bindU(fx)g)"
    by (rule stateT_eqI)
       (simp add: bind_bind csplit_def)
qed

end

subsection ‹Monad zero instance›

instantiation stateT :: (monad_zero, "domain") monad_zero
begin

definition zeroU_stateT_def:
  "zeroU = StateT(Λ s. mzero)"

lemma runStateT_zeroU [simp]:
  "runStateTzeroUs = mzero"
unfolding zeroU_stateT_def by simp

instance proof
  fix k :: "udom  udom('a,'b) stateT"
  show "bindUzeroUk = zeroU"
    by (rule stateT_eqI, simp add: bind_mzero)
qed

end

subsection ‹Monad plus instance›

instantiation stateT :: (monad_plus, "domain") monad_plus
begin

definition plusU_stateT_def:
  "plusU = (Λ a b. StateT(Λ s. mplus(runStateTas)(runStateTbs)))"

lemma runStateT_plusU [simp]:
  "runStateT(plusUab)s =
    mplus(runStateTas)(runStateTbs)"
unfolding plusU_stateT_def by simp

instance proof
  fix a b :: "udom('a, 'b) stateT" and k :: "udom  udom('a, 'b) stateT"
  show "bindU(plusUab)k = plusU(bindUak)(bindUbk)"
    by (rule stateT_eqI, simp add: bind_mplus)
next
  fix a b c :: "udom('a, 'b) stateT"
  show "plusU(plusUab)c = plusUa(plusUbc)"
    by (rule stateT_eqI, simp add: mplus_assoc)
qed

end

subsection ‹Monad zero plus instance›

instance stateT :: (monad_zero_plus, "domain") monad_zero_plus
proof
  fix m :: "udom('a, 'b) stateT"
  show "plusUzeroUm = m"
    by (rule stateT_eqI, simp add: mplus_mzero_left)
next
  fix m :: "udom('a, 'b) stateT"
  show "plusUmzeroU = m"
    by (rule stateT_eqI, simp add: mplus_mzero_right)
qed

subsection ‹Transfer properties to polymorphic versions›

lemma coerce_csplit [coerce_simp]:
  shows "coerce(csplitfp) = csplit(Λ x y. coerce(fxy))p"
unfolding csplit_def by simp

lemma csplit_coerce [coerce_simp]:
  fixes p :: "'a × 'b"
  shows "csplitf(COERCE('a × 'b, 'c × 'd)p) =
    csplit(Λ x y. f(COERCE('a, 'c)x)(COERCE('b, 'd)y))p"
unfolding coerce_prod csplit_def prod_map_def by simp

lemma fmap_stateT_simps [simp]:
  "fmapf(StateTm :: 'a('f::functor,'s) stateT) =
    StateT(Λ s. fmap(Λ (x, s'). (fx, s'))(ms))"
unfolding fmap_def [where 'f="('f, 's) stateT"]
by (simp add: coerce_simp eta_cfun)

lemma runStateT_fmap [simp]:
  "runStateT(fmapfm)s = fmap(Λ (x, s'). (fx, s'))(runStateTms)"
by (induct m rule: stateT_induct, simp)

lemma return_stateT_def:
  "(return :: _  'a('m::monad, 's) stateT) =
    (Λ x. StateT(Λ s. return(x, s)))"
unfolding return_def [where 'm="('m, 's) stateT"] returnU_stateT_def
by (simp add: coerce_simp)

lemma bind_stateT_def:
  "bind = (Λ m k. StateT(Λ s. runStateTms  (Λ (x, s'). runStateT(kx)s')))"
apply (subst bind_def, subst bindU_stateT_def)
apply (simp add: coerce_simp)
apply (simp add: coerce_idem domain_defl_simps monofun_cfun)
apply (simp add: eta_cfun)
done

text "TODO: add coerce_idem› to coerce_simps›, along\010with monotonicity rules for DEFL."

lemma bind_stateT_simps [simp]:
  "bind(StateTm :: 'a('m::monad,'s) stateT)k =
    StateT(Λ s. ms  (Λ (x, s'). runStateT(kx)s'))"
unfolding bind_stateT_def by simp

lemma runStateT_bind [simp]:
  "runStateT(m  k)s = runStateTms  (Λ (x, s'). runStateT(kx)s')"
unfolding bind_stateT_def by simp

end

Theory Error_Transformer

section ‹Error monad transformer›

theory Error_Transformer
imports Error_Monad
begin

subsection ‹Type definition›

text ‹The error monad transformer is defined in Haskell by composing
the given monad with a standard error monad:›

text_raw ‹
\begin{verbatim}
data Error e a = Err e | Ok a
newtype ErrorT e m a = ErrorT { runErrorT :: m (Error e a) }
\end{verbatim}
›

text ‹We can formalize this definition directly using tycondef›. \medskip›

tycondef 'a('f::"functor",'e::"domain") errorT =
  ErrorT (runErrorT :: "('a'e error)'f")

lemma coerce_errorT_abs [simp]: "coerce(errorT_absx) = errorT_abs(coercex)"
apply (simp add: errorT_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_errorT)
done

lemma coerce_ErrorT [simp]: "coerce(ErrorTk) = ErrorT(coercek)"
unfolding ErrorT_def by simp

lemma errorT_cases [case_names ErrorT]:
  obtains k where "y = ErrorTk"
proof
  show "y = ErrorT(runErrorTy)"
    by (cases y, simp_all)
qed

lemma ErrorT_runErrorT [simp]: "ErrorT(runErrorTm) = m"
by (cases m rule: errorT_cases, simp)

lemma errorT_induct [case_names ErrorT]:
  fixes P :: "'a('f::functor,'e) errorT  bool"
  assumes "k. P (ErrorTk)"
  shows "P y"
by (cases y rule: errorT_cases, simp add: assms)

lemma errorT_eq_iff:
  "a = b  runErrorTa = runErrorTb"
apply (cases a rule: errorT_cases)
apply (cases b rule: errorT_cases)
apply simp
done

lemma errorT_eqI:
  "runErrorTa = runErrorTb  a = b"
by (simp add: errorT_eq_iff)

lemma runErrorT_coerce [simp]:
  "runErrorT(coercek) = coerce(runErrorTk)"
by (induct k rule: errorT_induct, simp)

subsection ‹Functor class instance›

lemma fmap_error_def: "fmap = error_mapID"
apply (rule cfun_eqI, rename_tac f)
apply (rule cfun_eqI, rename_tac x)
apply (case_tac x rule: error.exhaust, simp_all)
apply (simp add: error_map_def fix_const)
apply (simp add: error_map_def fix_const Err_def)
apply (simp add: error_map_def fix_const Ok_def)
done

lemma fmapU_ErrorT [simp]:
  "fmapUf(ErrorTm) = ErrorT(fmap(fmapf)m)"
unfolding fmapU_errorT_def errorT_map_def fmap_error_def fix_const ErrorT_def
by simp

lemma runErrorT_fmapU [simp]:
  "runErrorT(fmapUfm) = fmap(fmapf)(runErrorTm)"
by (induct m rule: errorT_induct) simp

instance errorT :: ("functor", "domain") "functor"
proof
  fix f g and xs :: "udom('a, 'b) errorT"
  show "fmapUf(fmapUgxs) = fmapU(Λ x. f(gx))xs"
    apply (induct xs rule: errorT_induct)
    apply (simp add: fmap_fmap eta_cfun)
    done
qed

subsection ‹Transfer properties to polymorphic versions›

lemma fmap_ErrorT [simp]:
  fixes f :: "'a  'b" and m :: "'a'e error('m::functor)"
  shows "fmapf(ErrorTm) = ErrorT(fmap(fmapf)m)"
unfolding fmap_def [where 'f="('m,'e) errorT"]
by (simp_all add: coerce_simp eta_cfun)

lemma runErrorT_fmap [simp]:
  fixes f :: "'a  'b" and m :: "'a('m::functor,'e) errorT"
  shows "runErrorT(fmapfm) = fmap(fmapf)(runErrorTm)"
using fmap_ErrorT [of f "runErrorTm"]
by simp

lemma errorT_fmap_strict [simp]:
  shows "fmapf(::'a('m::monad,'e) errorT) = "
by (simp add: errorT_eq_iff fmap_strict)

subsection ‹Monad operations›

text ‹The error monad transformer does not yield a monad in the
usual sense: We cannot prove a monad› class instance, because
type 'a⋅('m,'e) errorT› contains values that break the monad
laws. However, it turns out that such values are inaccessible: The
monad laws are satisfied by all values constructible from the abstract
operations.›

text ‹To explore the properties of the error monad transformer
operations, we define them all as non-overloaded functions. \medskip
›

definition unitET :: "'a  'a('m::monad,'e) errorT"
  where "unitET = (Λ x. ErrorT(return(Okx)))"

definition bindET :: "'a('m::monad,'e) errorT 
    ('a  'b('m,'e) errorT)  'b('m,'e) errorT"
  where "bindET = (Λ m k. ErrorT(bind(runErrorTm)
    (Λ n. case n of Erre  return(Erre) | Okx  runErrorT(kx))))"

definition liftET :: "'a'm::monad  'a('m,'e) errorT"
  where "liftET = (Λ m. ErrorT(fmapOkm))"

definition throwET :: "'e  'a('m::monad,'e) errorT"
  where "throwET = (Λ e. ErrorT(return(Erre)))"

definition catchET :: "'a('m::monad,'e) errorT 
    ('e  'a('m,'e) errorT)  'a('m,'e) errorT"
  where "catchET = (Λ m h. ErrorT(bind(runErrorTm)(Λ n. case n of
    Erre  runErrorT(he) | Okx  return(Okx))))"

definition fmapET :: "('a  'b) 
    'a('m::monad,'e) errorT  'b('m,'e) errorT"
  where "fmapET = (Λ f m. bindETm(Λ x. unitET(fx)))"

lemma runErrorT_unitET [simp]:
  "runErrorT(unitETx) = return(Okx)"
unfolding unitET_def by simp

lemma runErrorT_bindET [simp]:
  "runErrorT(bindETmk) = bind(runErrorTm)
    (Λ n. case n of Erre  return(Erre) | Okx  runErrorT(kx))"
unfolding bindET_def by simp

lemma runErrorT_liftET [simp]:
  "runErrorT(liftETm) = fmapOkm"
unfolding liftET_def by simp

lemma runErrorT_throwET [simp]:
  "runErrorT(throwETe) = return(Erre)"
unfolding throwET_def by simp

lemma runErrorT_catchET [simp]:
  "runErrorT(catchETmh) =
    bind(runErrorTm)(Λ n. case n of
      Erre  runErrorT(he) | Okx  return(Okx))"
unfolding catchET_def by simp

lemma runErrorT_fmapET [simp]:
  "runErrorT(fmapETfm) =
    bind(runErrorTm)(Λ n. case n of
      Erre  return(Erre) | Okx  return(Ok(fx)))"
unfolding fmapET_def by simp

subsection ‹Laws›

lemma bindET_unitET [simp]:
  "bindET(unitETx)k = kx"
by (rule errorT_eqI, simp)

lemma catchET_unitET [simp]:
  "catchET(unitETx)h = unitETx"
by (rule errorT_eqI, simp)

lemma catchET_throwET [simp]:
  "catchET(throwETe)h = he"
by (rule errorT_eqI, simp)

lemma liftET_return:
  "liftET(returnx) = unitETx"
by (rule errorT_eqI, simp add: fmap_return)

lemma liftET_bind:
  "liftET(bindmk) = bindET(liftETm)(liftET oo k)"
by (rule errorT_eqI, simp add: fmap_bind bind_fmap)

lemma bindET_throwET:
  "bindET(throwETe)k = throwETe"
by (rule errorT_eqI, simp)

lemma bindET_bindET:
  "bindET(bindETmh)k = bindETm(Λ x. bindET(hx)k)"
apply (rule errorT_eqI)
apply simp
apply (simp add: bind_bind)
apply (rule cfun_arg_cong)
apply (rule cfun_eqI, simp)
apply (case_tac x)
apply (simp add: bind_strict)
apply simp
apply simp
done

lemma fmapET_fmapET:
  "fmapETf(fmapETgm) = fmapET(Λ x. f(gx))m"
by (simp add: fmapET_def bindET_bindET)

text ‹Right unit monad law is not satisfied in general.›

lemma bindET_unitET_right_counterexample:
  fixes m :: "'a('m::monad,'e) errorT"
  assumes "m = ErrorT(return)"
  assumes "return  ( :: ('a'e error)'m)"
  shows "bindETmunitET  m"
by (simp add: errorT_eq_iff assms)

text ‹Right unit is satisfied for inner monads with strict return.›

lemma bindET_unitET_right_restricted:
  fixes m :: "'a('m::monad,'e) errorT"
  assumes "return = ( :: ('a'e error)'m)"
  shows "bindETmunitET = m"
unfolding errorT_eq_iff
apply simp
apply (rule trans [OF _ monad_right_unit])
apply (rule cfun_arg_cong)
apply (rule cfun_eqI)
apply (case_tac x, simp_all add: assms)
done

subsection ‹Error monad transformer invariant›

text ‹This inductively-defined invariant is supposed to represent
the set of all values constructible using the standard errorT›
operations.›

inductive invar :: "'a('m::monad, 'e) errorT  bool"
  where invar_bottom: "invar "
  | invar_lub: "Y. chain Y; i. invar (Y i)  invar (i. Y i)"
  | invar_unitET: "x. invar (unitETx)"
  | invar_bindET: "m k. invar m; x. invar (kx)  invar (bindETmk)"
  | invar_throwET: "e. invar (throwETe)"
  | invar_catchET: "m h. invar m; e. invar (he)  invar (catchETmh)"
  | invar_liftET: "m. invar (liftETm)"

text ‹Right unit is satisfied for arguments built from standard functions.›

lemma bindET_unitET_right_invar:
  assumes "invar m"
  shows "bindETmunitET = m"
using assms
apply (induct set: invar)
apply (rule errorT_eqI, simp add: bind_strict)
apply (rule admD, simp, assumption, assumption)
apply (rule errorT_eqI, simp)
apply (simp add: errorT_eq_iff bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x, simp add: bind_strict, simp, simp)
apply (rule errorT_eqI, simp)
apply (simp add: errorT_eq_iff bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x, simp add: bind_strict, simp, simp)
apply (rule errorT_eqI, simp add: monad_fmap bind_bind)
done

text ‹Monad-fmap is satisfied for arguments built from standard functions.›

lemma errorT_monad_fmap_invar:
  fixes f :: "'a  'b" and m :: "'a('m::monad,'e) errorT"
  assumes "invar m"
  shows "fmapfm = bindETm(Λ x. unitET(fx))"
using assms
apply (induct set: invar)
apply (rule errorT_eqI, simp add: bind_strict fmap_strict)
apply (rule admD, simp, assumption, assumption)
apply (rule errorT_eqI, simp add: fmap_return)
apply (simp add: errorT_eq_iff bind_bind fmap_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x)
apply (simp add: bind_strict fmap_strict)
apply (simp add: fmap_return)
apply simp
apply (rule errorT_eqI, simp add: fmap_return)
apply (simp add: errorT_eq_iff bind_bind fmap_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x)
apply (simp add: bind_strict fmap_strict)
apply simp
apply (simp add: fmap_return)
apply (rule errorT_eqI, simp add: monad_fmap bind_bind return_error_def)
done

subsection ‹Invariant expressed as a deflation›

text ‹We can also define an invariant in a more semantic way, as the
set of fixed-points of a deflation.›

definition invar' :: "'a('m::monad, 'e) errorT  bool"
  where "invar' m  fmapETIDm = m"

text ‹All standard operations preserve the invariant.›

lemma invar'_unitET: "invar' (unitETx)"
  unfolding invar'_def by (simp add: fmapET_def)

lemma invar'_fmapET: "invar' m  invar' (fmapETfm)"
  unfolding invar'_def
  by (erule subst, simp add: fmapET_def bindET_bindET eta_cfun)

lemma invar'_bindET: "invar' m; x. invar' (kx)  invar' (bindETmk)"
  unfolding invar'_def
  by (simp add: fmapET_def bindET_bindET eta_cfun)

lemma invar'_throwET: "invar' (throwETe)"
  unfolding invar'_def by (simp add: fmapET_def bindET_throwET eta_cfun)

lemma invar'_catchET: "invar' m; e. invar' (he)  invar' (catchETmh)"
  unfolding invar'_def
  apply (simp add: fmapET_def eta_cfun)
  apply (rule errorT_eqI)
  apply (simp add: bind_bind eta_cfun)
  apply (rule cfun_arg_cong)
  apply (rule cfun_eqI)
  apply (case_tac x)
  apply (simp add: bind_strict)
  apply simp
  apply (drule_tac x=e in meta_spec)
  apply (erule_tac t="he" in subst) back
  apply (simp add: eta_cfun)
  apply simp
  done

lemma invar'_liftET: "invar' (liftETm)"
  unfolding invar'_def
  apply (simp add: fmapET_def errorT_eq_iff)
  apply (simp add: monad_fmap bind_bind)
  done

lemma invar'_bottom: "invar' "
  unfolding invar'_def fmapET_def
  by (simp add: errorT_eq_iff bind_strict)

lemma adm_invar': "adm invar'"
  unfolding invar'_def [abs_def] by simp

text ‹All monad laws are preserved by values satisfying the invariant.›

lemma bindET_fmapET_unitET:
  shows "bindET(fmapETfm)unitET = fmapETfm"
by (simp add: fmapET_def bindET_bindET)

lemma invar'_right_unit: "invar' m  bindETmunitET = m"
unfolding invar'_def by (erule subst, rule bindET_fmapET_unitET)

lemma invar'_monad_fmap:
  "invar' m  fmapETfm = bindETm(Λ x. unitET(fx))"
  unfolding invar'_def by (erule subst, simp add: errorT_eq_iff)

lemma invar'_bind_assoc:
  "invar' m; x. invar' (fx); y. invar' (gy)
     bindET(bindETmf)g = bindETm(Λ x. bindET(fx)g)"
  by (rule bindET_bindET)

end

Theory Writer_Transformer

section ‹Writer monad transformer›

theory Writer_Transformer
imports Writer_Monad
begin

subsection ‹Type definition›

text ‹Below is the standard Haskell definition of a writer monad
transformer:›

text_raw ‹
\begin{verbatim}
newtype WriterT w m a = WriterT { runWriterT :: m (a, w) }
\end{verbatim}
›

text ‹In this development, since a lazy pair type is not pre-defined
in HOLCF, we will use an equivalent formulation in terms of our
previous \texttt{Writer} type:›

text_raw ‹
\begin{verbatim}
data Writer w a = Writer w a
newtype WriterT w m a = WriterT { runWriterT :: m (Writer w a) }
\end{verbatim}
›

text ‹We can translate this definition directly into HOLCF using
tycondef›. \medskip›

tycondef 'a('m::"functor",'w) writerT =
  WriterT (runWriterT :: "('a'w writer)'m")

lemma coerce_writerT_abs [simp]:
  "coerce(writerT_absx) = writerT_abs(coercex)"
apply (simp add: writerT_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_writerT)
done

lemma coerce_WriterT [simp]: "coerce(WriterTk) = WriterT(coercek)"
unfolding WriterT_def by simp

lemma writerT_cases [case_names WriterT]:
  obtains k where "y = WriterTk"
proof
  show "y = WriterT(runWriterTy)"
    by (cases y, simp_all)
qed

lemma WriterT_runWriterT [simp]: "WriterT(runWriterTm) = m"
by (cases m rule: writerT_cases, simp)

lemma writerT_induct [case_names WriterT]:
  fixes P :: "'a('f::functor,'e) writerT  bool"
  assumes "k. P (WriterTk)"
  shows "P y"
by (cases y rule: writerT_cases, simp add: assms)

lemma writerT_eq_iff:
  "a = b  runWriterTa = runWriterTb"
apply (cases a rule: writerT_cases)
apply (cases b rule: writerT_cases)
apply simp
done

lemma writerT_below_iff:
  "a  b  runWriterTa  runWriterTb"
apply (cases a rule: writerT_cases)
apply (cases b rule: writerT_cases)
apply simp
done

lemma writerT_eqI:
  "runWriterTa = runWriterTb  a = b"
by (simp add: writerT_eq_iff)

lemma writerT_belowI:
  "runWriterTa  runWriterTb  a  b"
by (simp add: writerT_below_iff)

lemma runWriterT_coerce [simp]:
  "runWriterT(coercek) = coerce(runWriterTk)"
by (induct k rule: writerT_induct, simp)

subsection ‹Functor class instance›

lemma fmap_writer_def: "fmap = writer_mapID"
apply (rule cfun_eqI, rename_tac f)
apply (rule cfun_eqI, rename_tac x)
apply (case_tac x rule: writer.exhaust, simp_all)
apply (simp add: writer_map_def fix_const)
apply (simp add: writer_map_def fix_const Writer_def)
done

lemma fmapU_WriterT [simp]:
  "fmapUf(WriterTm) = WriterT(fmap(fmapf)m)"
unfolding fmapU_writerT_def writerT_map_def fmap_writer_def fix_const
  WriterT_def by simp

lemma runWriterT_fmapU [simp]:
  "runWriterT(fmapUfm) = fmap(fmapf)(runWriterTm)"
by (induct m rule: writerT_induct) simp

instance writerT :: ("functor", "domain") "functor"
proof
  fix f g :: "udom  udom" and xs :: "udom('a,'b) writerT"
  show "fmapUf(fmapUgxs) = fmapU(Λ x. f(gx))xs"
    apply (induct xs rule: writerT_induct)
    apply (simp add: fmap_fmap eta_cfun)
    done
qed

subsection ‹Monad operations›

text ‹The writer monad transformer does not yield a monad in the
usual sense: We cannot prove a monad› class instance, because
type 'a⋅('m,'w) writerT› contains values that break the monad
laws. However, it turns out that such values are inaccessible: The
monad laws are satisfied by all values constructible from the abstract
operations.›

text ‹To explore the properties of the writer monad transformer
operations, we define them all as non-overloaded functions. \medskip
›

definition unitWT :: "'a  'a('m::monad,'w::monoid) writerT"
  where "unitWT = (Λ x. WriterT(return(Writermemptyx)))"

definition bindWT :: "'a('m::monad,'w::monoid) writerT  ('a  'b('m,'w) writerT)  'b('m,'w) writerT"
  where "bindWT = (Λ m k. WriterT(bind(runWriterTm)
    (Λ(Writerwx). bind(runWriterT(kx))(Λ(Writerw'y).
      return(Writer(mappendww')y)))))"

definition liftWT :: "'a'm  'a('m::monad,'w::monoid) writerT"
  where "liftWT = (Λ m. WriterT(fmap(Writermempty)m))"

definition tellWT :: "'a  'w  'a('m::monad,'w::monoid) writerT"
  where "tellWT = (Λ x w. WriterT(return(Writerwx)))"

definition fmapWT :: "('a  'b)  'a('m::monad,'w::monoid) writerT  'b('m,'w) writerT"
  where "fmapWT = (Λ f m. bindWTm(Λ x. unitWT(fx)))"

lemma runWriterT_fmap [simp]:
  "runWriterT(fmapfm) = fmap(fmapf)(runWriterTm)"
by (subst fmap_def, simp add: coerce_simp eta_cfun)

lemma runWriterT_unitWT [simp]:
  "runWriterT(unitWTx) = return(Writermemptyx)"
unfolding unitWT_def by simp

lemma runWriterT_bindWT [simp]:
  "runWriterT(bindWTmk) = bind(runWriterTm)
    (Λ(Writerwx). bind(runWriterT(kx))(Λ(Writerw'y).
      return(Writer(mappendww')y)))"
unfolding bindWT_def by simp

lemma runWriterT_liftWT [simp]:
  "runWriterT(liftWTm) = fmap(Writermempty)m"
unfolding liftWT_def by simp

lemma runWriterT_tellWT [simp]:
  "runWriterT(tellWTxw) = return(Writerwx)"
unfolding tellWT_def by simp

lemma runWriterT_fmapWT [simp]:
  "runWriterT(fmapWTfm) =
    runWriterTm  (Λ (Writerwx). return(Writerw(fx)))"
by (simp add: fmapWT_def bindWT_def mempty_right)

subsection ‹Laws›

text ‹The liftWT› function maps return› and
bind› on the inner monad to unitWT› and bindWT›, as expected. \medskip›

lemma liftWT_return:
  "liftWT(returnx) = unitWTx"
by (rule writerT_eqI, simp add: fmap_return)

lemma liftWT_bind:
  "liftWT(bindmk) = bindWT(liftWTm)(liftWT oo k)"
by (rule writerT_eqI)
   (simp add: monad_fmap bind_bind mempty_left)

text ‹The composition rule holds unconditionally for fmap. The fmap
function also interacts as expected with unit and bind. \medskip›

lemma fmapWT_fmapWT:
  "fmapWTf(fmapWTgm) = fmapWT(Λ x. f(gx))m"
apply (simp add: writerT_eq_iff bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x, simp add: bind_strict, simp add: mempty_right)
done

lemma fmapWT_unitWT:
  "fmapWTf(unitWTx) = unitWT(fx)"
by (simp add: writerT_eq_iff mempty_right)

lemma fmapWT_bindWT:
  "fmapWTf(bindWTmk) = bindWTm(Λ x. fmapWTf(kx))"
apply (simp add: writerT_eq_iff bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, rename_tac x, simp)
apply (case_tac x, simp add: bind_strict, simp add: bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, rename_tac y, simp)
apply (case_tac y, simp add: bind_strict, simp add: mempty_right)
done

lemma bindWT_fmapWT:
  "bindWT(fmapWTfm)k = bindWTm(Λ x. k(fx))"
apply (simp add: writerT_eq_iff bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, rename_tac x, simp)
apply (case_tac x, simp add: bind_strict, simp add: mempty_right)
done

text ‹The left unit monad law is not satisfied in general. \medskip›

lemma bindWT_unitWT_counterexample:
  fixes k :: "'a  'b('m::monad,'w::monoid) writerT"
  assumes 1: "kx = WriterT(return)"
  assumes 2: "return  ( :: ('b'w writer)'m::monad)"
  shows "bindWT(unitWTx)k  kx"
by (simp add: writerT_eq_iff mempty_left assms)

text ‹However, left unit is satisfied for inner monads with a strict
return› function.›

lemma bindWT_unitWT_restricted:
  fixes k :: "'a  'b('m::monad,'w::monoid) writerT"
  assumes "return = ( :: ('b'w writer)'m)"
  shows "bindWT(unitWTx)k = kx"
unfolding writerT_eq_iff
apply (simp add: mempty_left)
apply (rule trans [OF _ monad_right_unit])
apply (rule cfun_arg_cong)
apply (rule cfun_eqI)
apply (case_tac x, simp_all add: assms)
done

text ‹The associativity of bindWT› holds
unconditionally. \medskip›

lemma bindWT_bindWT:
  "bindWT(bindWTmh)k = bindWTm(Λ x. bindWT(hx)k)"
apply (rule writerT_eqI)
apply simp
apply (simp add: bind_bind)
apply (rule cfun_arg_cong)
apply (rule cfun_eqI, simp)
apply (case_tac x)
apply (simp add: bind_strict)
apply (simp add: bind_bind)
apply (rule cfun_arg_cong)
apply (rule cfun_eqI, simp, rename_tac y)
apply (case_tac y)
apply (simp add: bind_strict)
apply (simp add: bind_bind)
apply (rule cfun_arg_cong)
apply (rule cfun_eqI, simp, rename_tac z)
apply (case_tac z)
apply (simp add: bind_strict)
apply (simp add: mappend_assoc)
done

text ‹The right unit monad law is not satisfied in general. \medskip›

lemma bindWT_unitWT_right_counterexample:
  fixes m :: "'a('m::monad,'w::monoid) writerT"
  assumes "m = WriterT(return)"
  assumes "return  ( :: ('a'w writer)'m)"
  shows "bindWTmunitWT  m"
by (simp add: writerT_eq_iff assms)

text ‹Right unit is satisfied for inner monads with a strict return› function. \medskip›

lemma bindWT_unitWT_right_restricted:
  fixes m :: "'a('m::monad,'w::monoid) writerT"
  assumes "return = ( :: ('a'w writer)'m)"
  shows "bindWTmunitWT = m"
unfolding writerT_eq_iff
apply simp
apply (rule trans [OF _ monad_right_unit])
apply (rule cfun_arg_cong)
apply (rule cfun_eqI)
apply (case_tac x, simp_all add: assms mempty_right)
done

subsection ‹Writer monad transformer invariant›

text ‹We inductively define a predicate that includes all values
that can be constructed from the standard writerT› operations.
\medskip›

inductive invar :: "'a('m::monad, 'w::monoid) writerT  bool"
  where invar_bottom: "invar "
  | invar_lub: "Y. chain Y; i. invar (Y i)  invar (i. Y i)"
  | invar_unitWT: "x. invar (unitWTx)"
  | invar_bindWT: "m k. invar m; x. invar (kx)  invar (bindWTmk)"
  | invar_tellWT: "x w. invar (tellWTxw)"
  | invar_liftWT: "m. invar (liftWTm)"

text ‹Right unit is satisfied for arguments built from standard
functions. \medskip›

lemma bindWT_unitWT_right_invar:
  fixes m :: "'a('m::monad,'w::monoid) writerT"
  assumes "invar m"
  shows "bindWTmunitWT = m"
using assms proof (induct set: invar)
  case invar_bottom thus ?case
    by (rule writerT_eqI, simp add: bind_strict)
next
  case invar_lub thus ?case
    by - (rule admD, simp, assumption, assumption)
next
  case invar_unitWT thus ?case
    by (rule writerT_eqI, simp add: bind_bind mempty_left)
next
  case invar_bindWT thus ?case
    apply (simp add: writerT_eq_iff bind_bind)
    apply (rule cfun_arg_cong, rule cfun_eqI, simp)
    apply (case_tac x, simp add: bind_strict, simp add: bind_bind)
    apply (rule cfun_arg_cong, rule cfun_eqI, simp, rename_tac y)
    apply (case_tac y, simp add: bind_strict, simp add: mempty_right)
    done
next
  case invar_tellWT thus ?case
    by (simp add: writerT_eq_iff mempty_right)
next
  case invar_liftWT thus ?case
    by (rule writerT_eqI, simp add: monad_fmap bind_bind mempty_right)
qed

text ‹Left unit is also satisfied for arguments built from standard
functions. \medskip›

lemma writerT_left_unit_invar_lemma:
  assumes "invar m"
  shows "runWriterTm  (Λ (Writerwx). return(Writerwx)) = runWriterTm"
using assms proof (induct m set: invar)
  case invar_bottom thus ?case
    by (simp add: bind_strict)
next
  case invar_lub thus ?case
    by - (rule admD, simp, assumption, assumption)
next
  case invar_unitWT thus ?case
    by simp
next
  case invar_bindWT thus ?case
    apply (simp add: bind_bind)
    apply (rule cfun_arg_cong)
    apply (rule cfun_eqI, simp, rename_tac n)
    apply (case_tac n, simp add: bind_strict)
    apply (simp add: bind_bind)
    apply (rule cfun_arg_cong)
    apply (rule cfun_eqI, simp, rename_tac p)
    apply (case_tac p, simp add: bind_strict)
    apply simp
    done
next
  case invar_tellWT thus ?case
    by simp
next
  case invar_liftWT thus ?case
    by (simp add: monad_fmap bind_bind)
qed

lemma bindWT_unitWT_invar:
  assumes "invar (kx)"
  shows "bindWT(unitWTx)k = kx"
apply (simp add: writerT_eq_iff mempty_left)
apply (rule writerT_left_unit_invar_lemma [OF assms])
done

subsection ‹Invariant expressed as a deflation›

definition invar' :: "'a('m::monad, 'w::monoid) writerT  bool"
  where "invar' m  fmapWTIDm = m"

text ‹All standard operations preserve the invariant.›

lemma invar'_bottom: "invar' "
  unfolding invar'_def by (simp add: writerT_eq_iff bind_strict)

lemma adm_invar': "adm invar'"
  unfolding invar'_def [abs_def] by simp

lemma invar'_unitWT: "invar' (unitWTx)"
  unfolding invar'_def by (simp add: writerT_eq_iff)

lemma invar'_bindWT: "invar' m; x. invar' (kx)  invar' (bindWTmk)"
  unfolding invar'_def
  apply (erule subst)
  apply (simp add: writerT_eq_iff)
  apply (simp add: bind_bind)
  apply (rule cfun_arg_cong)
  apply (rule cfun_eqI, case_tac x)
  apply (simp add: bind_strict)
  apply simp
  apply (simp add: bind_bind)
  apply (rule cfun_arg_cong)
  apply (rule cfun_eqI, rename_tac x, case_tac x)
  apply (simp add: bind_strict)
  apply simp
  done

lemma invar'_tellWT: "invar' (tellWTxw)"
  unfolding invar'_def by (simp add: writerT_eq_iff)

lemma invar'_liftWT: "invar' (liftWTm)"
  unfolding invar'_def by (simp add: writerT_eq_iff monad_fmap bind_bind)

text ‹Left unit is satisfied for arguments built from fmap.›

lemma bindWT_unitWT_fmapWT:
  "bindWT(unitWTx)(Λ x. fmapWTf(kx))
    = fmapWTf(kx)"
apply (simp add: fmapWT_def writerT_eq_iff bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x, simp_all add: bind_strict mempty_left)
done

text ‹Right unit is satisfied for arguments built from fmap.›

lemma bindWT_fmapWT_unitWT:
  shows "bindWT(fmapWTfm)unitWT = fmapWTfm"
apply (simp add: bindWT_fmapWT)
apply (simp add: fmapWT_def)
done

text ‹All monad laws are preserved by values satisfying the invariant.›

lemma invar'_right_unit: "invar' m  bindWTmunitWT = m"
unfolding invar'_def by (erule subst, rule bindWT_fmapWT_unitWT)

lemma invar'_monad_fmap:
  "invar' m  fmapWTfm = bindWTm(Λ x. unitWT(fx))"
  unfolding invar'_def
  by (erule subst, simp add: writerT_eq_iff mempty_right)

lemma invar'_bind_assoc:
  "invar' m; x. invar' (fx); y. invar' (gy)
     bindWT(bindWTmf)g = bindWTm(Λ x. bindWT(fx)g)"
  by (rule bindWT_bindWT)

end