Skip to content

Commit

Permalink
feat: let dot notation see through CoeFun instances (#5692)
Browse files Browse the repository at this point in the history
Projects like mathlib like to define projection functions with extra
structure, for example one could imagine defining `Multiset.card :
Multiset α →+ Nat`, which bundles the fact that `Multiset.card (m1 + m2)
= Multiset.card m1 + Multiset.card m2` for all `m1 m2 : Multiset α`. A
problem though is that so far this has prevented dot notation from
working: you can't write `(m1 + m2).card = m1.card + m2.card`.

With this PR, now you can. The way it works is that "LValue resolution"
will apply CoeFun instances when trying to resolve which argument should
receive the object of dot notation.

A contrived-yet-representative example:
```lean
structure Equiv (α β : Sort _) where
  toFun : α → β
  invFun : β → α

infixl:25 " ≃ " => Equiv

instance: CoeFun (α ≃ β) fun _ => α → β where
  coe := Equiv.toFun

structure Foo where
  n : Nat

def Foo.n' : Foo ≃ Nat := ⟨Foo.n, Foo.mk⟩

variable (f : Foo)
#check f.n'
-- Foo.n'.toFun f : Nat
```

Design note 1: While LValue resolution attempts to make use of named
arguments when positional arguments cannot be used, when we apply CoeFun
instances we disallow making use of named arguments. The rationale is
that argument names for CoeFun instances tend to be random, which could
lead dot notation randomly succeeding or failing. It is better to be
uniform, and so it uniformly fails in this case.

Design note 2: There is a limitation in that this will *not* make use of
the values of any of the provided arguments when synthesizing the CoeFun
instances (see the tests for an example), since argument elaboration
takes place after LValue resolution. However, we make sure that
synthesis will fail rather than choose the wrong CoeFun instance.

Performance note: Such instances will be synthesized twice, once during
LValue resolution, and again when applying arguments.

This also adds in a small optimization to the parameter list computation
in LValue resolution so that it lazily reduces when a relevant parameter
hasn't been found yet, rather than using `forallTelescopeReducing`. It
also switches to using `forallMetaTelescope` to make sure the CoeFun
synthesis will fail if multiple instances could apply.

Getting this to pretty print will be deferred to future work.

Closes #1910
  • Loading branch information
kmill authored Oct 14, 2024
1 parent 36c2511 commit a026bc7
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 38 deletions.
107 changes: 69 additions & 38 deletions src/Lean/Elab/App.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1118,9 +1118,17 @@ where

/-- Auxiliary inductive datatype that represents the resolution of an `LVal`. -/
inductive LValResolution where
/-- When applied to `f`, effectively expands to `BaseStruct.fieldName (self := Struct.toBase f)`.
This is a special named argument where it suppresses any explicit arguments depending on it so that type parameters don't need to be supplied. -/
| projFn (baseStructName : Name) (structName : Name) (fieldName : Name)
/-- Similar to `projFn`, but for extracting field indexed by `idx`. Works for structure-like inductive types in general. -/
| projIdx (structName : Name) (idx : Nat)
/-- When applied to `f`, effectively expands to `constName ... (Struct.toBase f)`, with the argument placed in the correct
positional argument if possible, or otherwise as a named argument. The `Struct.toBase` is not present if `baseStructName == structName`,
in which case these do not need to be structures. Supports generalized field notation. -/
| const (baseStructName : Name) (structName : Name) (constName : Name)
/-- Like `const`, but with `fvar` instead of `constName`.
The `fullName` is the name of the recursive function, and `baseName` is the base name of the type to search for in the parameter list. -/
| localRec (baseName : Name) (fullName : Name) (fvar : Expr)

private def throwLValError (e : Expr) (eType : Expr) (msg : MessageData) : TermElabM α :=
Expand Down Expand Up @@ -1290,45 +1298,70 @@ private def typeMatchesBaseName (type : Expr) (baseName : Name) : MetaM Bool :=
else
return (← whnfR type).isAppOf baseName

/-- Auxiliary method for field notation. It tries to add `e` as a new argument to `args` or `namedArgs`.
This method first finds the parameter with a type of the form `(baseName ...)`.
When the parameter is found, if it an explicit one and `args` is big enough, we add `e` to `args`.
Otherwise, if there isn't another parameter with the same name, we add `e` to `namedArgs`.
Remark: `fullName` is the name of the resolved "field" access function. It is used for reporting errors -/
private def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (fType : Expr)
: TermElabM (Array Arg × Array NamedArg) :=
forallTelescopeReducing fType fun xs _ => do
let mut argIdx := 0 -- position of the next explicit argument
let mut remainingNamedArgs := namedArgs
for h : i in [:xs.size] do
let x := xs[i]
let xDecl ← x.fvarId!.getDecl
/- If there is named argument with name `xDecl.userName`, then we skip it. -/
match remainingNamedArgs.findIdx? (fun namedArg => namedArg.name == xDecl.userName) with
| some idx =>
/--
Auxiliary method for field notation. Tries to add `e` as a new argument to `args` or `namedArgs`.
This method first finds the parameter with a type of the form `(baseName ...)`.
When the parameter is found, if it an explicit one and `args` is big enough, we add `e` to `args`.
Otherwise, if there isn't another parameter with the same name, we add `e` to `namedArgs`.
Remark: `fullName` is the name of the resolved "field" access function. It is used for reporting errors
-/
private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr) :
MetaM (Array Arg × Array NamedArg) := do
withoutModifyingState <| go f (← inferType f) 0 namedArgs (namedArgs.map (·.name)) true
where
/--
* `argIdx` is the position into `args` for the next place an explicit argument can be inserted.
* `remainingNamedArgs` keeps track of named arguments that haven't been visited yet,
for handling the case where multiple parameters have the same name.
* `unusableNamedArgs` keeps track of names that can't be used as named arguments. This is initialized with user-provided named arguments.
* `allowNamed` is whether or not to allow using named arguments.
Disabled after using `CoeFun` since those parameter names unlikely to be meaningful,
and otherwise whether dot notation works or not could feel random.
-/
go (f fType : Expr) (argIdx : Nat) (remainingNamedArgs : Array NamedArg) (unusableNamedArgs : Array Name) (allowNamed : Bool) := withIncRecDepth do
/- Use metavariables (rather than `forallTelescope`) to prevent `coerceToFunction?` from succeeding when multiple instances could apply -/
let (xs, bInfos, fType') ← forallMetaTelescope fType
let mut argIdx := argIdx
let mut remainingNamedArgs := remainingNamedArgs
let mut unusableNamedArgs := unusableNamedArgs
for x in xs, bInfo in bInfos do
let xDecl ← x.mvarId!.getDecl
if let some idx := remainingNamedArgs.findIdx? (·.name == xDecl.userName) then
/- If there is named argument with name `xDecl.userName`, then it is accounted for and we can't make use of it. -/
remainingNamedArgs := remainingNamedArgs.eraseIdx idx
| none =>
let type := xDecl.type
if (← typeMatchesBaseName type baseName) then
else
if (← typeMatchesBaseName xDecl.type baseName) then
/- We found a type of the form (baseName ...).
First, we check if the current argument is an explicit one,
and the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/
if argIdx ≤ args.size && xDecl.binderInfo.isExplicit then
/- We insert `e` as an explicit argument -/
and if the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/
if argIdx ≤ args.size && bInfo.isExplicit then
/- We can insert `e` as an explicit argument -/
return (args.insertAt! argIdx (Arg.expr e), namedArgs)
/- If we can't add `e` to `args`, we try to add it using a named argument, but this is only possible
if there isn't an argument with the same name occurring before it. -/
for j in [:i] do
let prev := xs[j]!
let prevDecl ← prev.fvarId!.getDecl
if prevDecl.userName == xDecl.userName then
throwError "invalid field notation, function '{fullName}' has argument with the expected type{indentExpr type}\nbut it cannot be used"
return (args, namedArgs.push { name := xDecl.userName, val := Arg.expr e })
if xDecl.binderInfo.isExplicit then
-- advance explicit argument position
else
/- If we can't add `e` to `args`, we try to add it using a named argument, but this is only possible
if there isn't an argument with the same name occurring before it. -/
if !allowNamed || unusableNamedArgs.contains xDecl.userName then
throwError "\
invalid field notation, function '{fullName}' has argument with the expected type\
{indentExpr xDecl.type}\n\
but it cannot be used"
else
return (args, namedArgs.push { name := xDecl.userName, val := Arg.expr e })
/- Advance `argIdx` and update seen named arguments. -/
if bInfo.isExplicit then
argIdx := argIdx + 1
throwError "invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, it must be explicit or implicit with a unique name"
unusableNamedArgs := unusableNamedArgs.push xDecl.userName
/- If named arguments aren't allowed, then it must still be possible to pass the value as an explicit argument.
Otherwise, we can abort now. -/
if allowNamed || argIdx ≤ args.size then
if let fType'@(.forallE ..) ← whnf fType' then
return ← go (mkAppN f xs) fType' argIdx remainingNamedArgs unusableNamedArgs allowNamed
if let some f' ← coerceToFunction? (mkAppN f xs) then
return ← go f' (← inferType f') argIdx remainingNamedArgs unusableNamedArgs false
throwError "\
invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, \
it must be explicit or implicit with a unique name"

/-- Adds the `TermInfo` for the field of a projection. See `Lean.Parser.Term.identProjKind`. -/
private def addProjTermInfo
Expand Down Expand Up @@ -1375,17 +1408,15 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
let projFn ← mkConst constName
let projFn ← addProjTermInfo lval.getRef projFn
if lvals.isEmpty then
let projFnType ← inferType projFn
let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFnType
let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFn
elabAppArgs projFn namedArgs args expectedType? explicit ellipsis
else
let f ← elabAppArgs projFn #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false)
loop f lvals
| LValResolution.localRec baseName fullName fvar =>
let fvar ← addProjTermInfo lval.getRef fvar
if lvals.isEmpty then
let fvarType ← inferType fvar
let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvarType
let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvar
elabAppArgs fvar namedArgs args expectedType? explicit ellipsis
else
let f ← elabAppArgs fvar #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false)
Expand Down
95 changes: 95 additions & 0 deletions tests/lean/run/1910.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/-!
# Dot notation and CoeFun
https://github.com/leanprover/lean4/issues/1910
-/

set_option pp.mvars false

/-!
Test that dot notation resolution can see through CoeFun instances.
-/

structure Equiv (α β : Sort _) where
toFun : α → β
invFun : β → α

infixl:25 " ≃ " => Equiv

instance : CoeFun (α ≃ β) fun _ => α → β where
coe := Equiv.toFun

structure Foo where
n : Nat

def Foo.n' : Foo ≃ Nat := ⟨Foo.n, Foo.mk⟩

variable (f : Foo)
/-- info: Foo.n'.toFun f : Nat -/
#guard_msgs in #check f.n'

example (f : Foo) : f.n' = f.n := rfl

/-!
Fail dot notation if it requires using a named argument from the CoeFun instance.
-/
structure F where
f : Bool → Nat → Nat

instance : CoeFun F (fun _ => (x : Bool) → (y : Nat) → Nat) where
coe x := fun (a : Bool) (b : Nat) => x.f a b

-- Recall CoeFun oddity: it uses the unfolded *value* to figure out parameter names.
-- That's why this is `a` and `b` rather than `x` and `y`.
/-- info: fun x => (fun a b => x.f a b) true 2 : F → Nat -/
#guard_msgs in #check fun (x : F) => x (a := true) (b := 2)

def Nat.foo : F := { f := fun _ b => b }

-- Ok:
/-- info: fun n x => (fun a b => Nat.foo.f a b) x n : Nat → Bool → Nat -/
#guard_msgs in #check fun (n : Nat) => (Nat.foo · n)

-- Intentionally fails:
/--
error: invalid field notation, function 'Nat.foo' has argument with the expected type
Nat
but it cannot be used
---
info: fun n => sorryAx (?_ n) true : (n : Nat) → ?_ n
-/
#guard_msgs in #check fun (n : Nat) => n.foo

/-!
Make sure that dot notation does not use the wrong CoeFun instance.
The following instances rely on the second one having higher priority,
so we need to fail completely when the instances would depend on argument values.
-/

structure Bar (b : Bool) where

instance : CoeFun (Bar b) (fun _ => Bar b → Bool) where
coe := fun _ _ => b

instance : CoeFun (Bar true) (fun _ => (b : Bool) → Bar b) where
coe := fun _ _ => {}

def Bar.bar : Bar true := {}

/-- info: fun f => (fun x => false) f : Bar false → Bool -/
#guard_msgs in #check fun (f : Bar false) => Bar.bar false f
/--
error: invalid field notation, function 'Bar.bar' does not have argument with type (Bar ...) that can be used, it must be explicit or implicit with a unique name
---
info: fun f => sorryAx (?_ f) true : (f : Bar false) → ?_ f
-/
#guard_msgs in #check fun (f : Bar false) => f.bar false

/-- info: fun f => (fun x => false) f : Bar false → Bool -/
#guard_msgs in #check fun (f : Bar false) => Bar.bar true false f
/--
error: invalid field notation, function 'Bar.bar' does not have argument with type (Bar ...) that can be used, it must be explicit or implicit with a unique name
---
info: fun f => sorryAx (?_ f) true : (f : Bar false) → ?_ f
-/
#guard_msgs in #check fun (f : Bar false) => f.bar true false

0 comments on commit a026bc7

Please sign in to comment.