Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: instantiate ematch theorems in grind #6485

Merged
merged 1 commit into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ builtin_initialize registerTraceClass `grind.internalize
builtin_initialize registerTraceClass `grind.ematch
builtin_initialize registerTraceClass `grind.ematch.pattern
builtin_initialize registerTraceClass `grind.ematch.instance
builtin_initialize registerTraceClass `grind.ematch.instance.assignment
builtin_initialize registerTraceClass `grind.issues
builtin_initialize registerTraceClass `grind.simp

Expand Down
75 changes: 70 additions & 5 deletions src/Lean/Meta/Tactic/Grind/EMatch.lean
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ structure Choice where

/-- Theorem instances found so far. We only internalize them after we complete a full round of E-matching. -/
structure TheoremInstance where
prop : Expr
proof : Expr
prop : Expr
generation : Nat
deriving Inhabited

Expand Down Expand Up @@ -163,10 +163,75 @@ private def processContinue (c : Choice) (p : Expr) : M Unit := do
let c := { c with gen := Nat.max gen c.gen }
modify fun s => { s with choiceStack := c :: s.choiceStack }

private partial def instantiateTheorem (c : Choice) : M Unit := do
trace[grind.ematch.instance] "{(← read).thm.origin.key} : {assignmentToMessageData c.assignment}"
-- TODO
return ()
/--
Stores new theorem instance in the state.
Recall that new instances are internalized later, after a full round of ematching.
-/
private def addNewInstance (origin : Origin) (proof : Expr) (generation : Nat) : M Unit := do
let proof ← instantiateMVars proof
if grind.debug.proofs.get (← getOptions) then
check proof
let prop ← inferType proof
trace[grind.ematch.instance] "{← origin.pp}: {prop}"
modify fun s => { s with newInstances := s.newInstances.push { proof, prop, generation } }

/--
After processing a (multi-)pattern, use the choice assignment to instantiate the proof.
Missing parameters are synthesized using type inference and type class synthesis."
-/
private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do
let thm := (← read).thm
trace[grind.ematch.instance.assignment] "{← thm.origin.pp}: {assignmentToMessageData c.assignment}"
let proof ← thm.getProofWithFreshMVarLevels
let numParams := thm.numParams
assert! c.assignment.size == numParams
let (mvars, bis, _) ← forallMetaBoundedTelescope (← inferType proof) numParams
if mvars.size != thm.numParams then
trace[grind.issues] "unexpected number of parameters at {← thm.origin.pp}"
return ()
-- Apply assignment
for h : i in [:mvars.size] do
let v := c.assignment[numParams - i - 1]!
unless isSameExpr v unassigned do
let mvarId := mvars[i].mvarId!
unless (← mvarId.checkedAssign v) do
trace[grind.issues] "type error constructing proof for {← thm.origin.pp}\nwhen assigning metavariable {mvars[i]} with {indentExpr v}"
return ()
-- Synthesize instances
for mvar in mvars, bi in bis do
if bi.isInstImplicit && !(← mvar.mvarId!.isAssigned) then
let type ← inferType mvar
unless (← synthesizeInstance mvar type) do
trace[grind.issues] "failed to synthesize instance when instantiating {← thm.origin.pp}{indentExpr type}"
return ()
if (← mvars.allM (·.mvarId!.isAssigned)) then
addNewInstance thm.origin (mkAppN proof mvars) c.gen
else
-- instance has hypothesis
mkImp mvars 0 proof #[]
where
synthesizeInstance (x type : Expr) : MetaM Bool := do
let .some val ← trySynthInstance type | return false
isDefEq x val

mkImp (mvars : Array Expr) (i : Nat) (proof : Expr) (xs : Array Expr) : M Unit := do
if h : i < mvars.size then
let mvar := mvars[i]
if (← mvar.mvarId!.isAssigned) then
mkImp mvars (i+1) (mkApp proof mvar) xs
else
let mvarType ← instantiateMVars (← inferType mvar)
if mvarType.hasMVar then
let thm := (← read).thm
trace[grind.issues] "failed to create hypothesis for instance of {← thm.origin.pp} hypothesis type has metavars{indentExpr mvarType}"
return ()
withLocalDeclD (← mkFreshUserName `h) mvarType fun x => do
mkImp mvars (i+1) (mkApp proof x) (xs.push x)
else
let proof ← instantiateMVars proof
let proof ← mkLambdaFVars xs proof
let thm := (← read).thm
addNewInstance thm.origin proof c.gen

/-- Process choice stack until we don't have more choices to be processed. -/
private partial def processChoices : M Unit := do
Expand Down
31 changes: 30 additions & 1 deletion src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.HeadIndex
import Lean.PrettyPrinter
import Lean.Util.FoldConsts
import Lean.Util.CollectFVars
import Lean.Meta.Basic
Expand Down Expand Up @@ -32,8 +33,21 @@ def Origin.key : Origin → Name
| .stx id _ => id
| .other => `other

def Origin.pp [Monad m] [MonadEnv m] [MonadError m] (o : Origin) : m MessageData := do
match o with
| .decl declName => return MessageData.ofConst (← mkConstWithLevelParams declName)
| .fvar fvarId => return mkFVar fvarId
| .stx _ ref => return ref
| .other => return "[unknown]"

/-- A theorem for heuristic instantiation based on E-matching. -/
structure EMatchTheorem where
/--
It stores universe parameter names for universe polymorphic proofs.
Recall that it is non-empty only when we elaborate an expression provided by the user.
When `proof` is just a constant, we can use the universe parameter names stored in the declaration.
-/
levelParams : Array Name
proof : Expr
numParams : Nat
patterns : List Expr
Expand All @@ -54,6 +68,20 @@ def EMatchTheorems.insert (s : EMatchTheorems) (thm : EMatchTheorem) : EMatchThe
else
return PersistentHashMap.insert s declName [thm]

def EMatchTheorem.getProofWithFreshMVarLevels (thm : EMatchTheorem) : MetaM Expr := do
if thm.proof.isConst && thm.levelParams.isEmpty then
let declName := thm.proof.constName!
let info ← getConstInfo declName
if info.levelParams.isEmpty then
return thm.proof
else
mkConstWithFreshMVarLevels declName
else if thm.levelParams.isEmpty then
return thm.proof
else
let us ← thm.levelParams.mapM fun _ => mkFreshLevelMVar
return thm.proof.instantiateLevelParamsArray thm.levelParams us

private builtin_initialize ematchTheoremsExt : SimpleScopedEnvExtension EMatchTheorem EMatchTheorems ←
registerSimpleScopedEnvExtension {
addEntry := EMatchTheorems.insert
Expand Down Expand Up @@ -316,7 +344,8 @@ def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr)
throwError "invalid pattern(s) for `{declName}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
ematchTheoremsExt.add {
proof, patterns, numParams, symbols
origin := .decl declName
levelParams := #[]
origin := .decl declName
}

def getEMatchTheorems : CoreM EMatchTheorems :=
Expand Down
19 changes: 10 additions & 9 deletions tests/lean/run/grind_ematch1.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ grind_pattern Array.get_set_ne => (a.set i v hi)[j]

set_option trace.grind.ematch.instance true

set_option grind.debug.proofs true

/--
info: [grind.ematch.instance] Array.get_set_eq : [α, bs, j, w, Lean.Grind.nestedProof (j < bs.toList.length) h₂]
[grind.ematch.instance] Array.get_set_eq : [α, as, i, v, Lean.Grind.nestedProof (i < as.toList.length) h₁]
[grind.ematch.instance] Array.get_set_ne : [α, bs, j, Lean.Grind.nestedProof (j < bs.toList.length) h₂, i, w, _, _]
info: [grind.ematch.instance] Array.get_set_eq: (bs.set j w ⋯)[j] = w
[grind.ematch.instance] Array.get_set_eq: (as.set i v ⋯)[i] = v
-/
#guard_msgs (info) in
example (as : Array α)
Expand All @@ -31,8 +32,8 @@ theorem Rtrans (a b c : Nat) : R a b → R b c → R a c := sorry
grind_pattern Rtrans => R a b, R b c

/--
info: [grind.ematch.instance] Rtrans : [b, c, d, _, _]
[grind.ematch.instance] Rtrans : [a, b, c, _, _]
info: [grind.ematch.instance] Rtrans: R b c → R c d → R b d
[grind.ematch.instance] Rtrans: R a b → R b c → R a c
-/
#guard_msgs (info) in
example : R a b → R b c → R c d → False := by
Expand All @@ -41,10 +42,10 @@ example : R a b → R b c → R c d → False := by

-- In the following test we are performing one round of ematching only
/--
info: [grind.ematch.instance] Rtrans : [c, d, e, _, _]
[grind.ematch.instance] Rtrans : [c, d, n, _, _]
[grind.ematch.instance] Rtrans : [b, c, d, _, _]
[grind.ematch.instance] Rtrans : [a, b, c, _, _]
info: [grind.ematch.instance] Rtrans: R c d → R d e → R c e
[grind.ematch.instance] Rtrans: R c d → R d n → R c n
[grind.ematch.instance] Rtrans: R b c → R c d → R b d
[grind.ematch.instance] Rtrans: R a b → R b c → R a c
-/
#guard_msgs (info) in
example : R a b → R b c → R c d → R d e → R d n → False := by
Expand Down
Loading