Session FO_Theory_Rewriting

v class="head">

Theory Utils

theory Utils
  imports Regular_Tree_Relations.Term_Context
    Regular_Tree_Relations.FSet_Utils
begin

subsection ‹Misc›

definition "funas_trs ℛ = ⋃ ((λ (s, t). funas_term s ∪ funas_term t) ` ℛ)"

fun linear_term :: "('f, 'v) term ⇒ bool" where
  "linear_term (Var _) = True" |
  "linear_term (Fun _ ts) = (is_partition (map vars_term ts) ∧ (∀t∈set ts. linear_term t))"

fun vars_term_list :: "('f, 'v) term ⇒ 'v list" where
  "vars_term_list (Var x) = [x]" |
  "vars_term_list (Fun _ ts) = concat (map vars_term_list ts)"

fun varposs :: "('f, 'v) term ⇒ pos set" where
  "varposs (Var x) = {[]}" |
  "varposs (Fun f ts) = (⋃i<length ts. {i # p | p. p ∈ varposs (ts ! i)})"

abbreviation "poss_args f ts ≡ map2 (λ i t. map ((#) i) (f t)) ([0 ..< length ts]) ts"

fun varposs_list :: "('f, 'v) term ⇒ pos list" where
  "varposs_list (Var x) = [[]]" |
  "varposs_list (Fun f ts) = concat (poss_args varposs_list ts)"

fun concat_index_split where
  "concat_index_split (o_idx, i_idx) (x # xs) =
     (if i_idx < length x
      then (o_idx, i_idx)
      else concat_index_split (Suc o_idx, i_idx - length x) xs)"

inductive_set trancl_list for ℛ where
  base[intro, Pure.intro] : "length xs = length ys ⟹
      (∀ i < length ys. (xs ! i, ys ! i) ∈ ℛ) ⟹ (xs, ys) ∈ trancl_list ℛ"
| list_trancl [Pure.intro]: "(xs, ys) ∈ trancl_list ℛ ⟹ i < length ys ⟹ (ys ! i, z) ∈ ℛ ⟹
    (xs, ys[i := z]) ∈ trancl_list ℛ"


lemma sorted_append_bigger:
  "sorted xs ⟹  ∀x ∈ set xs. x ≤ y ⟹ sorted (xs @ [y])"
proof (induct xs)
  case Nil
  then show ?case by simp
next
  case (Cons x xs)
  then have s: "sorted xs" by (cases xs) simp_all
  from Cons have a: "∀x∈set xs. x ≤ y" by simp
  from Cons(1)[OF s a] Cons(2-) show ?case by (cases xs) simp_all
qed

lemma find_SomeD:
  "List.find P xs = Some x ⟹ P x"
  "List.find P xs = Some x ⟹ x∈set xs"
  by (auto simp add: find_Some_iff)

lemma sum_list_replicate_length' [simp]:
  "sum_list (replicate n (Suc 0)) = n"
  by (induct n) simp_all

lemma arg_subteq [simp]:
  assumes "t ∈ set ts" shows "Fun f ts ⊵ t"
  using assms by auto

lemma finite_funas_term: "finite (funas_term s)"
  by (induct s) auto

lemma finite_funas_trs:
  "finite ℛ ⟹ finite (funas_trs ℛ)"
  by (induct rule: finite.induct) (auto simp: finite_funas_term funas_trs_def)

fun subterms where
  "subterms (Var x) = {Var x}"|
  "subterms (Fun f ts) = {Fun f ts} ∪ (⋃ (subterms ` set ts))"

lemma finite_subterms_fun: "finite (subterms s)"
  by (induct s) auto

lemma subterms_supteq_conv: "t ∈ subterms s ⟷ s ⊵ t"
  by (induct s) (auto elim: supteq.cases)

lemma set_all_subteq_subterms:
  "subterms s = {t. s ⊵ t}"
  using subterms_supteq_conv by auto

lemma finite_subterms: "finite {t. s ⊵ t}"
  unfolding set_all_subteq_subterms[symmetric]
  by (simp add: finite_subterms_fun)

lemma finite_strict_subterms: "finite {t. s ⊳ t}"
  by (intro finite_subset[OF _ finite_subterms]) auto

lemma finite_UN_I2:
  "finite A ⟹ (∀ B ∈ A. finite B) ⟹ finite (⋃ A)"
  by blast

lemma root_substerms_funas_term:
  "the ` (root ` (subterms s) - {None}) = funas_term s" (is "?Ls = ?Rs")
proof -
  thm subterms.induct
  {fix x assume "x ∈ ?Ls" then have "x ∈ ?Rs"
    proof (induct s arbitrary: x)
      case (Fun f ts)
      then show ?case
        by auto (metis DiffI Fun.hyps imageI option.distinct(1) singletonD) 
    qed auto}
  moreover
  {fix g assume "g ∈ ?Rs" then have "g ∈ ?Ls"
    proof (induct s arbitrary: g)
      case (Fun f ts)
      from Fun(2) consider "g = (f, length ts)" | "∃ t ∈ set ts. g ∈ funas_term t"
        by (force simp: in_set_conv_nth)
      then show ?case
      proof cases
        case 1 then show ?thesis
          by (auto simp: image_iff intro: bexI[of _ "Some (f, length ts)"])
      next
        case 2
        then obtain t where wit: "t ∈ set ts" "g ∈ funas_term t" by blast
        have "g ∈ the ` (root ` subterms t - {None})" using Fun(1)[OF wit] .
        then show ?thesis using wit(1)
          by (auto simp: image_iff)
      qed
    qed auto}
  ultimately show ?thesis by auto
qed

lemma root_substerms_funas_term_set:
  "the ` (root ` ⋃ (subterms ` R) - {None}) = ⋃ (funas_term ` R)"
  using root_substerms_funas_term
  by auto (smt DiffE DiffI UN_I image_iff)


lemma subst_merge:
  assumes part: "is_partition (map vars_term ts)"
  shows "∃σ. ∀i<length ts. ∀x∈vars_term (ts ! i). σ x = τ i x"
proof -
  let ?τ = "map τ [0 ..< length ts]"
  let ?σ = "fun_merge ?τ (map vars_term ts)"
  show ?thesis
    by (rule exI[of _ ?σ], intro allI impI ballI,
      insert fun_merge_part[OF part, of _ _ ?τ], auto)
qed


lemma rel_comp_empty_trancl_simp: "R O R = {} ⟹ R+ = R"
  by (metis O_assoc relcomp_empty2 sup_bot_right trancl_unfold trancl_unfold_right)

lemma choice_nat:
  assumes "∀i<n. ∃x. P x i"
  shows "∃f. ∀x<n. P (f x) x" using assms 
proof -
  from assms have "∀ i. ∃ x. i < n ⟶ P x i" by simp
  from choice[OF this] show ?thesis by auto
qed


lemma subseteq_set_conv_nth:
  "(∀ i < length ss. ss ! i ∈ T) ⟷ set ss ⊆ T"
  by (metis in_set_conv_nth subset_code(1))

lemma singelton_trancl [simp]: "{a}+ = {a}"
  using tranclD tranclD2 by fastforce 

context
includes fset.lifting
begin
lemmas frelcomp_empty_ftrancl_simp = rel_comp_empty_trancl_simp [Transfer.transferred]
lemmas in_fset_idx = in_set_idx [Transfer.transferred]
lemmas fsubseteq_fset_conv_nth = subseteq_set_conv_nth [Transfer.transferred]
lemmas singelton_ftrancl [simp] = singelton_trancl [Transfer.transferred]
end

lemma set_take_nth:
  assumes "x ∈ set (take i xs)"
  shows "∃ j < length xs. j < i ∧ xs ! j = x" using assms
  by (metis in_set_conv_nth length_take min_less_iff_conj nth_take)

lemma nth_sum_listI:
  assumes "length xs = length ys"
    and "∀ i < length xs. xs ! i = ys ! i"
  shows "sum_list xs = sum_list ys"
  using assms nth_equalityI by blast

lemma concat_nth_length:
  "i < length uss ⟹ j < length (uss ! i) ⟹
    sum_list (map length (take i uss)) + j < length (concat uss)"
by (induct uss arbitrary: i j) (simp, case_tac i, auto)

lemma sum_list_1_E [elim]:
  assumes "sum_list xs = Suc 0"
  obtains i where "i < length xs" "xs ! i = Suc 0" "∀ j < length xs. j ≠ i ⟶ xs ! j = 0"
proof -
  have "∃ i < length xs. xs ! i = Suc 0 ∧ (∀ j < length xs. j ≠ i ⟶ xs ! j = 0)" using assms
  proof (induct xs)
    case (Cons a xs) show ?case
    proof (cases a)
      case [simp]: 0
      obtain i where "i < length xs" "xs ! i = Suc 0" "(∀ j < length xs. j ≠ i ⟶ xs ! j = 0)"
        using Cons by auto
      then show ?thesis using less_Suc_eq_0_disj
        by (intro exI[of _ "Suc i"]) auto
    next
      case (Suc nat) then show ?thesis using Cons by auto
    qed
  qed auto
  then show " (⋀i. i < length xs ⟹ xs ! i = Suc 0 ⟹ ∀j<length xs. j ≠ i ⟶ xs ! j = 0 ⟹ thesis) ⟹ thesis"
    by blast
qed


lemma nth_equalityE:
  "xs = ys ⟹ (length xs = length ys ⟹ (⋀i. i < length xs ⟹ xs ! i = ys ! i) ⟹ P) ⟹ P"
  by simp

lemma map_cons_presv_distinct:
  "distinct t ⟹ distinct (map ((#) i) t)"
  by (simp add: distinct_conv_nth)

lemma concat_nth_nthI:
  assumes "length ss = length ts" "∀ i < length ts. length (ss ! i) = length (ts ! i)"
    and "∀ i < length ts. ∀ j < length (ts ! i). P (ss ! i ! j) (ts ! i ! j)"
  shows "∀ i < length (concat ts). P (concat ss ! i) (concat ts ! i)"
  using assms by (metis nth_concat_two_lists)


lemma last_nthI:
  assumes "i < length ts" "¬ i < length ts - Suc 0"
  shows "ts ! i = last ts" using assms
  by (induct ts arbitrary: i)
    (auto, metis last_conv_nth length_0_conv less_antisym nth_Cons')

(* induction scheme for transitive closures of lists *)
lemma trancl_list_appendI [simp, intro]:
  "(xs, ys) ∈ trancl_list ℛ ⟹ (x, y) ∈ ℛ ⟹ (x # xs, y # ys) ∈ trancl_list ℛ"
proof (induct rule: trancl_list.induct)
  case (base xs ys)
  then show ?case using less_Suc_eq_0_disj
    by (intro trancl_list.base) auto
next
  case (list_trancl xs ys i z)
  from list_trancl(3) have *: "y # ys[i := z] = (y # ys)[Suc i := z]" by auto
  show ?case using list_trancl unfolding *
    by (intro trancl_list.list_trancl) auto
qed

lemma trancl_list_append_tranclI [intro]:
  "(x, y) ∈ ℛ+ ⟹ (xs, ys) ∈ trancl_list ℛ ⟹ (x # xs, y # ys) ∈ trancl_list ℛ"
proof (induct rule: trancl.induct)
  case (trancl_into_trancl a b c)
  then have "(a # xs, b # ys) ∈ trancl_list ℛ" by auto
  from trancl_list.list_trancl[OF this, of 0 c]
  show ?case using trancl_into_trancl(3)
    by auto
qed auto

lemma trancl_list_conv:
  "(xs, ys) ∈ trancl_list ℛ ⟷ length xs = length ys ∧ (∀ i < length ys. (xs ! i, ys ! i) ∈ ℛ+)" (is "?Ls ⟷ ?Rs")
proof
  assume "?Ls" then show ?Rs
  proof (induct)
    case (list_trancl xs ys i z)
    then show ?case
      by auto (metis nth_list_update trancl.trancl_into_trancl)
  qed auto
next
  assume ?Rs then show ?Ls
  proof (induct ys arbitrary: xs)
    case Nil
    then show ?case by (cases xs) auto
  next
    case (Cons y ys)
    from Cons(2) obtain x xs' where *: "xs = x # xs'" and
      inv: "(x, y) ∈ ℛ+"
      by (cases xs) auto
    show ?case using Cons(1)[of "tl xs"] Cons(2) unfolding *
      by (intro trancl_list_append_tranclI[OF inv]) force
  qed
qed

lemma trancl_list_induct [consumes 2, case_names base step]:
  assumes "length ss = length ts" "∀ i < length ts. (ss ! i, ts ! i) ∈ ℛ+"
   and "⋀xs ys. length xs = length ys ⟹ ∀ i < length ys. (xs ! i, ys ! i) ∈ ℛ ⟹ P xs ys"
   and "⋀xs ys i z. length xs = length ys ⟹ ∀ i < length ys. (xs ! i, ys ! i) ∈ ℛ+ ⟹ P xs ys
      ⟹ i < length ys ⟹ (ys ! i, z) ∈ ℛ ⟹ P xs (ys[i := z])"
 shows "P ss ts" using assms
  by (intro trancl_list.induct[of ss ts ℛ P]) (auto simp: trancl_list_conv)


lemma swap_trancl:
  "(prod.swap ` R)+ = prod.swap ` (R+)"
proof -
  have [simp]: "prod.swap ` X = X¯" for X by auto
  show ?thesis by (simp add: trancl_converse)
qed

lemma swap_rtrancl:
  "(prod.swap ` R)* = prod.swap ` (R*)"
proof -
  have [simp]: "prod.swap ` X = X¯" for X by auto
  show ?thesis by (simp add: rtrancl_converse)
qed

lemma Restr_simps:
  "R ⊆ X × X ⟹ Restr (R+) X = R+"
  "R ⊆ X × X ⟹ Restr Id X O R = R"
  "R ⊆ X × X ⟹ R O Restr Id X = R"
  "R ⊆ X × X ⟹ S ⊆ X × X ⟹ Restr (R O S) X = R O S"
  "R ⊆ X × X ⟹ R+ ⊆ X × X"
  subgoal using trancl_mono_set[of R "X × X"] by (auto simp: trancl_full_on)
  subgoal by auto
  subgoal by auto
  subgoal by auto
  subgoal using trancl_subset_Sigma .
  done

lemma Restr_tracl_comp_simps:
  "ℛ ⊆ X × X ⟹ ℒ ⊆ X × X ⟹ ℒ+ O ℛ ⊆ X × X"
  "ℛ ⊆ X × X ⟹ ℒ ⊆ X × X ⟹ ℒ O ℛ+ ⊆ X × X"
  "ℛ ⊆ X × X ⟹ ℒ ⊆ X × X ⟹ ℒ+ O ℛ O ℒ+ ⊆ X × X"
  by (auto dest: subsetD[OF Restr_simps(5)[of ℒ]] subsetD[OF Restr_simps(5)[of ℛ]])


text ‹Conversions of the Nth function between lists and a spliting of the list into lists of lists›

lemma concat_index_split_mono_first_arg:
  "i < length (concat xs) ⟹ o_idx ≤ fst (concat_index_split (o_idx, i) xs)"
  by (induct xs arbitrary: o_idx i) (auto, metis Suc_leD add_diff_inverse_nat nat_add_left_cancel_less)

lemma concat_index_split_sound_fst_arg_aux:
  "i < length (concat xs) ⟹ fst (concat_index_split (o_idx, i) xs) < length xs + o_idx"
  by (induct xs arbitrary: o_idx i) (auto, metis add_Suc_right add_diff_inverse_nat nat_add_left_cancel_less)

lemma concat_index_split_sound_fst_arg:
  "i < length (concat xs) ⟹ fst (concat_index_split (0, i) xs) < length xs"
  using concat_index_split_sound_fst_arg_aux[of i xs 0] by auto

lemma concat_index_split_sound_snd_arg_aux:
  assumes "i < length (concat xs)"
  shows "snd (concat_index_split (n, i) xs) < length (xs ! (fst (concat_index_split (n, i) xs) - n))" using assms
proof (induct xs arbitrary: i n)
  case (Cons x xs)
  show ?case proof (cases "i < length x")
    case False then have size: "i - length x < length (concat xs)"
      using Cons(2) False by auto
    obtain k j where [simp]: "concat_index_split (Suc n, i - length x) xs = (k, j)"
      using old.prod.exhaust by blast
    show ?thesis using False Cons(1)[OF size, of "Suc n"] concat_index_split_mono_first_arg[OF size, of "Suc n"]
      by (auto simp: nth_append)
  qed (auto simp add: nth_append) 
qed auto

lemma concat_index_split_sound_snd_arg:
  assumes "i < length (concat xs)"
  shows "snd (concat_index_split (0, i) xs) < length (xs ! fst (concat_index_split (0, i) xs))"
  using concat_index_split_sound_snd_arg_aux[OF assms, of 0] by auto

lemma reconstr_1d_concat_index_split:
  assumes "i < length (concat xs)"
  shows "i = (λ (m, j). sum_list (map length (take (m - n) xs)) + j) (concat_index_split (n, i) xs)" using assms
proof (induct xs arbitrary: i n)
  case (Cons x xs) show ?case
  proof (cases "i < length x")
    case False
    obtain m k where res: "concat_index_split (Suc n, i - length x) xs = (m, k)"
      using prod_decode_aux.cases by blast
    then have unf_ind: "concat_index_split (n, i) (x # xs) = concat_index_split (Suc n, i - length x) xs" and
      size: "i - length x < length (concat xs)" using Cons(2) False by auto
    have "Suc n ≤ m" using concat_index_split_mono_first_arg[OF size, of "Suc n"] by (auto simp: res)
    then have [simp]: "sum_list (map length (take (m - n) (x # xs))) = sum_list (map length (take (m - Suc n) xs)) + length x"
      by (simp add: take_Cons')
    show ?thesis using Cons(1)[OF size, of "Suc n"] False unfolding unf_ind res by auto
  qed auto
qed auto

lemma concat_index_split_larger_lists [simp]:
  assumes "i < length (concat xs)"
  shows "concat_index_split (n, i) (xs @ ys) = concat_index_split (n, i) xs" using assms
  by (induct xs arbitrary: ys n i) auto

lemma concat_index_split_split_sound_aux:
  assumes "i < length (concat xs)"
  shows "concat xs ! i = (λ (k, j). xs ! (k - n) ! j) (concat_index_split (n, i) xs)" using assms
proof (induct xs arbitrary: i n)
  case (Cons x xs)
  show ?case proof (cases "i < length x")
    case False then have size: "i - length x < length (concat xs)"
      using Cons(2) False by auto
    obtain k j where [simp]: "concat_index_split (Suc n, i - length x) xs = (k, j)"
      using prod_decode_aux.cases by blast
    show ?thesis using False Cons(1)[OF size, of "Suc n"]
      using concat_index_split_mono_first_arg[OF size, of "Suc n"]
      by (auto simp: nth_append)
  qed (auto simp add: nth_append) 
qed auto

lemma concat_index_split_sound:
  assumes "i < length (concat xs)"
  shows "concat xs ! i = (λ (k, j). xs ! k ! j) (concat_index_split (0, i) xs)"
  using concat_index_split_split_sound_aux[OF assms, of 0] by auto

lemma concat_index_split_sound_bounds:
  assumes "i < length (concat xs)" and "concat_index_split (0, i) xs = (m, n)"
  shows "m < length xs" "n < length (xs ! m)"
  using concat_index_split_sound_fst_arg[OF assms(1)] concat_index_split_sound_snd_arg[OF assms(1)]
  by (auto simp: assms(2))

lemma concat_index_split_less_length_concat:
  assumes "i < length (concat xs)" and "concat_index_split (0, i) xs = (m, n)"
  shows "i = sum_list (map length (take m xs)) + n" "m < length xs" "n < length (xs ! m)"
    "concat xs ! i = xs ! m ! n"
  using concat_index_split_sound[OF assms(1)] reconstr_1d_concat_index_split[OF assms(1), of 0]
  using concat_index_split_sound_fst_arg[OF assms(1)] concat_index_split_sound_snd_arg[OF assms(1)] assms(2)
  by auto

lemma nth_concat_split':
  assumes "i < length (concat xs)"
  obtains j k where "j < length xs" "k < length (xs ! j)" "concat xs ! i = xs ! j ! k" "i = sum_list (map length (take j xs)) + k"
  using concat_index_split_less_length_concat[OF assms]
  by (meson old.prod.exhaust)

lemma sum_list_split [dest!, consumes 1]:
  assumes "sum_list (map length (take i xs)) + j = sum_list (map length (take k xs)) + l"
   and "i < length xs" "k < length xs"
   and "j < length (xs ! i)" "l < length (xs ! k)"
 shows "i = k ∧ j = l" using assms
proof (induct xs rule: rev_induct)
  case (snoc x xs)
  then show ?case
    by (auto simp: nth_append split: if_splits)
       (metis concat_nth_length length_concat not_add_less1)+
qed auto

lemma concat_index_split_unique:
  assumes "i < length (concat xs)" and "length xs = length ys"
    and "∀ i < length xs. length (xs ! i) = length (ys ! i)"
  shows "concat_index_split (n, i) xs = concat_index_split (n, i) ys" using assms
proof (induct xs arbitrary: ys n i)
  case (Cons x xs) note IH = this show ?case
  proof (cases ys)
    case Nil then show ?thesis using Cons(3) by auto
  next
    case [simp]: (Cons y ys')
    have [simp]: "length y = length x" using IH(4) by force
    have [simp]: "¬ i < length x ⟹ i - length x < length (concat xs)" using IH(2) by auto
    have [simp]: "i < length ys' ⟹ length (xs ! i) = length (ys' ! i)" for i using IH(3, 4)
      by (auto simp: All_less_Suc) (metis IH(4) Suc_less_eq length_Cons Cons nth_Cons_Suc)
    show ?thesis using IH(2-) IH(1)[of "i - length x" ys' "Suc n"] by auto
  qed
qed auto

lemma set_vars_term_list [simp]:
  "set (vars_term_list t) = vars_term t"
  by (induct t) simp_all

lemma vars_term_list_empty_ground [simp]:
  "vars_term_list t = [] ⟷ ground t"
  by (induct t) auto

lemma varposs_imp_poss:
  assumes "p ∈ varposs t"
  shows "p ∈ poss t"
  using assms by (induct t arbitrary: p) auto

lemma vaposs_list_fun:
  assumes "p ∈ set (varposs_list (Fun f ts))"
  obtains i ps where "i < length ts" "p = i # ps"
  using assms set_zip_leftD by fastforce

lemma varposs_list_distinct:
  "distinct (varposs_list t)"
proof (induct t)
  case (Fun f ts)
  then show ?case proof (induct ts rule: rev_induct)
    case (snoc x xs)
    then have "distinct (varposs_list (Fun f xs))" "distinct (varposs_list x)" by auto
    then show ?case using snoc by (auto simp add: map_cons_presv_distinct dest: set_zip_leftD)
  qed auto
qed auto

lemma varposs_append:
  "varposs (Fun f (ts @ [t])) = varposs (Fun f ts) ∪ ((#) (length ts)) ` varposs t"
  by (auto simp: nth_append split: if_splits)

lemma varposs_eq_varposs_list:
  "set (varposs_list t) = varposs t"
proof (induct t)
  case (Fun f ts)
  then show ?case proof (induct ts rule: rev_induct)
    case (snoc x xs)
    then have "varposs (Fun f xs) = set (varposs_list (Fun f xs))"
      "varposs x = set (varposs_list x)" by auto
    then show ?case using snoc unfolding varposs_append
      by auto
  qed auto
qed auto

lemma varposs_list_var_terms_length:
  "length (varposs_list t) = length (vars_term_list t)"
  by (induct t) (auto simp: vars_term_list.simps intro: eq_length_concat_nth)

lemma vars_term_list_nth:
  assumes "i < length (vars_term_list (Fun f ts))"
    and "concat_index_split (0, i) (map vars_term_list ts) = (k, j)"
  shows "k < length ts ∧ j < length (vars_term_list (ts ! k)) ∧
    vars_term_list (Fun f ts) ! i = map vars_term_list ts ! k ! j ∧
    i = sum_list (map length (map vars_term_list (take k ts))) + j"
  using assms concat_index_split_less_length_concat[of i "map vars_term_list ts" k j]
  by (auto simp: vars_term_list.simps comp_def take_map) 

lemma varposs_list_nth:
  assumes "i < length (varposs_list (Fun f ts))"
     and "concat_index_split (0, i) (poss_args varposs_list ts) = (k, j)"
  shows "k < length ts ∧ j < length (varposs_list (ts ! k)) ∧
    varposs_list (Fun f ts) ! i = k # (map varposs_list ts) ! k ! j ∧
    i = sum_list (map length (map varposs_list (take k ts))) + j"
  using assms concat_index_split_less_length_concat[of i "poss_args varposs_list ts" k j]
  by (auto simp: comp_def take_map intro: nth_sum_listI)

lemma varposs_list_to_var_term_list:
  assumes "i < length (varposs_list t)"
  shows "the_Var (t |_ (varposs_list t ! i)) = (vars_term_list t) ! i" using assms
proof (induct t arbitrary: i)
  case (Fun f ts)
  have "concat_index_split (0, i) (poss_args varposs_list ts) = concat_index_split (0, i) (map vars_term_list ts)"
    using Fun(2) concat_index_split_unique[of i "poss_args varposs_list ts" "map vars_term_list ts" 0]
    using varposs_list_var_terms_length[of "ts ! i" for i]
    by (auto simp: vars_term_list.simps)
  then obtain k j where "concat_index_split (0, i) (poss_args varposs_list ts) = (k, j)"
    "concat_index_split (0, i) (map vars_term_list ts) = (k, j)" by fastforce
  from varposs_list_nth[OF Fun(2) this(1)] vars_term_list_nth[OF _ this(2)]
  show ?case using Fun(2) Fun(1)[OF nth_mem] varposs_list_var_terms_length[of "Fun f ts"] by auto
qed (auto simp: vars_term_list.simps)

end

Theory Multihole_Context

(*
Author:  Bertram Felgenhauer <bertram.felgenhauer@uibk.ac.at> (2015)
Author:  Christian Sternagel <c.sternagel@gmail.com> (2013-2016)
Author:  Martin Avanzini <martin.avanzini@uibk.ac.at> (2014)
Author:  René Thiemann <rene.thiemann@uibk.ac.at> (2013-2015)
Author:  Julian Nagele <julian.nagele@uibk.ac.at> (2016)
License: LGPL (see file COPYING.LESSER)
*)

section ‹Preliminaries›
subsection ‹Multihole Contexts›

theory Multihole_Context
imports 
  Utils
begin

unbundle lattice_syntax

subsubsection ‹Partitioning lists into chunks of given length›

lemma concat_nth:
  assumes "m < length xs" and "n < length (xs ! m)"
    and "i = sum_list (map length (take m xs)) + n"
  shows "concat xs ! i = xs ! m ! n"
using assms
proof (induct xs arbitrary: m n i)
  case (Cons x xs)
  show ?case
  proof (cases m)
    case 0
    then show ?thesis using Cons by (simp add: nth_append)
  next
    case (Suc k)
    with Cons(1) [of k n "i - length x"] and Cons(2-)
      show ?thesis by (simp_all add: nth_append)
  qed
qed simp

lemma sum_list_take_eq:
  fixes xs :: "nat list"
  shows "k < i ⟹ i < length xs ⟹ sum_list (take i xs) =
    sum_list (take k xs) + xs ! k + sum_list (take (i - Suc k) (drop (Suc k) xs))"
  by (subst id_take_nth_drop [of k]) (auto simp: min_def drop_take)

fun partition_by where
  "partition_by xs [] = []" |
  "partition_by xs (y#ys) = take y xs # partition_by (drop y xs) ys"

lemma partition_by_map0_append [simp]:
  "partition_by xs (map (λx. 0) ys @ zs) = replicate (length ys) [] @ partition_by xs zs"
by (induct ys) simp_all

lemma concat_partition_by [simp]:
  "sum_list ys = length xs ⟹ concat (partition_by xs ys) = xs"
by (induct ys arbitrary: xs) simp_all

definition partition_by_idx where
  "partition_by_idx l ys i j = partition_by [0..<l] ys ! i ! j"

lemma partition_by_nth_nth_old:
  assumes "i < length (partition_by xs ys)"
    and "j < length (partition_by xs ys ! i)"
    and "sum_list ys = length xs"
  shows "partition_by xs ys ! i ! j = xs ! (sum_list (map length (take i (partition_by xs ys))) + j)"
  using concat_nth [OF assms(1, 2) refl]
  unfolding concat_partition_by [OF assms(3)] by simp

lemma map_map_partition_by:
  "map (map f) (partition_by xs ys) = partition_by (map f xs) ys"
by (induct ys arbitrary: xs) (auto simp: take_map drop_map)

lemma length_partition_by [simp]:
  "length (partition_by xs ys) = length ys"
  by (induct ys arbitrary: xs) simp_all

lemma partition_by_Nil [simp]:
  "partition_by [] ys = replicate (length ys) []"
  by (induct ys) simp_all

lemma partition_by_concat_id [simp]:
  assumes "length xss = length ys"
    and "⋀i. i < length ys ⟹ length (xss ! i) = ys ! i"
  shows "partition_by (concat xss) ys = xss"
  using assms by (induct ys arbitrary: xss) (simp, case_tac xss, simp, fastforce)

lemma partition_by_nth:
  "i < length ys ⟹ partition_by xs ys ! i = take (ys ! i) (drop (sum_list (take i ys)) xs)"
  by (induct ys arbitrary: xs i) (simp, case_tac i, simp_all add: ac_simps)

lemma partition_by_nth_less:
  assumes "k < i" and "i < length zs"
    and "length xs = sum_list (take i zs) + j"
  shows "partition_by (xs @ y # ys) zs ! k = take (zs ! k) (drop (sum_list (take k zs)) xs)"
proof -
  have "partition_by (xs @ y # ys) zs ! k =
    take (zs ! k) (drop (sum_list (take k zs)) (xs @ y # ys))"
    using assms by (auto simp: partition_by_nth)
  moreover have "zs ! k + sum_list (take k zs) ≤ length xs"
    using assms by (simp add: sum_list_take_eq)
  ultimately show ?thesis by simp
qed

lemma partition_by_nth_greater:
  assumes "i < k" and "k < length zs" and "j < zs ! i"
    and "length xs = sum_list (take i zs) + j"
  shows "partition_by (xs @ y # ys) zs ! k =
    take (zs ! k) (drop (sum_list (take k zs) - 1) (xs @ ys))"
proof -
  have "partition_by (xs @ y # ys) zs ! k =
    take (zs ! k) (drop (sum_list (take k zs)) (xs @ y # ys))"
    using assms by (auto simp: partition_by_nth)
  moreover have "sum_list (take k zs) > length xs"
    using assms by (auto simp: sum_list_take_eq)
  ultimately show ?thesis by (auto) (metis Suc_diff_Suc drop_Suc_Cons)
qed

lemma length_partition_by_nth:
  "sum_list ys = length xs ⟹ i < length ys ⟹ length (partition_by xs ys ! i) = ys ! i"
by (induct ys arbitrary: xs i; case_tac i) auto

lemma partition_by_nth_nth_elem:
  assumes "sum_list ys = length xs" "i < length ys" "j < ys ! i"
  shows "partition_by xs ys ! i ! j ∈ set xs"
proof -
  from assms have "j < length (partition_by xs ys ! i)" by (simp only: length_partition_by_nth)
  then have "partition_by xs ys ! i ! j ∈ set (partition_by xs ys ! i)" by auto
  with assms(2) have "partition_by xs ys ! i ! j ∈ set (concat (partition_by xs ys))" by auto
  then show ?thesis using assms by simp
qed

lemma partition_by_nth_nth:
  assumes "sum_list ys = length xs" "i < length ys" "j < ys ! i"
  shows "partition_by xs ys ! i ! j = xs ! partition_by_idx (length xs) ys i j"
        "partition_by_idx (length xs) ys i j < length xs"
unfolding partition_by_idx_def
proof -
  let ?n = "partition_by [0..<length xs] ys ! i ! j"
  show "?n < length xs"
    using partition_by_nth_nth_elem[OF _ assms(2,3), of "[0..<length xs]"] assms(1) by simp
  have li: "i < length (partition_by [0..<length xs] ys)" using assms(2) by simp
  have lj: "j < length (partition_by [0..<length xs] ys ! i)"
    using assms by (simp add: length_partition_by_nth)
  have "partition_by (map ((!) xs) [0..<length xs]) ys ! i ! j = xs ! ?n"
    by (simp only: map_map_partition_by[symmetric] nth_map[OF li] nth_map[OF lj])
  then show "partition_by xs ys ! i ! j = xs ! ?n" by (simp add: map_nth)
qed
  
lemma map_length_partition_by [simp]:
  "sum_list ys = length xs ⟹ map length (partition_by xs ys) = ys"
  by (intro nth_equalityI, auto simp: length_partition_by_nth)

lemma map_partition_by_nth [simp]:
  "i < length ys ⟹ map f (partition_by xs ys ! i) = partition_by (map f xs) ys ! i"
  by (induct ys arbitrary: i xs) (simp, case_tac i, simp_all add: take_map drop_map)

lemma sum_list_partition_by [simp]:
  "sum_list ys = length xs ⟹
    sum_list (map (λx. sum_list (map f x)) (partition_by xs ys)) = sum_list (map f xs)"
  by (induct ys arbitrary: xs) (simp_all, metis append_take_drop_id sum_list_append map_append)

lemma partition_by_map_conv:
  "partition_by xs ys = map (λi. take (ys ! i) (drop (sum_list (take i ys)) xs)) [0 ..< length ys]"
  by (rule nth_equalityI) (simp_all add: partition_by_nth)

lemma UN_set_partition_by_map:
  "sum_list ys = length xs ⟹ (⋃x∈set (partition_by (map f xs) ys). ⋃ (set x)) = ⋃(set (map f xs))"
  by (induct ys arbitrary: xs)
     (simp_all add: drop_map take_map, metis UN_Un append_take_drop_id set_append)

lemma UN_set_partition_by:
  "sum_list ys = length xs ⟹ (⋃zs ∈ set (partition_by xs ys). ⋃x ∈ set zs. f x) = (⋃x ∈ set xs. f x)"
  by (induct ys arbitrary: xs) (simp_all, metis UN_Un append_take_drop_id set_append)

lemma Ball_atLeast0LessThan_partition_by_conv:
  "(∀i∈{0..<length ys}. ∀x∈set (partition_by xs ys ! i). P x) =
    (∀x ∈ ⋃(set (map set (partition_by xs ys))). P x)"
  by auto (metis atLeast0LessThan in_set_conv_nth length_partition_by lessThan_iff)

lemma Ball_set_partition_by:
  "sum_list ys = length xs ⟹
  (∀x ∈ set (partition_by xs ys). ∀y ∈ set x. P y) = (∀x ∈ set xs. P x)"
proof (induct ys arbitrary: xs)
  case (Cons y ys)
  then show ?case
    apply (subst (2) append_take_drop_id [of y xs, symmetric])
    apply (simp only: set_append)
    apply auto
  done
qed simp

lemma partition_by_append2:
  "partition_by xs (ys @ zs) = partition_by (take (sum_list ys) xs) ys @ partition_by (drop (sum_list ys) xs) zs"
by (induct ys arbitrary: xs) (auto simp: drop_take ac_simps split: split_min)

lemma partition_by_concat2:
  "partition_by xs (concat ys) =
   concat (map (λi . partition_by (partition_by xs (map sum_list ys) ! i) (ys ! i)) [0..<length ys])"
proof -
  have *: "map (λi . partition_by (partition_by xs (map sum_list ys) ! i) (ys ! i)) [0..<length ys] =
    map (λ(x,y). partition_by x y) (zip (partition_by xs (map sum_list ys)) ys)"
    using zip_nth_conv[of "partition_by xs (map sum_list ys)" ys] by auto
  show ?thesis unfolding * by (induct ys arbitrary: xs) (auto simp: partition_by_append2)
qed

lemma partition_by_partition_by:
  "length xs = sum_list (map sum_list ys) ⟹
   partition_by (partition_by xs (concat ys)) (map length ys) =
   map (λi. partition_by (partition_by xs (map sum_list ys) ! i) (ys ! i)) [0..<length ys]"
by (auto simp: partition_by_concat2 intro: partition_by_concat_id)

subsubsection ‹Multihole contexts definition and functionalities›
datatype ('f, vars_mctxt : 'v) mctxt = MVar 'v | MHole | MFun 'f "('f, 'v) mctxt list"

subsubsection ‹Conversions from and to multihole contexts›

primrec mctxt_of_term :: "('f, 'v) term ⇒ ('f, 'v) mctxt" where
  "mctxt_of_term (Var x) = MVar x" |
  "mctxt_of_term (Fun f ts) = MFun f (map mctxt_of_term ts)"

primrec term_of_mctxt :: "('f, 'v) mctxt ⇒ ('f, 'v) term" where
  "term_of_mctxt (MVar x) = Var x" |
  "term_of_mctxt (MFun f Cs) = Fun f (map term_of_mctxt Cs)"

fun num_holes :: "('f, 'v) mctxt ⇒ nat" where
  "num_holes (MVar _) = 0" |
  "num_holes MHole = 1" |
  "num_holes (MFun _ ctxts) = sum_list (map num_holes ctxts)"

fun ground_mctxt :: "('f, 'v) mctxt ⇒ bool" where 
  "ground_mctxt (MVar _) = False" |
  "ground_mctxt MHole = True" |
  "ground_mctxt (MFun f Cs) = Ball (set Cs) ground_mctxt"

fun map_mctxt :: "('f ⇒ 'g) ⇒ ('f, 'v) mctxt ⇒ ('g, 'v) mctxt"
where
  "map_mctxt _ (MVar x) = (MVar x)" |
  "map_mctxt _ (MHole) = MHole" |
  "map_mctxt fg (MFun f Cs) = MFun (fg f) (map (map_mctxt fg) Cs)"

abbreviation "partition_holes xs Cs ≡ partition_by xs (map num_holes Cs)"
abbreviation "partition_holes_idx l Cs ≡ partition_by_idx l (map num_holes Cs)"

fun fill_holes :: "('f, 'v) mctxt ⇒ ('f, 'v) term list ⇒ ('f, 'v) term" where
  "fill_holes (MVar x) _ = Var x" |
  "fill_holes MHole [t] = t" |
  "fill_holes (MFun f cs) ts = Fun f (map (λ i. fill_holes (cs ! i)
    (partition_holes ts cs ! i)) [0 ..< length cs])"

fun fill_holes_mctxt :: "('f, 'v) mctxt ⇒ ('f, 'v) mctxt list ⇒ ('f, 'v) mctxt" where
  "fill_holes_mctxt (MVar x) _ = MVar x" |
  "fill_holes_mctxt MHole [] = MHole" |
  "fill_holes_mctxt MHole [t] = t" |
  "fill_holes_mctxt (MFun f cs) ts = (MFun f (map (λ i. fill_holes_mctxt (cs ! i) 
    (partition_holes ts cs ! i)) [0 ..< length cs]))"


fun unfill_holes :: "('f, 'v) mctxt ⇒ ('f, 'v) term ⇒ ('f, 'v) term list" where
  "unfill_holes MHole t = [t]"
| "unfill_holes (MVar w) (Var v) = (if v = w then [] else undefined)"
| "unfill_holes (MFun g Cs) (Fun f ts) = (if f = g ∧ length ts = length Cs then
    concat (map (λi. unfill_holes (Cs ! i) (ts ! i)) [0..<length ts]) else undefined)"

fun funas_mctxt where
  "funas_mctxt (MFun f Cs) = {(f, length Cs)} ∪ ⋃(funas_mctxt ` set Cs)" |
  "funas_mctxt _ = {}"

fun split_vars :: "('f, 'v) term ⇒ (('f, 'v) mctxt × 'v list)" where
  "split_vars (Var x) = (MHole, [x])" |
  "split_vars (Fun f ts) = (MFun f (map (fst ∘ split_vars) ts), concat (map (snd ∘ split_vars) ts))"


fun hole_poss_list :: "('f, 'v) mctxt ⇒ pos list" where
  "hole_poss_list (MVar x) = []" |
  "hole_poss_list MHole = [[]]" |
  "hole_poss_list (MFun f cs) = concat (poss_args hole_poss_list cs)"

fun map_vars_mctxt :: "('v ⇒ 'w) ⇒ ('f, 'v) mctxt ⇒ ('f, 'w) mctxt"
where
  "map_vars_mctxt vw MHole = MHole" |
  "map_vars_mctxt vw (MVar v) = (MVar (vw v))" |
  "map_vars_mctxt vw (MFun f Cs) = MFun f (map (map_vars_mctxt vw) Cs)"

inductive eq_fill :: "('f, 'v) term ⇒ ('f, 'v) mctxt × ('f, 'v) term list ⇒ bool" ("(_/ =f _)" [51, 51] 50)
where
  eqfI [intro]: "t = fill_holes D ss ⟹ num_holes D = length ss ⟹ t =f (D, ss)"

subsubsection ‹Semilattice Structures›

instantiation mctxt :: (type, type) inf

begin

fun inf_mctxt :: "('a, 'b) mctxt ⇒ ('a, 'b) mctxt ⇒ ('a, 'b) mctxt"
where
  "MHole ⊓ D = MHole" |
  "C ⊓ MHole = MHole" |
  "MVar x ⊓ MVar y = (if x = y then MVar x else MHole)" |
  "MFun f Cs ⊓ MFun g Ds =
    (if f = g ∧ length Cs = length Ds then MFun f (map (case_prod (⊓)) (zip Cs Ds))
    else MHole)" |
  "C ⊓ D = MHole"

instance ..

end

lemma inf_mctxt_idem [simp]:
  fixes C :: "('f, 'v) mctxt"
  shows "C ⊓ C = C"
  by (induct C) (auto simp: zip_same_conv_map intro: map_idI)

lemma inf_mctxt_MHole2 [simp]:
  "C ⊓ MHole = MHole"
  by (induct C) simp_all

lemma inf_mctxt_comm [ac_simps]:
  "(C :: ('f, 'v) mctxt) ⊓ D = D ⊓ C"
  by (induct C D rule: inf_mctxt.induct) (fastforce simp: in_set_conv_nth intro!: nth_equalityI)+

lemma inf_mctxt_assoc [ac_simps]:
  fixes C :: "('f, 'v) mctxt"
  shows "C ⊓ D ⊓ E = C ⊓ (D ⊓ E)"
  apply (induct C D arbitrary: E rule: inf_mctxt.induct)
  apply (auto simp: )
  apply (case_tac E, auto)+
  apply (fastforce simp: in_set_conv_nth intro!: nth_equalityI)
  apply (case_tac E, auto)+
done

instantiation mctxt :: (type, type) order
begin

definition "(C :: ('a, 'b) mctxt) ≤ D ⟷ C ⊓ D = C"
definition "(C :: ('a, 'b) mctxt) < D ⟷ C ≤ D ∧ ¬ D ≤ C"

instance
  by (standard, simp_all add: less_eq_mctxt_def less_mctxt_def ac_simps, metis inf_mctxt_assoc)

end

inductive less_eq_mctxt' :: "('f, 'v) mctxt ⇒ ('f,'v) mctxt ⇒ bool" where
  "less_eq_mctxt' MHole u"
| "less_eq_mctxt' (MVar v) (MVar v)"
| "length cs = length ds ⟹ (⋀i. i < length cs ⟹ less_eq_mctxt' (cs ! i) (ds ! i)) ⟹ less_eq_mctxt' (MFun f cs) (MFun f ds)"


subsubsection ‹Lemmata›

lemma partition_holes_fill_holes_conv:
  "fill_holes (MFun f cs) ts =
    Fun f [fill_holes (cs ! i) (partition_holes ts cs ! i). i ← [0 ..< length cs]]"
  by (simp add: partition_by_nth take_map)

lemma partition_holes_fill_holes_mctxt_conv:
  "fill_holes_mctxt (MFun f Cs) ts =
    MFun f [fill_holes_mctxt (Cs ! i) (partition_holes ts Cs ! i). i ← [0 ..< length Cs]]"
  by (simp add: partition_by_nth take_map)

text ‹The following induction scheme provides the @{term MFun} case with the list argument split
  according to the argument contexts. This feature is quite delicate: its benefit can be
  destroyed by premature simplification using the @{thm concat_partition_by} simplification rule.›

lemma fill_holes_induct2[consumes 2, case_names MHole MVar MFun]:
  fixes P :: "('f,'v) mctxt ⇒ 'a list ⇒ 'b list ⇒ bool"
  assumes len1: "num_holes C = length xs" and len2: "num_holes C = length ys"
  and Hole: "⋀x y. P MHole [x] [y]"
  and Var: "⋀v. P (MVar v) [] []"
  and Fun: "⋀f Cs xs ys.  sum_list (map num_holes Cs) = length xs ⟹
    sum_list (map num_holes Cs) = length ys ⟹
    (⋀i. i < length Cs ⟹ P (Cs ! i) (partition_holes xs Cs ! i) (partition_holes ys Cs ! i)) ⟹
    P (MFun f Cs) (concat (partition_holes xs Cs)) (concat (partition_holes ys Cs))"
  shows "P C xs ys"
proof (insert len1 len2, induct C arbitrary: xs ys)
  case MHole then show ?case using Hole by (cases xs; cases ys) auto
next
  case (MVar v) then show ?case using Var by auto
next
  case (MFun f Cs) then show ?case using Fun[of Cs xs ys f] by (auto simp: length_partition_by_nth)
qed

lemma fill_holes_induct[consumes 1, case_names MHole MVar MFun]:
  fixes P :: "('f,'v) mctxt ⇒ 'a list ⇒ bool"
  assumes len: "num_holes C = length xs"
  and Hole: "⋀x. P MHole [x]"
  and Var: "⋀v. P (MVar v) []"
  and Fun: "⋀f Cs xs. sum_list (map num_holes Cs) = length xs ⟹
    (⋀i. i < length Cs ⟹ P (Cs ! i) (partition_holes xs Cs ! i)) ⟹
    P (MFun f Cs) (concat (partition_holes xs Cs))"
  shows "P C xs"
  using fill_holes_induct2[of C xs xs "λ C xs _. P C xs"] assms by simp

lemma length_partition_holes_nth [simp]:
  assumes "sum_list (map num_holes cs) = length ts"
    and "i < length cs"
  shows "length (partition_holes ts cs ! i) = num_holes (cs ! i)"
  using assms by (simp add: length_partition_by_nth)

(*some compatibility lemmas (which should be dropped eventually)*)
lemmas
  map_partition_holes_nth [simp] =
    map_partition_by_nth [of _ "map num_holes Cs" for Cs, unfolded length_map] and
  length_partition_holes [simp] =
    length_partition_by [of _ "map num_holes Cs" for Cs, unfolded length_map]

lemma fill_holes_term_of_mctxt:
  "num_holes C = 0 ⟹ fill_holes C [] = term_of_mctxt C"
  by (induct C) (auto simp add: map_eq_nth_conv)

lemma fill_holes_MHole:
  "length ts = Suc 0 ⟹ ts ! 0 = u ⟹ fill_holes MHole ts = u"
  by (cases ts) simp_all

lemma fill_holes_arbitrary:
  assumes lCs: "length Cs = length ts"
    and lss: "length ss = length ts"
    and rec: "⋀ i. i < length ts ⟹ num_holes (Cs ! i) = length (ss ! i) ∧ f (Cs ! i) (ss ! i) = ts ! i"
  shows "map (λi. f (Cs ! i) (partition_holes (concat ss) Cs ! i)) [0 ..< length Cs] = ts"
proof -
  have "sum_list (map num_holes Cs) = length (concat ss)" using assms
    by (auto simp: length_concat map_nth_eq_conv intro: arg_cong[of _ _ "sum_list"])
  moreover have "partition_holes (concat ss) Cs = ss"
    using assms by (auto intro: partition_by_concat_id)
  ultimately show ?thesis using assms by (auto intro: nth_equalityI)
qed

lemma fill_holes_MFun:
  assumes lCs: "length Cs = length ts"
    and lss: "length ss = length ts"
    and rec: "⋀ i. i < length ts ⟹ num_holes (Cs ! i) = length (ss ! i) ∧ fill_holes (Cs ! i) (ss ! i) = ts ! i"
  shows "fill_holes (MFun f Cs) (concat ss) = Fun f ts" 
  unfolding fill_holes.simps term.simps
    by (rule conjI[OF refl], rule fill_holes_arbitrary[OF lCs lss rec])

lemma eqfE:
  assumes "t =f (D, ss)" shows "t = fill_holes D ss" "num_holes D = length ss"
  using assms[unfolded eq_fill.simps] by auto

lemma eqf_MFunE:
  assumes "s =f (MFun f Cs,ss)"  
  obtains ts sss where "s = Fun f ts" "length ts = length Cs" "length sss = length Cs" 
  "⋀ i. i < length Cs ⟹ ts ! i =f (Cs ! i, sss ! i)"
  "ss = concat sss"
proof -
  from eqfE[OF assms] have fh: "s = fill_holes (MFun f Cs) ss" 
    and nh: "sum_list (map num_holes Cs) = length ss" by auto
  from fh obtain ts where s: "s = Fun f ts" by (cases s, auto)
  from fh[unfolded s] 
  have ts: "ts = map (λi. fill_holes (Cs ! i) (partition_holes ss Cs ! i)) [0..<length Cs]" 
    (is "_ = map (?f Cs ss) _")
    by auto
  let ?sss = "partition_holes ss Cs"
  from nh 
  have *: "length ?sss = length Cs" "⋀i. i < length Cs ⟹ ts ! i =f (Cs ! i, ?sss ! i)" "ss = concat ?sss"
    by (auto simp: ts)
  have len: "length ts = length Cs" unfolding ts by auto
  assume ass: "⋀ts sss. s = Fun f ts ⟹
              length ts = length Cs ⟹
              length sss = length Cs ⟹ (⋀i. i < length Cs ⟹ ts ! i =f (Cs ! i, sss ! i)) ⟹ ss = concat sss ⟹ thesis"
  show thesis
    by (rule ass[OF s len *])
qed

lemma eqf_MFunI:
  assumes "length sss = length Cs"
    and "length ts = length Cs"
    and"⋀ i. i < length Cs ⟹ ts ! i =f (Cs ! i, sss ! i)"
  shows "Fun f ts =f (MFun f Cs, concat sss)"
proof 
  have "num_holes (MFun f Cs) = sum_list (map num_holes Cs)" by simp
  also have "map num_holes Cs = map length sss"
    by (rule nth_equalityI, insert assms eqfE[OF assms(3)], auto)
  also have "sum_list (…) = length (concat sss)" unfolding length_concat ..
  finally show "num_holes (MFun f Cs) = length (concat sss)" .
  show "Fun f ts = fill_holes (MFun f Cs) (concat sss)"
    by (rule fill_holes_MFun[symmetric], insert assms(1,2) eqfE[OF assms(3)], auto)
qed

lemma split_vars_ground_vars:
  assumes "ground_mctxt C" and "num_holes C = length xs" 
  shows "split_vars (fill_holes C (map Var xs)) = (C, xs)" using assms
proof (induct C arbitrary: xs)
  case (MHole xs)
  then show ?case by (cases xs, auto)
next
  case (MFun f Cs xs)
  have "fill_holes (MFun f Cs) (map Var xs) =f (MFun f Cs, map Var xs)"
    by (rule eqfI, insert MFun(3), auto)
  from eqf_MFunE[OF this] 
  obtain ts xss where fh: "fill_holes (MFun f Cs) (map Var xs) = Fun f ts"
    and lent: "length ts = length Cs"
    and lenx: "length xss = length Cs"
    and args: "⋀i. i < length Cs ⟹ ts ! i =f (Cs ! i, xss ! i)"
    and id: "map Var xs = concat xss" by auto
  from arg_cong[OF id, of "map the_Var"] have id2: "xs = concat (map (map the_Var) xss)" 
    by (metis map_concat length_map map_nth_eq_conv term.sel(1))    
  {
    fix i
    assume i: "i < length Cs"
    then have mem: "Cs ! i ∈ set Cs" by auto
    with MFun(2) have ground: "ground_mctxt (Cs ! i)" by auto
    have "map Var (map the_Var (xss ! i)) = map id (xss ! i)" unfolding map_map o_def map_eq_conv
    proof
      fix x
      assume "x ∈ set (xss ! i)"
      with lenx i have "x ∈ set (concat xss)" by auto
      from this[unfolded id[symmetric]] show "Var (the_Var x) = id x" by auto
    qed
    then have idxss: "map Var (map the_Var (xss ! i)) = xss ! i" by auto
    note rec = eqfE[OF args[OF i]]
    note IH = MFun(1)[OF mem ground, of "map the_Var (xss ! i)", unfolded rec(2) idxss rec(1)[symmetric]]
    from IH have "split_vars (ts ! i) = (Cs ! i, map the_Var (xss ! i))" by auto
    note this idxss
  }
  note IH = this
  have "?case = (map fst (map split_vars ts) = Cs ∧ concat (map snd (map split_vars ts)) = concat (map (map the_Var) xss))"
    unfolding fh unfolding id2 by auto
  also have "…"
  proof (rule conjI[OF nth_equalityI arg_cong[of _ _ concat, OF nth_equalityI, rule_format]], unfold length_map lent lenx)
    fix i
    assume i: "i < length Cs" 
    with arg_cong[OF IH(2)[OF this], of "map the_Var"]
      IH[OF this] show "map snd (map split_vars ts) ! i = map (map the_Var) xss ! i" using lent lenx by auto
  qed (insert IH lent, auto)
  finally show ?case .
qed auto


lemma split_vars_vars_term_list: "snd (split_vars t) = vars_term_list t"
proof (induct t)
  case (Fun f ts)
  then show ?case by (auto simp: vars_term_list.simps o_def, induct ts, auto)
qed (auto simp: vars_term_list.simps)


lemma split_vars_num_holes: "num_holes (fst (split_vars t)) = length (snd (split_vars t))"
proof (induct t)
  case (Fun f ts)
  then show ?case by (induct ts, auto)
qed simp

lemma ground_eq_fill: "t =f (C,ss) ⟹ ground t = (ground_mctxt C ∧ (∀ s ∈ set ss. ground s))" 
proof (induct C arbitrary: t ss)
  case (MVar x)
  from eqfE[OF this] show ?case by simp
next
  case (MHole t ss)
  from eqfE[OF this] show ?case by (cases ss, auto)
next
  case (MFun f Cs s ss)
  from eqf_MFunE[OF MFun(2)] obtain ts sss where s: "s = Fun f ts" and len: "length ts = length Cs" "length sss = length Cs" 
    and IH: "⋀ i. i < length Cs ⟹ ts ! i =f (Cs ! i, sss ! i)" and ss: "ss = concat sss" by metis
  {
    fix i
    assume i: "i < length Cs"
    then have "Cs ! i ∈ set Cs" by simp
    from MFun(1)[OF this IH[OF i]]
    have "ground (ts ! i) = (ground_mctxt (Cs ! i) ∧ (∀a∈set (sss ! i). ground a))" .
  } note IH = this
  note conv = set_conv_nth
  have "?case = ((∀x∈set ts. ground x) = ((∀x∈set Cs. ground_mctxt x) ∧ (∀a∈set sss. ∀x∈set a. ground x)))"
    unfolding s ss by simp
  also have "..." unfolding conv[of ts] conv[of Cs] conv[of sss] len using IH by auto
  finally show ?case by simp
qed

lemma ground_fill_holes:
  assumes nh: "num_holes C = length ss"
  shows "ground (fill_holes C ss) = (ground_mctxt C ∧ (∀ s ∈ set ss. ground s))"
  by (rule ground_eq_fill[OF eqfI[OF refl nh]])

lemma split_vars_ground' [simp]:
  "ground_mctxt (fst (split_vars t))"
  by (induct t) auto

lemma split_vars_funas_mctxt [simp]:
  "funas_mctxt (fst (split_vars t)) = funas_term t"
  by (induct t) auto


lemma less_eq_mctxt_prime: "C ≤ D ⟷ less_eq_mctxt' C D"
proof
  assume "less_eq_mctxt' C D" then show "C ≤ D"
    by (induct C D rule: less_eq_mctxt'.induct) (auto simp: less_eq_mctxt_def intro: nth_equalityI)
next
  assume "C ≤ D" then show "less_eq_mctxt' C D" unfolding less_eq_mctxt_def
  by (induct C D rule: inf_mctxt.induct)
     (auto split: if_splits simp: set_zip intro!: less_eq_mctxt'.intros nth_equalityI elim!: nth_equalityE, metis)
qed

lemmas less_eq_mctxt_induct = less_eq_mctxt'.induct[folded less_eq_mctxt_prime, consumes 1]
lemmas less_eq_mctxt_intros = less_eq_mctxt'.intros[folded less_eq_mctxt_prime]

lemma less_eq_mctxt_MHoleE2:
  assumes "C ≤ MHole"
  obtains (MHole) "C = MHole"
  using assms unfolding less_eq_mctxt_prime by (cases C, auto)

lemma less_eq_mctxt_MVarE2:
  assumes "C ≤ MVar v"
  obtains (MHole) "C = MHole" | (MVar) "C = MVar v"
  using assms unfolding less_eq_mctxt_prime by (cases C) auto

lemma less_eq_mctxt_MFunE2:
  assumes "C ≤ MFun f ds"
  obtains (MHole) "C = MHole"
    | (MFun) cs where "C = MFun f cs" "length cs = length ds" "⋀i. i < length cs ⟹ cs ! i ≤ ds ! i"
  using assms unfolding less_eq_mctxt_prime by (cases C) auto

lemmas less_eq_mctxtE2 = less_eq_mctxt_MHoleE2 less_eq_mctxt_MVarE2 less_eq_mctxt_MFunE2


lemma less_eq_mctxt_MVarE1:
  assumes "MVar v ≤ D"
  obtains (MVar) "D = MVar v"
  using assms by (cases D) (auto elim: less_eq_mctxtE2)

lemma MHole_Bot [simp]: "MHole ≤ D"
  by (simp add: less_eq_mctxt_intros(1))

lemma less_eq_mctxt_MFunE1:
  assumes "MFun f cs ≤ D"
  obtains (MFun) ds where "D = MFun f ds" "length cs = length ds" "⋀i. i < length cs ⟹ cs ! i ≤ ds ! i"
  using assms by (cases D) (auto elim: less_eq_mctxtE2)


lemma length_unfill_holes [simp]:
  assumes "C ≤ mctxt_of_term t"
  shows "length (unfill_holes C t) = num_holes C"
  using assms
proof (induct C t rule: unfill_holes.induct)
  case (3 f Cs g ts) with 3(1)[OF _ nth_mem] 3(2) show ?case
    by (auto simp: less_eq_mctxt_def length_concat
      intro!: cong[of sum_list, OF refl] nth_equalityI elim!: nth_equalityE)
qed (auto simp: less_eq_mctxt_def)

lemma map_vars_mctxt_id [simp]:
  "map_vars_mctxt (λ x. x) C = C"
  by (induct C, auto intro: nth_equalityI)


lemma split_vars_eqf_subst_map_vars_term:
  "t ⋅ σ =f (map_vars_mctxt vw (fst (split_vars t)), map σ (snd (split_vars t)))"
proof (induct t)
  case (Fun f ts)
  have "?case = (Fun f (map (λt. t ⋅ σ) ts)
    =f (MFun f (map (map_vars_mctxt vw ∘ (fst ∘ split_vars)) ts), concat (map (map σ ∘ (snd ∘ split_vars)) ts)))"
    by (simp add: map_concat)
  also have "..." 
  proof (rule eqf_MFunI, simp, simp, unfold length_map)
    fix i
    assume i: "i < length ts"
    then have mem: "ts ! i ∈ set ts" by auto
    show "map (λt. t ⋅ σ) ts ! i =f (map (map_vars_mctxt vw ∘ (fst ∘ split_vars)) ts ! i, map (map σ ∘ (snd ∘ split_vars)) ts ! i)"
      using Fun[OF mem] i by auto
  qed
  finally show ?case by simp
qed auto

lemma split_vars_eqf_subst: "t ⋅ σ =f (fst (split_vars t), (map σ (snd (split_vars t))))"
  using split_vars_eqf_subst_map_vars_term[of t σ "λ x. x"] by simp

lemma split_vars_fill_holes:
  assumes "C = fst (split_vars s)" and "ss = map Var (snd (split_vars s))"
  shows "fill_holes C ss = s" using assms
  by (metis eqfE(1) split_vars_eqf_subst subst_apply_term_empty)


lemma fill_unfill_holes:
  assumes "C ≤ mctxt_of_term t"
  shows "fill_holes C (unfill_holes C t) = t"
  using assms
proof (induct C t rule: unfill_holes.induct)
  case (3 f Cs g ts) with 3(1)[OF _ nth_mem] 3(2) show ?case
    by (auto simp: less_eq_mctxt_def intro!: fill_holes_arbitrary elim!: nth_equalityE)
qed (auto simp: less_eq_mctxt_def split: if_splits)


lemma hole_poss_list_length:
  "length (hole_poss_list D) = num_holes D"
  by (induct D) (auto simp: length_concat intro!: nth_sum_listI)

lemma unfill_holles_hole_poss_list_length:
  assumes "C ≤ mctxt_of_term t"
  shows "length (unfill_holes C t) = length (hole_poss_list C)" using assms
proof (induct C arbitrary: t)
  case (MVar x)
  then have [simp]: "t = Var x" by (cases t) (auto dest: less_eq_mctxt_MVarE1)
  show ?case by simp
next
  case (MFun f ts) then show ?case
    by (cases t) (auto simp: length_concat comp_def
      elim!: less_eq_mctxt_MFunE1 less_eq_mctxt_MVarE1 intro!: nth_sum_listI)
qed auto

lemma unfill_holes_to_subst_at_hole_poss:
  assumes "C ≤ mctxt_of_term t"
  shows "unfill_holes C t = map ((|_) t) (hole_poss_list C)" using assms
proof (induct C arbitrary: t)
  case (MVar x)
  then show ?case by (cases t) (auto elim: less_eq_mctxt_MVarE1)
next
  case (MFun f ts)
  from MFun(2) obtain ss where [simp]: "t = Fun f ss" and l: "length ts = length ss"
    by (cases t) (auto elim: less_eq_mctxt_MFunE1)
  let ?ts = "map (λi. unfill_holes (ts ! i) (ss ! i)) [0..<length ts]"
  let ?ss = "map (λ x. map ((|_) (Fun f ss)) (case x of (x, y) ⇒ map ((#) x) (hole_poss_list y))) (zip [0..<length ts] ts)"
  have eq_l [simp]: "length (concat ?ts) = length (concat ?ss)" using MFun
    by (auto simp: length_concat comp_def elim!: less_eq_mctxt_MFunE1 split!: prod.splits intro!: nth_sum_listI)
  {fix i assume ass: "i < length (concat ?ts)"
    then have lss: "i < length (concat ?ss)" by auto
    obtain m n where [simp]: "concat_index_split (0, i) ?ts = (m, n)" by fastforce
    then have [simp]: "concat_index_split (0, i) ?ss = (m, n)" using concat_index_split_unique[OF ass, of ?ss 0] MFun(2)
      by (auto simp: unfill_holles_hole_poss_list_length[of "ts ! i" "ss ! i" for i]
       simp del: length_unfill_holes elim!: less_eq_mctxt_MFunE1)
    from concat_index_split_less_length_concat(2-)[OF ass ] concat_index_split_less_length_concat(2-)[OF lss]
    have "concat ?ts ! i = concat ?ss! i" using MFun(1)[OF nth_mem, of m "ss ! m"] MFun(2)
      by (auto elim!: less_eq_mctxt_MFunE1)} note nth = this
  show ?case using MFun
    by (auto simp: comp_def map_concat length_concat
        elim!: less_eq_mctxt_MFunE1 split!: prod.splits
        intro!: nth_equalityI nth_sum_listI nth)
qed auto

lemma hole_poss_split_varposs_list_length [simp]:
  "length (hole_poss_list (fst (split_vars t))) = length (varposs_list t)"
  by (induct t)(auto simp: length_concat comp_def intro!: nth_sum_listI)

lemma hole_poss_split_vars_varposs_list:
  "hole_poss_list (fst (split_vars t)) = varposs_list t"
proof (induct t)
  case (Fun f ts)
  let ?ts = "poss_args hole_poss_list (map (fst ∘ split_vars) ts)"
  let ?ss = "poss_args varposs_list ts"
  have len: "length (concat ?ts) = length (concat ?ss)" "length ?ts = length ?ss"
    "∀ i < length ?ts. length (?ts ! i) = length (?ss ! i)" by (auto intro: eq_length_concat_nth)
  {fix i assume ass: "i < length (concat ?ts)"
    then have lss: "i < length (concat ?ss)" using len by auto
    obtain m n where int: "concat_index_split (0, i) ?ts = (m, n)" by fastforce
    then have [simp]: "concat_index_split (0, i) ?ss = (m, n)" using concat_index_split_unique[OF ass len(2-)] by auto
    from concat_index_split_less_length_concat(2-)[OF ass int] concat_index_split_less_length_concat(2-)[OF lss]
    have "concat ?ts ! i = concat ?ss! i" using Fun[OF nth_mem, of m] by auto}
  then show ?case using len by (auto intro: nth_equalityI)
qed auto



lemma funas_term_fill_holes_iff: "num_holes C = length ts ⟹
   g ∈ funas_term (fill_holes C ts) ⟷ g ∈ funas_mctxt C ∨ (∃t ∈ set ts. g ∈ funas_term t)"
proof (induct C ts rule: fill_holes_induct)
  case (MFun f Cs ts)
  have "(∃i < length Cs. g ∈ funas_term (fill_holes (Cs ! i) (partition_holes (concat (partition_holes ts Cs)) Cs ! i)))
    ⟷ (∃C ∈ set Cs. g ∈ funas_mctxt C) ∨ (∃us ∈ set (partition_holes ts Cs). ∃t ∈ set us. g ∈ funas_term t)"
    using MFun by (auto simp: ex_set_conv_ex_nth) blast
  then show ?case by auto
qed auto

lemma vars_term_fill_holes [simp]:
  "num_holes C = length ts ⟹ ground_mctxt C ⟹
    vars_term (fill_holes C ts) = ⋃(vars_term ` set ts)"
proof (induct C arbitrary: ts)
  case MHole
  then show ?case by (cases ts) simp_all
next
  case (MFun f Cs)
  then have *: "length (partition_holes ts Cs) = length Cs" by simp
  let ?f = "λx. ⋃y ∈ set x. vars_term y"
  show ?case
    using MFun
    unfolding partition_holes_fill_holes_conv
    by (simp add: UN_upt_len_conv [OF *, of ?f] UN_set_partition_by)
qed simp



lemma funas_mctxt_fill_holes [simp]:
  assumes "num_holes C = length ts"
  shows "funas_term (fill_holes C ts) = funas_mctxt C ∪ ⋃(set (map funas_term ts))"
  using funas_term_fill_holes_iff[OF assms] by auto

lemma funas_mctxt_fill_holes_mctxt [simp]:
  assumes "num_holes C = length Ds"
  shows "funas_mctxt (fill_holes_mctxt C Ds) = funas_mctxt C ∪ ⋃(set (map funas_mctxt Ds))"
  (is "?f C Ds = ?g C Ds")
using assms
proof (induct C arbitrary: Ds)
  case MHole
  then show ?case by (cases Ds) simp_all
next
  case (MFun f Cs)
  then have num_holes: "sum_list (map num_holes Cs) = length Ds" by simp
  let ?ys = "partition_holes Ds Cs"
  have "⋀i. i < length Cs ⟹ ?f (Cs ! i) (?ys ! i) = ?g (Cs ! i) (?ys ! i)"
    using MFun by (metis nth_mem num_holes.simps(3) length_partition_holes_nth)
  then have "(⋃i ∈ {0 ..< length Cs}. ?f (Cs ! i) (?ys ! i)) =
    (⋃i ∈ {0 ..< length Cs}. ?g (Cs ! i) (?ys ! i))" by simp
  then show ?case
    using num_holes
    unfolding partition_holes_fill_holes_mctxt_conv
    by (simp add: UN_Un_distrib UN_upt_len_conv [of _ _ "λx. ⋃(set x)"] UN_set_partition_by_map)
qed simp

end
dy>

Theory Ground_MCtxt

theory Ground_MCtxt
  imports
   Multihole_Context
   Regular_Tree_Relations.Ground_Terms
   Regular_Tree_Relations.Ground_Ctxt
begin

subsection ‹Ground multihole context›

datatype (gfuns_mctxt: 'f) gmctxt = GMHole | GMFun 'f "'f gmctxt list"

subsubsection ‹Basic function on ground mutlihole contexts›

primrec gmctxt_of_gterm :: "'f gterm ⇒ 'f gmctxt" where
  "gmctxt_of_gterm (GFun f ts) = GMFun f (map gmctxt_of_gterm ts)"

fun num_gholes :: "'f gmctxt ⇒ nat" where
  "num_gholes GMHole = Suc 0"
| "num_gholes (GMFun _ ctxts) = sum_list (map num_gholes ctxts)"

primrec gterm_of_gmctxt :: "'f gmctxt ⇒ 'f gterm" where
  "gterm_of_gmctxt (GMFun f Cs) = GFun f (map gterm_of_gmctxt Cs)"

primrec term_of_gmctxt :: "'f gmctxt ⇒ ('f, 'v) term" where
  "term_of_gmctxt (GMFun f Cs) = Fun f (map term_of_gmctxt Cs)"

primrec gmctxt_of_gctxt :: "'f gctxt ⇒ 'f gmctxt" where
  "gmctxt_of_gctxt □G = GMHole"
| "gmctxt_of_gctxt (GMore f ss C ts) =
    GMFun f (map gmctxt_of_gterm ss @ gmctxt_of_gctxt C # map gmctxt_of_gterm ts)"

fun gctxt_of_gmctxt :: "'f gmctxt ⇒ 'f gctxt" where
  "gctxt_of_gmctxt GMHole = □G"
| "gctxt_of_gmctxt (GMFun f Cs) = (let n = length (takeWhile (λ C. num_gholes C = 0) Cs) in
     (if n < length Cs then
        GMore f (map gterm_of_gmctxt (take n Cs)) (gctxt_of_gmctxt (Cs ! n)) (map gterm_of_gmctxt (drop (Suc n) Cs))
      else undefined))"

primrec gmctxt_of_mctxt :: "('f, 'v) mctxt ⇒ 'f gmctxt" where
   "gmctxt_of_mctxt MHole = GMHole"
|  "gmctxt_of_mctxt (MFun f Cs) = GMFun f (map gmctxt_of_mctxt Cs)"

primrec mctxt_of_gmctxt :: "'f gmctxt ⇒ ('f, 'v) mctxt" where
   "mctxt_of_gmctxt GMHole = MHole"
|  "mctxt_of_gmctxt (GMFun f Cs) = MFun f (map mctxt_of_gmctxt Cs)"

fun funas_gmctxt where
  "funas_gmctxt (GMFun f Cs) = {(f, length Cs)} ∪ ⋃(funas_gmctxt ` set Cs)" |
  "funas_gmctxt _ = {}"

abbreviation "partition_gholes xs Cs ≡ partition_by xs (map num_gholes Cs)"

fun fill_gholes :: "'f gmctxt ⇒ 'f gterm list ⇒ 'f gterm" where
  "fill_gholes GMHole [t] = t"
| "fill_gholes (GMFun f cs) ts = GFun f (map (λ i. fill_gholes (cs ! i)
    (partition_gholes ts cs ! i)) [0 ..< length cs])"

fun fill_gholes_gmctxt :: "'f gmctxt ⇒ 'f gmctxt list ⇒ 'f gmctxt" where
  "fill_gholes_gmctxt GMHole [] = GMHole" |
  "fill_gholes_gmctxt GMHole [t] = t" |
  "fill_gholes_gmctxt (GMFun f cs) ts = (GMFun f (map (λ i. fill_gholes_gmctxt (cs ! i) 
    (partition_gholes ts cs ! i)) [0 ..< length cs]))"

subsubsection ‹An inverse of @{term fill_gholes}›
fun unfill_gholes :: "'f gmctxt ⇒ 'f gterm ⇒ 'f gterm list" where
  "unfill_gholes GMHole t = [t]"
| "unfill_gholes (GMFun g Cs) (GFun f ts) = (if f = g ∧ length ts = length Cs then
    concat (map (λi. unfill_gholes (Cs ! i) (ts ! i)) [0..<length ts]) else undefined)"

fun sup_gmctxt_args :: "'f gmctxt ⇒ 'f gmctxt ⇒ 'f gmctxt list" where
  "sup_gmctxt_args GMHole D = [D]" |
  "sup_gmctxt_args C GMHole = replicate (num_gholes C) GMHole" |
  "sup_gmctxt_args (GMFun f Cs) (GMFun g Ds) =
    (if f = g ∧ length Cs = length Ds then concat (map (case_prod sup_gmctxt_args) (zip Cs Ds))
    else undefined)"

fun ghole_poss :: "'f gmctxt ⇒ pos set" where
  "ghole_poss GMHole = {[]}" |
  "ghole_poss (GMFun f cs) = ⋃(set (map (λ i. (λ p. i # p) ` ghole_poss (cs ! i)) [0 ..< length cs]))"

abbreviation "poss_rec f ts ≡ map2 (λ i t. map ((#) i) (f t)) ([0 ..< length ts]) ts"