Session Planarity_Certificates

Theory Lib

(*
 * Copyright 2014, NICTA
 *
 * This software may be distributed and modified according to the terms of
 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
 * See "LICENSE_BSD2.txt" for details.
 *
 * @TAG(NICTA_BSD)
 *)

(*
   Miscellaneous library definitions and lemmas.
*)

chapter "Library"

theory Lib
imports Main
begin

(* FIXME: eliminate *)
lemma hd_map_simp:
  "b  []  hd (map a b) = a (hd b)"
  by (rule hd_map)

lemma tl_map_simp:
  "tl (map a b) = map a (tl b)"
  by (induct b,auto)

(* FIXME: could be added to Set.thy *)
lemma Collect_eq:
  "{x. P x} = {x. Q x}  (x. P x = Q x)"
  by (rule iffI) auto

(* FIXME: move next to HOL.iff_allI *)
lemma iff_impI: "P  Q = R  (P  Q) = (P  R)" by blast

definition
  fun_app :: "('a  'b)  'a  'b" (infixr "$" 10) where
  "f $ x  f x"

declare fun_app_def [iff]

lemma fun_app_cong[fundef_cong]:
  " f x = f' x'   (f $ x) = (f' $ x')"
  by simp

lemma fun_app_apply_cong[fundef_cong]:
  "f x y = f' x' y'  (f $ x) y = (f' $ x') y'"
  by simp

lemma if_apply_cong[fundef_cong]:
  " P = P'; x = x'; P'  f x' = f' x'; ¬ P'  g x' = g' x' 
      (if P then f else g) x = (if P' then f' else g') x'"
  by simp

lemma case_prod_apply_cong[fundef_cong]:
  " f (fst p) (snd p) s = f' (fst p') (snd p') s'   case_prod f p s = case_prod f' p' s'"
  by (simp add: split_def)

definition
  pred_conj :: "('a  bool)  ('a  bool)  ('a  bool)" (infixl "and" 35)
where
  "pred_conj P Q  λx. P x  Q x"

definition
  pred_disj :: "('a  bool)  ('a  bool)  ('a  bool)" (infixl "or" 30)
where
  "pred_disj P Q  λx. P x  Q x"

definition
  pred_neg :: "('a  bool)  ('a  bool)" ("not _" [40] 40)
where
  "pred_neg P  λx. ¬ P x"

definition "K  λx y. x"

definition
  zipWith :: "('a  'b  'c)  'a list  'b list  'c list" where
  "zipWith f xs ys  map (case_prod f) (zip xs ys)"

primrec
  delete :: "'a  'a list  'a list"
where
  "delete y [] = []"
| "delete y (x#xs) = (if y=x then xs else x # delete y xs)"

primrec
  find :: "('a  bool)  'a list  'a option"
where
  "find f [] = None"
| "find f (x # xs) = (if f x then Some x else find f xs)"

definition
 "swp f  λx y. f y x"

primrec (nonexhaustive)
  theRight :: "'a + 'b  'b" where
  "theRight (Inr x) = x"

primrec (nonexhaustive)
  theLeft :: "'a + 'b  'a" where
  "theLeft (Inl x) = x"

definition
 "isLeft x  (y. x = Inl y)"

definition
 "isRight x  (y. x = Inr y)"

definition
 "const x  λy. x"

lemma tranclD2:
  "(x, y)  R+  z. (x, z)  R*  (z, y)  R"
  by (erule tranclE) auto

lemma linorder_min_same1 [simp]:
  "(min y x = y) = (y  (x::'a::linorder))"
  by (auto simp: min_def linorder_not_less)

lemma linorder_min_same2 [simp]:
  "(min x y = y) = (y  (x::'a::linorder))"
  by (auto simp: min_def linorder_not_le)

text ‹A combinator for pairing up well-formed relations.
        The divisor function splits the population in halves,
        with the True half greater than the False half, and
        the supplied relations control the order within the halves.›

definition
  wf_sum :: "('a  bool)  ('a × 'a) set  ('a × 'a) set  ('a × 'a) set"
where
  "wf_sum divisor r r' 
     ({(x, y). ¬ divisor x  ¬ divisor y}  r')
     {(x, y). ¬ divisor x  divisor y}
    ({(x, y). divisor x  divisor y}  r)"

lemma wf_sum_wf:
  " wf r; wf r'   wf (wf_sum divisor r r')"
  apply (simp add: wf_sum_def)
  apply (rule wf_Un)+
      apply (erule wf_Int2)
     apply (rule wf_subset
             [where r="measure (λx. If (divisor x) 1 0)"])
      apply simp
     apply clarsimp
    apply blast
   apply (erule wf_Int2)
  apply blast
  done

abbreviation(input)
 "option_map == map_option"

lemmas option_map_def = map_option_case

lemma False_implies_equals [simp]:
  "((False  P)  PROP Q)  PROP Q"
  apply (rule equal_intr_rule)
   apply (erule meta_mp)
   apply simp
  apply simp
  done

lemma split_paired_Ball:
  "(x  A. P x) = (x y. (x,y)  A  P (x,y))"
  by auto

lemma split_paired_Bex:
  "(x  A. P x) = (x y. (x,y)  A  P (x,y))"
  by auto

end

Theory OptionMonad

(*
 * Copyright 2014, NICTA
 *
 * This software may be distributed and modified according to the terms of
 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
 * See "LICENSE_BSD2.txt" for details.
 *
 * @TAG(NICTA_BSD)
 *)
(*
 * Contributions by:
 *   2012 Lars Noschinski <noschinl@in.tum.de>
 *     Option monad while loop formalisation.
 *)

theory OptionMonad
imports Lib
begin

type_synonym ('s,'a) lookup = "'s  'a option"

text ‹Similar to map_option but the second function returns option as well›
definition
  opt_map :: "('s,'a) lookup  ('a  'b option)  ('s,'b) lookup" (infixl "|>" 54)
where
  "f |> g  λs. case f s of None  None | Some x  g x"

lemma opt_map_cong [fundef_cong]:
  " f = f'; v s. f s = Some v  g v = g' v  f |> g = f' |> g'"
  by (rule ext) (simp add: opt_map_def split: option.splits)

lemma in_opt_map_eq:
  "((f |> g) s = Some v) = (v'. f s = Some v'  g v' = Some v)"
  by (simp add: opt_map_def split: option.splits)

lemma opt_mapE:
  " (f |> g) s = Some v; v'. f s = Some v'; g v' = Some v   P   P"
  by (auto simp: in_opt_map_eq)


definition
  obind :: "('s,'a) lookup  ('a  ('s,'b) lookup)  ('s,'b) lookup" (infixl "|>>" 53)
where
  "f |>> g  λs. case f s of None  None | Some x  g x s"

definition
  "ofail = K None"

definition
  "oreturn = K o Some"

definition
  "oassert P  if P then oreturn () else ofail"

text ‹
  If the result can be an exception.
  Corresponding bindE would be analogous to lifting in NonDetMonad.
›

definition
  "oreturnOk x = K (Some (Inr x))"

definition
  "othrow e = K (Some (Inl e))"

definition
  "oguard G  (λs. if G s then Some () else None)"

definition
  "ocondition c L R  (λs. if c s then L s else R s)"

definition
  "oskip  oreturn ()"

text ‹Monad laws›
lemma oreturn_bind [simp]: "(oreturn x |>> f) = f x"
  by (auto simp add: oreturn_def obind_def K_def intro!: ext)

lemma obind_return [simp]: "(m |>> oreturn) = m"
  by (auto simp add: oreturn_def obind_def K_def intro!: ext split: option.splits)
 
lemma obind_assoc: 
  "(m |>> f) |>> g  =  m |>> (λx. f x |>> g)"
  by (auto simp add: oreturn_def obind_def K_def intro!: ext split: option.splits)


text ‹Binding fail›

lemma obind_fail [simp]:
  "f |>> (λ_. ofail) = ofail"
  by (auto simp add: ofail_def obind_def K_def intro!: ext split: option.splits)

lemma ofail_bind [simp]:
  "ofail |>> m = ofail"
  by (auto simp add: ofail_def obind_def K_def intro!: ext split: option.splits)



text ‹Function package setup›
lemma opt_bind_cong [fundef_cong]:
  " f = f'; v s. f' s = Some v  g v s = g' v s   f |>> g = f' |>> g'"
  by (rule ext) (simp add: obind_def split: option.splits)

lemma opt_bind_cong_apply [fundef_cong]:
  " f s = f' s; v. f' s = Some v  g v s = g' v s   (f |>> g) s = (f' |>> g') s"
  by (simp add: obind_def split: option.splits)

lemma oassert_bind_cong [fundef_cong]:
  " P = P'; P'  m = m'   oassert P |>> m = oassert P' |>> m'"
  by (auto simp: oassert_def)

lemma oassert_bind_cong_apply [fundef_cong]:
  " P = P'; P'  m () s = m' () s   (oassert P |>> m) s = (oassert P' |>> m') s"
  by (auto simp: oassert_def)

lemma oreturn_bind_cong [fundef_cong]:
  " x = x'; m x' = m' x'   oreturn x |>> m = oreturn x' |>> m'"
  by simp

lemma oreturn_bind_cong_apply [fundef_cong]:
  " x = x'; m x' s = m' x' s   (oreturn x |>> m) s = (oreturn x' |>> m') s"
  by simp

lemma oreturn_bind_cong2 [fundef_cong]:
  " x = x'; m x' = m' x'   (oreturn $ x) |>> m = (oreturn $ x') |>> m'"
  by simp

lemma oreturn_bind_cong2_apply [fundef_cong]:
  " x = x'; m x' s = m' x' s   ((oreturn $ x) |>> m) s = ((oreturn $ x') |>> m') s"
  by simp

lemma ocondition_cong [fundef_cong]:
"c = c'; s. c' s  l s = l' s; s. ¬c' s  r s = r' s
   ocondition c l r = ocondition c' l' r'"
  by (auto simp: ocondition_def)


text ‹Decomposition›

lemma ocondition_K_true [simp]:
  "ocondition (λ_. True) T F = T"
  by (simp add: ocondition_def)

lemma ocondition_K_false [simp]:
  "ocondition (λ_. False) T F = F"
  by (simp add: ocondition_def)

lemma ocondition_False:
    " s. ¬ P s   ocondition P L R = R"
  by (rule ext, clarsimp simp: ocondition_def)

lemma ocondition_True:
    " s. P s   ocondition P L R = L"
  by (rule ext, clarsimp simp: ocondition_def)

lemma in_oreturn [simp]:
  "(oreturn x s = Some v) = (v = x)"
  by (auto simp: oreturn_def K_def)

lemma oreturnE:
  "oreturn x s = Some v; v = x  P x  P v"
  by simp

lemma in_ofail [simp]:
  "ofail s  Some v"
  by (auto simp: ofail_def K_def)

lemma ofailE:
  "ofail s = Some v  P"
  by simp

lemma in_oassert_eq [simp]:
  "(oassert P s = Some v) = P"
  by (simp add: oassert_def)

lemma oassertE:
  " oassert P s = Some v; P  Q   Q"
  by simp

lemma in_obind_eq:
  "((f |>> g) s = Some v) = (v'. f s = Some v'  g v' s = Some v)"
  by (simp add: obind_def split: option.splits)

lemma obindE:
  " (f |>> g) s = Some v; 
     v'. f s = Some v'; g v' s = Some v  P  P"
  by (auto simp: in_obind_eq)

lemma in_othrow_eq [simp]:
  "(othrow e s = Some v) = (v = Inl e)"
  by (auto simp: othrow_def K_def) 

lemma othrowE:
  "othrow e s = Some v; v = Inl e  P (Inl e)  P v"
  by simp

lemma in_oreturnOk_eq [simp]:
  "(oreturnOk x s = Some v) = (v = Inr x)"
  by (auto simp: oreturnOk_def K_def) 

lemma oreturnOkE:
  "oreturnOk x s = Some v; v = Inr x  P (Inr x)  P v"
  by simp

lemmas omonadE [elim!] =
  opt_mapE obindE oreturnE ofailE othrowE oreturnOkE oassertE

section ‹"While" loops over option monad.›

text ‹
  This is an inductive definition of a while loop over the plain option monad
  (without passing through a state)
›

inductive_set
  option_while' :: "('a  bool)  ('a  'a option)  'a option rel"
  for C B
where
    final: "¬ C r  (Some r, Some r)  option_while' C B"
  | fail: " C r; B r = None   (Some r, None)  option_while' C B"
  | step: " C r;  B r = Some r'; (Some r', sr'')  option_while' C B 
            (Some r, sr'')  option_while' C B"

definition
  "option_while C B r 
    (if (s. (Some r, s)  option_while' C B) then
      (THE s. (Some r, s)  option_while' C B) else None)"

lemma option_while'_inj:
  assumes "(s,s')  option_while' C B" "(s, s'')  option_while' C B"
  shows "s' = s''"
  using assms by (induct rule: option_while'.induct) (auto elim: option_while'.cases)

lemma option_while'_inj_step:
  " C s; B s = Some s'; (Some s, t)  option_while' C B ; (Some s', t')  option_while' C B   t = t'"
  by (metis option_while'.step option_while'_inj)

lemma option_while'_THE:
  assumes "(Some r, sr')  option_while' C B"
  shows "(THE s. (Some r, s)  option_while' C B) = sr'"
  using assms by (blast dest: option_while'_inj)

lemma option_while_simps:
  "¬ C s  option_while C B s = Some s"
  "C s  B s = None  option_while C B s = None"
  "C s  B s = Some s'  option_while C B s = option_while C B s'"
  "(Some s, ss')  option_while' C B  option_while C B s = ss'"
  using option_while'_inj_step[of C s B s']
  by (auto simp: option_while_def option_while'_THE
      intro: option_while'.intros
      dest: option_while'_inj
      elim: option_while'.cases)

lemma option_while_rule:
  assumes "option_while C B s = Some s'"
  assumes "I s"
  assumes istep: "s s'. C s  I s  B s = Some s'  I s'"
  shows "I s'  ¬ C s'" 
proof -
  { fix ss ss' assume "(ss, ss')  option_while' C B" "ss = Some s" "ss' = Some s'"
    then have ?thesis using I s
      by (induct arbitrary: s) (auto intro: istep) }
  then show ?thesis using assms(1)
    by (auto simp: option_while_def option_while'_THE split: if_split_asm)
qed

lemma option_while'_term:
  assumes "I r"
  assumes "wf M"
  assumes step_less: "r r'. I r; C r; B r = Some r'  (r',r)  M"
  assumes step_I: "r r'. I r; C r; B r = Some r'  I r'"
  obtains sr' where "(Some r, sr')  option_while' C B"
  apply atomize_elim
  using assms(2,1)
proof induct
  case (less r)
  show ?case
  proof (cases "C r" "B r" rule: bool.exhaust[case_product option.exhaust])
    case (True_Some r')
    then have "(r',r)  M" "I r'"
      by (auto intro: less step_less step_I)
    then obtain sr' where "(Some r', sr')  option_while' C B"
      by atomize_elim (rule less)
    then have "(Some r, sr')  option_while' C B"
      using True_Some by (auto intro: option_while'.intros)
    then show ?thesis ..
  qed (auto intro: option_while'.intros)
qed

lemma option_while_rule':
  assumes "option_while C B s = ss'"
  assumes "wf M"
  assumes "I (Some s)"
  assumes less: "s s'. C s  I (Some s)  B s = Some s'  (s', s)  M"
  assumes step: "s s'. C s  I (Some s)  B s = Some s'  I (Some s')"
  assumes final: "s. C s  I (Some s)  B s = None  I None"
  shows "I ss'  (case ss' of Some s'  ¬ C s' | _  True)" 
proof -
  define ss where "ss = Some s"
  obtain ss1' where "(Some s, ss1')  option_while' C B"
    using assms(3,2,4,5) by (rule option_while'_term)
  then have *: "(ss, ss')  option_while' C B" using ‹option_while C B s = ss'
    by (auto simp: option_while_simps ss_def)
  show ?thesis
  proof (cases ss')
    case (Some s') with * ss_def show ?thesis using I _
      by (induct arbitrary:s) (auto intro: step)
  next
    case None with * ss_def show ?thesis using I _
      by (induct arbitrary:s) (auto intro: step final)
  qed
qed

section ‹Lift @{term option_while} to the @{typ "('a,'s) lookup"} monad›

definition
  owhile :: "('a  's  bool)  ('a  ('s,'a) lookup)  'a  ('s,'a) lookup"
where
 "owhile c b a  λs. option_while (λa. c a s) (λa. b a s) a"

lemma owhile_unroll:
  "owhile C B r = ocondition (C r) (B r |>> owhile C B) (oreturn r)"
  by (auto simp: ocondition_def obind_def oreturn_def owhile_def
           option_while_simps K_def split: option.split)

text ‹rule for terminating loops›

lemma owhile_rule:
  assumes "I r s"
  assumes "wf M"
  assumes less: "r r'. I r s; C r s; B r s = Some r'  (r',r)  M"
  assumes step: "r r'. I r s; C r s; B r s = Some r'  I r' s"
  assumes fail: "r r'. I r s; C r s; B r s = None  Q None"
  assumes final: "r. I r s; ¬C r s  Q (Some r)"
  shows "Q (owhile C B r s)"
proof -
  let ?rs' = "owhile C B r s"
  have "(case ?rs' of Some r  I r s | _  Q None)
       (case ?rs' of Some r'  ¬ C r' s | _  True)"
    by (rule option_while_rule'[where B="λr. B r s" and s=r, OF _ ‹wf _])
       (auto simp: owhile_def intro: assms)
  then show ?thesis by (auto intro: final split: option.split_asm)
qed

end

Theory NonDetMonad

(*
 * Copyright 2014, NICTA
 *
 * This software may be distributed and modified according to the terms of
 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
 * See "LICENSE_BSD2.txt" for details.
 *
 * @TAG(NICTA_BSD)
 *)

(* 
   Nondeterministic state and error monads with failure in Isabelle.
*)

chapter "Nondeterministic State Monad with Failure"

theory NonDetMonad
imports "../Lib"
begin

text ‹
  \label{c:monads}

  State monads are used extensively in the seL4 specification. They are
  defined below.
›

section "The Monad"

text ‹
  The basic type of the nondeterministic state monad with failure is
  very similar to the normal state monad. Instead of a pair consisting
  of result and new state, we return a set of these pairs coupled with
  a failure flag. Each element in the set is a potential result of the
  computation. The flag is @{const True} if there is an execution path
  in the computation that may have failed. Conversely, if the flag is
  @{const False}, none of the computations resulting in the returned
  set can have failed.› 
type_synonym ('s,'a) nondet_monad = "'s  ('a × 's) set × bool"


text ‹
  The definition of fundamental monad functions return› and
  bind›. The monad function return x› does not change 
  the  state, does not fail, and returns x›.
› 
definition
  return :: "'a  ('s,'a) nondet_monad" where
  "return a  λs. ({(a,s)},False)"

text ‹
  The monad function bind f g›, also written f >>= g›,
  is the execution of @{term f} followed by the execution of g›.
  The function g› takes the result value \emph{and} the result
  state of f› as parameter. The definition says that the result of
  the combined operation is the union of the set of sets that is created
  by g› applied to the result sets of f›. The combined
  operation may have failed, if f› may have failed or g› may
  have failed on any of the results of f›.
›
definition
  bind :: "('s, 'a) nondet_monad  ('a  ('s, 'b) nondet_monad)  
           ('s, 'b) nondet_monad" (infixl ">>=" 60)
  where
  "bind f g  λs. ((fst ` case_prod g ` fst (f s)),
                   True  snd ` case_prod g ` fst (f s)  snd (f s))"

text ‹
  Sometimes it is convenient to write bind› in reverse order.
›
abbreviation(input)
  bind_rev :: "('c  ('a, 'b) nondet_monad)  ('a, 'c) nondet_monad  
               ('a, 'b) nondet_monad" (infixl "=<<" 60) where 
  "g =<< f  f >>= g"

text ‹
  The basic accessor functions of the state monad. get› returns
  the current state as result, does not fail, and does not change the state.
  put s› returns nothing (@{typ unit}), changes the current state
  to s› and does not fail.
›
definition
  get :: "('s,'s) nondet_monad" where
  "get  λs. ({(s,s)}, False)"

definition
  put :: "'s  ('s, unit) nondet_monad" where
  "put s  λ_. ({((),s)}, False)"


subsection "Nondeterminism"

text ‹
  Basic nondeterministic functions. select A› chooses an element 
  of the set A›, does not change the state, and does not fail
  (even if the set is empty). f ⊓ g› executes f› or
  executes g›. It retuns the union of results of f› and
  g›, and may have failed if either may have failed.  
›
definition
  select :: "'a set  ('s,'a) nondet_monad" where
  "select A  λs. (A × {s}, False)"

definition
  alternative :: "('s,'a) nondet_monad  ('s,'a) nondet_monad  
                  ('s,'a) nondet_monad" 
  (infixl "" 20)
where
  "f  g  λs. (fst (f s)  fst (g s), snd (f s)  snd (g s))"


text ‹A variant of select› that takes a pair. The first component
  is a set as in normal select›, the second component indicates
  whether the execution failed. This is useful to lift monads between
  different state spaces.
›
definition
  select_f :: "'a set × bool   ('s,'a) nondet_monad" where
  "select_f S  λs. (fst S × {s}, snd S)"

text select_state› takes a relationship between
  states, and outputs nondeterministically a state
  related to the input state.›

definition
  state_select :: "('s × 's) set  ('s, unit) nondet_monad"
where
  "state_select r  λs. ((λx. ((), x)) ` {s'. (s, s')  r}, ¬ (s'. (s, s')  r))"

subsection "Failure"

text ‹The monad function that always fails. Returns an empty set of
results and sets the failure flag.›
definition
  fail :: "('s, 'a) nondet_monad" where
 "fail  λs. ({}, True)"

text ‹Assertions: fail if the property P› is not true›
definition
  assert :: "bool  ('a, unit) nondet_monad" where
 "assert P  if P then return () else fail"

text ‹Fail if the value is @{const None}, 
  return result v› for @{term "Some v"}
definition
  assert_opt :: "'a option  ('b, 'a) nondet_monad" where
 "assert_opt v  case v of None  fail | Some v  return v"

text ‹An assertion that also can introspect the current state.›

definition
  state_assert :: "('s  bool)  ('s, unit) nondet_monad"
where
  "state_assert P  get >>= (λs. assert (P s))"

subsection "Generic functions on top of the state monad"

text ‹Apply a function to the current state and return the result
without changing the state.›
definition
  gets :: "('s  'a)  ('s, 'a) nondet_monad" where
 "gets f  get >>= (λs. return (f s))"

text ‹Modify the current state using the function passed in.›
definition
  modify :: "('s  's)  ('s, unit) nondet_monad" where
 "modify f  get >>= (λs. put (f s))"

lemma simpler_gets_def: "gets f = (λs. ({(f s, s)}, False))"
  apply (simp add: gets_def return_def bind_def get_def)
  done

lemma simpler_modify_def:
  "modify f = (λs. ({((), f s)}, False))"
  by (simp add: modify_def bind_def get_def put_def)

text ‹Execute the given monad when the condition is true, 
  return ()› otherwise.›
definition
  "when" :: "bool  ('s, unit) nondet_monad  
           ('s, unit) nondet_monad" where 
  "when P m  if P then m else return ()"

text ‹Execute the given monad unless the condition is true, 
  return ()› otherwise.›
definition 
  unless :: "bool  ('s, unit) nondet_monad  
            ('s, unit) nondet_monad" where
  "unless P m  when (¬P) m"

text ‹
  Perform a test on the current state, performing the left monad if
  the result is true or the right monad if the result is false.
›
definition
  condition :: "('s  bool)  ('s, 'r) nondet_monad  ('s, 'r) nondet_monad  ('s, 'r) nondet_monad"
where
  "condition P L R  λs. if (P s) then (L s) else (R s)"

notation (output)
  condition  ("(condition (_)//  (_)//  (_))" [1000,1000,1000] 1000)

text ‹
Apply an option valued function to the current state, fail 
if it returns @{const None}, return v› if it returns 
@{term "Some v"}.  
› 
definition
  gets_the :: "('s  'a option)  ('s, 'a) nondet_monad" where
 "gets_the f  gets f >>= assert_opt"


subsection ‹The Monad Laws›

text ‹A more expanded definition of bind›
lemma bind_def':
  "(f >>= g) 
       λs. ({(r'', s''). (r', s')  fst (f s). (r'', s'')  fst (g r' s') },
                     snd (f s)  ((r', s')  fst (f s). snd (g r' s')))"
  apply (rule eq_reflection)
  apply (auto simp add: bind_def split_def Let_def)
  done

text ‹Each monad satisfies at least the following three laws.›

text @{term return} is absorbed at the left of a @{term bind}, 
  applying the return value directly:› 
lemma return_bind [simp]: "(return x >>= f) = f x"
  by (simp add: return_def bind_def)

text @{term return} is absorbed on the right of a @{term bind} 
lemma bind_return [simp]: "(m >>= return) = m"
  apply (rule ext)
  apply (simp add: bind_def return_def split_def)
  done
 
text @{term bind} is associative›
lemma bind_assoc: 
  fixes m :: "('a,'b) nondet_monad"
  fixes f :: "'b  ('a,'c) nondet_monad"
  fixes g :: "'c  ('a,'d) nondet_monad"
  shows "(m >>= f) >>= g  =  m >>= (λx. f x >>= g)"
  apply (unfold bind_def Let_def split_def)
  apply (rule ext)
  apply clarsimp
  apply (auto intro: rev_image_eqI)
  done


section ‹Adding Exceptions›

text ‹
  The type @{typ "('s,'a) nondet_monad"} gives us nondeterminism and
  failure. We now extend this monad with exceptional return values
  that abort normal execution, but can be handled explicitly.
  We use the sum type to indicate exceptions. 

  In @{typ "('s, 'e + 'a) nondet_monad"}, @{typ "'s"} is the state,
  @{typ 'e} is an exception, and @{typ 'a} is a normal return value.

  This new type itself forms a monad again. Since type classes in 
  Isabelle are not powerful enough to express the class of monads,
  we provide new names for the @{term return} and @{term bind} functions
  in this monad. We call them returnOk› (for normal return values)
  and bindE› (for composition). We also define throwError›
  to return an exceptional value.
›
definition
  returnOk :: "'a  ('s, 'e + 'a) nondet_monad" where
  "returnOk  return o Inr"

definition
  throwError :: "'e  ('s, 'e + 'a) nondet_monad" where
  "throwError  return o Inl"

text ‹
  Lifting a function over the exception type: if the input is an
  exception, return that exception; otherwise continue execution.
›
definition
  lift :: "('a  ('s, 'e + 'b) nondet_monad)  
           'e +'a  ('s, 'e + 'b) nondet_monad"
where
  "lift f v  case v of Inl e  throwError e
                      | Inr v'  f v'"

text ‹
  The definition of @{term bind} in the exception monad (new
  name bindE›): the same as normal @{term bind}, but 
  the right-hand side is skipped if the left-hand side
  produced an exception.
›
definition
  bindE :: "('s, 'e + 'a) nondet_monad  
            ('a  ('s, 'e + 'b) nondet_monad)  
            ('s, 'e + 'b) nondet_monad"  (infixl ">>=E" 60)
where
  "bindE f g  bind f (lift g)"


text ‹
  Lifting a normal nondeterministic monad into the 
  exception monad is achieved by always returning its
  result as normal result and never throwing an exception.
›
definition
  liftE :: "('s,'a) nondet_monad  ('s, 'e+'a) nondet_monad"
where
  "liftE f  f >>= (λr. return (Inr r))"


text ‹
  Since the underlying type and return› function changed, 
  we need new definitions for when and unless:
›
definition
  whenE :: "bool  ('s, 'e + unit) nondet_monad  
            ('s, 'e + unit) nondet_monad" 
  where
  "whenE P f  if P then f else returnOk ()"

definition
  unlessE :: "bool  ('s, 'e + unit) nondet_monad  
            ('s, 'e + unit) nondet_monad" 
  where
  "unlessE P f  if P then returnOk () else f"


text ‹
  Throwing an exception when the parameter is @{term None}, otherwise
  returning @{term "v"} for @{term "Some v"}.
›
definition
  throw_opt :: "'e  'a option  ('s, 'e + 'a) nondet_monad" where
  "throw_opt ex x  
  case x of None  throwError ex | Some v  returnOk v"


text ‹
  Failure in the exception monad is redefined in the same way
  as @{const whenE} and @{const unlessE}, with @{term returnOk}
  instead of @{term return}.
› 
definition
  assertE :: "bool  ('a, 'e + unit) nondet_monad" where
 "assertE P  if P then returnOk () else fail"

subsection "Monad Laws for the Exception Monad"

text ‹More direct definition of @{const liftE}:›
lemma liftE_def2:
  "liftE f = (λs. ((λ(v,s'). (Inr v, s')) ` fst (f s), snd (f s)))"
  by (auto simp: liftE_def return_def split_def bind_def)

text ‹Left @{const returnOk} absorbtion over @{term bindE}:›
lemma returnOk_bindE [simp]: "(returnOk x >>=E f) = f x"
  apply (simp only: bindE_def return_def returnOk_def)
  apply (clarsimp simp: lift_def)
  done

lemma lift_return [simp]:
  "lift (return  Inr) = return"
  by (rule ext)
     (simp add: lift_def throwError_def split: sum.splits)

text ‹Right @{const returnOk} absorbtion over @{term bindE}:›
lemma bindE_returnOk [simp]: "(m >>=E returnOk) = m"
  by (simp add: bindE_def returnOk_def)

text ‹Associativity of @{const bindE}:›
lemma bindE_assoc:
  "(m >>=E f) >>=E g = m >>=E (λx. f x >>=E g)"
  apply (simp add: bindE_def bind_assoc)
  apply (rule arg_cong [where f="λx. m >>= x"])
  apply (rule ext)
  apply (case_tac x, simp_all add: lift_def throwError_def)
  done

text @{const returnOk} could also be defined via @{const liftE}:›
lemma returnOk_liftE:
  "returnOk x = liftE (return x)"
  by (simp add: liftE_def returnOk_def)

text ‹Execution after throwing an exception is skipped:›
lemma throwError_bindE [simp]:
  "(throwError E >>=E f) = throwError E"
  by (simp add: bindE_def bind_def throwError_def lift_def return_def)


section "Syntax"

text ‹This section defines traditional Haskell-like do-syntax 
  for the state monad in Isabelle.›

subsection "Syntax for the Nondeterministic State Monad"

text ‹We use K_bind› to syntactically indicate the 
  case where the return argument of the left side of a @{term bind}
  is ignored›
definition
  K_bind_def [iff]: "K_bind  λx y. x"

nonterminal
  dobinds and dobind and nobind

syntax (ASCII)
  "_dobind"    :: "[pttrn, 'a] => dobind"             ("(_ <-/ _)" 10)
syntax
  "_dobind"    :: "[pttrn, 'a] => dobind"             ("(_ / _)" 10)
  ""           :: "dobind => dobinds"                 ("_")
  "_nobind"    :: "'a => dobind"                      ("_")
  "_dobinds"   :: "[dobind, dobinds] => dobinds"      ("(_);//(_)")

  "_do"        :: "[dobinds, 'a] => 'a"               ("(do ((_);//(_))//od)" 100)
translations
  "_do (_dobinds b bs) e"  == "_do b (_do bs e)"
  "_do (_nobind b) e"      == "b >>= (CONST K_bind e)"
  "do x <- a; e od"        == "a >>= (λx. e)"  

text ‹Syntax examples:›
lemma "do x  return 1; 
          return (2::nat); 
          return x 
       od = 
       return 1 >>= 
       (λx. return (2::nat) >>= 
            K_bind (return x))" 
  by (rule refl)

lemma "do x  return 1; 
          return 2; 
          return x 
       od = return 1" 
  by simp

subsection "Syntax for the Exception Monad"

text ‹
  Since the exception monad is a different type, we
  need to syntactically distinguish it in the syntax.
  We use doE›/odE› for this, but can re-use
  most of the productions from do›/od›
  above.
›

syntax
  "_doE" :: "[dobinds, 'a] => 'a"  ("(doE ((_);//(_))//odE)" 100)

translations
  "_doE (_dobinds b bs) e"  == "_doE b (_doE bs e)"
  "_doE (_nobind b) e"      == "b >>=E (CONST K_bind e)"
  "doE x <- a; e odE"       == "a >>=E (λx. e)"

text ‹Syntax examples:›
lemma "doE x  returnOk 1; 
           returnOk (2::nat); 
           returnOk x 
       odE =
       returnOk 1 >>=E 
       (λx. returnOk (2::nat) >>=E 
            K_bind (returnOk x))"
  by (rule refl)

lemma "doE x  returnOk 1; 
           returnOk 2; 
           returnOk x 
       odE = returnOk 1" 
  by simp



section "Library of Monadic Functions and Combinators"


text ‹Lifting a normal function into the monad type:›
definition
  liftM :: "('a  'b)  ('s,'a) nondet_monad  ('s, 'b) nondet_monad"
where
  "liftM f m  do x  m; return (f x) od"

text ‹The same for the exception monad:›
definition
  liftME :: "('a  'b)  ('s,'e+'a) nondet_monad  ('s,'e+'b) nondet_monad"
where
  "liftME f m  doE x  m; returnOk (f x) odE"

text ‹
  Run a sequence of monads from left to right, ignoring return values.›
definition
  sequence_x :: "('s, 'a) nondet_monad list  ('s, unit) nondet_monad" 
where
  "sequence_x xs  foldr (λx y. x >>= (λ_. y)) xs (return ())"

text ‹
  Map a monadic function over a list by applying it to each element
  of the list from left to right, ignoring return values.
›
definition
  mapM_x :: "('a  ('s,'b) nondet_monad)  'a list  ('s, unit) nondet_monad"
where
  "mapM_x f xs  sequence_x (map f xs)"

text ‹
  Map a monadic function with two parameters over two lists,
  going through both lists simultaneously, left to right, ignoring
  return values.
›
definition
  zipWithM_x :: "('a  'b  ('s,'c) nondet_monad)  
                 'a list  'b list  ('s, unit) nondet_monad"
where
  "zipWithM_x f xs ys  sequence_x (zipWith f xs ys)"


text ‹The same three functions as above, but returning a list of
return values instead of unit›
definition
  sequence :: "('s, 'a) nondet_monad list  ('s, 'a list) nondet_monad" 
where
  "sequence xs  let mcons = (λp q. p >>= (λx. q >>= (λy. return (x#y))))
                 in foldr mcons xs (return [])"

definition
  mapM :: "('a  ('s,'b) nondet_monad)  'a list  ('s, 'b list) nondet_monad"
where
  "mapM f xs  sequence (map f xs)"

definition
  zipWithM :: "('a  'b  ('s,'c) nondet_monad)  
                 'a list  'b list  ('s, 'c list) nondet_monad"
where
  "zipWithM f xs ys  sequence (zipWith f xs ys)"

definition
  foldM :: "('b  'a  ('s, 'a) nondet_monad)  'b list  'a  ('s, 'a) nondet_monad" 
where
  "foldM m xs a  foldr (λp q. q >>= m p) xs (return a) "

text ‹The sequence and map functions above for the exception monad,
with and without lists of return value›
definition
  sequenceE_x :: "('s, 'e+'a) nondet_monad list  ('s, 'e+unit) nondet_monad" 
where
  "sequenceE_x xs  foldr (λx y. doE _ <- x; y odE) xs (returnOk ())"

definition
  mapME_x :: "('a  ('s,'e+'b) nondet_monad)  'a list  
              ('s,'e+unit) nondet_monad"
where
  "mapME_x f xs  sequenceE_x (map f xs)"

definition
  sequenceE :: "('s, 'e+'a) nondet_monad list  ('s, 'e+'a list) nondet_monad" 
where
  "sequenceE xs  let mcons = (λp q. p >>=E (λx. q >>=E (λy. returnOk (x#y))))
                 in foldr mcons xs (returnOk [])"

definition
  mapME :: "('a  ('s,'e+'b) nondet_monad)  'a list  
              ('s,'e+'b list) nondet_monad"
where
  "mapME f xs  sequenceE (map f xs)"


text ‹Filtering a list using a monadic function as predicate:›
primrec
  filterM :: "('a  ('s, bool) nondet_monad)  'a list  ('s, 'a list) nondet_monad"
where
  "filterM P []       = return []"
| "filterM P (x # xs) = do
     b  <- P x;
     ys <- filterM P xs; 
     return (if b then (x # ys) else ys)
   od"


section "Catching and Handling Exceptions"

text ‹
  Turning an exception monad into a normal state monad
  by catching and handling any potential exceptions:
›
definition
  catch :: "('s, 'e + 'a) nondet_monad 
            ('e  ('s, 'a) nondet_monad) 
            ('s, 'a) nondet_monad" (infix "<catch>" 10)
where
  "f <catch> handler 
     do x  f;
        case x of
          Inr b  return b
        | Inl e  handler e
     od"

text ‹
  Handling exceptions, but staying in the exception monad.
  The handler may throw a type of exceptions different from
  the left side.
›
definition
  handleE' :: "('s, 'e1 + 'a) nondet_monad 
               ('e1  ('s, 'e2 + 'a) nondet_monad) 
               ('s, 'e2 + 'a) nondet_monad" (infix "<handle2>" 10)
where
  "f <handle2> handler 
   do
      v  f;
      case v of
        Inl e  handler e
      | Inr v'  return (Inr v')
   od"

text ‹
  A type restriction of the above that is used more commonly in
  practice: the exception handle (potentially) throws exception
  of the same type as the left-hand side.
›
definition
  handleE :: "('s, 'x + 'a) nondet_monad  
              ('x  ('s, 'x + 'a) nondet_monad)  
              ('s, 'x + 'a) nondet_monad" (infix "<handle>" 10)
where
  "handleE  handleE'"


text ‹
  Handling exceptions, and additionally providing a continuation
  if the left-hand side throws no exception:
›
definition
  handle_elseE :: "('s, 'e + 'a) nondet_monad 
                   ('e  ('s, 'ee + 'b) nondet_monad) 
                   ('a  ('s, 'ee + 'b) nondet_monad) 
                   ('s, 'ee + 'b) nondet_monad"
  ("_ <handle> _ <else> _" 10)
where
  "f <handle> handler <else> continue 
   do v  f;
   case v of Inl e   handler e
           | Inr v'  continue v'
   od"

subsection "Loops"

text ‹
  Loops are handled using the following inductive predicate;
  non-termination is represented using the failure flag of the
  monad.
›

inductive_set
  whileLoop_results :: "('r  's  bool)  ('r  ('s, 'r) nondet_monad)  ((('r × 's) option) × (('r × 's) option)) set"
  for C B
where
    " ¬ C r s   (Some (r, s), Some (r, s))  whileLoop_results C B"
  | " C r s; snd (B r s)   (Some (r, s), None)  whileLoop_results C B"
  | " C r s; (r', s')  fst (B r s); (Some (r', s'), z)  whileLoop_results C B  
        (Some (r, s), z)  whileLoop_results C B"

inductive_cases whileLoop_results_cases_valid: "(Some x, Some y)  whileLoop_results C B"
inductive_cases whileLoop_results_cases_fail: "(Some x, None)  whileLoop_results C B"
inductive_simps whileLoop_results_simps: "(Some x, y)  whileLoop_results C B"
inductive_simps whileLoop_results_simps_valid: "(Some x, Some y)  whileLoop_results C B"
inductive_simps whileLoop_results_simps_start_fail [simp]: "(None, x)  whileLoop_results C B"

inductive
  whileLoop_terminates :: "('r  's  bool)  ('r  ('s, 'r) nondet_monad)  'r  's  bool"
  for C B
where
    "¬ C r s  whileLoop_terminates C B r s"
  | " C r s; (r', s')  fst (B r s). whileLoop_terminates C B r' s' 
         whileLoop_terminates C B r s"

inductive_cases whileLoop_terminates_cases: "whileLoop_terminates C B r s"
inductive_simps whileLoop_terminates_simps: "whileLoop_terminates C B r s"

definition
  "whileLoop C B  (λr s.
     ({(r',s'). (Some (r, s), Some (r', s'))  whileLoop_results C B},
        (Some (r, s), None)  whileLoop_results C B  (¬ whileLoop_terminates C B r s)))"

notation (output)
  whileLoop  ("(whileLoop (_)//  (_))" [1000, 1000] 1000)

definition
  whileLoopE :: "('r  's  bool)  ('r  ('s, 'e + 'r) nondet_monad)
       'r  's  (('e + 'r) × 's) set × bool"
where
  "whileLoopE C body 
      λr. whileLoop (λr s. (case r of Inr v  C v s | _  False)) (lift body) (Inr r)"

notation (output)
  whileLoopE  ("(whileLoopE (_)//  (_))" [1000, 1000] 1000)

section "Hoare Logic"

subsection "Validity"

text ‹This section defines a Hoare logic for partial correctness for
  the nondeterministic state monad as well as the exception monad.
  The logic talks only about the behaviour part of the monad and ignores
  the failure flag.

  The logic is defined semantically. Rules work directly on the
  validity predicate.

  In the nondeterministic state monad, validity is a triple of precondition,
  monad, and postcondition. The precondition is a function from state to 
  bool (a state predicate), the postcondition is a function from return value
  to state to bool. A triple is valid if for all states that satisfy the
  precondition, all result values and result states that are returned by
  the monad satisfy the postcondition. Note that if the computation returns
  the empty set, the triple is trivially valid. This means @{term "assert P"} 
  does not require us to prove that @{term P} holds, but rather allows us
  to assume @{term P}! Proving non-failure is done via separate predicate and
  calculus (see below).
›
definition
  valid :: "('s  bool)  ('s,'a) nondet_monad  ('a  's  bool)  bool" 
  ("_/ _ /_")
where
  "P f Q  s. P s  ((r,s')  fst (f s). Q r s')"

text ‹
  Validity for the exception monad is similar and build on the standard 
  validity above. Instead of one postcondition, we have two: one for
  normal and one for exceptional results.
›
definition
  validE :: "('s  bool)  ('s, 'a + 'b) nondet_monad  
             ('b  's  bool)  
             ('a  's  bool)  bool" 
("_/ _ /(_⦄,/ _)")
where
  "P f Q⦄,E  P f  λv s. case v of Inr r  Q r s | Inl e  E e s "


text ‹
  The following two instantiations are convenient to separate reasoning
  for exceptional and normal case.
›
definition
  validE_R :: "('s  bool)  ('s, 'e + 'a) nondet_monad  
               ('a  's  bool)  bool"
   ("_/ _ /_⦄, -")
where
 "P f Q⦄,-  validE P f Q (λx y. True)"

definition
  validE_E :: "('s  bool)   ('s, 'e + 'a) nondet_monad  
               ('e  's  bool)  bool"
   ("_/ _ /-, _")
where
 "P f -,Q  validE P f (λx y. True) Q"


text ‹Abbreviations for trivial preconditions:›
abbreviation(input)
  top :: "'a  bool" ("")
where
  "  λ_. True"

abbreviation(input)
  bottom :: "'a  bool" ("")
where
  "  λ_. False"

text ‹Abbreviations for trivial postconditions (taking two arguments):›
abbreviation(input)
  toptop :: "'a  'b  bool" ("⊤⊤")
where
 "⊤⊤  λ_ _. True"

abbreviation(input)
  botbot :: "'a  'b  bool" ("⊥⊥")
where
 "⊥⊥  λ_ _. False"

text ‹
  Lifting ∧› and ∨› over two arguments. 
  Lifting ∧› and ∨› over one argument is already
  defined (written and› and or›).
›
definition
  bipred_conj :: "('a  'b  bool)  ('a  'b  bool)  ('a  'b  bool)" 
  (infixl "And" 96)
where
  "bipred_conj P Q  λx y. P x y  Q x y"

definition
  bipred_disj :: "('a  'b  bool)  ('a  'b  bool)  ('a  'b  bool)" 
  (infixl "Or" 91)
where
  "bipred_disj P Q  λx y. P x y  Q x y"


subsection "Determinism"

text ‹A monad of type nondet_monad› is deterministic iff it
returns exactly one state and result and does not fail› 
definition
  det :: "('a,'s) nondet_monad  bool"
where
  "det f  s. r. f s = ({r},False)" 

text ‹A deterministic nondet_monad› can be turned
  into a normal state monad:›
definition
  the_run_state :: "('s,'a) nondet_monad  's  'a × 's"
where
  "the_run_state M  λs. THE s'. fst (M s) = {s'}"


subsection "Non-Failure"

text ‹
  With the failure flag, we can formulate non-failure separately
  from validity. A monad m› does not fail under precondition
  P›, if for no start state in that precondition it sets
  the failure flag.
›
definition
  no_fail :: "('s  bool)  ('s,'a) nondet_monad  bool"
where
  "no_fail P m  s. P s  ¬ (snd (m s))"


text ‹
  It is often desired to prove non-failure and a Hoare triple
  simultaneously, as the reasoning is often similar. The following
  definitions allow such reasoning to take place.
›

definition
  validNF ::"('s  bool)  ('s,'a) nondet_monad  ('a  's  bool)  bool"
      ("_/ _ /_⦄!")
where
  "validNF P f Q  valid P f Q  no_fail P f"

definition
  validE_NF :: "('s  bool)  ('s, 'a + 'b) nondet_monad 
             ('b  's  bool) 
             ('a  's  bool)  bool"
  ("_/ _ /(_⦄,/ _⦄!)")
where
  "validE_NF P f Q E  validE P f Q E  no_fail P f"

lemma validE_NF_alt_def:
  " P  B  Q ⦄, E ⦄! =  P  B  λv s. case v of Inl e  E e s | Inr r  Q r s ⦄!"
  by (clarsimp simp: validE_NF_def validE_def validNF_def)

text ‹
  Usually, well-formed monads constructed from the primitives
  above will have the following property: if they return an
  empty set of results, they will have the failure flag set.
›
definition
  empty_fail :: "('s,'a) nondet_monad  bool" 
where
  "empty_fail m  s. fst (m s) = {}  snd (m s)"


text ‹
  Useful in forcing otherwise unknown executions to have
  the @{const empty_fail} property.
›
definition
  mk_ef :: "'a set × bool  'a set × bool"
where
  "mk_ef S  (fst S, fst S = {}  snd S)"

section "Basic exception reasoning"

text ‹
  The following predicates no_throw› and no_return› allow
  reasoning that functions in the exception monad either do
  no throw an exception or never return normally.
›

definition "no_throw P A   P  A  λ_ _. True ⦄, λ_ _. False "

definition "no_return P A   P  A λ_ _. False⦄,λ_ _. True "

end

Theory NonDetMonadLemmas

(*
 * Copyright 2014, NICTA
 *
 * This software may be distributed and modified according to the terms of
 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
 * See "LICENSE_BSD2.txt" for details.
 *
 * @TAG(NICTA_BSD)
 *)

theory NonDetMonadLemmas
imports NonDetMonad
begin

section "General Lemmas Regarding the Nondeterministic State Monad"

subsection "Congruence Rules for the Function Package"

lemma bind_cong[fundef_cong]:
  " f = f'; v s s'. (v, s')  fst (f' s)  g v s' = g' v s'   f >>= g = f' >>= g'"
  apply (rule ext) 
  apply (auto simp: bind_def Let_def split_def intro: rev_image_eqI)
  done

lemma bind_apply_cong [fundef_cong]:
  " f s = f' s'; rv st. (rv, st)  fst (f' s')  g rv st = g' rv st 
        (f >>= g) s = (f' >>= g') s'"
  apply (simp add: bind_def)
  apply (auto simp: split_def intro: SUP_cong [OF refl] intro: rev_image_eqI)
  done

lemma bindE_cong[fundef_cong]:
  " M = M' ; v s s'. (Inr v, s')  fst (M' s)  N v s' = N' v s'   bindE M N = bindE M' N'"
  apply (simp add: bindE_def)
  apply (rule bind_cong)
   apply (rule refl)
  apply (unfold lift_def)
  apply (case_tac v, simp_all)
  done

lemma bindE_apply_cong[fundef_cong]:
  " f s = f' s'; rv st. (Inr rv, st)  fst (f' s')  g rv st = g' rv st  
   (f >>=E g) s = (f' >>=E g') s'"
  apply (simp add: bindE_def)
  apply (rule bind_apply_cong)
   apply assumption
  apply (case_tac rv, simp_all add: lift_def)
  done

lemma K_bind_apply_cong[fundef_cong]:
  " f st = f' st'   K_bind f arg st = K_bind f' arg' st'"
  by simp

lemma when_apply_cong[fundef_cong]:
  " C = C'; s = s'; C'  m s' = m' s'   whenE C m s = whenE C' m' s'"
  by (simp add: whenE_def)

lemma unless_apply_cong[fundef_cong]:
  " C = C'; s = s'; ¬ C'  m s' = m' s'   unlessE C m s = unlessE C' m' s'"
  by (simp add: unlessE_def)

lemma whenE_apply_cong[fundef_cong]:
  " C = C'; s = s'; C'  m s' = m' s'   whenE C m s = whenE C' m' s'"
  by (simp add: whenE_def)

lemma unlessE_apply_cong[fundef_cong]:
  " C = C'; s = s'; ¬ C'  m s' = m' s'   unlessE C m s = unlessE C' m' s'"
  by (simp add: unlessE_def)

subsection "Simplifying Monads"

lemma nested_bind [simp]:
  "do x <- do y <- f; return (g y) od; h x od =
   do y <- f; h (g y) od"
  apply (clarsimp simp add: bind_def)
  apply (rule ext)
  apply (clarsimp simp add: Let_def split_def return_def)
  done

lemma fail_bind [simp]:
  "fail >>= f = fail"
  by (simp add: bind_def fail_def)

lemma fail_bindE [simp]:
  "fail >>=E f = fail"
  by (simp add: bindE_def bind_def fail_def)

lemma assert_False [simp]:
  "assert False >>= f = fail"
  by (simp add: assert_def)

lemma assert_True [simp]:
  "assert True >>= f = f ()"
  by (simp add: assert_def)

lemma assertE_False [simp]:
  "assertE False >>=E f = fail"
  by (simp add: assertE_def)

lemma assertE_True [simp]:
  "assertE True >>=E f = f ()"
  by (simp add: assertE_def)

lemma when_False_bind [simp]:
  "when False g >>= f = f ()"
  by (rule ext) (simp add: when_def bind_def return_def)

lemma when_True_bind [simp]:
  "when True g >>= f = g >>= f"
  by (simp add: when_def bind_def return_def)

lemma whenE_False_bind [simp]:
  "whenE False g >>=E f = f ()"
  by (simp add: whenE_def bindE_def returnOk_def lift_def)

lemma whenE_True_bind [simp]:
  "whenE True g >>=E f = g >>=E f"
  by (simp add: whenE_def bindE_def returnOk_def lift_def)

lemma when_True [simp]: "when True X = X"
  by (clarsimp simp: when_def)

lemma when_False [simp]: "when False X = return ()"
  by (clarsimp simp: when_def)

lemma unless_False [simp]: "unless False X = X"
  by (clarsimp simp: unless_def)

lemma unless_True [simp]: "unless True X = return ()"
  by (clarsimp simp: unless_def)

lemma unlessE_whenE:
  "unlessE P = whenE (~P)"
  by (rule ext)+ (simp add: unlessE_def whenE_def)

lemma unless_when:
  "unless P = when (~P)"
  by (rule ext)+ (simp add: unless_def when_def)

lemma gets_to_return [simp]: "gets (λs. v) = return v"
  by (clarsimp simp: gets_def put_def get_def bind_def return_def)

lemma assert_opt_Some:
  "assert_opt (Some x) = return x"
  by (simp add: assert_opt_def)

lemma assertE_liftE:
  "assertE P = liftE (assert P)"
  by (simp add: assertE_def assert_def liftE_def returnOk_def)

lemma liftE_handleE' [simp]: "((liftE a) <handle2> b) = liftE a"
  apply (clarsimp simp: liftE_def handleE'_def)
  done

lemma liftE_handleE [simp]: "((liftE a) <handle> b) = liftE a"
  apply (unfold handleE_def)
  apply simp
  done

lemma condition_split:
  "P (condition C a b s) = ((((C s)  P (a s))  (¬ (C s)  P (b s))))"
  apply (clarsimp simp: condition_def)
  done

lemma condition_split_asm:
  "P (condition C a b s) = (¬ (C s  ¬ P (a s)  ¬ C s  ¬ P (b s)))"
  apply (clarsimp simp: condition_def)
  done

lemmas condition_splits = condition_split condition_split_asm

lemma condition_true_triv [simp]:
  "condition (λ_. True) A B = A"
  apply (rule ext)
  apply (clarsimp split: condition_splits)
  done

lemma condition_false_triv [simp]:
  "condition (λ_. False) A B = B"
  apply (rule ext)
  apply (clarsimp split: condition_splits)
  done

lemma condition_true: " P s   condition P A B s = A s"
  apply (clarsimp simp: condition_def)
  done

lemma condition_false: " ¬ P s   condition P A B s = B s"
  apply (clarsimp simp: condition_def)
  done

section "Low-level monadic reasoning"

lemma monad_eqI [intro]:
  " r t s. (r, t)  fst (A s)  (r, t)  fst (B s);
     r t s. (r, t)  fst (B s)  (r, t)  fst (A s);
     x. snd (A x) = snd (B x) 
   (A :: ('s, 'a) nondet_monad) = B"
  apply (fastforce intro!: set_eqI prod_eqI)
  done

lemma monad_state_eqI [intro]:
  " r t. (r, t)  fst (A s)  (r, t)  fst (B s');
     r t. (r, t)  fst (B s')  (r, t)  fst (A s);
     snd (A s) = snd (B s') 
   (A :: ('s, 'a) nondet_monad) s = B s'"
  apply (fastforce intro!: set_eqI prod_eqI)
  done

subsection "General whileLoop reasoning"

definition
  "whileLoop_terminatesE C B  (λr.
     whileLoop_terminates (λr s. case r of Inr v  C v s | _  False) (lift B) (Inr r))"

lemma whileLoop_cond_fail:
    " ¬ C x s   (whileLoop C B x s) = (return x s)"
  apply (auto simp: return_def whileLoop_def
       intro: whileLoop_results.intros
              whileLoop_terminates.intros
       elim!: whileLoop_results.cases)
  done

lemma whileLoopE_cond_fail:
    " ¬ C x s   (whileLoopE C B x s) = (returnOk x s)"
  apply (clarsimp simp: whileLoopE_def returnOk_def)
  apply (auto intro: whileLoop_cond_fail)
  done

lemma whileLoop_results_simps_no_move [simp]:
  shows "((Some x, Some x)  whileLoop_results C B) = (¬ C (fst x) (snd x))"
    (is "?LHS x = ?RHS x")
proof (rule iffI)
  assume "?LHS x"
  then have "(a. Some x = Some a)  ?RHS (the (Some x))"
   by (induct rule: whileLoop_results.induct, auto)
  thus "?RHS x"
    by clarsimp
next
  assume "?RHS x"
  thus "?LHS x"
    by (metis surjective_pairing whileLoop_results.intros(1))
qed

lemma whileLoop_unroll:
  "(whileLoop C B r) =  ((condition (C r) (B r >>= (whileLoop C B)) (return r)))"
  (is "?LHS r = ?RHS r")
proof -
  have cond_fail: "r s. ¬ C r s  ?LHS r s = ?RHS r s"
    apply (subst whileLoop_cond_fail, simp)
    apply (clarsimp simp: condition_def bind_def return_def)
    done

  have cond_pass: "r s. C r s  whileLoop C B r s = (B r >>= (whileLoop C B)) s"
    apply (rule monad_state_eqI)
      apply (clarsimp simp: whileLoop_def bind_def split_def)
      apply (subst (asm) whileLoop_results_simps_valid)
      apply fastforce
     apply (clarsimp simp: whileLoop_def bind_def split_def)
     apply (subst whileLoop_results.simps)
     apply fastforce
    apply (clarsimp simp: whileLoop_def bind_def split_def)
    apply (subst whileLoop_results.simps)
    apply (subst whileLoop_terminates.simps)
    apply fastforce
    done

  show ?thesis
    apply (rule ext)
    apply (metis cond_fail cond_pass condition_def)
    done
qed

lemma whileLoop_unroll':
    "(whileLoop C B r) = ((condition (C r) (B r) (return r)) >>= (whileLoop C B))"
  apply (rule ext)
  apply (subst whileLoop_unroll)
  apply (clarsimp simp: condition_def bind_def return_def split_def)
  apply (subst whileLoop_cond_fail, simp)
  apply (clarsimp simp: return_def)
  done

lemma whileLoopE_unroll:
  "(whileLoopE C B r) =  ((condition (C r) (B r >>=E (whileLoopE C B)) (returnOk r)))"
  apply (rule ext)
  apply (unfold whileLoopE_def)
  apply (subst whileLoop_unroll)
  apply (clarsimp simp: whileLoopE_def bindE_def returnOk_def split: condition_splits)
  apply (clarsimp simp: lift_def)
  apply (rename_tac x, rule_tac f="λa. (B r >>= a) x" in arg_cong)
  apply (rule ext)+
  apply (clarsimp simp: lift_def split: sum.splits)
  apply (subst whileLoop_unroll)
  apply (subst condition_false)
   apply metis
  apply (clarsimp simp: throwError_def)
  done

lemma whileLoopE_unroll':
  "(whileLoopE C B r) =  ((condition (C r) (B r) (returnOk r)) >>=E (whileLoopE C B))"
  apply (rule ext)
  apply (subst whileLoopE_unroll)
  apply (clarsimp simp: condition_def bindE_def bind_def returnOk_def return_def lift_def split_def)
  apply (subst whileLoopE_cond_fail, simp)
  apply (clarsimp simp: returnOk_def return_def)
  done

(* These lemmas are useful to apply to rules to convert valid rules into
 * a format suitable for wp. *)

lemma valid_make_schematic_post:
  "(s0.  λs. P s0 s  f  λrv s. Q s0 rv s ) 
    λs. s0. P s0 s  (rv s'. Q s0 rv s'  Q' rv s')  f  Q' "
  by (auto simp add: valid_def no_fail_def split: prod.splits)

lemma validNF_make_schematic_post:
  "(s0.  λs. P s0 s  f  λrv s. Q s0 rv s ⦄!) 
    λs. s0. P s0 s  (rv s'. Q s0 rv s'  Q' rv s')  f  Q' ⦄!"
  by (auto simp add: valid_def validNF_def no_fail_def split: prod.splits)

lemma validE_make_schematic_post:
  "(s0.  λs. P s0 s  f  λrv s. Q s0 rv s ⦄,  λrv s. E s0 rv s ) 
    λs. s0. P s0 s  (rv s'. Q s0 rv s'  Q' rv s')
         (rv s'. E s0 rv s'  E' rv s')  f  Q' ⦄,  E' "
  by (auto simp add: validE_def valid_def no_fail_def split: prod.splits sum.splits)

lemma validE_NF_make_schematic_post:
  "(s0.  λs. P s0 s  f  λrv s. Q s0 rv s ⦄,  λrv s. E s0 rv s ⦄!) 
    λs. s0. P s0 s  (rv s'. Q s0 rv s'  Q' rv s')
         (rv s'. E s0 rv s'  E' rv s')  f  Q' ⦄,  E' ⦄!"
  by (auto simp add: validE_NF_def validE_def valid_def no_fail_def split: prod.splits sum.splits)

lemma validNF_conjD1: " P  f  λrv s. Q rv s  Q' rv s ⦄!   P  f  Q ⦄!"
  by (fastforce simp: validNF_def valid_def no_fail_def)

lemma validNF_conjD2: " P  f  λrv s. Q rv s  Q' rv s ⦄!   P  f  Q' ⦄!"
  by (fastforce simp: validNF_def valid_def no_fail_def)

end

Theory OptionMonadND

(*
 * Copyright 2014, NICTA
 *
 * This software may be distributed and modified according to the terms of
 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
 * See "LICENSE_BSD2.txt" for details.
 *
 * @TAG(NICTA_BSD)
 *)

(* Option monad syntax plus the connection between the option monad and the nondet monad *)

theory OptionMonadND
imports
  OptionMonad
  "wp/NonDetMonadLemmas"
begin

(* FIXME: better concrete syntax? *)
(* Syntax defined here so we can reuse NonDetMonad definitions *)
syntax
  "_doO" :: "[dobinds, 'a] => 'a"  ("(DO (_);//   (_)//OD)" 100)

translations
  "_doO (_dobinds b bs) e" == "_doO b (_doO bs e)"
  "_doO (_nobind b) e"     == "b |>> (CONST K_bind e)"
  "DO x <- a; e OD"        == "a |>> (λx. e)"


definition
 ogets :: "('a  'b)  ('a  'b option)"
where
 "ogets f  (λs. Some (f s))"

definition
  ocatch :: "('s,('e + 'a)) lookup  ('e  ('s,'a) lookup)  ('s, 'a) lookup"
  (infix "<ocatch>" 10)
where
  "f <ocatch> handler 
     DO x  f;
        case x of
          Inr b  oreturn b
        | Inl e  handler e
     OD"


definition
  odrop :: "('s, 'e + 'a) lookup  ('s, 'a) lookup"
where
  "odrop f 
     DO x  f;
        case x of
          Inr b  oreturn b
        | Inl e  ofail
     OD"

definition
  osequence_x :: "('s, 'a) lookup list  ('s, unit) lookup"
where
  "osequence_x xs  foldr (λx y. DO _ <- x; y OD) xs (oreturn ())"

definition
  osequence :: "('s, 'a) lookup list  ('s, 'a list) lookup"
where
  "osequence xs  let mcons = (λp q. p |>> (λx. q |>> (λy. oreturn (x#y))))
                 in foldr mcons xs (oreturn [])"

definition
  omap :: "('a  ('s,'b) lookup)  'a list  ('s, 'b list) lookup"
where
  "omap f xs  osequence (map f xs)"

definition
  opt_cons :: "'a option  'a list  'a list" (infixr "o#" 65)
where
  "opt_cons x xs  case x of None  xs | Some x'  x' # xs"

lemmas monad_simps =
  gets_the_def bind_def assert_def assert_opt_def 
  simpler_gets_def fail_def return_def

lemma gets_the_opt_map:
  "gets_the (f |> g) = do x  gets_the f; assert_opt (g x) od"
  by (rule ext) (simp add: monad_simps opt_map_def split: option.splits)

lemma gets_the_opt_o:
  "gets_the (f |> Some o g) = do x  gets_the f; return (g x) od"
  by (simp add: gets_the_opt_map assert_opt_Some)

lemma gets_the_obind:
  "gets_the (f |>> g) = gets_the f >>= (λx. gets_the (g x))"
  by (rule ext) (simp add: monad_simps obind_def split: option.splits)

lemma gets_the_return:
  "gets_the (oreturn x) = return x"
  by (simp add: monad_simps oreturn_def K_def)

lemma gets_the_fail:
  "gets_the ofail = fail"
  by (simp add: monad_simps ofail_def K_def)

lemma gets_the_returnOk:
  "gets_the (oreturnOk x) = returnOk x"
  by (simp add: monad_simps K_def oreturnOk_def returnOk_def)

lemma gets_the_throwError:
  "gets_the (othrow e) = throwError e"
  by (simp add: monad_simps othrow_def throwError_def K_def)

lemma gets_the_assert:
  "gets_the (oassert P) = assert P"
  by (simp add: oassert_def assert_def gets_the_fail gets_the_return)

lemmas omonad_simps [simp] =
  gets_the_opt_map assert_opt_Some gets_the_obind
  gets_the_return gets_the_fail gets_the_returnOk
  gets_the_throwError gets_the_assert



section "Relation between option monad loops and non-deterministic monad loops."

(* Option monad whileLoop formalisation thanks to Lars Noschinski <noschinl@in.tum.de>. *)

lemma gets_the_conv:
  "(gets_the B s) = (case B s of Some r'  ({(r', s)}, False) | _  ({}, True))"
  by (auto simp: gets_the_def gets_def get_def bind_def return_def fail_def assert_opt_def split: option.splits)

lemma gets_the_loop_terminates:
  "whileLoop_terminates C (λa. gets_the (B a)) r s
     (rs'. (Some r, rs')  option_while' (λa. C a s) (λa. B a s))" (is "?L  ?R")
proof
  assume ?L then show ?R
  proof (induct rule: whileLoop_terminates.induct[case_names 1 2])
    case (2 r s) then show ?case
      by (cases "B r s") (auto simp: gets_the_conv intro: option_while'.intros)
  qed (auto intro: option_while'.intros)
next
  assume ?R then show ?L
  proof (elim exE)
    fix rs' assume "(Some r, rs')  option_while' (λa. C a s) (λa. B a s)"
    then have "whileLoop_terminates C (λa. gets_the (B a)) (the (Some r)) s"
      by induct (auto intro: whileLoop_terminates.intros simp: gets_the_conv)
    then show ?thesis by simp
  qed
qed

lemma gets_the_whileLoop:
  fixes C :: "'a  's  bool"
  shows "whileLoop C (λa. gets_the (B a)) r = gets_the (owhile C B r)"
proof -
  { fix r s r' s' assume "(Some (r,s), Some (r', s'))  whileLoop_results C (λa. gets_the (B a))"
    then have "s = s'  (Some r, Some r')  option_while' (λa. C a s) (λa. B a s)"
    by (induct "Some (r, s)" "Some (r', s')" arbitrary: r s)
       (auto intro: option_while'.intros simp: gets_the_conv split: option.splits) }
  note wl'_Inl = this

  { fix r s assume "(Some (r,s), None)  whileLoop_results C (λa. gets_the (B a))"
    then have "(Some r, None)  option_while' (λa. C a s) (λa. B a s)"
      by (induct "Some (r, s)" "None :: (('a × 's) option)" arbitrary: r s)
         (auto intro: option_while'.intros simp: gets_the_conv split: option.splits) }
  note wl'_Inr = this

  { fix r s r' assume "(Some r, Some r')  option_while' (λa. C a s) (λa. B a s)"
    then have "(Some (r,s), Some (r',s))  whileLoop_results C (λa. gets_the (B a))"
    by (induct "Some r" "Some r'" arbitrary: r)
       (auto intro: whileLoop_results.intros simp: gets_the_conv) }
  note option_while'_Some = this

  { fix r s assume "(Some r, None)  option_while' (λa. C a s) (λa. B a s)"
    then have "(Some (r,s), None)  whileLoop_results C (λa. gets_the (B a))"
    by (induct "Some r" "None :: 'a option" arbitrary: r)
       (auto intro: whileLoop_results.intros simp: gets_the_conv) }
  note option_while'_None = this

  have "s. owhile C B r s = None
       whileLoop C (λa. gets_the (B a)) r s = ({}, True)"
    by (auto simp: whileLoop_def owhile_def option_while_def option_while'_THE gets_the_loop_terminates
      split: if_split_asm dest: option_while'_None wl'_Inl option_while'_inj)
  moreover
  have "s r'. owhile C B r s = Some r'
       whileLoop C (λa. gets_the (B a)) r s = ({(r', s)}, False)"
    by (auto simp: whileLoop_def owhile_def option_while_def option_while'_THE gets_the_loop_terminates
      split: if_split_asm dest: wl'_Inl wl'_Inr option_while'_inj intro: option_while'_Some)
  ultimately
  show ?thesis
    by (auto simp: fun_eq_iff gets_the_conv split: option.split)
qed

end

Theory WP

(*
 * Copyright 2014, NICTA
 *
 * This software may be distributed and modified according to the terms of
 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
 * See "LICENSE_BSD2.txt" for details.
 *
 * @TAG(NICTA_BSD)
 *)

theory WP
imports Main
begin

definition
  triple_judgement :: "('a  bool)  'b  ('a  'b  bool)  bool"
where
 "triple_judgement pre body property = (s. pre s  property s body)"

definition
  postcondition :: "('r  's  bool)  ('a  'b  ('r × 's) set)
             'a  'b  bool"
where
 "postcondition P f = (λa b. (rv, s)  f a b. P rv s)"

definition
  postconditions :: "('a  'b  bool)  ('a  'b  bool)  ('a  'b  bool)"
where
 "postconditions P Q = (λa b. P a b  Q a b)"

ML_file ‹WP-method.ML›

declare [[wp_warn_unused = false]]

setup WeakestPre.setup

method_setup wp = WeakestPre.apply_rules_args false›
  "applies weakest precondition rules"

method_setup wp_once = WeakestPre.apply_once_args false›
  "applies one weakest precondition rule"

method_setup wp_trace = WeakestPre.apply_rules_args true›
  "applies weakest precondition rules with tracing"

method_setup wp_once_trace = WeakestPre.apply_once_args true›
  "applies one weakest precondition rule with tracing"

end

File ‹WP-method.ML›

(*
 * Copyright 2014, NICTA
 *
 * This software may be distributed and modified according to the terms of
 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
 * See "LICENSE_BSD2.txt" for details.
 *
 * @TAG(NICTA_BSD)
 *)

signature WP =
sig
  type wp_rules = {trips: thm list * (theory -> term -> term),
    rules: (int * thm) Net.net * int * (int * thm) list,
    splits: thm list, combs: thm list, unsafe_rules: thm list};

  val debug_get: Proof.context -> wp_rules;

  val derived_rule: thm -> thm -> thm list;
  val get_combined_rules': thm list -> thm -> thm list;
  val get_combined_rules: thm list -> thm list -> thm list;

  val get_rules: Proof.context -> thm list -> wp_rules;

  val apply_rules_tac_n: bool -> Proof.context -> thm list -> thm list Unsynchronized.ref
                      -> int -> tactic;
  val apply_rules_tac: bool -> Proof.context -> thm list -> thm list Unsynchronized.ref
                    -> tactic;
  val apply_rules_args: bool -> (Proof.context -> Method.method) context_parser;

  val apply_once_tac: bool -> Proof.context -> thm list -> thm list Unsynchronized.ref
                   -> tactic;
  val apply_once_args: bool -> (Proof.context -> Method.method) context_parser;

  val setup: theory -> theory;
  val warn_unused: bool Config.T

  val wp_add: Thm.attribute;
  val wp_del: Thm.attribute;
  val splits_add: Thm.attribute;
  val splits_del: Thm.attribute;
  val combs_add: Thm.attribute;
  val combs_del: Thm.attribute;
  val wp_unsafe_add: Thm.attribute;
  val wp_unsafe_del: Thm.attribute;
end;

structure WeakestPre =
struct

type wp_rules = {trips: thm list * (theory -> term -> term),
    rules: (int * thm) Net.net * int * (int * thm) list,
    splits: thm list, combs: thm list, unsafe_rules: thm list};

fun accum_last_occurence' [] _ = ([], Termtab.empty)
  | accum_last_occurence' ((t, v) :: ts) tt1 = let
      val tm = Thm.prop_of t;
      val tt2 = Termtab.insert_list (K false) (tm, v) tt1;
      val (ts', tt3) = accum_last_occurence' ts tt2;
  in case Termtab.lookup tt3 tm of
        NONE => ((t, Termtab.lookup_list tt2 tm)  :: ts',
                    Termtab.update (tm, ()) tt3)
      | SOME _ => (ts', tt3)
  end;

fun accum_last_occurence ts =
        fst (accum_last_occurence' ts Termtab.empty);

fun flat_last_occurence ts =
  map fst (accum_last_occurence (map (fn v => (v, ())) ts));

fun dest_rules (trips, _, others) =
  rev (order_list (Net.entries trips @ others));

fun get_key trip_conv t = let
    val t' = Thm.concl_of t |> trip_conv (Thm.theory_of_thm t)
        |> Envir.beta_eta_contract;
  in case t' of Const (@{const_name Trueprop}, _) $
      (Const (@{const_name triple_judgement}, _) $ _ $ f $ _) => SOME f
    | _ => NONE end;

fun add_rule_inner trip_conv t (trips, n, others) = (
  case get_key trip_conv t of
      SOME k => (Net.insert_term (K false)
                 (k, (n, t)) trips, n + 1, others)
    | _ => (trips, n + 1, (n, t) :: others)
  );

fun del_rule_inner trip_conv t (trips, n, others) =
    case get_key trip_conv t of
      SOME k => (Net.delete_term_safe (Thm.eq_thm_prop o apply2 snd)
                 (k, (n, t)) trips, n, others)
    | _ => (trips, n, remove (Thm.eq_thm_prop o apply2 snd) (n, t) others)

val no_rules = (Net.empty, 0, []);

fun mk_rules trip_conv rules = fold_rev (add_rule_inner trip_conv) rules no_rules;

fun mk_trip_conv trips thy = Pattern.rewrite_term thy
    (map (Thm.concl_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq) trips) []

fun rules_merge (wp_rules, wp_rules') = let
    val trips = Thm.merge_thms (fst (#trips wp_rules), fst (#trips wp_rules'));
    val trip_conv = mk_trip_conv trips
    val rules = flat_last_occurence (dest_rules (#rules wp_rules) @ dest_rules (#rules wp_rules'));
  in {trips = (trips, trip_conv),
        rules = mk_rules trip_conv rules,
        splits = Thm.merge_thms (#splits wp_rules, #splits wp_rules'),
        combs = Thm.merge_thms (#combs wp_rules, #combs wp_rules'),
        unsafe_rules = Thm.merge_thms (#unsafe_rules wp_rules, #unsafe_rules wp_rules')} end

structure WPData = Generic_Data
(struct
    type T = wp_rules;
    val empty = {trips = ([], K I), rules = no_rules,
      splits = [], combs = [], unsafe_rules = []};
    val extend = I;

    val merge = rules_merge;
end);

fun derived_rule rule combinator =
  [rule RSN (1, combinator)] handle THM _ => [];

fun get_combined_rules' combs' rule =
  rule :: (List.concat (map (derived_rule rule) combs'));

fun get_combined_rules rules' combs' =
  List.concat (map (get_combined_rules' combs') rules');

fun add_rule rule rs =
    {trips = #trips rs,
      rules = add_rule_inner (snd (#trips rs)) rule (#rules rs),
      splits = #splits rs, combs = #combs rs,
      unsafe_rules = #unsafe_rules rs};

fun del_rule rule rs =
    {trips = #trips rs,
      rules = del_rule_inner (snd (#trips rs)) rule (#rules rs),
      splits = #splits rs, combs = #combs rs,
      unsafe_rules = #unsafe_rules rs};

fun add_trip rule (rs : wp_rules) = let
    val trips = Thm.add_thm rule (fst (#trips rs));
    val trip_conv = mk_trip_conv trips
  in {trips = (trips, trip_conv),
      rules = mk_rules trip_conv (dest_rules (#rules rs)),
      splits = #splits rs, combs = #combs rs,
      unsafe_rules = #unsafe_rules rs} end;

fun del_trip rule (rs : wp_rules) = let
    val trips = Thm.del_thm rule (fst (#trips rs));
    val trip_conv = mk_trip_conv trips
  in {trips = (trips, trip_conv),
      rules = mk_rules trip_conv (dest_rules (#rules rs)),
      splits = #splits rs, combs = #combs rs,
      unsafe_rules = #unsafe_rules rs} end;

fun add_split rule (rs : wp_rules) =
    {trips = #trips rs, rules = #rules rs,
      splits = Thm.add_thm rule (#splits rs), combs = #combs rs,
      unsafe_rules = #unsafe_rules rs};

fun add_comb rule (rs : wp_rules) =
    {trips = #trips rs, rules = #rules rs,
      splits = #splits rs, combs = Thm.add_thm rule (#combs rs),
      unsafe_rules = #unsafe_rules rs};

fun del_split rule rs =
    {trips = #trips rs, rules = #rules rs,
      splits = Thm.del_thm rule (#splits rs), combs = #combs rs,
      unsafe_rules = #unsafe_rules rs};

fun del_comb rule rs =
    {trips = #trips rs, rules = #rules rs,
      splits = #splits rs, combs = Thm.del_thm rule (#combs rs),
      unsafe_rules = #unsafe_rules rs};

fun add_unsafe_rule rule rs =
    {trips = #trips rs, rules = #rules rs,
      splits = #splits rs, combs = #combs rs,
      unsafe_rules = Thm.add_thm rule (#unsafe_rules rs)};

fun del_unsafe_rule rule rs =
    {trips = #trips rs, rules = #rules rs,
      splits = #splits rs, combs = #combs rs,
      unsafe_rules = Thm.del_thm rule (#unsafe_rules rs)};

fun gen_att m = Thm.declaration_attribute (fn thm => fn context =>
    WPData.map (m thm) context);

val wp_add = gen_att add_rule;
val wp_del = gen_att del_rule;
val trip_add = gen_att add_trip;
val trip_del = gen_att del_trip;
val splits_add = gen_att add_split;
val splits_del = gen_att del_split;
val combs_add = gen_att add_comb;
val combs_del = gen_att del_comb;
val wp_unsafe_add = gen_att add_unsafe_rule;
val wp_unsafe_del = gen_att del_unsafe_rule;

val setup =
      Attrib.setup @{binding "wp"}
          (Attrib.add_del wp_add wp_del)
          "monadic weakest precondition rules"
      #> Attrib.setup @{binding "wp_trip"}
          (Attrib.add_del trip_add trip_del)
          "monadic triple conversion rules"
      #> Attrib.setup @{binding "wp_split"}
          (Attrib.add_del splits_add splits_del)
          "monadic split rules"
      #> Attrib.setup @{binding "wp_comb"}
          (Attrib.add_del combs_add combs_del)
          "monadic combination rules"
      #> Attrib.setup @{binding "wp_unsafe"}
          (Attrib.add_del wp_unsafe_add wp_unsafe_del)
          "unsafe monadic weakest precondition rules"

fun debug_get ctxt = WPData.get (Context.Proof ctxt);

fun get_rules ctxt extras = fold_rev add_rule extras (debug_get ctxt);

fun append_used_rule rule used_rules = used_rules := !used_rules @ [rule]

fun add_extra_rule rule extra_rules = extra_rules := !extra_rules @ [rule]

fun resolve_ruleset_tac ctxt rs used_rules_ref n t =
  let
    fun append_rule rule thm = Seq.map (fn thm => (append_used_rule rule used_rules_ref; thm)) thm;
    fun rtac th = resolve_tac ctxt [th]
  in case
    Thm.cprem_of t n |> Thm.term_of |> snd (#trips rs) (Thm.theory_of_thm t)
        |> Envir.beta_eta_contract |> Logic.strip_assums_concl
     handle THM _ => @{const True}
  of Const (@{const_name Trueprop}, _) $
      (Const (@{const_name triple_judgement}, _) $ _ $ f $ _) => let
        val ts = Net.unify_term (#1 (#rules rs)) f |> order_list |> rev;
        val combapps = Seq.maps (fn combapp => Seq.map (fn combapp' => (combapp, combapp')) (rtac combapp n t))
                                (Seq.of_list (#combs rs)) |> Seq.list_of |> Seq.of_list;
        fun per_rule_tac t = (fn thm => append_rule t (rtac t n thm)) ORELSE
                             (fn _ => Seq.maps (fn combapp => append_rule t
                                        (append_rule (#1 combapp) (rtac t n (#2 combapp)))) combapps);
      in FIRST (map per_rule_tac ts) ORELSE
         FIRST (map (fn split => fn thm => append_rule split (rtac split n thm)) (#splits rs)) end t
    | _ => FIRST (map (fn rule => fn thm => append_rule rule (rtac rule n thm))
                      (map snd (#3 (#rules rs)) @ #splits rs)) t end;

fun pretty_rule ctxt rule =
  Pretty.big_list (Thm.get_name_hint rule) [Thm.pretty_thm ctxt rule]
           |> Pretty.string_of;

fun trace_used_thms false _ _ = Seq.empty
  | trace_used_thms true used_rules ctxt =
      let val used_thms = !used_rules
      in map (fn rule => tracing (pretty_rule ctxt rule)) used_thms
        |> Seq.of_list end;

val warn_unused = Attrib.setup_config_bool @{binding wp_warn_unused} (K false);

fun warn_unused_thms ctxt thms extra_rules used_rules =
  if Config.get ctxt warn_unused
  then
    let val used_thms = map (fn rule => Thm.get_name_hint rule) (!used_rules)
        val unused_thms = map Thm.get_name_hint (!extra_rules @ thms) |> subtract (op =) used_thms
    in if not (null unused_thms)
       then "Unused theorems: " ^ commas_quote unused_thms |> warning
       else ()
    end
  else ()

fun warn_unsafe_thms unsafe_thms n ctxt t =
  let val used_rules = Unsynchronized.ref [] : thm list Unsynchronized.ref;
      val useful_unsafe_thms =
          filter (fn rule =>
            (is_some o SINGLE (
              resolve_ruleset_tac ctxt (get_rules ctxt [rule]) used_rules n)) t) unsafe_thms
      val unsafe_thm_names = map (fn rule => Thm.get_name_hint rule) useful_unsafe_thms
  in if not (null unsafe_thm_names)
     then "Unsafe theorems that could be used: " ^ commas_quote unsafe_thm_names |> warning
     else () end;

fun apply_rules_tac_n trace ctxt extras extras_ref n =
let
  val rules = get_rules ctxt extras;
  val used_rules = Unsynchronized.ref [] : thm list Unsynchronized.ref
in
  (fn t => Seq.map (fn thm => (warn_unused_thms ctxt extras extras_ref used_rules;
                               trace_used_thms trace used_rules ctxt; thm))
    (CHANGED (REPEAT_DETERM (resolve_ruleset_tac ctxt rules used_rules n)) t)) THEN_ELSE
  (fn t => (warn_unsafe_thms (#unsafe_rules rules) n ctxt t; all_tac t),
  fn t => (warn_unsafe_thms (#unsafe_rules rules) n ctxt t; no_tac t))
end;

fun apply_rules_tac trace ctxt extras extras_ref = apply_rules_tac_n trace ctxt extras extras_ref 1;
fun apply_once_tac trace ctxt extras extras_ref t =
  let val used_rules = Unsynchronized.ref [] : thm list Unsynchronized.ref;
  in Seq.map (fn thm => (warn_unused_thms ctxt extras extras_ref used_rules;
                         trace_used_thms trace used_rules ctxt; thm))
    (resolve_ruleset_tac ctxt (get_rules ctxt extras) used_rules 1 t) end

fun clear_rules ({combs, rules=_, trips, splits, unsafe_rules}) =
  {combs=combs, rules=no_rules, trips=trips, splits=splits, unsafe_rules=unsafe_rules}

fun wp_modifiers extras_ref =
 [Args.add -- Args.colon >> K (I, fn att => (add_extra_rule (#2 att) extras_ref; wp_add att)),
  Args.del -- Args.colon >> K (I, wp_del),
  Args.$$$ "comb" -- Args.colon >> K (I, fn att => (add_extra_rule (#2 att) extras_ref; combs_add att)),
  Args.$$$ "comb" -- Args.add -- Args.colon >> K (I, fn att => (add_extra_rule (#2 att) extras_ref; combs_add att)),
  Args.$$$ "comb" -- Args.del -- Args.colon >> K (I, combs_del),
  Args.$$$ "only" -- Args.colon
    >> K (Context.proof_map (WPData.map clear_rules), fn att =>
                               (add_extra_rule (#2 att) extras_ref; wp_add att))];

fun has_colon xs = exists (Token.keyword_with (curry (op =) ":")) xs;

fun if_colon scan1 scan2 xs = if has_colon (snd xs) then scan1 xs else scan2 xs;

(* FIXME: It would be nice if we could just use Method.sections, but to maintain
   compatability we require that the order of thms in each section is reversed. *)
fun thms ss = Scan.repeat (Scan.unless (Scan.lift (Scan.first ss)) Attrib.multi_thm) >> flat;
fun app (f, att) ths context = fold_map (Thm.apply_attribute att) ths (Context.map_proof f context);

fun section ss = Scan.depend (fn context => (Scan.first ss -- Scan.pass context (thms ss)) :|--
  (fn (m, ths) => Scan.succeed (swap (app m (rev ths) context))));

fun sections ss = Scan.repeat (section ss);

fun apply_rules_args trace xs =
  let val extras_ref = Unsynchronized.ref [] : thm list Unsynchronized.ref;
  in
    if_colon
    (sections (wp_modifiers extras_ref) >>
      K (fn ctxt => SIMPLE_METHOD (apply_rules_tac trace ctxt [] extras_ref)))
    (Attrib.thms >> curry (fn (extras, ctxt) =>
      Method.SIMPLE_METHOD (
        apply_rules_tac trace ctxt extras extras_ref
      )
    ))
  end xs;

fun apply_once_args trace xs =
  let val extras_ref = Unsynchronized.ref [] : thm list Unsynchronized.ref;
  in
    if_colon
    (sections (wp_modifiers extras_ref) >>
      K (fn ctxt => SIMPLE_METHOD (apply_once_tac trace ctxt [] extras_ref)))
    (Attrib.thms >> curry (fn (extras, ctxt) =>
      Method.SIMPLE_METHOD (
        apply_once_tac trace ctxt extras extras_ref
      )
    ))
  end xs;

end;

structure WeakestPreInst : WP = WeakestPre;

Theory OptionMonadWP

(*
 * Copyright 2014, NICTA
 *
 * This software may be distributed and modified according to the terms of
 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
 * See "LICENSE_BSD2.txt" for details.
 *
 * @TAG(NICTA_BSD)
 *)

(*
Hoare reasoning and WP (weakest-precondition) generator rules for the option monad.

This list is almost certainly incomplete; add rules here as they are needed.
*)

theory OptionMonadWP
imports
  OptionMonadND
  "wp/WP"
begin

declare K_def [simp]

(* Hoare triples.
   TODO: design a sensible syntax for them. *)

(* Partial correctness. *)
definition ovalid :: "('s  bool)  ('s  'a option)  ('a  's  bool)  bool" where
  "ovalid P f Q  s r. P s  f s = Some r  Q r s"
(* Total correctness. *)
definition ovalidNF :: "('s  bool)  ('s  'a option)  ('a  's  bool)  bool" where
  "ovalidNF P f Q  s. P s  (f s  None  (r. f s = Some r  Q r s))"
(* Termination. *)
definition no_ofail where "no_ofail P f  s. P s  f s  None"

(*
This rule lets us apply ovalidNF machinery for proving no_ofail.
However, we ought to eventually write working wp rules for no_ofail (see below).
*)
lemma no_ofail_is_ovalidNF: "no_ofail P f  ovalidNF P f (λ_ _. True)"
  by (simp add: no_ofail_def ovalidNF_def)
lemma ovalidNF_combine: " ovalid P f Q; no_ofail P f   ovalidNF P f Q"
  by (auto simp: ovalidNF_def ovalid_def no_ofail_def)


(* Annotating programs with loop invariant and measure. *)
definition owhile_inv ::
  "('a  's  bool)  ('a  's  'a option)  'a
    ('a  's  bool)  ('a  's  nat)  's  'a option"
  where "owhile_inv C B x I M = owhile C B x"  

lemmas owhile_add_inv = owhile_inv_def[symmetric]


(* WP rules for ovalid. *)
lemma obind_wp [wp]:
  " r. ovalid (R r) (g r) Q; ovalid P f R   ovalid P (obind f g) Q"
  by (simp add: ovalid_def obind_def split: option.splits, fast)

lemma oreturn_wp [wp]:
  "ovalid (P x) (oreturn x) P"
  by (simp add: ovalid_def oreturn_def K_def)

lemma ocondition_wp [wp]:
  " ovalid L l Q; ovalid R r Q 
    ovalid (λs. if C s then L s else R s) (ocondition C l r) Q"
  by (auto simp: ovalid_def ocondition_def)

lemma ofail_wp [wp]:
  "ovalid (λ_. True) ofail Q"
  by (simp add: ovalid_def ofail_def)

lemma ovalid_K_bind_wp [wp]:
  "ovalid P f Q  ovalid P (K_bind f x) Q"
  by simp

lemma ogets_wp [wp]: "ovalid (λs. P (f s) s) (ogets f) P"
  by (simp add: ovalid_def ogets_def)

lemma oguard_wp [wp]: "ovalid (λs. f s  P () s) (oguard f) P"
  by (simp add: ovalid_def oguard_def)

lemma oskip_wp [wp]:
  "ovalid (λs. P () s) oskip P"
  by (simp add: ovalid_def oskip_def)

lemma ovalid_case_prod [wp]:
  assumes "(x y. ovalid (P x y) (B x y) Q)"
  shows "ovalid (case v of (x, y)  P x y) (case v of (x, y)  B x y) Q"
  using assms unfolding ovalid_def by auto

lemma owhile_ovalid [wp]:
"a. ovalid (λs. I a s  C a s) (B a) I;
   a s. I a s; ¬ C a s  Q a s
   ovalid (I a) (owhile_inv C B a I M) Q"
  unfolding owhile_inv_def owhile_def ovalid_def
  apply clarify
  apply (frule_tac I = "λa. I a s" in option_while_rule)
  apply auto
  done

definition ovalid_property where "ovalid_property P x = (λs f. (r. Some r = x s f  P r s))"
lemma ovalid_is_triple [wp_trip]:
  "ovalid P f Q = triple_judgement P f (ovalid_property Q (λs f. f s))"
  by (auto simp: triple_judgement_def ovalid_def ovalid_property_def)


lemma ovalid_wp_comb1 [wp_comb]:
  " ovalid P' f Q; ovalid P f Q'; s. P s  P' s   ovalid P f (λr s. Q r s  Q' r s)"
  by (simp add: ovalid_def)

lemma ovalid_wp_comb2 [wp_comb]:
  " ovalid P f Q; s. P' s  P s   ovalid P' f Q"
  by (auto simp: ovalid_def)

lemma ovalid_wp_comb3 [wp_comb]:
  " ovalid P f Q; ovalid P' f Q'   ovalid (λs. P s  P' s) f (λr s. Q r s  Q' r s)"
  by (auto simp: ovalid_def)



(* WP rules for ovalidNF. *)
lemma obind_NF_wp [wp]:
  " r. ovalidNF (R r) (g r) Q; ovalidNF P f R   ovalidNF P (obind f g) Q"
  by (auto simp: ovalidNF_def obind_def split: option.splits)

lemma oreturn_NF_wp [wp]:
  "ovalidNF (P x) (oreturn x) P"
  by (simp add: ovalidNF_def oreturn_def)

lemma ocondition_NF_wp [wp]:
  " ovalidNF L l Q; ovalidNF R r Q 
    ovalidNF (λs. if C s then L s else R s) (ocondition C l r) Q"
  by (simp add: ovalidNF_def ocondition_def)

lemma ofail_NF_wp [wp]:
  "ovalidNF (λ_. False) ofail Q"
  by (simp add: ovalidNF_def ofail_def)

lemma ovalidNF_K_bind_wp [wp]:
  "ovalidNF P f Q  ovalidNF P (K_bind f x) Q"
  by simp

lemma ogets_NF_wp [wp]:
  "ovalidNF (λs. P (f s) s) (ogets f) P"
  by (simp add: ovalidNF_def ogets_def)

lemma oguard_NF_wp [wp]:
  "ovalidNF (λs. f s  P () s) (oguard f) P"
  by (simp add: ovalidNF_def oguard_def)

lemma oskip_NF_wp [wp]:
  "ovalidNF (λs. P () s) oskip P"
  by (simp add: ovalidNF_def oskip_def)

lemma ovalid_NF_case_prod [wp]:
  assumes "(x y. ovalidNF (P x y) (B x y) Q)"
  shows "ovalidNF (case v of (x, y)  P x y) (case v of (x, y)  B x y) Q"
  using assms unfolding ovalidNF_def by auto

lemma owhile_NF [wp]:
"a. ovalidNF (λs. I a s  C a s) (B a) I;
   a m. ovalid (λs. I a s  C a s  M a s = m) (B a) (λr s. M r s < m);
   a s. I a s; ¬ C a s  Q a s
   ovalidNF (I a) (owhile_inv C B a I M) Q"
  unfolding owhile_inv_def ovalidNF_def ovalid_def
  apply clarify
  apply (rename_tac s, rule_tac I = I and M = "measure (λr. M r s)" in owhile_rule)
       apply fastforce
      apply fastforce
     apply fastforce
    apply blast+
  done

definition ovalidNF_property where "ovalidNF_property P x = (λs f. (x s f  None  (r. Some r = x s f  P r s)))"
lemma ovalidNF_is_triple [wp_trip]:
  "ovalidNF P f Q = triple_judgement P f (ovalidNF_property Q (λs f. f s))"
  by (auto simp: triple_judgement_def ovalidNF_def ovalidNF_property_def)


lemma ovalidNF_wp_comb1 [wp_comb]:
  " ovalidNF P' f Q; ovalidNF P f Q'; s. P s  P' s   ovalidNF P f (λr s. Q r s  Q' r s)"
  by (simp add: ovalidNF_def)

lemma ovalidNF_wp_comb2 [wp_comb]:
  " ovalidNF P f Q; s. P' s  P s   ovalidNF P' f Q"
  by (simp add: ovalidNF_def)

lemma ovalidNF_wp_comb3 [wp_comb]:
  " ovalidNF P f Q; ovalidNF P' f Q'   ovalidNF (λs. P s  P' s) f (λr s. Q r s  Q' r s)"
  by (simp add: ovalidNF_def)



(* FIXME: WP rules for no_ofail, which might not be correct. *)
lemma no_ofail_ofail [wp]: "no_ofail (λ_. False) ofail"
  by (simp add: no_ofail_def ofail_def)

lemma no_ofail_ogets [wp]: "no_ofail (λ_. True) (ogets f)"
  by (simp add: no_ofail_def ogets_def)

lemma no_ofail_obind [wp]:
  " r. no_ofail (P r) (g r); no_ofail Q f; ovalid Q f P   no_ofail Q (obind f g)"
  by (auto simp: no_ofail_def obind_def ovalid_def)

lemma no_ofail_K_bind [wp]:
  "no_ofail P f  no_ofail P (K_bind f x)"
  by simp

lemma no_ofail_oguard [wp]:
  "no_ofail (λs. f s) (oguard f)"
  by (auto simp: no_ofail_def oguard_def)

lemma no_ofail_ocondition [wp]:
  " no_ofail L l; no_ofail R r 
      no_ofail (λs. if C s then L s else R s) (ocondition C l r)"
  by (simp add: no_ofail_def ocondition_def)

lemma no_ofail_oreturn [wp]:
  "no_ofail (λ_. True) (oreturn x)"
  by (simp add: no_ofail_def oreturn_def)

lemma no_ofail_oskip [wp]:
  "no_ofail (λ_. True) oskip"
  by (simp add: no_ofail_def oskip_def)

lemma no_ofail_is_triple [wp_trip]:
  "no_ofail P f = triple_judgement P f (λs f. f s  None)"
  by (auto simp: triple_judgement_def no_ofail_def)

lemma no_ofail_wp_comb1 [wp_comb]:
  " no_ofail P f; s. P' s  P s   no_ofail P' f"
  by (simp add: no_ofail_def)

lemma no_ofail_wp_comb2 [wp_comb]:
  " no_ofail P f; no_ofail P' f   no_ofail (λs. P s  P' s) f"
  by (simp add: no_ofail_def)



(* Some extra lemmas for our predicates. *)
lemma ovalid_grab_asm:
  "(G  ovalid P f Q)  ovalid (λs. G  P s) f Q"
  by (simp add: ovalid_def)
lemma ovalidNF_grab_asm:
  "(G  ovalidNF P f Q)  ovalidNF (λs. G  P s) f Q"
  by (simp add: ovalidNF_def)
lemma no_ofail_grab_asm:
  "(G  no_ofail P f)  no_ofail (λs. G  P s) f"
  by (simp add: no_ofail_def)

lemma ovalid_assume_pre:
  "(s. P s  ovalid P f Q)  ovalid P f Q"
  by (auto simp: ovalid_def)
lemma ovalidNF_assume_pre:
  "(s. P s  ovalidNF P f Q)  ovalidNF P f Q"
  by (simp add: ovalidNF_def)
lemma no_ofail_assume_pre:
  "(s. P s  no_ofail P f)  no_ofail P f"
  by (simp add: no_ofail_def)

lemma ovalid_pre_imp:
  " s. P' s  P s; ovalid P f Q   ovalid P' f Q"
  by (simp add: ovalid_def)
lemma ovalidNF_pre_imp:
  " s. P' s  P s; ovalidNF P f Q   ovalidNF P' f Q"
  by (simp add: ovalidNF_def)
lemma no_ofail_pre_imp:
  " s. P' s  P s; no_ofail P f   no_ofail P' f"
  by (simp add: no_ofail_def)

lemma ovalid_post_imp:
  " r s. Q r s  Q' r s; ovalid P f Q   ovalid P f Q'"
  by (simp add: ovalid_def)
lemma ovalidNF_post_imp:
  " r s. Q r s  Q' r s; ovalidNF P f Q   ovalidNF P f Q'"
  by (simp add: ovalidNF_def)

lemma ovalid_post_imp_assuming_pre:
  " r s.  P s; Q r s   Q' r s; ovalid P f Q   ovalid P f Q'"
  by (simp add: ovalid_def)
lemma ovalidNF_post_imp_assuming_pre:
  " r s.  P s; Q r s   Q' r s; ovalidNF P f Q   ovalidNF P f Q'"
  by (simp add: ovalidNF_def)

end

Theory Graph_Genus

theory Graph_Genus
imports
  "HOL-Combinatorics.Permutations"
  Graph_Theory.Graph_Theory
begin

lemma nat_diff_mod_right:
  fixes a b c :: nat
  assumes "b < a"
  shows "(a - b) mod c = (a - b mod c) mod c"
proof -
  from assms have b_mod: "b mod c  a"
    by (metis mod_less_eq_dividend linear not_le order_trans)
  have "int ((a - b) mod c) = (int a - int b mod int c) mod int c"
    using assms by (simp add: zmod_int of_nat_diff mod_simps)
  also have " = int ((a - b mod c) mod c)"
    using assms b_mod
    by (simp add: zmod_int [symmetric] of_nat_diff [symmetric])
  finally show ?thesis by simp
qed

lemma inj_on_f_imageI:
  assumes "inj_on f S" "t. t  T  t  S"
  shows "inj_on ((`) f) T"
  using assms by (auto simp: inj_on_image_eq_iff intro: inj_onI)



section ‹Combinatorial Maps›

lemma (in bidirected_digraph) has_dom_arev:
  "has_dom arev (arcs G)"
  using arev_dom by (auto simp: has_dom_def)

record 'b pre_map =
  edge_rev :: "'b  'b"
  edge_succ :: "'b  'b"

definition edge_pred :: "'b pre_map  'b  'b" where
  "edge_pred M = inv (edge_succ M)"


locale pre_digraph_map = pre_digraph + fixes M :: "'b pre_map"

locale digraph_map = fin_digraph G
  + pre_digraph_map G M
  + bidirected_digraph G "edge_rev M" for G M +
  assumes edge_succ_permutes: "edge_succ M permutes arcs G"
  assumes edge_succ_cyclic: "v. v  verts G  out_arcs G v  {}  cyclic_on (edge_succ M) (out_arcs G v)"

lemma (in fin_digraph) digraph_mapI:
  assumes bidi: "a. a  arcs G  edge_rev M a = a"
    "a. a  arcs G  edge_rev M a  a"
    "a. a  arcs G  edge_rev M (edge_rev M a) = a"
    "a. a  arcs G  tail G (edge_rev M a) = head G a"
  assumes edge_succ_permutes: "edge_succ M permutes arcs G"
  assumes edge_succ_cyclic: "v. v  verts G  out_arcs G v  {}  cyclic_on (edge_succ M) (out_arcs G v)"
  shows "digraph_map G M"
  using assms by unfold_locales auto

lemma (in fin_digraph) digraph_mapI_permutes:
  assumes bidi: "edge_rev M permutes arcs G"
    "a. a  arcs G  edge_rev M a  a"
    "a. a  arcs G  edge_rev M (edge_rev M a) = a"
    "a. a  arcs G  tail G (edge_rev M a) = head G a"
  assumes edge_succ_permutes: "edge_succ M permutes arcs G"
  assumes edge_succ_cyclic: "v. v  verts G  out_arcs G v  {}  cyclic_on (edge_succ M) (out_arcs G v)"
  shows "digraph_map G M"
proof -
  interpret bidirected_digraph G "edge_rev M" using bidi by unfold_locales (auto simp: permutes_def)
  show ?thesis
    using edge_succ_permutes edge_succ_cyclic by unfold_locales
qed


context digraph_map
begin

  lemma digraph_map[intro]: "digraph_map G M" by unfold_locales

  lemma permutation_edge_succ: "permutation (edge_succ M)"
    by (metis edge_succ_permutes finite_arcs permutation_permutes)

  lemma edge_pred_succ[simp]: "edge_pred M (edge_succ M a) = a"
    by (metis edge_pred_def edge_succ_permutes permutes_inverses(2))

  lemma edge_succ_pred[simp]: "edge_succ M (edge_pred M a) = a"
    by (metis edge_pred_def edge_succ_permutes permutes_inverses(1))

  lemma edge_pred_permutes: "edge_pred M permutes arcs G"
    unfolding edge_pred_def using edge_succ_permutes by (rule permutes_inv)

  lemma permutation_edge_pred: "permutation (edge_pred M)"
    by (metis edge_pred_permutes finite_arcs permutation_permutes)

  lemma edge_succ_eq_iff[simp]: "x y. edge_succ M x = edge_succ M y  x = y"
    by (metis edge_pred_succ)

  lemma edge_rev_in_arcs[simp]: "edge_rev M a  arcs G  a  arcs G"
    by (metis arev_arev arev_permutes_arcs permutes_not_in)

  lemma edge_succ_in_arcs[simp]: "edge_succ M a  arcs G  a  arcs G"
    by (metis edge_pred_succ edge_succ_permutes permutes_not_in)

  lemma edge_pred_in_arcs[simp]: "edge_pred M a  arcs G  a  arcs G"
    by (metis edge_succ_pred edge_pred_permutes permutes_not_in)

  lemma tail_edge_succ[simp]: "tail G (edge_succ M a) = tail G a"
  proof cases
    assume "a  arcs G"
    then have "tail G a  verts G" by auto
    moreover
    then have "out_arcs G (tail G a)  {}"
      using a  arcs G by auto
    ultimately
    have "cyclic_on (edge_succ M) (out_arcs G (tail G a))"
      by (rule edge_succ_cyclic)
    moreover
    have "a  out_arcs G (tail G a)"
      using a  arcs G by simp
    ultimately
    have "edge_succ M a  out_arcs G (tail G a)"
      by (rule cyclic_on_inI)
    then show ?thesis by simp
  next
    assume "a  arcs G" then show ?thesis using edge_succ_permutes by (simp add: permutes_not_in)
  qed

  lemma tail_edge_pred[simp]: "tail G (edge_pred M a) = tail G a"
  by (metis edge_succ_pred tail_edge_succ)

  lemma bij_edge_succ[intro]: "bij (edge_succ M)"
    using edge_succ_permutes by (simp add: permutes_conv_has_dom)

  lemma edge_pred_cyclic:
    assumes "v  verts G" "out_arcs G v  {}"
    shows "cyclic_on (edge_pred M) (out_arcs G v)"
  proof -
    obtain a where orb_a_eq: "orbit (edge_succ M) a = out_arcs G v"
      using edge_succ_cyclic[OF assms] by (auto simp: cyclic_on_def)
    have "cyclic_on (edge_pred M) (orbit (edge_pred M) a)"
      using permutation_edge_pred by (rule cyclic_on_orbit')
    also have "orbit (edge_pred M) a = orbit (edge_succ M) a"
      unfolding edge_pred_def using permutation_edge_succ by (rule orbit_inv_eq)
    finally show "cyclic_on (edge_pred M) (out_arcs G v)" by (simp add: orb_a_eq)
  qed

  definition (in pre_digraph_map) face_cycle_succ :: "'b  'b" where
    "face_cycle_succ  edge_succ M o edge_rev M"

  definition (in pre_digraph_map) face_cycle_pred :: "'b  'b" where
    "face_cycle_pred  edge_rev M o edge_pred M"

  lemma face_cycle_pred_succ[simp]:
    shows "face_cycle_pred (face_cycle_succ a) = a"
    unfolding face_cycle_pred_def face_cycle_succ_def by simp

  lemma face_cycle_succ_pred[simp]:
    shows "face_cycle_succ (face_cycle_pred a) = a"
    unfolding face_cycle_pred_def face_cycle_succ_def by simp

  lemma tail_face_cycle_succ: "a  arcs G  tail G (face_cycle_succ a) = head G a"
    by (auto simp: face_cycle_succ_def)

  lemma funpow_prop:
    assumes "x. P (f x)  P x"
    shows "P ((f ^^ n) x)  P x"
    using assms by (induct n) (auto simp: )

  lemma face_cycle_succ_no_arc[simp]: "a  arcs G  face_cycle_succ a = a"
    by (auto simp: face_cycle_succ_def permutes_not_in[OF arev_permutes_arcs]
      permutes_not_in[OF edge_succ_permutes])

  lemma funpow_face_cycle_succ_no_arc[simp]:
    assumes "a  arcs G" shows "(face_cycle_succ ^^ n) a = a"
    using assms by (induct n) auto

  lemma funpow_face_cycle_pred_no_arc[simp]:
    assumes "a  arcs G" shows "(face_cycle_pred ^^ n) a = a"
    using assms
    by (induct n) (auto simp: face_cycle_pred_def permutes_not_in[OF arev_permutes_arcs]
      permutes_not_in[OF edge_pred_permutes])

  lemma face_cycle_succ_closed[simp]:
    "face_cycle_succ a  arcs G  a  arcs G"
    by (metis comp_apply edge_rev_in_arcs edge_succ_in_arcs face_cycle_succ_def)

  lemma face_cycle_pred_closed[simp]:
    "face_cycle_pred a  arcs G  a  arcs G"
    by (metis face_cycle_succ_closed face_cycle_succ_pred)

  lemma face_cycle_succ_permutes:
    "face_cycle_succ permutes arcs G"
    unfolding face_cycle_succ_def
    using arev_permutes_arcs edge_succ_permutes by (rule permutes_compose)

  lemma permutation_face_cycle_succ: "permutation face_cycle_succ"
    using face_cycle_succ_permutes finite_arcs by (metis permutation_permutes)

  lemma bij_face_cycle_succ: "bij face_cycle_succ"
    using face_cycle_succ_permutes by (simp add: permutes_conv_has_dom)

  lemma face_cycle_pred_permutes:
    "face_cycle_pred permutes arcs G"
    unfolding face_cycle_pred_def
    using edge_pred_permutes arev_permutes_arcs by (rule permutes_compose)

  definition (in pre_digraph_map) face_cycle_set :: "'b  'b set" where
    "face_cycle_set a = orbit face_cycle_succ a"

  definition (in pre_digraph_map) face_cycle_sets :: "'b set set" where
    "face_cycle_sets = face_cycle_set ` arcs G"

  lemma face_cycle_set_altdef: "face_cycle_set a = {(face_cycle_succ ^^ n) a | n. True}"
    unfolding face_cycle_set_def
    by (intro orbit_altdef_self_in permutation_self_in_orbit permutation_face_cycle_succ)

  lemma face_cycle_set_self[simp, intro]: "a  face_cycle_set a"
    unfolding face_cycle_set_def using permutation_face_cycle_succ by (rule permutation_self_in_orbit)

  lemma empty_not_in_face_cycle_sets: "{}  face_cycle_sets"
      by (auto simp: face_cycle_sets_def)

  lemma finite_face_cycle_set[simp, intro]: "finite (face_cycle_set a)"
    using face_cycle_set_self unfolding face_cycle_set_def by (simp add: finite_orbit)

  lemma finite_face_cycle_sets[simp, intro]: "finite face_cycle_sets"
    by (auto simp: face_cycle_sets_def)

  lemma face_cycle_set_induct[case_names base step, induct set: face_cycle_set]:
    assumes consume: "a  face_cycle_set x"
      and ih_base: "P x"
      and ih_step: "y. y  face_cycle_set x  P y  P (face_cycle_succ y)"
    shows "P a"
    using consume unfolding face_cycle_set_def
    by induct (auto simp: ih_step face_cycle_set_def[symmetric] ih_base )

  lemma face_cycle_succ_cyclic:
    "cyclic_on face_cycle_succ (face_cycle_set a)"
    unfolding face_cycle_set_def using permutation_face_cycle_succ by (rule cyclic_on_orbit')

  lemma face_cycle_eq:
    assumes "b  face_cycle_set a" shows "face_cycle_set b = face_cycle_set a"
    using assms unfolding face_cycle_set_def
    by (auto intro: orbit_swap orbit_trans permutation_face_cycle_succ permutation_self_in_orbit)

  lemma face_cycle_succ_in_arcsI: "a. a  arcs G  face_cycle_succ a  arcs G"
    by (auto simp: face_cycle_succ_def)

  lemma face_cycle_succ_inI: "x y. x  face_cycle_set y  face_cycle_succ x  face_cycle_set y"
    by (metis face_cycle_succ_cyclic cyclic_on_inI)

  lemma face_cycle_succ_inD: "x y. face_cycle_succ x  face_cycle_set y  x  face_cycle_set y"
    by (metis face_cycle_eq face_cycle_set_self face_cycle_succ_inI)

  lemma face_cycle_set_parts:
    "face_cycle_set a = face_cycle_set b  face_cycle_set a  face_cycle_set b = {}"
    by (metis disjoint_iff_not_equal face_cycle_eq)

  definition fc_equiv :: "'b  'b  bool" where
    "fc_equiv a b  a  face_cycle_set b"

  lemma reflp_fc_equiv: "reflp fc_equiv"
    by (rule reflpI) (simp add: fc_equiv_def)

  lemma symp_fc_equiv: "symp fc_equiv"
    using face_cycle_set_parts
    by (intro sympI) (auto simp: fc_equiv_def)

  lemma transp_fc_equiv: "transp fc_equiv"
    using face_cycle_set_parts
    by (intro transpI) (auto simp: fc_equiv_def)

  lemma "equivp fc_equiv"
    by (intro equivpI reflp_fc_equiv symp_fc_equiv transp_fc_equiv)

  lemma in_face_cycle_setD:
    assumes "y  face_cycle_set x" "x  arcs G" shows "y  arcs G"
    using assms
    by (auto simp: face_cycle_set_def dest: permutes_orbit_subset[OF face_cycle_succ_permutes])

  lemma in_face_cycle_setsD:
    assumes "x  face_cycle_sets" shows "x  arcs G"
    using assms by (auto simp: face_cycle_sets_def dest: in_face_cycle_setD)

end

definition (in pre_digraph) isolated_verts :: "'a set" where
  "isolated_verts  {v  verts G. out_arcs G v = {}}"

definition (in pre_digraph_map) euler_char :: int where
  "euler_char  int (card (verts G)) - int (card (arcs G) div 2) + int (card face_cycle_sets)"

definition (in pre_digraph_map) euler_genus :: int where
  "euler_genus  (int (2 * card sccs) - int (card isolated_verts) - euler_char) div 2"

definition comb_planar :: "('a,'b) pre_digraph  bool" where
  "comb_planar G  M. digraph_map G M  pre_digraph_map.euler_genus G M = 0"


text ‹Number of isolated vertices is a graph invariant›
context
  fixes G hom assumes hom: "pre_digraph.digraph_isomorphism G hom"
begin

  interpretation wf_digraph G using hom by (auto simp: pre_digraph.digraph_isomorphism_def)

  lemma isolated_verts_app_iso[simp]:
    "pre_digraph.isolated_verts (app_iso hom G) = iso_verts hom ` isolated_verts"
    using hom
    by (auto simp: pre_digraph.isolated_verts_def iso_verts_tail inj_image_mem_iff out_arcs_app_iso_eq)

  lemma card_isolated_verts_iso[simp]:
    "card (iso_verts hom ` pre_digraph.isolated_verts G) = card isolated_verts"
    apply (rule card_image)
    using hom apply (rule digraph_isomorphism_inj_on_verts[THEN subset_inj_on])
    apply (auto simp: isolated_verts_def)
    done

end



context digraph_map begin

  lemma face_cycle_succ_neq:
    assumes "a  arcs G" "tail G a  head G a" shows "face_cycle_succ a  a "
  proof -
    from assms have "edge_rev M a  arcs G"
      by (subst edge_rev_in_arcs) simp
    then have "cyclic_on (edge_succ M) (out_arcs G (tail G (edge_rev M a)))"
      by (intro edge_succ_cyclic) (auto dest: tail_in_verts simp: out_arcs_def intro: exI[where x="edge_rev M a"])
    then have "edge_succ M (edge_rev M a)  (out_arcs G (tail G (edge_rev M a)))"
      by (rule cyclic_on_inI) (auto simp: ‹edge_rev M a  _[simplified])
    moreover have "tail G (edge_succ M (edge_rev M a)) = head G a"
      using assms by auto
    then have "edge_succ M (edge_rev M a)  a" using assms by metis
    ultimately show ?thesis
      using assms by (auto simp: face_cycle_succ_def)
  qed

end


section ‹Maps and Isomorphism›

definition (in pre_digraph)
  "wrap_iso_arcs hom f = perm_restrict (iso_arcs hom o f o iso_arcs (inv_iso hom)) (arcs (app_iso hom G))"

definition (in pre_digraph_map) map_iso :: "('a,'b,'a2,'b2) digraph_isomorphism  'b2 pre_map" where
  "map_iso f  
   edge_rev = wrap_iso_arcs f (edge_rev M)
  , edge_succ = wrap_iso_arcs f (edge_succ M)
  "

lemma funcsetI_permutes:
  assumes "f permutes S" shows "f  S  S"
  by (metis assms funcsetI permutes_in_image)

context
  fixes G hom assumes hom: "pre_digraph.digraph_isomorphism G hom"
begin

  interpretation wf_digraph G using hom by (auto simp: pre_digraph.digraph_isomorphism_def)

  lemma wrap_iso_arcs_iso_arcs[simp]:
    assumes "x  arcs G"
    shows "wrap_iso_arcs hom f (iso_arcs hom x) = iso_arcs hom (f x)"
    using assms hom by (auto simp: wrap_iso_arcs_def perm_restrict_def)

  lemma inj_on_wrap_iso_arcs:
    assumes dom: "f. f  F  has_dom f (arcs G)"
    assumes funcset: "F  arcs G  arcs G"
    shows "inj_on (wrap_iso_arcs hom) F"
  proof (rule inj_onI)
    fix f g assume F: "f  F" "g  F" and eq: "wrap_iso_arcs hom f = wrap_iso_arcs hom g"
    { fix x assume "x  arcs G"
      then have "f x = x" "g x = x" using F dom by (auto simp: has_dom_def)
      then have "f x = g x" by simp
    }
    moreover
    { fix x assume "x  arcs G"
      then have "f x  arcs G" "g x  arcs G" using F funcset by auto
      with digraph_isomorphism_inj_on_arcs[OF hom] _
      have "iso_arcs hom (f x) = iso_arcs hom (g x)  f x = g x"
        by (rule inj_onD)
      then have "f x = g x"
        using assms hom  x  arcs G eq
        by (auto simp: wrap_iso_arcs_def fun_eq_iff perm_restrict_def split: if_splits)
    }
    ultimately show "f = g" by auto
  qed
  
  lemma inj_on_wrap_iso_arcs_f:
    assumes "A  arcs G" "f  A  A" "B = iso_arcs hom ` A"
    assumes "inj_on f A" shows "inj_on (wrap_iso_arcs hom f) B"
  proof (rule inj_onI)
    fix x y
    assume in_hom_A: "x  B" "y  B"
      and wia_eq: "wrap_iso_arcs hom f x = wrap_iso_arcs hom f y"
    from in_hom_A B = _ obtain x0 where x0: "x = iso_arcs hom x0" "x0  A" by auto
    from in_hom_A B = _ obtain y0 where y0: "y = iso_arcs hom y0" "y0  A" by auto
    have arcs_0: "x0  arcs G" "y0  arcs G" "f x0  arcs G" "f y0  arcs G"
      using x0 y0 A  _ f  _ by auto
  
    have "(iso_arcs hom o f o iso_arcs (inv_iso hom)) x = (iso_arcs hom o f o iso_arcs (inv_iso hom)) y"
      using in_hom_A wia_eq assms(1) B = _ by (auto simp: wrap_iso_arcs_def perm_restrict_def split: if_splits)
    then show "x = y"
      using hom assms digraph_isomorphism_inj_on_arcs[OF hom] x0 y0 arcs_0 ‹inj_on f A A  _
      by (auto dest!:  inj_onD)
  qed
  
  lemma wrap_iso_arcs_in_funcsetI:
    assumes "A  arcs G" "f  A  A"
    shows "wrap_iso_arcs hom f  iso_arcs hom ` A   iso_arcs hom ` A"
  proof
    fix x assume "x  iso_arcs hom ` A"
    then obtain x0 where "x = iso_arcs hom x0" "x0  A" by blast
    then have "f x0  A" using f  _ by auto
    then show "wrap_iso_arcs hom f x  iso_arcs hom ` A"
      unfolding x = _ using x0  A assms hom by (auto simp: wrap_iso_arcs_def perm_restrict_def)
  qed
  
  lemma wrap_iso_arcs_permutes:
    assumes "A  arcs G" "f permutes A"
    shows "wrap_iso_arcs hom f permutes (iso_arcs hom ` A)"
  proof -
    { fix x assume A: "x  iso_arcs hom ` A"
      have "wrap_iso_arcs hom f x = x"
      proof cases
        assume "x  iso_arcs hom ` arcs G"
        then have "iso_arcs (inv_iso hom) x  A" "x  arcs (app_iso hom G)"
          using A hom by (metis arcs_app_iso image_eqI pre_digraph.iso_arcs_iso_inv, simp)
        then have "f (iso_arcs (inv_iso hom) x) = (iso_arcs (inv_iso hom) x)"
          using f permutes A by (simp add: permutes_not_in) 
        then show ?thesis using hom assms x  arcs _
          by (simp add: wrap_iso_arcs_def perm_restrict_def)
      next
        assume "x  iso_arcs hom ` arcs G"
        then show ?thesis
          by (simp add: wrap_iso_arcs_def perm_restrict_def)
      qed
    } note not_in_id = this
  
    have "f  A  A" using assms by (intro funcsetI_permutes)
    have inj_on_wrap: "inj_on (wrap_iso_arcs hom f) (iso_arcs hom ` A)"
      using assms f  A  A by (intro inj_on_wrap_iso_arcs_f) (auto intro: subset_inj_on permutes_inj)
    have woa_in_fs: "wrap_iso_arcs hom f  iso_arcs hom ` A  iso_arcs hom ` A"
      using assms f  A  A by (intro wrap_iso_arcs_in_funcsetI)
  
    { fix x y assume "wrap_iso_arcs hom f x = wrap_iso_arcs hom f y"
      then have "x = y"
        apply (cases "x  iso_arcs hom ` A"; cases "y  iso_arcs hom ` A")
        using woa_in_fs inj_on_wrap by (auto dest: inj_onD simp: not_in_id)
    } note uniqueD = this
  
    note f permutes A
    moreover
    note not_in_id
    moreover
    { fix y have "x. wrap_iso_arcs hom f x = y"
      proof cases
        assume "y  iso_arcs hom ` A"
        then obtain y0 where "y0  A" "iso_arcs hom y0 = y" by blast
        with f permutes A obtain x0 where "x0  A" "f x0 = y0" unfolding permutes_def by metis
        moreover
        then have "x. x  arcs G  iso_arcs hom x0 = iso_arcs hom x  x0 = x"
          using assms hom by (auto simp: digraph_isomorphism_def dest: inj_onD)
        ultimately
        have "wrap_iso_arcs hom f (iso_arcs hom x0) = y"
          using _ = y assms hom by (auto simp: wrap_iso_arcs_def perm_restrict_def)
        then show ?thesis ..
      qed (metis not_in_id)
    }
    ultimately
    show ?thesis unfolding permutes_def by (auto simp: dest: uniqueD)
  qed
  
end

lemma (in digraph_map) digraph_map_isoI:
  assumes "digraph_isomorphism hom" shows "digraph_map (app_iso hom G) (map_iso hom)"
proof -
  interpret iG: fin_digraph "app_iso hom G" using assms by (rule fin_digraphI_app_iso)
  show ?thesis
  proof (rule iG.digraph_mapI_permutes)
    show "edge_rev (map_iso hom) permutes arcs (app_iso hom G)"
      using assms unfolding map_iso_def by (simp add: wrap_iso_arcs_permutes arev_permutes_arcs)
  next
    show "edge_succ (map_iso hom) permutes arcs (app_iso hom G)"
      using assms unfolding map_iso_def by (simp add: wrap_iso_arcs_permutes edge_succ_permutes)
  next
    fix a assume A: "a  arcs (app_iso hom G)"
    show "tail (app_iso hom G) (edge_rev (map_iso hom) a) = head (app_iso hom G) a"
      using A assms
      by (cases rule: in_arcs_app_iso_cases) (auto simp: map_iso_def iso_verts_tail iso_verts_head)
    show "edge_rev (map_iso hom) (edge_rev (map_iso hom) a) = a"
      using A assms by (cases rule: in_arcs_app_iso_cases) (auto simp: map_iso_def)
    show "edge_rev (map_iso hom) a  a"
      using A assms by (auto simp: map_iso_def arev_neq)
  next
    fix v assume "v  verts (app_iso hom G)" and oa_hom: "out_arcs (app_iso hom G) v  {}"
    then obtain v0 where "v0  verts G" "v = iso_verts hom v0" by auto
    moreover
    then have oa: "out_arcs G v0  {}"
      using assms oa_hom by (auto simp: out_arcs_def iso_verts_tail)
    ultimately
    have cyclic_on_v0: "cyclic_on (edge_succ M) (out_arcs G v0)"
      by (intro edge_succ_cyclic)

    from oa_hom obtain a where "a  out_arcs (app_iso hom G) v" by blast
    then obtain a0 where "a0  arcs G" "a = iso_arcs hom a0" by auto
    then have "a0  out_arcs G v0"
      using v = _ v0  _ a  _ assms by (simp add: iso_verts_tail)

    show "cyclic_on (edge_succ (map_iso hom)) (out_arcs (app_iso hom G) v)"
    proof (rule cyclic_on_singleI)
      show "a  out_arcs (app_iso hom G) v" by fact
    next
      have "out_arcs (app_iso hom G) v = iso_arcs hom ` out_arcs G v0"
        unfolding v = _ by (rule out_arcs_app_iso_eq) fact+
      also have "out_arcs G v0 = orbit (edge_succ M) a0"
        using cyclic_on_v0 a0  out_arcs G v0 unfolding cyclic_on_alldef by simp
      also have "iso_arcs hom `  = orbit (edge_succ (map_iso hom)) a"
      proof -
        have "x. x  orbit (edge_succ M) a0  x  arcs G"
          using ‹out_arcs G v0 = _ by auto
        then show ?thesis using ‹out_arcs G v0 = _
          unfolding a = _ using a0  out_arcs G v0 assms
          by (intro orbit_inverse) (auto simp: map_iso_def)
      qed
      finally show "out_arcs (app_iso hom G) v = orbit (edge_succ (map_iso hom)) a" .
    qed
  qed
qed

end

Theory List_Aux

theory List_Aux
imports
  "List-Index.List_Index"
begin

section ‹Auxiliary List Lemmas›

lemma nth_rotate_conv_nth1_conv_nth:
  assumes "m < length xs"
  shows "rotate1 xs ! m = xs ! (Suc m mod length xs)"
  using assms
proof (induction xs arbitrary: m)
  case (Cons x xs)
  show ?case
  proof (cases "m < length xs")
    case False
    with Cons.prems have "m = length xs" by force
    then show ?thesis by (auto simp: nth_append)
  qed (auto simp: nth_append)
qed simp

lemma nth_rotate_conv_nth:
  assumes "m < length xs"
  shows "rotate n xs ! m = xs ! ((m + n) mod length xs)"
  using assms
proof (induction n arbitrary: m)
  case 0 then show ?case by simp
next
  case (Suc n)
  show ?case
  proof cases
    assume "m + 1 < length xs"
    with Suc show ?thesis using Suc by (auto simp: nth_rotate_conv_nth1_conv_nth)
  next
    assume "¬(m + 1 < length xs)"
    with Suc have "m + 1 = length xs" "0 < length xs" by auto
    moreover
    { have "Suc (m + n) mod length xs = (Suc m + n) mod length xs"
        by auto
      also have " = n mod length xs" using m + 1 = _ by simp
      finally have "Suc (m + n) mod length xs = n mod length xs" .}
    ultimately
    show ?thesis by (auto simp: nth_rotate_conv_nth1_conv_nth Suc.IH)
  qed
qed

lemma not_nil_if_in_set:
  assumes "x  set xs" shows "xs  []"
  using assms by auto

lemma length_fold_remove1_le:
 "length (fold remove1 ys xs)  length xs"
proof (induct ys arbitrary: xs)
  case (Cons y ys)
  then have "length (fold remove1 ys (remove1 y xs))  length (remove1 y xs)" by auto
  also have "  length xs" by (auto simp: length_remove1)
  finally show ?case by simp
qed simp

lemma set_fold_remove1':
  assumes "x  set xs - set ys" shows "x  set (fold remove1 ys xs)"
  using assms by (induct ys arbitrary: xs) auto

lemma set_fold_remove1:
  "set (fold remove1 xs ys)  set ys"
  by (induct xs arbitrary: ys) (auto, metis notin_set_remove1 subsetCE)

lemma set_fold_remove1_distinct:
  assumes "distinct xs" shows "set (fold remove1 ys xs) = set xs - set ys"
  using assms by (induct ys arbitrary: xs) auto

lemma distinct_fold_remove1:
  assumes "distinct xs"
  shows "distinct (fold remove1 ys xs)"
  using assms by (induct ys arbitrary: xs) auto

end

Theory Executable_Permutations

section ‹Permutations as Products of Disjoint Cycles›

theory Executable_Permutations
imports
  "HOL-Library.Rewrite"
  "HOL-Combinatorics.Permutations"
  Graph_Theory.Auxiliary
  List_Aux
begin

subsection ‹Cyclic Permutations›

definition list_succ :: "'a list  'a  'a" where
  "list_succ xs x = (if x  set xs then xs ! ((index xs x + 1) mod length xs) else x)"

text ‹
  We demonstrate the functions on the following simple lemmas

  @{lemma "list_succ [1 :: int, 2, 3] 1 = 2" by code_simp}
  @{lemma "list_succ [1 :: int, 2, 3] 2 = 3" by code_simp}
  @{lemma "list_succ [1 :: int, 2, 3] 3 = 1" by code_simp}

lemma list_succ_altdef:
  "list_succ xs x = (let n = index xs x in if n + 1 = length xs then xs ! 0 else if n + 1 < length xs then xs ! (n + 1) else x)"
  using index_le_size[of xs x] unfolding list_succ_def index_less_size_conv[symmetric] by (auto simp: Let_def)

lemma list_succ_Nil:
  "list_succ [] = id"
  by (simp add: list_succ_def fun_eq_iff)

lemma list_succ_singleton:
  "list_succ [x] = list_succ []"
  by (simp add: fun_eq_iff list_succ_def)

lemma list_succ_short:
  assumes "length xs < 2" shows "list_succ xs = id"
  using assms
  by (cases xs) (rename_tac [2] y ys, case_tac [2] ys, auto simp: list_succ_Nil list_succ_singleton)

lemma list_succ_simps:
  "index xs x + 1 = length xs  list_succ xs x = xs ! 0"
  "index xs x + 1 < length xs  list_succ xs x = xs ! (index xs x + 1)"
  "length xs  index xs x  list_succ xs x = x"
  by (auto simp: list_succ_altdef)

lemma list_succ_not_in:
  assumes "x  set xs" shows "list_succ xs x = x"
  using assms by (auto simp: list_succ_def)

lemma list_succ_list_succ_rev:
  assumes "distinct xs" shows "list_succ (rev xs) (list_succ xs x) = x"
proof -
  { assume "index xs x + 1 < length xs"
    moreover then have "length xs - Suc (Suc (length xs - Suc (Suc (index xs x)))) = index xs x"
      by linarith
    ultimately have ?thesis using assms
      by (simp add: list_succ_def index_rev index_nth_id rev_nth)
  }
  moreover
  { assume A: "index xs x + 1 = length xs"
    moreover
    from A have "xs  []" by auto
    moreover
    with A have "last xs = xs ! index xs x" by (cases "length xs") (auto simp: last_conv_nth)
    ultimately
    have ?thesis
      using assms
      by (auto simp add: list_succ_def rev_nth index_rev index_nth_id last_conv_nth)
  }
  moreover
  { assume A: "index xs x  length xs"
    then have "x  set xs" by (metis index_less less_irrefl)
    then have ?thesis by (auto simp: list_succ_def) }
  ultimately show ?thesis by (metis discrete le_less not_less) 
qed

lemma inj_list_succ: "distinct xs  inj (list_succ xs)"
  by (metis injI list_succ_list_succ_rev)

lemma inv_list_succ_eq: "distinct xs  inv (list_succ xs) = list_succ (rev xs)"
  by (metis distinct_rev inj_imp_inv_eq inj_list_succ list_succ_list_succ_rev)

lemma bij_list_succ: "distinct xs  bij (list_succ xs)"
  by (metis bij_def inj_list_succ distinct_rev list_succ_list_succ_rev surj_def)

lemma list_succ_permutes:
  assumes "distinct xs" shows "list_succ xs permutes set xs"
  using assms by (auto simp: permutes_conv_has_dom bij_list_succ has_dom_def list_succ_def)

lemma permutation_list_succ:
  assumes "distinct xs" shows "permutation (list_succ xs)"
  using list_succ_permutes[OF assms] by (auto simp: permutation_permutes)

lemma list_succ_nth:
  assumes "distinct xs" "n < length xs" shows "list_succ xs (xs ! n) = xs ! (Suc n mod length xs)"
  using assms by (auto simp: list_succ_def index_nth_id)

lemma list_succ_last[simp]:
  assumes "distinct xs" "xs  []" shows "list_succ xs (last xs) = hd xs"
  using assms by (auto simp: list_succ_def hd_conv_nth)

lemma list_succ_rotate1[simp]:
  assumes "distinct xs" shows "list_succ (rotate1 xs) = list_succ xs"
proof (rule ext)
  fix y show "list_succ (rotate1 xs) y = list_succ xs y"
    using assms
  proof (induct xs)
    case Nil then show ?case by simp
  next
    case (Cons x xs)
    show ?case
    proof (cases "x = y")
      case True
      then have "index (xs @ [y]) y = length xs"
        using ‹distinct (x # xs) by (simp add: index_append)
      with True show ?thesis by (cases "xs=[]") (auto simp: list_succ_def nth_append)
    next
      case False
      then show ?thesis
        apply (cases "index xs y + 1 < length xs")
        apply (auto simp:list_succ_def index_append nth_append)
        by (metis Suc_lessI index_less_size_conv mod_self nth_Cons_0 nth_append nth_append_length)
    qed
  qed
qed
  
lemma list_succ_rotate[simp]:
  assumes "distinct xs" shows "list_succ (rotate n xs) = list_succ xs"
  using assms by (induct n) auto

lemma list_succ_in_conv:
  "list_succ xs x  set xs  x  set xs"
  by (auto simp: list_succ_def not_nil_if_in_set )

lemma list_succ_in_conv1:
  assumes "A  set xs = {}"
  shows "list_succ xs x  A  x  A"
  by (metis assms disjoint_iff_not_equal list_succ_in_conv list_succ_not_in)

lemma list_succ_commute:
  assumes "set xs  set ys = {}"
  shows "list_succ xs (list_succ ys x) = list_succ ys (list_succ xs x)"
proof -
  have "x. x  set xs  list_succ ys x = x"
     "x. x  set ys  list_succ xs x = x"
    using assms by (blast intro: list_succ_not_in)+
  then show ?thesis
    by (cases "x  set xs  set ys") (auto simp: list_succ_in_conv list_succ_not_in)
qed


subsection ‹Arbitrary Permutations›

fun lists_succ :: "'a list list  'a  'a" where
  "lists_succ [] x = x"
| "lists_succ (xs # xss) x = list_succ xs (lists_succ xss x)"

definition distincts ::  "'a list list  bool" where
  "distincts xss  distinct xss  (xs  set xss. distinct xs  xs  [])  (xs  set xss. ys  set xss. xs  ys  set xs  set ys = {})"

lemma distincts_distinct: "distincts xss  distinct xss"
  by (auto simp: distincts_def)

lemma distincts_Nil[simp]: "distincts []"
  by (simp add: distincts_def)

lemma distincts_single: "distincts [xs]  distinct xs  xs  []"
  by (auto simp add: distincts_def)

lemma distincts_Cons: "distincts (xs # xss)
    xs  []  distinct xs  distincts xss  (set xs  (ys  set xss. set ys)) = {}" (is "?L  ?R")
proof 
  assume ?L then show ?R by (auto simp: distincts_def)
next
  assume ?R
  then have "distinct (xs # xss)"
    apply (auto simp: disjoint_iff_not_equal distincts_distinct)
    apply (metis length_greater_0_conv nth_mem)
    done
  moreover
  from ?R have "xs  set (xs # xss). distinct xs  xs  []"
    by (auto simp: distincts_def)
  moreover
  from ?R have "xs'  set (xs # xss). ys  set (xs # xss). xs'  ys  set xs'  set ys = {}"
    by (simp add: distincts_def) blast
  ultimately show ?L unfolding distincts_def by (intro conjI)
qed

lemma distincts_Cons': "distincts (xs # xss)
    xs  []  distinct xs  distincts xss  (ys  set xss. set xs  set ys = {})" (is "?L  ?R")
 unfolding distincts_Cons by blast

lemma distincts_rev:
  "distincts (map rev xss)  distincts xss"
  by (simp add: distincts_def distinct_map)

lemma length_distincts:
  assumes "distincts xss"
  shows "length xss = card (set ` set xss)"
  using assms
proof (induct xss)
  case Nil then show ?case by simp
next
  case (Cons xs xss)
  then have "set xs  set ` set xss"
    using equals0I[of "set xs"] by (auto simp: distincts_Cons disjoint_iff_not_equal )
  with Cons show ?case by (auto simp add: distincts_Cons)
qed

lemma distincts_remove1: "distincts xss  distincts (remove1 xs xss)"
  by (auto simp: distincts_def)

lemma distinct_Cons_remove1:
  "x  set xs  distinct (x # remove1 x xs) = distinct xs"
  by (induct xs) auto

lemma set_Cons_remove1:
  "x  set xs  set (x # remove1 x xs) = set xs"
  by (induct xs) auto

lemma distincts_Cons_remove1:
  "xs  set xss  distincts (xs # remove1 xs xss) = distincts xss"
  by (simp only: distinct_Cons_remove1 set_Cons_remove1 distincts_def)

lemma distincts_inj_on_set:
  assumes "distincts xss" shows "inj_on set (set xss)"
  by (rule inj_onI) (metis assms distincts_def inf.idem set_empty)

lemma distincts_distinct_set:
  assumes "distincts xss" shows "distinct (map set xss)"
  using assms by (auto simp: distinct_map distincts_distinct distincts_inj_on_set)

lemma distincts_distinct_nth:
  assumes "distincts xss" "n < length xss" shows "distinct (xss ! n)"
  using assms by (auto simp: distincts_def)

lemma lists_succ_not_in:
  assumes "x  (xsset xss. set xs)" shows "lists_succ xss x = x"
  using assms by (induct xss) (auto simp: list_succ_not_in)

lemma lists_succ_in_conv:
  "lists_succ xss x  (xsset xss. set xs)  x  (xsset xss. set xs)"
  by (induct xss) (auto simp: list_succ_in_conv lists_succ_not_in list_succ_not_in)

lemma lists_succ_in_conv1:
  assumes "A  (xsset xss. set xs) = {}"
  shows "lists_succ xss x  A  x  A"
  by (metis Int_iff assms emptyE lists_succ_in_conv lists_succ_not_in)

lemma lists_succ_Cons_pf: "lists_succ (xs # xss) = list_succ xs o lists_succ xss"
  by auto

lemma lists_succ_Nil_pf: "lists_succ [] = id"
  by (simp add: fun_eq_iff)

lemmas lists_succ_simps_pf = lists_succ_Cons_pf lists_succ_Nil_pf

lemma lists_succ_permutes:
  assumes "distincts xss"
  shows "lists_succ xss permutes (xs  set xss. set xs)"
  using assms
proof (induction xss)
  case Nil then show ?case by auto
next
  case (Cons xs xss)
  have "list_succ xs permutes (set xs)"
    using Cons by (intro list_succ_permutes) (simp add: distincts_def in_set_member)
  moreover
  have "lists_succ xss permutes (ys  set xss. set ys)"
    using Cons by (auto simp: Cons distincts_def)
  ultimately show "lists_succ (xs # xss) permutes (ys  set (xs # xss). set ys)"
    using Cons by (auto simp: lists_succ_Cons_pf intro: permutes_compose permutes_subset)
qed

lemma bij_lists_succ: "distincts xss  bij (lists_succ xss)"
  by (induct xss) (auto simp: lists_succ_simps_pf bij_comp bij_list_succ distincts_Cons)

lemma lists_succ_snoc: "lists_succ (xss @ [xs]) = lists_succ xss o list_succ xs"
  by (induct xss) auto

lemma inv_lists_succ_eq:
  assumes "distincts xss"
  shows "inv (lists_succ xss) = lists_succ (rev (map rev xss))"
proof -
  have *: "f g. inv (λb. f (g b)) = inv (f o g)" by (simp add: o_def)
  have **: "lists_succ [] = id" by auto
  show ?thesis
    using assms by (induct xss) (auto simp: * ** lists_succ_snoc lists_succ_Cons_pf o_inv_distrib
      inv_list_succ_eq distincts_Cons bij_list_succ bij_lists_succ)
qed

lemma lists_succ_remove1:
  assumes "distincts xss" "xs  set xss"
  shows "lists_succ (xs # remove1 xs xss) = lists_succ xss"
  using assms
proof (induct xss)
  case Nil then show ?case by simp
next
  case (Cons ys xss)
  show ?case
  proof cases
    assume "xs = ys" then show ?case by simp
  next
    assume "xs  ys"
    with Cons.prems have inter: "set xs  set ys = {}" and "xs  set xss"
      by (auto simp: distincts_Cons)
    have dists:
        "distincts (xs # remove1 xs xss)"
        "distincts (xs # ys # remove1 xs xss)"
      using ‹distincts (ys # xss) xs  set xss by (auto simp: distincts_def)

    have "list_succ xs  (list_succ ys  lists_succ (remove1 xs xss))
        = list_succ ys  (list_succ xs  lists_succ (remove1 xs xss))"
      using inter unfolding fun_eq_iff comp_def
      by (subst list_succ_commute) auto
    also have " = list_succ ys o (lists_succ (xs # remove1 xs xss))"
      using dists by (simp add: lists_succ_Cons_pf distincts_Cons)
    also have " = list_succ ys o lists_succ xss"
      using xs  set xss ‹distincts (ys # xss)
      by (simp add: distincts_Cons Cons.hyps)
    finally
    show "lists_succ (xs # remove1 xs (ys # xss)) = lists_succ (ys # xss)"
      using Cons dists by (auto simp: lists_succ_Cons_pf distincts_Cons)
  qed
qed

lemma lists_succ_no_order:
  assumes "distincts xss" "distincts yss" "set xss = set yss"
  shows "lists_succ xss = lists_succ yss"
  using assms
proof (induct xss arbitrary: yss)
  case Nil then show ?case by simp
next
  case (Cons xs xss)
  have "xs  set xss" "xs  set yss" using Cons.prems
    by (auto dest: distincts_distinct)
  have "lists_succ xss = lists_succ (remove1 xs yss)"
    using Cons.prems xs  _
    by (intro Cons.hyps) (auto simp add: distincts_Cons distincts_remove1 distincts_distinct)
  then have "lists_succ (xs # xss) = lists_succ (xs # remove1 xs yss)"
    using Cons.prems xs  _
    by (simp add: lists_succ_Cons_pf distincts_Cons_remove1)
  then show ?case
    using Cons.prems xs  _ by (simp add: lists_succ_remove1)
qed



section ‹List Orbits›

text ‹Computes the orbit of @{term x} under @{term f}
definition orbit_list :: "('a  'a)  'a  'a list" where
  "orbit_list f x  iterate 0 (funpow_dist1 f x x) f x"

partial_function (tailrec)
  orbit_list_impl :: "('a  'a)  'a  'a list  'a  'a list"
where
  "orbit_list_impl f s acc x = (let x' = f x in if x' = s then rev (x # acc) else orbit_list_impl f s (x # acc) x')"

context notes [simp] = length_fold_remove1_le begin
text ‹Computes the list of orbits›
fun orbits_list :: "('a  'a)  'a list  'a list list" where
  "orbits_list f [] = []"
| "orbits_list f (x # xs) =
     orbit_list f x # orbits_list f (fold remove1 (orbit_list f x) xs)"

fun orbits_list_impl :: "('a  'a)  'a list  'a list list" where
  "orbits_list_impl f [] = []"
| "orbits_list_impl f (x # xs) =
     (let fc = orbit_list_impl f x [] x in fc # orbits_list_impl f (fold remove1 fc xs))"

declare orbit_list_impl.simps[code]
end

abbreviation sset :: "'a list list  'a set set" where
  "sset xss  set ` set xss"

lemma iterate_funpow_step:
  assumes "f x  y" "y  orbit f x"
  shows "iterate 0 (funpow_dist1 f x y) f x = x # iterate 0 (funpow_dist1 f (f x) y) f (f x)"
proof -
  from assms have A: "y  orbit f (f x)" by (simp add: orbit_step)
  have "iterate 0 (funpow_dist1 f x y) f x = x # iterate 1 (funpow_dist1 f x y) f x" (is "_ = _ # ?it")
    unfolding iterate_def by (rewrite in " = _" upt_conv_Cons) auto
  also have "?it = map (λn. (f ^^ n) x) (map Suc [0..<funpow_dist f (f x) y])"
    unfolding iterate_def map_Suc_upt by simp
  also have " = map (λn. (f ^^ n) (f x)) [0..<funpow_dist f (f x) y]"
    by (simp add: funpow_swap1)
  also have " = iterate 0 (funpow_dist1 f (f x) y) f (f x)"
    unfolding iterate_def
    unfolding iterate_def by (simp add: funpow_dist_step[OF assms(1) A])
  finally show ?thesis .
qed

lemma orbit_list_impl_conv:
  assumes "y  orbit f x"
  shows "orbit_list_impl f y acc x = rev acc @ iterate 0 (funpow_dist1 f x y) f x"
  using assms
proof (induct n"funpow_dist1 f x y" arbitrary: x acc)
  case (Suc x)

  show ?case
  proof cases
    assume "f x = y"
    then show ?thesis by (subst orbit_list_impl.simps) (simp add: Let_def iterate_def funpow_dist_0)
  next
    assume not_y :"f x  y"

    have y_in_succ: "y  orbit f (f x)"
      by (intro orbit_step Suc.prems not_y)

    have "orbit_list_impl f y acc x = orbit_list_impl f y (x # acc) (f x)"
      using not_y by (subst orbit_list_impl.simps) simp
    also have " = rev (x # acc) @ iterate 0 (funpow_dist1 f (f x) y) f (f x)" (is "_ = ?rev @ ?it")
      by (intro Suc funpow_dist_step not_y y_in_succ)
    also have " = rev acc @ iterate 0 (funpow_dist1 f x y) f x"
      using not_y Suc.prems by (simp add: iterate_funpow_step)
    finally show ?thesis .
  qed
qed

lemma orbit_list_conv_impl:
  assumes "x  orbit f x"
  shows "orbit_list f x = orbit_list_impl f x [] x"
  unfolding orbit_list_impl_conv[OF assms] orbit_list_def by simp


lemma set_orbit_list:
  assumes "x  orbit f x"
  shows "set (orbit_list f x) = orbit f x"
  by (simp add: orbit_list_def orbit_conv_funpow_dist1[OF assms] set_iterate)

lemma set_orbit_list':
  assumes "permutation f" shows "set (orbit_list f x) = orbit f x"
  using assms by (simp add: permutation_self_in_orbit set_orbit_list)

lemma distinct_orbit_list:
  assumes "x  orbit f x"
  shows "distinct (orbit_list f x)"
  by (simp del: upt_Suc add: orbit_list_def iterate_def distinct_map inj_on_funpow_dist1[OF assms])

lemma distinct_orbit_list':
  assumes "permutation f" shows "distinct (orbit_list f x)"
  using assms by (simp add: permutation_self_in_orbit distinct_orbit_list)

lemma orbits_list_conv_impl:
  assumes "permutation f"
  shows "orbits_list f xs = orbits_list_impl f xs"
proof (induct "length xs" arbitrary: xs rule: less_induct)
  case less show ?case
    using assms by (cases xs) (auto simp: assms less less_Suc_eq_le length_fold_remove1_le
      orbit_list_conv_impl permutation_self_in_orbit Let_def)
qed

lemma orbit_list_not_nil[simp]: "orbit_list f x  []"
  by (simp add: orbit_list_def)

lemma sset_orbits_list:
  assumes "permutation f" shows "sset (orbits_list f xs) = (orbit f) ` set xs"
proof (induct "length xs" arbitrary: xs rule: less_induct)
  case less
  show ?case
  proof (cases xs)
    case Nil then show ?thesis by simp
  next
    case (Cons x' xs')
    let ?xs'' = "fold remove1 (orbit_list f x') xs'"
    have A: "sset (orbits_list f ?xs'') = orbit f ` set ?xs''"
      using Cons by (simp add: less_Suc_eq_le length_fold_remove1_le less.hyps)
    have B: "set (orbit_list f x') = orbit f x'"
      by (rule set_orbit_list) (simp add: permutation_self_in_orbit assms)

    have "orbit f ` set (fold remove1 (orbit_list f x') xs')  orbit f ` set xs'"
      using set_fold_remove1[of _ xs'] by auto
    moreover
    have "orbit f ` set xs' - {orbit f x'}  (orbit f ` set (fold remove1 (orbit_list f x') xs'))" (is "?L  ?R")
    proof
      fix A assume "A  ?L"
      then obtain y where "A = orbit f y" "y  set xs'" by auto
      have "A  orbit f x'" using A  ?L by auto
      from A = _ A  _ have "y  orbit f x'"
        by (meson assms cyclic_on_orbit orbit_cyclic_eq3 permutation_permutes)
      with y  _ have "y  set (fold remove1 (orbit_list f x') xs')"
        by (auto simp: set_fold_remove1' set_orbit_list permutation_self_in_orbit assms)
      then show "A  ?R" using A = _ by auto
    qed
    ultimately
    show ?thesis by (auto simp: A B Cons)
  qed
qed



subsection ‹Relation to @{term cyclic_on}

lemma list_succ_orbit_list:
  assumes "s  orbit f s" "x. x  orbit f s  f x = x"
  shows "list_succ (orbit_list f s) = f"
proof -
  have "distinct (orbit_list f s)" "x. x  set (orbit_list f s)  x = f x"
    using assms by (simp_all add: distinct_orbit_list set_orbit_list)
  moreover
  have "i. i < length (orbit_list f s)  orbit_list f s ! (Suc i mod length (orbit_list f s)) = f (orbit_list f s ! i)"
    using funpow_dist1_prop[OF s  orbit f s] by (auto simp: orbit_list_def funpow_mod_eq)
  ultimately show ?thesis
    by (auto simp: list_succ_def fun_eq_iff)
qed

lemma list_succ_funpow_conv:
  assumes A: "distinct xs" "x  set xs"
  shows "(list_succ xs ^^ n) x = xs ! ((index xs x + n) mod length xs)"
proof -
  have "xs  []" using assms by auto
  then show ?thesis
    by (induct n) (auto simp: hd_conv_nth A index_nth_id list_succ_def mod_simps)
qed

lemma orbit_list_succ:
  assumes "distinct xs" "x  set xs"
  shows "orbit (list_succ xs) x = set xs"
proof (intro set_eqI iffI)
  fix y assume "y  orbit (list_succ xs) x"
  then show "y  set xs"
    by induct (auto simp: list_succ_in_conv x  set xs)
next
  fix y assume "y  set xs"
  moreover
  { fix i j have "i < length xs  j < length xs  n. xs ! j = xs ! ((i + n) mod length xs)"
      using assms by (auto simp: exI[where x="j + (length xs - i)"])
  }
  ultimately
  show "y  orbit (list_succ xs) x"
    using assms by (auto simp: orbit_altdef_permutation permutation_list_succ list_succ_funpow_conv index_nth_id in_set_conv_nth)
qed

lemma cyclic_on_list_succ:
  assumes "distinct xs" "xs  []" shows "cyclic_on (list_succ xs) (set xs)"
  using assms last_in_set by (auto simp: cyclic_on_def orbit_list_succ)

lemma obtain_orbit_list_func:
  assumes "s  orbit f s" "x. x  orbit f s  f x = x"
  obtains xs where "f = list_succ xs" "set xs = orbit f s" "distinct xs" "hd xs = s"
proof -
  { from assms have "f = list_succ (orbit_list f s)" by (simp add: list_succ_orbit_list)
    moreover
    have "set (orbit_list f s) = orbit f s" "distinct (orbit_list f s)"
      by (auto simp: set_orbit_list distinct_orbit_list assms)
    moreover have "hd (orbit_list f s) = s"
      by (simp add: orbit_list_def iterate_def hd_map del: upt_Suc)
    ultimately have "xs. f = list_succ xs  set xs = orbit f s  distinct xs  hd xs = s" by blast
  } then show ?thesis by (metis that)
qed

lemma cyclic_on_obtain_list_succ:
  assumes "cyclic_on f S" "x. x  S  f x = x"
  obtains xs where "f = list_succ xs" "set xs = S" "distinct xs"
proof -
  from assms obtain s where s: "s  orbit f s" "x. x  orbit f s  f x = x"  "S = orbit f s"
    by (auto simp: cyclic_on_def)
  then show ?thesis by (metis that obtain_orbit_list_func)
qed

lemma cyclic_on_obtain_list_succ':
  assumes "cyclic_on f S" "f permutes S"
  obtains xs where "f = list_succ xs" "set xs = S" "distinct xs"
  using assms unfolding permutes_def by (metis cyclic_on_obtain_list_succ)

lemma list_succ_unique:
  assumes "s  orbit f s" "x. x  orbit f s  f x = x"
  shows "∃!xs. f = list_succ xs  distinct xs  hd xs = s  set xs = orbit f s"
proof -
  from assms obtain xs where xs: "f = list_succ xs" "distinct xs" "hd xs = s" "set xs = orbit f s" 
    by (rule obtain_orbit_list_func)
  moreover
  { fix zs
    assume A: "f = list_succ zs" "distinct zs" "hd zs = s" "set zs = orbit f s"
    then have "zs  []" using s  orbit f s by auto
    from ‹distinct xs ‹distinct zs ‹set xs = orbit f s ‹set zs = orbit f s
    have len: "length xs = length zs" by (metis distinct_card)

    { fix n assume "n < length xs"
      then have "zs ! n = xs ! n"
      proof (induct n)
        case 0 with A xs zs  [] show ?case by (simp add: hd_conv_nth nth_rotate_conv_nth)
      next
        case (Suc n)
        then have "list_succ zs (zs ! n) = list_succ xs (xs! n)"
          using f = list_succ xs f = list_succ zs by simp
        with ‹Suc n < _ show ?case
          by (simp add:list_succ_nth len ‹distinct xs ‹distinct zs)
      qed }
    then have "zs = xs" by (metis len nth_equalityI) }
  ultimately show ?thesis by metis
qed

lemma distincts_orbits_list:
  assumes "distinct as" "permutation f"
  shows "distincts (orbits_list f as)"
  using assms(1)
proof (induct "length as" arbitrary: as rule: less_induct)
  case less
  show ?case
  proof (cases as)
    case Nil then show ?thesis by simp
  next
    case (Cons a as')
    let ?as' = "fold remove1 (orbit_list f a) as'"
    from Cons less.prems have A: "distincts (orbits_list f (fold remove1 (orbit_list f a) as'))"
      by (intro less) (auto simp: distinct_fold_remove1 length_fold_remove1_le less_Suc_eq_le)

    have B: "set (orbit_list f a)  (sset (orbits_list f (fold remove1 (orbit_list f a) as'))) = {}"
    proof -
      have "orbit f a  set (fold remove1 (orbit_list f a) as') = {}"
        using assms less.prems Cons by (simp add: set_fold_remove1_distinct set_orbit_list')
      then have "orbit f a   (orbit f ` set (fold remove1 (orbit_list f a) as')) = {}"
        by auto (metis assms(2) cyclic_on_orbit disjoint_iff_not_equal permutation_self_in_orbit[OF assms(2)] orbit_cyclic_eq3 permutation_permutes)
      then show ?thesis using assms
      by (auto simp: set_orbit_list' sset_orbits_list disjoint_iff_not_equal)
    qed
    show ?thesis
      using A B assms by (auto simp: distincts_Cons Cons distinct_orbit_list')
  qed
qed

lemma cyclic_on_lists_succ':
  assumes "distincts xss"
  shows "A  sset xss  cyclic_on (lists_succ xss) A"
  using assms
proof (induction xss arbitrary: A)
  case Nil then show ?case by auto
next
  case (Cons xs xss A)
  then have inter: "set xs  (ysset xss. set ys) = {}" by (auto simp: distincts_Cons)

  note pcp[OF _ _ inter] = permutes_comp_preserves_cyclic1 permutes_comp_preserves_cyclic2
  from Cons show "cyclic_on (lists_succ (xs # xss)) A"
    by (cases "A = set xs")
      (auto intro: pcp simp: cyclic_on_list_succ list_succ_permutes
        lists_succ_permutes lists_succ_Cons_pf distincts_Cons)
qed

lemma cyclic_on_lists_succ:
  assumes "distincts xss"
  shows "xs. xs  set xss  cyclic_on (lists_succ xss) (set xs)"
  using assms by (auto intro: cyclic_on_lists_succ')

lemma permutes_as_lists_succ:
  assumes "distincts xss"
  assumes ls_eq: "xs. xs  set xss  list_succ xs = perm_restrict f (set xs)"
  assumes "f permutes ((sset xss))"
  shows "f = lists_succ xss"
  using assms
proof (induct xss arbitrary: f)
  case Nil then show ?case by simp
next
  case (Cons xs xss)
  let ?sets = "λxss. ys  set xss. set ys"

  have xs: "distinct xs" "xs  []" using Cons by (auto simp: distincts_Cons)

  have f_xs: "perm_restrict f (set xs) = list_succ xs"
    using Cons by simp

  have co_xs: "cyclic_on (perm_restrict f (set xs)) (set xs)"
    unfolding f_xs using xs by (rule cyclic_on_list_succ)

  have perm_xs: "perm_restrict f (set xs) permutes set xs"
    unfolding f_xs using ‹distinct xs by (rule list_succ_permutes)

  have perm_xss: "perm_restrict f (?sets xss) permutes (?sets xss)"
  proof -
    have "perm_restrict f (?sets (xs # xss) - set xs) permutes (?sets (xs # xss) - set xs)"
      using Cons co_xs by (intro perm_restrict_diff_cyclic) (auto simp: cyclic_on_perm_restrict)
    also have "?sets (xs # xss) - set xs = ?sets xss"
      using Cons by (auto simp: distincts_Cons)
    finally show ?thesis .
  qed

  have f_xss: "perm_restrict f (?sets xss) = lists_succ xss"
  proof -
    have *: "xs. xs  set xss  ((xset xss. set x)  set xs) = set xs"
      by blast
    with perm_xss Cons.prems show ?thesis
      by (intro Cons.hyps) (auto simp: distincts_Cons perm_restrict_perm_restrict *)
  qed

  from Cons.prems show "f = lists_succ (xs # xss)"
    by (simp add: lists_succ_Cons_pf distincts_Cons f_xss[symmetric]
      perm_restrict_union perm_xs perm_xss)
qed

lemma cyclic_on_obtain_lists_succ:
  assumes
    permutes: "f permutes S" and
    S: "S = (sset css)" and
    dists: "distincts css" and
    cyclic: "cs. cs  set css  cyclic_on f (set cs)"
  obtains xss where "f = lists_succ xss" "distincts xss" "map set xss = map set css" "map hd xss = map hd css"
proof -
  let ?fc = "λcs. perm_restrict f (set cs)"
  define some_list where "some_list cs = (SOME xs. ?fc cs = list_succ xs  set xs = set cs  distinct xs  hd xs = hd cs)" for cs
  { fix cs assume "cs  set css"
    then have "cyclic_on (?fc cs) (set cs)" "x. x  set cs  ?fc cs x = x" "hd cs  set cs"
      using cyclic dists by (auto simp add: cyclic_on_perm_restrict perm_restrict_def distincts_def)
    then have "hd cs  orbit (?fc cs) (hd cs)"  "x. x  orbit (?fc cs) (hd cs)  ?fc cs x = x" "hd cs  set cs" "set cs = orbit (?fc cs) (hd cs)"
      by (auto simp: cyclic_on_alldef)
    then have "xs. ?fc cs = list_succ xs  set xs = set cs  distinct xs  hd xs = hd cs"
      by (metis obtain_orbit_list_func)
    then have "?fc cs = list_succ (some_list cs)  set (some_list cs) = set cs  distinct (some_list cs)  hd (some_list cs) = hd cs"
      unfolding some_list_def by (rule someI_ex)
    then have "?fc cs = list_succ (some_list cs)" "set (some_list cs) = set cs" "distinct (some_list cs)" "hd (some_list cs) = hd cs"
      by auto
  } note sl_cs  = this

  have "cs. cs  set css  cs  []" using dists by (auto simp: distincts_def)
  then have some_list_ne: "cs. cs  set css  some_list cs  []"
    by (metis set_empty sl_cs(2))

  have set: "map set (map some_list css) = map set css" "map hd (map some_list css) = map hd css"
    using sl_cs(2,4) by (auto simp add: map_idI)

  have distincts: "distincts (map some_list css)"
  proof -
    have c_dist: "xs ys. xsset css; ysset css; xs  ys  set xs  set ys = {}"
      using dists by (auto simp: distincts_def)

    have "distinct (map some_list css)"
    proof -
      have "inj_on some_list (set css)"
        using sl_cs(2) c_dist by (intro inj_onI) (metis inf.idem set_empty) 
      with ‹distincts css show ?thesis
        by (auto simp: distincts_distinct distinct_map)
    qed
    moreover
    have "xsset (map some_list css). distinct xs  xs  []"
      using sl_cs(3) some_list_ne by auto
    moreover
    from c_dist have "(xsset (map some_list css). ysset (map some_list css). xs  ys  set xs  set ys = {})"
      using sl_cs(2) by auto
    ultimately
    show ?thesis by (simp add: distincts_def)
  qed

  have f: "f = lists_succ (map some_list css)"
    using distincts
  proof (rule permutes_as_lists_succ)
    fix xs assume "xs  set (map some_list css)"
    then show "list_succ xs = perm_restrict f (set xs)"
      using sl_cs(1) sl_cs(2) by auto
  next
    have "S = (xsset (map some_list css). set xs)"
      using S sl_cs(2) by auto
    with permutes show "f permutes (sset (map some_list css))"
      by simp
  qed

  from f distincts set  show ?thesis ..
qed


subsection ‹Permutations of a List›

lemma length_remove1_less:
  assumes "x  set xs" shows "length (remove1 x xs) < length xs"
proof -
  from assms have "0 < length xs" by auto
  with assms show ?thesis by (auto simp: length_remove1)
qed
context notes [simp] = length_remove1_less begin
fun permutations :: "'a list  'a list list" where
  permutations_Nil: "permutations [] = [[]]"
| permutations_Cons:
    "permutations xs = [y # ys. y <- xs, ys <- permutations (remove1 y xs)]"
end

declare permutations_Cons[simp del]

text ‹
  The function above returns all permutations of a list. The function below computes
  only those which yield distinct cyclic permutation functions (cf. @{term list_succ}).
›

fun cyc_permutations :: "'a list  'a list list" where
  "cyc_permutations [] = [[]]"
| "cyc_permutations (x # xs) = map (Cons x) (permutations xs)"



lemma nil_in_permutations[simp]: "[]  set (permutations xs)  xs = []"
  by (induct xs) (auto simp: permutations_Cons)

lemma permutations_not_nil:
  assumes "xs  []"
  shows "permutations xs = concat (map (λx. map ((#) x) (permutations (remove1 x xs))) xs)"
  using assms by (cases xs) (auto simp: permutations_Cons)

lemma set_permutations_step:
  assumes "xs  []"
  shows "set (permutations xs) = (x  set xs. Cons x ` set (permutations (remove1 x xs)))"
  using assms by (cases xs) (auto simp: permutations_Cons)

lemma in_set_permutations:
  assumes "distinct xs"
  shows "ys  set (permutations xs)  distinct ys  set xs = set ys" (is "?L xs ys  ?R xs ys")
  using assms
proof (induct "length xs" arbitrary: xs ys)
  case 0 then show ?case by auto
next
  case (Suc n)
  then have "xs  []" by auto

  show ?case
  proof
    assume "?L xs ys"
    then obtain y ys' where "ys = y # ys'" "y  set xs" "ys'  set (permutations (remove1 (hd ys) xs))"
      using xs  [] by (auto simp: permutations_not_nil)
    moreover
    then have "?R (remove1 y xs) ys'"
      using Suc.prems Suc.hyps(2) by (intro Suc.hyps(1)[THEN iffD1]) (auto simp: length_remove1)
    ultimately show "?R xs ys"
      using Suc by auto
  next
    assume "?R xs ys"
    with xs  [] obtain y ys' where "ys = y # ys'" "y  set xs" by (cases ys) auto
    moreover
    then have "ys'  set (permutations (remove1 y xs))"
      using Suc ?R xs ys by (intro Suc.hyps(1)[THEN iffD2]) (auto simp: length_remove1)
    ultimately
    show "?L xs ys"
      using xs  [] by (auto simp: permutations_not_nil)
  qed
qed

lemma in_set_cyc_permutations:
  assumes "distinct xs"
  shows "ys  set (cyc_permutations xs)  distinct ys  set xs = set ys  hd ys = hd xs" (is "?L xs ys  ?R xs ys")
proof (cases xs)
  case (Cons x xs) with assms show ?thesis
    by (cases ys) (auto simp: in_set_permutations intro!: imageI)
qed auto

lemma in_set_cyc_permutations_obtain:
  assumes "distinct xs" "distinct ys" "set xs = set ys"
  obtains n where "rotate n ys  set (cyc_permutations xs)"
proof (cases xs)
  case Nil with assms have "rotate 0 ys  set (cyc_permutations xs)" by auto
  then show ?thesis ..
next
  case (Cons x xs')
  let ?ys' = "rotate (index ys x) ys"
  have "ys  []" "x  set ys"
    using Cons assms by auto
  then have "distinct ?ys'  set xs = set ?ys'  hd ?ys' = hd xs"
    using assms Cons by (auto simp add: hd_rotate_conv_nth)
  with ‹distinct xs have "?ys'  set (cyc_permutations xs)"
    by (rule in_set_cyc_permutations[THEN iffD2])
  then show ?thesis ..
qed

lemma list_succ_set_cyc_permutations:
  assumes "distinct xs" "xs  []"
  shows "list_succ ` set (cyc_permutations xs) = {f. f permutes set xs  cyclic_on f (set xs)}" (is "?L = ?R")
proof (intro set_eqI iffI)
  fix f assume "f  ?L"
  moreover have "ys. set xs = set ys  xs  []  ys  []" by auto
  ultimately show "f  ?R"
    using assms by (auto simp: in_set_cyc_permutations list_succ_permutes cyclic_on_list_succ)
next
  fix f assume "f  ?R"
  then obtain ys where ys: "list_succ ys = f" "distinct ys" "set ys = set xs"
    by (auto elim: cyclic_on_obtain_list_succ')
  moreover
  with ‹distinct xs obtain n where "rotate n ys  set (cyc_permutations xs)"
    by (auto elim: in_set_cyc_permutations_obtain)
  then have "list_succ (rotate n ys)  ?L" by simp
  ultimately
  show "f  ?L" by simp
qed


subsection ‹Enumerating Permutations from List Orbits›

definition cyc_permutationss :: "'a list list  'a list list list" where
  "cyc_permutationss = product_lists o map cyc_permutations"

lemma cyc_permutationss_Nil[simp]: "cyc_permutationss [] = [[]]"
  by (auto simp: cyc_permutationss_def)

lemma in_set_cyc_permutationss:
  assumes "distincts xss"
  shows "yss  set (cyc_permutationss xss)  distincts yss  map set xss = map set yss  map hd xss = map hd yss"
proof -
  { assume A: "list_all2 (λx ys. x  set ys) yss (map cyc_permutations xss)"
    then have "length yss = length xss" by (auto simp: list_all2_lengthD)
    then have "(sset xss) = (sset yss)" "distincts yss" "map set xss = map set yss" "map hd xss = map hd yss"
      using A assms
      by (induct yss xss rule: list_induct2) (auto simp: distincts_Cons in_set_cyc_permutations)
  } note X = this
  { assume A: "distincts yss" "map set xss = map set yss" "map hd xss = map hd yss"
    then have "length yss = length xss" by (auto dest: map_eq_imp_length_eq)
    then have "list_all2 (λx ys. x  set ys) yss (map cyc_permutations xss)"
      using A assms
      by (induct yss xss rule: list_induct2) (auto simp: distincts_Cons in_set_cyc_permutations)
  } note Y = this
  show "?thesis"
    unfolding cyc_permutationss_def
    by (auto simp: product_lists_set intro: X Y)
qed

lemma lists_succ_set_cyc_permutationss:
  assumes "distincts xss"
  shows "lists_succ ` set (cyc_permutationss xss) = {f. f permutes (sset xss)  (c  sset xss. cyclic_on f c)}" (is "?L = ?R")
  using assms
proof (intro set_eqI iffI)
  fix f assume "f  ?L"
  then obtain yss where "yss  set (cyc_permutationss xss)" "f = lists_succ yss" by (rule imageE)
  moreover
  from yss  _ assms have "set (map set xss) = set (map set yss)"
    by (auto simp: in_set_cyc_permutationss)
  then have "sset xss = sset