Skip to content

Commit

Permalink
feat: ematch theorem activation for grind (#6475)
Browse files Browse the repository at this point in the history
This PR adds support for activating relevant theorems for the (WIP)
`grind` tactic. We say a theorem is relevant to a `grind` goal if the
symbols occurring in its patterns also occur in the goal.
  • Loading branch information
leodemoura authored Dec 30, 2024
1 parent 24a8561 commit 9b28c58
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/Lean/Elab/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def elabGrindPattern : CommandElab := fun stx => do
let pattern ← instantiateMVars (← elabTerm term none)
let pattern ← Grind.unfoldReducible pattern
return pattern.abstract xs
Grind.addTheoremPattern declName xs.size patterns.toList
Grind.addEMatchTheorem declName xs.size patterns.toList
| _ => throwUnsupportedSyntax

def grind (mvarId : MVarId) (mainDeclName : Name) : MetaM Unit := do
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import Lean.Meta.Tactic.Grind.PP
import Lean.Meta.Tactic.Grind.Simp
import Lean.Meta.Tactic.Grind.Ctor
import Lean.Meta.Tactic.Grind.Parser
import Lean.Meta.Tactic.Grind.TheoremPatterns
import Lean.Meta.Tactic.Grind.EMatchTheorem

namespace Lean

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ inductive Origin where
| other
deriving Inhabited, Repr

structure TheoremPattern where
/-- A unique identifier corresponding to the origin. -/
def Origin.key : Origin → Name
| .decl declName => declName
| .fvar fvarId => fvarId.name
| .stx id _ => id
| .other => `other

/-- A theorem for heuristic instantiation based on E-matching. -/
structure EMatchTheorem where
proof : Expr
numParams : Nat
patterns : List Expr
Expand All @@ -34,16 +42,21 @@ structure TheoremPattern where
origin : Origin
deriving Inhabited

abbrev TheoremPatterns := SMap Name (List TheoremPattern)
/-- The key is a symbol from `EMatchTheorem.symbols`. -/
abbrev EMatchTheorems := PHashMap Name (List EMatchTheorem)

builtin_initialize theoremPatternsExt : SimpleScopedEnvExtension TheoremPattern TheoremPatterns ←
def EMatchTheorems.insert (s : EMatchTheorems) (thm : EMatchTheorem) : EMatchTheorems := Id.run do
let .const declName :: syms := thm.symbols
| unreachable!
let thm := { thm with symbols := syms }
if let some thms := s.find? declName then
return PersistentHashMap.insert s declName (thm::thms)
else
return PersistentHashMap.insert s declName [thm]

private builtin_initialize ematchTheoremsExt : SimpleScopedEnvExtension EMatchTheorem EMatchTheorems ←
registerSimpleScopedEnvExtension {
addEntry := fun s t => Id.run do
let .const declName :: _ := t.symbols | unreachable!
if let some ts := s.find? declName then
s.insert declName (t::ts)
else
s.insert declName [t]
addEntry := EMatchTheorems.insert
initial := .empty
}

Expand Down Expand Up @@ -282,19 +295,23 @@ private def ppParamsAt (proof : Expr) (numParms : Nat) (paramPos : List Nat) : M
msg := msg ++ m!"{x} : {← inferType x}"
addMessageContextFull msg

def addTheoremPattern (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do
def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do
let .thmInfo info ← getConstInfo declName
| throwError "`{declName}` is not a theorem, you cannot assign patterns to non-theorems for the `grind` tactic"
let us := info.levelParams.map mkLevelParam
let proof := mkConst declName us
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
assert! symbols.all fun s => s matches .const _
trace[grind.pattern] "{declName}: {patterns.map ppPattern}"
if let .missing pos ← checkCoverage proof numParams bvarFound then
let pats : MessageData := m!"{patterns.map ppPattern}"
throwError "invalid pattern(s) for `{declName}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
theoremPatternsExt.add {
ematchTheoremsExt.add {
proof, patterns, numParams, symbols
origin := .decl declName
}

def getEMatchTheorems : CoreM EMatchTheorems :=
return ematchTheoremsExt.getState (← getEnv)

end Lean.Meta.Grind
19 changes: 18 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Internalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ private def updateAppMap (e : Expr) : GoalM Unit := do
s.appMap.insert key [e]
}

private def activateTheoremPatterns (fName : Name) : GoalM Unit := do
if let some thms := (← get).thmMap.find? fName then
modify fun s => { s with thmMap := s.thmMap.erase fName }
let appMap := (← get).appMap
for thm in thms do
let symbols := thm.symbols.filter fun sym => !appMap.contains sym
let thm := { thm with symbols }
match symbols with
| [] =>
trace[grind.pattern] "activated `{thm.origin.key}`"
modify fun s => { s with newThms := s.newThms.push thm }
| _ =>
trace[grind.pattern] "reinsert `{thm.origin.key}`"
modify fun s => { s with thmMap := s.thmMap.insert thm }

partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
if (← alreadyInternalized e) then return ()
match e with
Expand All @@ -63,7 +78,9 @@ partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
internalize c generation
registerParent e c
else
unless f.isConst do
if let .const fName _ := f then
activateTheoremPatterns fName
else
internalize f generation
registerParent e f
for h : i in [: args.size] do
Expand Down
3 changes: 2 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Run.lean
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α := do
def mkGoal (mvarId : MVarId) : GrindM Goal := do
let trueExpr ← getTrueExpr
let falseExpr ← getFalseExpr
GoalM.run' { mvarId } do
let thmMap ← getEMatchTheorems
GoalM.run' { mvarId, thmMap } do
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)

Expand Down
10 changes: 10 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import Lean.Meta.Tactic.Simp.Types
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Grind.Canon
import Lean.Meta.Tactic.Grind.Attr
import Lean.Meta.Tactic.Grind.EMatchTheorem

namespace Lean.Meta.Grind

Expand Down Expand Up @@ -273,6 +274,15 @@ structure Goal where
gmt : Nat := 0
/-- Next unique index for creating ENodes -/
nextIdx : Nat := 0
/-- Active theorems that we have performed ematching at least once. -/
thms : PArray EMatchTheorem := {}
/-- Active theorems that we have not performed any round of ematching yet. -/
newThms : PArray EMatchTheorem := {}
/--
Inactive global theorems. As we internalize terms, we activate theorems as we find their symbols.
Local theorem provided by users are added directly into `newThms`.
-/
thmMap : EMatchTheorems
deriving Inhabited

def Goal.admit (goal : Goal) : MetaM Unit :=
Expand Down
39 changes: 39 additions & 0 deletions tests/lean/run/grind_pattern2.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
def Set (α : Type) := α → Bool

def insertElem [DecidableEq α] (s : Set α) (a : α) : Set α :=
fun x => a = x || s x

def contains (s : Set α) (a : α) : Bool :=
s a

theorem contains_insert [DecidableEq α] (s : Set α) (a : α) : contains (insertElem s a) a := by
simp [contains, insertElem]

grind_pattern contains_insert => contains (insertElem s a) a

-- TheoremPattern activation test

set_option trace.grind.pattern true

/--
warning: declaration uses 'sorry'
---
info: [grind.pattern] activated `contains_insert`
-/
#guard_msgs in
example [DecidableEq α] (s₁ s₂ : Set α) (a₁ a₂ : α) :
s₂ = insertElem s₁ a₁ → a₁ = a₂ → contains s₂ a₂ := by
fail_if_success grind
sorry

/--
warning: declaration uses 'sorry'
---
info: [grind.pattern] reinsert `contains_insert`
[grind.pattern] activated `contains_insert`
-/
#guard_msgs in
example [DecidableEq α] (s₁ s₂ : Set α) (a₁ a₂ : α) :
¬ contains s₂ a₂ → s₂ = insertElem s₁ a₁ → a₁ = a₂ → False := by
fail_if_success grind
sorry

0 comments on commit 9b28c58

Please sign in to comment.