Skip to content

Commit

Permalink
feat: FunInd: more equalities in context, more careful cleanup (#5364)
Browse files Browse the repository at this point in the history
A round of clean-up for the context of the functional induction
principle cases.

* Already previously, with `match e with | p => …`, functional induction
would ensure that `h : e = p` is in scope, but it wouldn’t work in
dependent cases. Now it introduces heterogeneous equality where needed
(fixes #4146)
* These equalities are now added always (previously we omitted them when
the discriminant was a variable that occurred in the goal, on the
grounds that the goal gets refined through the match, but it’s more
consistent to introduce the equality in any case)
* We no longer use `MVarId.cleanup` to clean up the goal; it was
sometimes too aggressive (fixes #5347)
* Instead, we clean up more carefully and with a custom strategy:
* First, we substitute all variables without a user-accessible name, if
we can.
  * Then, we substitute all variable, if we can, outside in.
* As we do that, we look for `HEq`s that we can turn into `Eq`s to
substitute some more
  * We substitute unused `let`s.
  
**Breaking change**: In some cases leads to a different functional
induction principle (different names and order of assumptions, for
example).
  • Loading branch information
nomeata authored Sep 16, 2024
1 parent 3f8e3e7 commit 445c8f2
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 77 deletions.
54 changes: 29 additions & 25 deletions src/Lean/Meta/Match/MatcherApp/Transform.lean
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,8 @@ Performs a possibly type-changing transformation to a `MatcherApp`.
If `useSplitter` is true, the matcher is replaced with the splitter.
NB: Not all operations on `MatcherApp` can handle one `matcherName` is a splitter.
The array `addEqualities`, if provided, indicates for which of the discriminants an equality
connecting the discriminant to the parameters of the alternative (like in `match h : x with …`)
should be added (if it is isn't already there).
If `addEqualities` is true, then equalities connecting the discriminant to the parameters of the
alternative (like in `match h : x with …`) are be added, if not already there.
This function works even if the the type of alternatives do *not* fit the inferred type. This
allows you to post-process the `MatcherApp` with `MatcherApp.inferMatchType`, which will
Expand All @@ -212,20 +211,13 @@ def transform
[AddMessageContext n] [MonadOptions n]
(matcherApp : MatcherApp)
(useSplitter := false)
(addEqualities : Array Bool := mkArray matcherApp.discrs.size false)
(addEqualities : Bool := false)
(onParams : Expr → n Expr := pure)
(onMotive : Array Expr → Expr → n Expr := fun _ e => pure e)
(onAlt : Expr → Expr → n Expr := fun _ e => pure e)
(onRemaining : Array Expr → n (Array Expr) := pure) :
n MatcherApp := do

if addEqualities.size != matcherApp.discrs.size then
throwError "MatcherApp.transform: addEqualities has wrong size"

-- Do not add equalities when the matcher already does so
let addEqualities := Array.zipWith addEqualities matcherApp.discrInfos fun b di =>
if di.hName?.isSome then false else b

-- We also handle CasesOn applications here, and need to treat them specially in a
-- few places.
-- TODO: Expand MatcherApp with the necessary fields to make this more uniform
Expand All @@ -241,17 +233,26 @@ def transform
let params' ← matcherApp.params.mapM onParams
let discrs' ← matcherApp.discrs.mapM onParams


let (motive', uElim) ← lambdaTelescope matcherApp.motive fun motiveArgs motiveBody => do
let (motive', uElim, addHEqualities) ← lambdaTelescope matcherApp.motive fun motiveArgs motiveBody => do
unless motiveArgs.size == matcherApp.discrs.size do
throwError "unexpected matcher application, motive must be lambda expression with #{matcherApp.discrs.size} arguments"
let mut motiveBody' ← onMotive motiveArgs motiveBody

-- Prepend (x = e) → to the motive when an equality is requested
for arg in motiveArgs, discr in discrs', b in addEqualities do if b then
motiveBody' ← liftMetaM <| mkArrow (← mkEq discr arg) motiveBody'

return (← mkLambdaFVars motiveArgs motiveBody', ← getLevel motiveBody')
-- Prepend `(x = e) →` or `(HEq x e) → ` to the motive when an equality is requested
-- and not already present, and remember whether we added an Eq or a HEq
let mut addHEqualities : Array (Option Bool) := #[]
for arg in motiveArgs, discr in discrs', di in matcherApp.discrInfos do
if addEqualities && di.hName?.isNone then
if ← isProof arg then
addHEqualities := addHEqualities.push none
else
let heq ← mkEqHEq discr arg
motiveBody' ← liftMetaM <| mkArrow heq motiveBody'
addHEqualities := addHEqualities.push heq.isHEq
else
addHEqualities := addHEqualities.push none

return (← mkLambdaFVars motiveArgs motiveBody', ← getLevel motiveBody', addHEqualities)

let matcherLevels ← match matcherApp.uElimPos? with
| none => pure matcherApp.matcherLevels
Expand All @@ -261,15 +262,14 @@ def transform
-- (and count them along the way)
let mut remaining' := #[]
let mut extraEqualities : Nat := 0
for discr in discrs'.reverse, b in addEqualities.reverse do if b then
remaining' := remaining'.push (← mkEqRefl discr)
extraEqualities := extraEqualities + 1
for discr in discrs'.reverse, b in addHEqualities.reverse do
match b with
| none => pure ()
| some is_heq =>
remaining' := remaining'.push (← (if is_heq then mkHEqRefl else mkEqRefl) discr)
extraEqualities := extraEqualities + 1

if useSplitter && !isCasesOn then
-- We replace the matcher with the splitter
let matchEqns ← Match.getEquationsFor matcherApp.matcherName
let splitter := matchEqns.splitterName

let aux1 := mkAppN (mkConst matcherApp.matcherName matcherLevels.toList) params'
let aux1 := mkApp aux1 motive'
let aux1 := mkAppN aux1 discrs'
Expand All @@ -278,6 +278,10 @@ def transform
check aux1
let origAltTypes ← inferArgumentTypesN matcherApp.alts.size aux1

-- We replace the matcher with the splitter
let matchEqns ← Match.getEquationsFor matcherApp.matcherName
let splitter := matchEqns.splitterName

let aux2 := mkAppN (mkConst splitter matcherLevels.toList) params'
let aux2 := mkApp aux2 motive'
let aux2 := mkAppN aux2 discrs'
Expand Down
101 changes: 66 additions & 35 deletions src/Lean/Meta/Tactic/FunInd.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ prelude
import Lean.Meta.Basic
import Lean.Meta.Match.MatcherApp.Transform
import Lean.Meta.Check
import Lean.Meta.Tactic.Cleanup
import Lean.Meta.Tactic.Subst
import Lean.Meta.Injective -- for elimOptParam
import Lean.Meta.ArgsPacker
Expand Down Expand Up @@ -402,19 +401,51 @@ def assertIHs (vals : Array Expr) (mvarid : MVarId) : MetaM MVarId := do
mvarid ← mvarid.assert (.mkSimple s!"ih{i+1}") (← inferType v) v
return mvarid


/--
Substitutes equations, but makes sure to only substitute variables introduced after the motives
(given by the index) as the motive could depend on anything before, and `substVar` would happily
drop equations about these fixed parameters.
Goal cleanup:
Substitutes equations (with `substVar`) to remove superfluous varialbes, and clears unused
let bindings.
Substitutes from the outside in so that the inner-bound variable name wins, but does a first pass
looking only at variables with names with macro scope, so that preferably they disappear.
Careful to only touch the context after the motives (given by the index) as the motive could depend
on anything before, and `substVar` would happily drop equations about these fixed parameters.
-/
def substVarAfter (mvarId : MVarId) (index : Nat) : MetaM MVarId := do
mvarId.withContext do
let mut mvarId := mvarId
for localDecl in (← getLCtx) do
if localDecl.index > index then
mvarId ← trySubstVar mvarId localDecl.fvarId
return mvarId
partial def cleanupAfter (mvarId : MVarId) (index : Nat) : MetaM MVarId := do
let mvarId ← go mvarId index true
let mvarId ← go mvarId index false
return mvarId
where
go (mvarId : MVarId) (index : Nat) (firstPass : Bool) : MetaM MVarId := do
if let some mvarId ← cleanupAfter? mvarId index firstPass then
go mvarId index firstPass
else
return mvarId

allHeqToEq (mvarId : MVarId) (index : Nat) : MetaM MVarId :=
mvarId.withContext do
let mut mvarId := mvarId
for localDecl in (← getLCtx) do
if localDecl.index > index then
let (_, mvarId') ← heqToEq mvarId localDecl.fvarId
mvarId := mvarId'
return mvarId

cleanupAfter? (mvarId : MVarId) (index : Nat) (firstPass : Bool) : MetaM (Option MVarId) := do
mvarId.withContext do
for localDecl in (← getLCtx) do
if localDecl.index > index && (!firstPass || localDecl.userName.hasMacroScopes) then
if localDecl.isLet then
if let some mvarId' ← observing? <| mvarId.clear localDecl.fvarId then
return some mvarId'
if let some mvarId' ← substVar? mvarId localDecl.fvarId then
-- After substituting, some HEq might turn into Eqs, and we want to be able to substitute
-- them as well
let mvarId' ← allHeqToEq mvarId' index
return some mvarId'
return none


/--
Second helper monad collecting the cases as mvars
Expand All @@ -429,7 +460,7 @@ def M2.branch {α} (act : M2 α) : M2 α :=


/-- Base case of `buildInductionBody`: Construct a case for the final induction hypthesis. -/
def buildInductionCase (oldIH newIH : FVarId) (isRecCall : Expr → Option Expr) (toClear toPreserve : Array FVarId)
def buildInductionCase (oldIH newIH : FVarId) (isRecCall : Expr → Option Expr) (toClear : Array FVarId)
(goal : Expr) (e : Expr) : M2 Expr := do
let _e' ← foldAndCollect oldIH newIH isRecCall e
let IHs : Array Expr ← M.ask
Expand All @@ -441,8 +472,6 @@ def buildInductionCase (oldIH newIH : FVarId) (isRecCall : Expr → Option Expr)
trace[Meta.FunInd] "Goal before cleanup:{mvarId}"
for fvarId in toClear do
mvarId ← mvarId.clear fvarId
mvarId ← mvarId.cleanup (toPreserve := toPreserve)
trace[Meta.FunInd] "Goal after cleanup (toClear := {toClear.map mkFVar}) (toPreserve := {toPreserve.map mkFVar}):{mvarId}"
modify (·.push mvarId)
let mvar ← instantiateMVars mvar
pure mvar
Expand All @@ -457,7 +486,7 @@ Like `mkLambdaFVars (usedOnly := true)`, but
The result `r` can be applied with `r.beta (maskArray mask args)`.
We use this when generating the functional induction principle to refine the goal through a `match`,
here `xs` are the discriminans of the `match`.
here `xs` are the discriminants of the `match`.
We do not expect non-trivial discriminants to appear in the goal (and if they do, the user will
get a helpful equality into the context).
-/
Expand Down Expand Up @@ -487,7 +516,7 @@ Builds an expression of type `goal` by replicating the expression `e` into its t
where it calls `buildInductionCase`. Collects the cases of the final induction hypothesis
as `MVars` as it goes.
-/
partial def buildInductionBody (toClear toPreserve : Array FVarId) (goal : Expr)
partial def buildInductionBody (toClear : Array FVarId) (goal : Expr)
(oldIH newIH : FVarId) (isRecCall : Expr → Option Expr) (e : Expr) : M2 Expr := do

-- if-then-else cause case split:
Expand All @@ -496,10 +525,10 @@ partial def buildInductionBody (toClear toPreserve : Array FVarId) (goal : Expr)
let c' ← foldAndCollect oldIH newIH isRecCall c
let h' ← foldAndCollect oldIH newIH isRecCall h
let t' ← withLocalDecl `h .default c' fun h => M2.branch do
let t' ← buildInductionBody toClear (toPreserve.push h.fvarId!) goal oldIH newIH isRecCall t
let t' ← buildInductionBody toClear goal oldIH newIH isRecCall t
mkLambdaFVars #[h] t'
let f' ← withLocalDecl `h .default (mkNot c') fun h => M2.branch do
let f' ← buildInductionBody toClear (toPreserve.push h.fvarId!) goal oldIH newIH isRecCall f
let f' ← buildInductionBody toClear goal oldIH newIH isRecCall f
mkLambdaFVars #[h] f'
let u ← getLevel goal
return mkApp5 (mkConst ``dite [u]) goal c' h' t' f'
Expand All @@ -508,11 +537,11 @@ partial def buildInductionBody (toClear toPreserve : Array FVarId) (goal : Expr)
let h' ← foldAndCollect oldIH newIH isRecCall h
let t' ← withLocalDecl `h .default c' fun h => M2.branch do
let t ← instantiateLambda t #[h]
let t' ← buildInductionBody toClear (toPreserve.push h.fvarId!) goal oldIH newIH isRecCall t
let t' ← buildInductionBody toClear goal oldIH newIH isRecCall t
mkLambdaFVars #[h] t'
let f' ← withLocalDecl `h .default (mkNot c') fun h => M2.branch do
let f ← instantiateLambda f #[h]
let f' ← buildInductionBody toClear (toPreserve.push h.fvarId!) goal oldIH newIH isRecCall f
let f' ← buildInductionBody toClear goal oldIH newIH isRecCall f
mkLambdaFVars #[h] f'
let u ← getLevel goal
return mkApp5 (mkConst ``dite [u]) goal c' h' t' f'
Expand All @@ -523,8 +552,8 @@ partial def buildInductionBody (toClear toPreserve : Array FVarId) (goal : Expr)
match_expr goal with
| And goal₁ goal₂ => match_expr e with
| PProd.mk _α _β e₁ e₂ =>
let e₁' ← buildInductionBody toClear toPreserve goal₁ oldIH newIH isRecCall e₁
let e₂' ← buildInductionBody toClear toPreserve goal₂ oldIH newIH isRecCall e₂
let e₁' ← buildInductionBody toClear goal₁ oldIH newIH isRecCall e₁
let e₂' ← buildInductionBody toClear goal₂ oldIH newIH isRecCall e₂
return mkApp4 (.const ``And.intro []) goal₁ goal₂ e₁' e₂'
| _ =>
throwError "Goal is PProd, but expression is:{indentExpr e}"
Expand All @@ -543,14 +572,14 @@ partial def buildInductionBody (toClear toPreserve : Array FVarId) (goal : Expr)
-- so we need to replace that IH
if matcherApp.remaining.size == 1 && matcherApp.remaining[0]!.isFVarOf oldIH then
let matcherApp' ← matcherApp.transform (useSplitter := true)
(addEqualities := mask.map not)
(addEqualities := true)
(onParams := (foldAndCollect oldIH newIH isRecCall ·))
(onMotive := fun xs _body => pure (absMotiveBody.beta (maskArray mask xs)))
(onAlt := fun expAltType alt => M2.branch do
removeLamda alt fun oldIH' alt => do
forallBoundedTelescope expAltType (some 1) fun newIH' goal' => do
let #[newIH'] := newIH' | unreachable!
let alt' ← buildInductionBody (toClear.push newIH'.fvarId!) toPreserve goal' oldIH' newIH'.fvarId! isRecCall alt
let alt' ← buildInductionBody (toClear.push newIH'.fvarId!) goal' oldIH' newIH'.fvarId! isRecCall alt
mkLambdaFVars #[newIH'] alt')
(onRemaining := fun _ => pure #[.fvar newIH])
return matcherApp'.toExpr
Expand All @@ -562,32 +591,34 @@ partial def buildInductionBody (toClear toPreserve : Array FVarId) (goal : Expr)
let (mask, absMotiveBody) ← mkLambdaFVarsMasked matcherApp.discrs goal

let matcherApp' ← matcherApp.transform (useSplitter := true)
(addEqualities := mask.map not)
(addEqualities := true)
(onParams := (foldAndCollect oldIH newIH isRecCall ·))
(onMotive := fun xs _body => pure (absMotiveBody.beta (maskArray mask xs)))
(onAlt := fun expAltType alt => M2.branch do
buildInductionBody toClear toPreserve expAltType oldIH newIH isRecCall alt)
buildInductionBody toClear expAltType oldIH newIH isRecCall alt)
return matcherApp'.toExpr

if let .letE n t v b _ := e then
let t' ← foldAndCollect oldIH newIH isRecCall t
let v' ← foldAndCollect oldIH newIH isRecCall v
return ← withLetDecl n t' v' fun x => M2.branch do
let b' ← buildInductionBody toClear toPreserve goal oldIH newIH isRecCall (b.instantiate1 x)
let b' ← buildInductionBody toClear goal oldIH newIH isRecCall (b.instantiate1 x)
mkLetFVars #[x] b'

if let some (n, t, v, b) := e.letFun? then
let t' ← foldAndCollect oldIH newIH isRecCall t
let v' ← foldAndCollect oldIH newIH isRecCall v
return ← withLocalDecl n .default t' fun x => M2.branch do
let b' ← buildInductionBody toClear toPreserve goal oldIH newIH isRecCall (b.instantiate1 x)
let b' ← buildInductionBody toClear goal oldIH newIH isRecCall (b.instantiate1 x)
mkLetFun x v' b'

liftM <| buildInductionCase oldIH newIH isRecCall toClear toPreserve goal e
liftM <| buildInductionCase oldIH newIH isRecCall toClear goal e

/--
Given an expression `e` with metavariables
* collects all these meta-variables,
Given an expression `e` with metavariables `mvars`
* performs more cleanup:
* removes unused let-expressions after index `index`
* tries to substitute variables after index `index`
* lifts them to the current context by reverting all local declarations after index `index`
* introducing a local variable for each of the meta variable
* assigning that local variable to the mvar
Expand All @@ -605,7 +636,7 @@ do not handle delayed assignemnts correctly.
def abstractIndependentMVars (mvars : Array MVarId) (index : Nat) (e : Expr) : MetaM Expr := do
trace[Meta.FunInd] "abstractIndependentMVars, to revert after {index}, original mvars: {mvars}"
let mvars ← mvars.mapM fun mvar => do
let mvar ← substVarAfter mvar index
let mvar ← cleanupAfter mvar index
mvar.withContext do
let fvarIds := (← getLCtx).foldl (init := #[]) (start := index+1) fun fvarIds decl => fvarIds.push decl.fvarId
let (_, mvar) ← mvar.revert fvarIds
Expand Down Expand Up @@ -662,7 +693,7 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do
let body ← instantiateLambda body targets
removeLamda body fun oldIH body => do
let body ← instantiateLambda body extraParams
let body' ← buildInductionBody #[genIH.fvarId!] #[] goal oldIH genIH.fvarId! isRecCall body
let body' ← buildInductionBody #[genIH.fvarId!] goal oldIH genIH.fvarId! isRecCall body
if body'.containsFVar oldIH then
throwError m!"Did not fully eliminate {mkFVar oldIH} from induction principle body:{indentExpr body}"
mkLambdaFVars (targets.push genIH) (← mkLambdaFVars extraParams body')
Expand Down Expand Up @@ -972,7 +1003,7 @@ def deriveInductionStructural (names : Array Name) (numFixed : Nat) : MetaM Unit
removeLamda body fun oldIH body => do
trace[Meta.FunInd] "replacing {Expr.fvar oldIH} with {genIH}"
let body ← instantiateLambda body extraParams
let body' ← buildInductionBody #[genIH.fvarId!] #[] goal oldIH genIH.fvarId! isRecCall body
let body' ← buildInductionBody #[genIH.fvarId!] goal oldIH genIH.fvarId! isRecCall body
if body'.containsFVar oldIH then
throwError m!"Did not fully eliminate {mkFVar oldIH} from induction principle body:{indentExpr body}"
mkLambdaFVars (targets.push genIH) (← mkLambdaFVars extraParams body')
Expand Down
2 changes: 1 addition & 1 deletion tests/lean/run/funind_structural.lean
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ termination_by structural x => x

/--
info: zip.induct.{u_1, u_2} {α : Type u_1} {β : Type u_2} (motive : List α → List β → Prop)
(case1 : ∀ (x : List β), motive [] x) (case2 : ∀ (x : List α), (x = [] → False) → motive x [])
(case1 : ∀ (x : List β), motive [] x) (case2 : ∀ (t : List α), (t = [] → False) → motive t [])
(case3 : ∀ (x : α) (xs : List α) (y : β) (ys : List β), motive xs ys → motive (x :: xs) (y :: ys)) :
∀ (a : List α) (a_1 : List β), motive a a_1
-/
Expand Down
Loading

0 comments on commit 445c8f2

Please sign in to comment.