Skip to content

Commit

Permalink
feat: dot notation improvements
Browse files Browse the repository at this point in the history
This PR changes how generalized field notation ("dot notation") resolves the name of the function, adds a feature where terms such as `x.toString` can resolve as `toString x` as a last resort, and modifies the `Function` dot notation to always use the first explicit argument rather than the first function argument. The new rule is that if `x : S`, then `x.f` resolves the name `S.f` relative to the root namespace (hence it now responds to `export` and `open). Breaking change: aliases now resolve differently. Before, if `x : S`, and `S.f` is an alias for `S'.f`, then `x.f` would use `S'.f` and look for an argument of type `S'`. Now, it looks for an argument of type `S`, which is more generally useful behavior. Code making use of the old behavior should consider defining `S` or `S'` in terms of the other, since dot notation can unfold definitions during resolution.

This also fixes a bug in explicit field notation (`@x.f`) where `x` could be passed as the wrong argument.

Closes leanprover#3031
  • Loading branch information
kmill committed Nov 23, 2024
1 parent ba3f2b3 commit fa36ca0
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 68 deletions.
131 changes: 66 additions & 65 deletions src/Lean/Elab/App.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1140,8 +1140,9 @@ inductive LValResolution where
| projIdx (structName : Name) (idx : Nat)
/-- When applied to `f`, effectively expands to `constName ... (Struct.toBase f)`, with the argument placed in the correct
positional argument if possible, or otherwise as a named argument. The `Struct.toBase` is not present if `baseStructName == structName`,
in which case these do not need to be structures. Supports generalized field notation. -/
| const (baseStructName : Name) (structName : Name) (constName : Name)
in which case these do not need to be structures. Supports generalized field notation.
If `useFirstExplicit` is true, then `f` is inserted as the first explicit argument to `constName`. -/
| const (baseStructName : Name) (structName : Name) (constName : Name) (useFirstExplicit : Bool := false)
/-- Like `const`, but with `fvar` instead of `constName`.
The `fullName` is the name of the recursive function, and `baseName` is the base name of the type to search for in the parameter list. -/
| localRec (baseName : Name) (fullName : Name) (fvar : Expr)
Expand All @@ -1150,47 +1151,49 @@ private def throwLValError (e : Expr) (eType : Expr) (msg : MessageData) : TermE
throwError "{msg}{indentExpr e}\nhas type{indentExpr eType}"

/--
`findMethod? S fName` tries the following for each namespace `S'` in the resolution order for `S`:
- If `env` contains `S' ++ fName`, returns `(S', S' ++ fName)`
- Otherwise if `env` contains private name `prv` for `S' ++ fName`, returns `(S', prv)`
`findMethod? S fName` tries the for each namespace `S'` in the resolution order for `S` to resolve the name `S'.fname`.
If it resolves to `name`, returns `(S', name)`.
-/
private partial def findMethod? (structName fieldName : Name) : MetaM (Option (Name × Name)) := do
let env ← getEnv
let find? structName' : MetaM (Option (Name × Name)) := do
let fullName := structName' ++ fieldName
if env.contains fullName then
return some (structName', fullName)
let fullNamePrv := mkPrivateName env fullName
if env.contains fullNamePrv then
return some (structName', fullNamePrv)
return none
-- We do not want to make use of the current namespace for resolution.
let candidates := ResolveName.resolveGlobalName (← getEnv) Name.anonymous (← getOpenDecls) fullName
|>.filter (fun (_, fieldList) => fieldList.isEmpty)
|>.map Prod.fst
match candidates with
| [] => return none
| [fullName'] => return some (structName', fullName')
| _ => throwError "\
invalid field notation '{fieldName}', the name '{fullName}' is ambiguous, possible interpretations: \
{MessageData.joinSep (candidates.map (m!"'{.ofConstName ·}'")) ", "}"
-- Optimization: the first element of the resolution order is `structName`,
-- so we can skip computing the resolution order in the common case
-- of the name resolving in the `structName` namespace.
find? structName <||> do
let resolutionOrder ← if isStructure env structName then getStructureResolutionOrder structName else pure #[structName]
for h : i in [1:resolutionOrder.size] do
if let some res ← find? resolutionOrder[i] then
for ns in resolutionOrder[1:resolutionOrder.size] do
if let some res ← find? ns then
return res
return none

/--
Return `some (structName', fullName)` if `structName ++ fieldName` is an alias for `fullName`, and
`fullName` is of the form `structName' ++ fieldName`.
TODO: if there is more than one applicable alias, it returns `none`. We should consider throwing an error or
warning.
-/
private def findMethodAlias? (env : Environment) (structName fieldName : Name) : Option (Name × Name) :=
let fullName := structName ++ fieldName
-- We never skip `protected` aliases when resolving dot-notation.
let aliasesCandidates := getAliases env fullName (skipProtected := false) |>.filterMap fun alias =>
match alias.eraseSuffix? fieldName with
| none => none
| some structName' => some (structName', alias)
match aliasesCandidates with
| [r] => some r
| _ => none
private def findTopLevelMethod? (fieldName : Name) : MetaM (Option Name) := do
let env ← getEnv
if env.contains fieldName then
return some fieldName
let prvName := mkPrivateName env fieldName
if env.contains prvName then
return some prvName
-- There should be no `protected` top-level aliases, but no need to filter.
let aliasCandidates := getAliases env fieldName (skipProtected := false)
match aliasCandidates with
| [] => pure ()
| [alias] => return some alias
| _ => throwError "\
invalid field notation '{fieldName}', the name '{fieldName}' is ambigous, possible interpretations: \
{MessageData.joinSep (aliasCandidates.map (m!"'{.ofConstName ·}'")) ", "}"
return none

private def throwInvalidFieldNotation (e eType : Expr) : TermElabM α :=
throwLValError e eType "invalid field notation, type is not of the form (C ...) where C is a constant"
Expand All @@ -1201,7 +1204,7 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L
| LVal.fieldName _ fieldName _ _ =>
let fullName := Name.str `Function fieldName
if (← getEnv).contains fullName then
return LValResolution.const `Function `Function fullName
return LValResolution.const `Function `Function fullName (useFirstExplicit := true)
| _ => pure ()
match eType.getAppFn.constName?, lval with
| some structName, LVal.fieldIdx _ idx =>
Expand All @@ -1223,30 +1226,27 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L
throwLValError e eType m!"invalid projection, structure has only {numFields} field(s)"
| some structName, LVal.fieldName _ fieldName _ _ =>
let env ← getEnv
let searchEnv : Unit → TermElabM LValResolution := fun _ => do
if let some (baseStructName, fullName) ← findMethod? structName (.mkSimple fieldName) then
return LValResolution.const baseStructName structName fullName
else if let some (structName', fullName) := findMethodAlias? env structName (.mkSimple fieldName) then
return LValResolution.const structName' structName' fullName
else
throwLValError e eType
m!"invalid field '{fieldName}', the environment does not contain '{Name.mkStr structName fieldName}'"
-- search local context first, then environment
let searchCtx : Unit → TermElabM LValResolution := fun _ => do
let fullName := Name.mkStr structName fieldName
for localDecl in (← getLCtx) do
if localDecl.isAuxDecl then
if let some localDeclFullName := (← read).auxDeclToFullName.find? localDecl.fvarId then
if fullName == (privateToUserName? localDeclFullName).getD localDeclFullName then
/- LVal notation is being used to make a "local" recursive call. -/
return LValResolution.localRec structName fullName localDecl.toExpr
searchEnv ()
if isStructure env structName then
match findField? env structName (Name.mkSimple fieldName) with
| some baseStructName => return LValResolution.projFn baseStructName structName (Name.mkSimple fieldName)
| none => searchCtx ()
else
searchCtx ()
if let some baseStructName := findField? env structName (Name.mkSimple fieldName) then
return LValResolution.projFn baseStructName structName (Name.mkSimple fieldName)
-- Search the local context first
let fullName := Name.mkStr structName fieldName
for localDecl in (← getLCtx) do
if localDecl.isAuxDecl then
if let some localDeclFullName := (← read).auxDeclToFullName.find? localDecl.fvarId then
if fullName == (privateToUserName? localDeclFullName).getD localDeclFullName then
/- LVal notation is being used to make a "local" recursive call. -/
return LValResolution.localRec structName fullName localDecl.toExpr
-- Then search the environment
if let some (baseStructName, fullName) ← findMethod? structName (.mkSimple fieldName) then
return LValResolution.const baseStructName structName fullName
-- Otherwise search for a top-level declaration named `fieldName`.
-- We only do this if `eType` does not unfold so that this is done as a last resort (`resolveLValLoop` loops so long as the type unfolds).
if (← unfoldDefinition? eType).isNone then
if let some fullName ← findTopLevelMethod? (.mkSimple fieldName) then
return LValResolution.const structName structName fullName (useFirstExplicit := true)
throwLValError e eType
m!"invalid field '{fieldName}', the environment does not contain '{Name.mkStr structName fieldName}'"
| none, LVal.fieldName _ _ (some suffix) _ =>
if e.isConst then
throwUnknownConstant (e.constName! ++ suffix)
Expand Down Expand Up @@ -1326,7 +1326,8 @@ Otherwise, if there isn't another parameter with the same name, we add `e` to `n
Remark: `fullName` is the name of the resolved "field" access function. It is used for reporting errors
-/
private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr) :
private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr)
(explicit : Bool) (useFirstExplicit : Bool := false) :
MetaM (Array Arg × Array NamedArg) := do
withoutModifyingState <| go f (← inferType f) 0 namedArgs (namedArgs.map (·.name)) true
where
Expand All @@ -1351,25 +1352,25 @@ where
/- If there is named argument with name `xDecl.userName`, then it is accounted for and we can't make use of it. -/
remainingNamedArgs := remainingNamedArgs.eraseIdx idx
else
if (← typeMatchesBaseName xDecl.type baseName) then
/- We found a type of the form (baseName ...).
First, we check if the current argument is an explicit one,
if ← (pure <| useFirstExplicit && bInfo.isExplicit) <||> typeMatchesBaseName xDecl.type baseName then
/- We found a type of the form (baseName ...), or we found the first explicit argument in useFirstExplicit mode.
First, we check if the current argument is one that can be used positionally,
and if the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/
if h : argIdx ≤ args.size ∧ bInfo.isExplicit then
if h : argIdx ≤ args.size ∧ (explicit || bInfo.isExplicit) then
/- We can insert `e` as an explicit argument -/
return (args.insertIdx argIdx (Arg.expr e), namedArgs)
else
/- If we can't add `e` to `args`, we try to add it using a named argument, but this is only possible
if there isn't an argument with the same name occurring before it. -/
if !allowNamed || unusableNamedArgs.contains xDecl.userName then
throwError "\
invalid field notation, function '{fullName}' has argument with the expected type\
invalid field notation, function '{.ofConstName fullName}' has argument with the expected type\
{indentExpr xDecl.type}\n\
but it cannot be used"
else
return (args, namedArgs.push { name := xDecl.userName, val := Arg.expr e })
/- Advance `argIdx` and update seen named arguments. -/
if bInfo.isExplicit then
if explicit || bInfo.isExplicit then
argIdx := argIdx + 1
unusableNamedArgs := unusableNamedArgs.push xDecl.userName
/- If named arguments aren't allowed, then it must still be possible to pass the value as an explicit argument.
Expand All @@ -1380,7 +1381,7 @@ where
if let some f' ← coerceToFunction? (mkAppN f xs) then
return ← go f' (← inferType f') argIdx remainingNamedArgs unusableNamedArgs false
throwError "\
invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, \
invalid field notation, function '{.ofConstName fullName}' does not have argument with type ({.ofConstName baseName} ...) that can be used, \
it must be explicit or implicit with a unique name"

/-- Adds the `TermInfo` for the field of a projection. See `Lean.Parser.Term.identProjKind`. -/
Expand Down Expand Up @@ -1421,20 +1422,20 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
else
let f ← elabAppArgs projFn #[{ name := `self, val := Arg.expr f, suppressDeps := true }] #[] (expectedType? := none) (explicit := false) (ellipsis := false)
loop f lvals
| LValResolution.const baseStructName structName constName =>
| LValResolution.const baseStructName structName constName useFirstExplicit =>
let f ← if baseStructName != structName then mkBaseProjections baseStructName structName f else pure f
let projFn ← mkConst constName
let projFn ← addProjTermInfo lval.getRef projFn
if lvals.isEmpty then
let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFn
let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFn explicit useFirstExplicit
elabAppArgs projFn namedArgs args expectedType? explicit ellipsis
else
let f ← elabAppArgs projFn #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false)
loop f lvals
| LValResolution.localRec baseName fullName fvar =>
let fvar ← addProjTermInfo lval.getRef fvar
if lvals.isEmpty then
let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvar
let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvar explicit
elabAppArgs fvar namedArgs args expectedType? explicit ellipsis
else
let f ← elabAppArgs fvar #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false)
Expand Down
119 changes: 119 additions & 0 deletions tests/lean/run/3031.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/-!
# Tests for generalized field notation through aliases and "top-level" dot notation
https://github.com/leanprover/lean4/issues/3031
-/

/-!
Alias dot notation. There used to be a different kind of alias dot notation;
in the following example, it would have looked for an argument of type `Common.String`.
Now it looks for one of type `String`, allowing libraries to add "extension methods" from within their own namespaces.
-/
def Common.String.a (s : String) : Nat := s.length

export Common (String.a)

/-- info: String.a "x" : Nat -/
#guard_msgs in #check String.a "x"
/-- info: String.a "x" : Nat -/
#guard_msgs in #check "x".a

/-!
Declarations take precedence over aliases
-/
def String.a (s : String) : Nat := s.length + 100
/-- info: "x".a : Nat -/
#guard_msgs in #check "x".a
/-- info: 100 -/
#guard_msgs in #eval "".a

/-!
Private declarations take precedence over aliases
-/
private def String.b (s : String) : Nat := 0
def Common.String.b (s : String) : Nat := 1
export Common (String.b)
/-- info: 0 -/
#guard_msgs in #eval "".b

/-!
Multiple aliases is an error
-/
def Common.String.c (s : String) : Nat := 0
def Common'.String.c (s : String) : Nat := 0
export Common (String.c)
export Common' (String.c)
/--
error: invalid field notation 'c', the name 'String.c' is ambiguous, possible interpretations: 'Common'.String.c', 'Common.String.c'
-/
#guard_msgs in #eval "".c

/-!
Aliases work with inheritance
-/
namespace Ex1
structure A
structure B extends A
def Common.A.x (_ : A) : Nat := 0
export Common (A.x)
/-- info: fun b => A.x b.toA : B → Nat -/
#guard_msgs in #check fun (b : B) => b.x
end Ex1

/-!
`open` also works
-/
def Common.String.parse (_ : String) : List Nat := []

namespace ExOpen1
/--
error: invalid field 'parse', the environment does not contain 'String.parse'
""
has type
String
-/
#guard_msgs in #check "".parse
section
open Common
/-- info: String.parse "" : List Nat -/
#guard_msgs in #check "".parse
end
section
open Common (String.parse)
/-- info: String.parse "" : List Nat -/
#guard_msgs in #check "".parse
end
end ExOpen1


namespace Ex2
class A (n : Nat) where
x : Nat

/-!
"Top-level" dot notation. As a last resort, field notation looks for a top-level declaration or alias
and supplies the value as the first explicit argument.
-/
instance : ToString (A n) where
toString a := s!"A.x is {a.x}"

/-- info: fun a => toString a : A 2 → String -/
#guard_msgs in #check fun (a : A 2) => a.toString

/-!
Incidental fix: `@` for generalized field notation was failing if there were implicit arguments.
True projections were ok.
-/
def A.x' {n : Nat} (a : A n) := a.x

/-- info: fun a => a.x' : A 2 → Nat -/
#guard_msgs in #check fun (a : A 2) => @a.x'
end Ex2

namespace Ex3
variable (f : α → β) (g : β → γ)
/-!
Functions use the "top-level" dot notation rule: they use the first explicit argument, rather than the first function argument.
-/
/-- info: g ∘ f : α → γ -/
#guard_msgs in #check g.comp f
end Ex3
13 changes: 12 additions & 1 deletion tests/lean/run/DVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ example (v : Vec Nat 1) : Nat :=

#check @Vec.hd

-- works
-- Does not work: Aliases find that `v` could be the `TypeVec` argument since `TypeVec` is an abbrev for `Vec`.
/--
error: application type mismatch
@Vec.hd ?_ v
argument
v
has type
Vec Nat 1 : Type
but is expected to have type
TypeVec (?_ + 1) : Type (_ + 1)
-/
#guard_msgs in set_option pp.mvars false in
example (v : Vec Nat 1) : Nat :=
v.hd
21 changes: 19 additions & 2 deletions tests/lean/run/alias.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,31 @@ def Set (α : Type) := α → Prop
def Set.union (s₁ s₂ : Set α) : Set α :=
fun a => s₁ a ∨ s₂ a

def FinSet (n : Nat) := Fin n → Prop
def FinSet (n : Nat) := Set (Fin n)

/-!
The type of `x` is unfolded to find `Set.union`
-/
example (x y : FinSet 10) : FinSet 10 :=
x.union y

namespace FinSet
export Set (union)
export Set (union)
end FinSet

/-!
Since the types are defeq, this alias works:
-/
example (x y : FinSet 10) : FinSet 10 :=
FinSet.union x y

/-!
However, this dot notation fails since there is no `FinSet` argument.
However, unfolding is the preferred method.
-/
/--
error: invalid field notation, function 'FinSet.union' does not have argument with type (FinSet ...) that can be used, it must be explicit or implicit with a unique name
-/
#guard_msgs in
example (x y : FinSet 10) : FinSet 10 :=
x.union y

0 comments on commit fa36ca0

Please sign in to comment.