diff --git a/src/Lean/Elab/App.lean b/src/Lean/Elab/App.lean index 4ba28f696ead..350e833df365 100644 --- a/src/Lean/Elab/App.lean +++ b/src/Lean/Elab/App.lean @@ -1140,8 +1140,9 @@ inductive LValResolution where | projIdx (structName : Name) (idx : Nat) /-- When applied to `f`, effectively expands to `constName ... (Struct.toBase f)`, with the argument placed in the correct positional argument if possible, or otherwise as a named argument. The `Struct.toBase` is not present if `baseStructName == structName`, - in which case these do not need to be structures. Supports generalized field notation. -/ - | const (baseStructName : Name) (structName : Name) (constName : Name) + in which case these do not need to be structures. Supports generalized field notation. + If `useFirstExplicit` is true, then `f` is inserted as the first explicit argument to `constName`. -/ + | const (baseStructName : Name) (structName : Name) (constName : Name) (useFirstExplicit : Bool := false) /-- Like `const`, but with `fvar` instead of `constName`. The `fullName` is the name of the recursive function, and `baseName` is the base name of the type to search for in the parameter list. -/ | localRec (baseName : Name) (fullName : Name) (fvar : Expr) @@ -1150,47 +1151,49 @@ private def throwLValError (e : Expr) (eType : Expr) (msg : MessageData) : TermE throwError "{msg}{indentExpr e}\nhas type{indentExpr eType}" /-- -`findMethod? S fName` tries the following for each namespace `S'` in the resolution order for `S`: -- If `env` contains `S' ++ fName`, returns `(S', S' ++ fName)` -- Otherwise if `env` contains private name `prv` for `S' ++ fName`, returns `(S', prv)` +`findMethod? S fName` tries the for each namespace `S'` in the resolution order for `S` to resolve the name `S'.fname`. +If it resolves to `name`, returns `(S', name)`. -/ private partial def findMethod? (structName fieldName : Name) : MetaM (Option (Name × Name)) := do let env ← getEnv let find? structName' : MetaM (Option (Name × Name)) := do let fullName := structName' ++ fieldName - if env.contains fullName then - return some (structName', fullName) - let fullNamePrv := mkPrivateName env fullName - if env.contains fullNamePrv then - return some (structName', fullNamePrv) - return none + -- We do not want to make use of the current namespace for resolution. + let candidates := ResolveName.resolveGlobalName (← getEnv) Name.anonymous (← getOpenDecls) fullName + |>.filter (fun (_, fieldList) => fieldList.isEmpty) + |>.map Prod.fst + match candidates with + | [] => return none + | [fullName'] => return some (structName', fullName') + | _ => throwError "\ + invalid field notation '{fieldName}', the name '{fullName}' is ambiguous, possible interpretations: \ + {MessageData.joinSep (candidates.map (m!"'{.ofConstName ·}'")) ", "}" -- Optimization: the first element of the resolution order is `structName`, -- so we can skip computing the resolution order in the common case -- of the name resolving in the `structName` namespace. find? structName <||> do let resolutionOrder ← if isStructure env structName then getStructureResolutionOrder structName else pure #[structName] - for h : i in [1:resolutionOrder.size] do - if let some res ← find? resolutionOrder[i] then + for ns in resolutionOrder[1:resolutionOrder.size] do + if let some res ← find? ns then return res return none -/-- - Return `some (structName', fullName)` if `structName ++ fieldName` is an alias for `fullName`, and - `fullName` is of the form `structName' ++ fieldName`. - - TODO: if there is more than one applicable alias, it returns `none`. We should consider throwing an error or - warning. --/ -private def findMethodAlias? (env : Environment) (structName fieldName : Name) : Option (Name × Name) := - let fullName := structName ++ fieldName - -- We never skip `protected` aliases when resolving dot-notation. - let aliasesCandidates := getAliases env fullName (skipProtected := false) |>.filterMap fun alias => - match alias.eraseSuffix? fieldName with - | none => none - | some structName' => some (structName', alias) - match aliasesCandidates with - | [r] => some r - | _ => none +private def findTopLevelMethod? (fieldName : Name) : MetaM (Option Name) := do + let env ← getEnv + if env.contains fieldName then + return some fieldName + let prvName := mkPrivateName env fieldName + if env.contains prvName then + return some prvName + -- There should be no `protected` top-level aliases, but no need to filter. + let aliasCandidates := getAliases env fieldName (skipProtected := false) + match aliasCandidates with + | [] => pure () + | [alias] => return some alias + | _ => throwError "\ + invalid field notation '{fieldName}', the name '{fieldName}' is ambigous, possible interpretations: \ + {MessageData.joinSep (aliasCandidates.map (m!"'{.ofConstName ·}'")) ", "}" + return none private def throwInvalidFieldNotation (e eType : Expr) : TermElabM α := throwLValError e eType "invalid field notation, type is not of the form (C ...) where C is a constant" @@ -1201,7 +1204,7 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L | LVal.fieldName _ fieldName _ _ => let fullName := Name.str `Function fieldName if (← getEnv).contains fullName then - return LValResolution.const `Function `Function fullName + return LValResolution.const `Function `Function fullName (useFirstExplicit := true) | _ => pure () match eType.getAppFn.constName?, lval with | some structName, LVal.fieldIdx _ idx => @@ -1223,30 +1226,27 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L throwLValError e eType m!"invalid projection, structure has only {numFields} field(s)" | some structName, LVal.fieldName _ fieldName _ _ => let env ← getEnv - let searchEnv : Unit → TermElabM LValResolution := fun _ => do - if let some (baseStructName, fullName) ← findMethod? structName (.mkSimple fieldName) then - return LValResolution.const baseStructName structName fullName - else if let some (structName', fullName) := findMethodAlias? env structName (.mkSimple fieldName) then - return LValResolution.const structName' structName' fullName - else - throwLValError e eType - m!"invalid field '{fieldName}', the environment does not contain '{Name.mkStr structName fieldName}'" - -- search local context first, then environment - let searchCtx : Unit → TermElabM LValResolution := fun _ => do - let fullName := Name.mkStr structName fieldName - for localDecl in (← getLCtx) do - if localDecl.isAuxDecl then - if let some localDeclFullName := (← read).auxDeclToFullName.find? localDecl.fvarId then - if fullName == (privateToUserName? localDeclFullName).getD localDeclFullName then - /- LVal notation is being used to make a "local" recursive call. -/ - return LValResolution.localRec structName fullName localDecl.toExpr - searchEnv () if isStructure env structName then - match findField? env structName (Name.mkSimple fieldName) with - | some baseStructName => return LValResolution.projFn baseStructName structName (Name.mkSimple fieldName) - | none => searchCtx () - else - searchCtx () + if let some baseStructName := findField? env structName (Name.mkSimple fieldName) then + return LValResolution.projFn baseStructName structName (Name.mkSimple fieldName) + -- Search the local context first + let fullName := Name.mkStr structName fieldName + for localDecl in (← getLCtx) do + if localDecl.isAuxDecl then + if let some localDeclFullName := (← read).auxDeclToFullName.find? localDecl.fvarId then + if fullName == (privateToUserName? localDeclFullName).getD localDeclFullName then + /- LVal notation is being used to make a "local" recursive call. -/ + return LValResolution.localRec structName fullName localDecl.toExpr + -- Then search the environment + if let some (baseStructName, fullName) ← findMethod? structName (.mkSimple fieldName) then + return LValResolution.const baseStructName structName fullName + -- Otherwise search for a top-level declaration named `fieldName`. + -- We only do this if `eType` does not unfold so that this is done as a last resort (`resolveLValLoop` loops so long as the type unfolds). + if (← unfoldDefinition? eType).isNone then + if let some fullName ← findTopLevelMethod? (.mkSimple fieldName) then + return LValResolution.const structName structName fullName (useFirstExplicit := true) + throwLValError e eType + m!"invalid field '{fieldName}', the environment does not contain '{Name.mkStr structName fieldName}'" | none, LVal.fieldName _ _ (some suffix) _ => if e.isConst then throwUnknownConstant (e.constName! ++ suffix) @@ -1326,7 +1326,8 @@ Otherwise, if there isn't another parameter with the same name, we add `e` to `n Remark: `fullName` is the name of the resolved "field" access function. It is used for reporting errors -/ -private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr) : +private partial def addLValArg (baseName : Name) (fullName : Name) (e : Expr) (args : Array Arg) (namedArgs : Array NamedArg) (f : Expr) + (explicit : Bool) (useFirstExplicit : Bool := false) : MetaM (Array Arg × Array NamedArg) := do withoutModifyingState <| go f (← inferType f) 0 namedArgs (namedArgs.map (·.name)) true where @@ -1351,11 +1352,11 @@ where /- If there is named argument with name `xDecl.userName`, then it is accounted for and we can't make use of it. -/ remainingNamedArgs := remainingNamedArgs.eraseIdx idx else - if (← typeMatchesBaseName xDecl.type baseName) then - /- We found a type of the form (baseName ...). - First, we check if the current argument is an explicit one, + if ← (pure <| useFirstExplicit && bInfo.isExplicit) <||> typeMatchesBaseName xDecl.type baseName then + /- We found a type of the form (baseName ...), or we found the first explicit argument in useFirstExplicit mode. + First, we check if the current argument is one that can be used positionally, and if the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/ - if h : argIdx ≤ args.size ∧ bInfo.isExplicit then + if h : argIdx ≤ args.size ∧ (explicit || bInfo.isExplicit) then /- We can insert `e` as an explicit argument -/ return (args.insertIdx argIdx (Arg.expr e), namedArgs) else @@ -1363,13 +1364,13 @@ where if there isn't an argument with the same name occurring before it. -/ if !allowNamed || unusableNamedArgs.contains xDecl.userName then throwError "\ - invalid field notation, function '{fullName}' has argument with the expected type\ + invalid field notation, function '{.ofConstName fullName}' has argument with the expected type\ {indentExpr xDecl.type}\n\ but it cannot be used" else return (args, namedArgs.push { name := xDecl.userName, val := Arg.expr e }) /- Advance `argIdx` and update seen named arguments. -/ - if bInfo.isExplicit then + if explicit || bInfo.isExplicit then argIdx := argIdx + 1 unusableNamedArgs := unusableNamedArgs.push xDecl.userName /- If named arguments aren't allowed, then it must still be possible to pass the value as an explicit argument. @@ -1380,7 +1381,7 @@ where if let some f' ← coerceToFunction? (mkAppN f xs) then return ← go f' (← inferType f') argIdx remainingNamedArgs unusableNamedArgs false throwError "\ - invalid field notation, function '{fullName}' does not have argument with type ({baseName} ...) that can be used, \ + invalid field notation, function '{.ofConstName fullName}' does not have argument with type ({.ofConstName baseName} ...) that can be used, \ it must be explicit or implicit with a unique name" /-- Adds the `TermInfo` for the field of a projection. See `Lean.Parser.Term.identProjKind`. -/ @@ -1421,12 +1422,12 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp else let f ← elabAppArgs projFn #[{ name := `self, val := Arg.expr f, suppressDeps := true }] #[] (expectedType? := none) (explicit := false) (ellipsis := false) loop f lvals - | LValResolution.const baseStructName structName constName => + | LValResolution.const baseStructName structName constName useFirstExplicit => let f ← if baseStructName != structName then mkBaseProjections baseStructName structName f else pure f let projFn ← mkConst constName let projFn ← addProjTermInfo lval.getRef projFn if lvals.isEmpty then - let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFn + let (args, namedArgs) ← addLValArg baseStructName constName f args namedArgs projFn explicit useFirstExplicit elabAppArgs projFn namedArgs args expectedType? explicit ellipsis else let f ← elabAppArgs projFn #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false) @@ -1434,7 +1435,7 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp | LValResolution.localRec baseName fullName fvar => let fvar ← addProjTermInfo lval.getRef fvar if lvals.isEmpty then - let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvar + let (args, namedArgs) ← addLValArg baseName fullName f args namedArgs fvar explicit elabAppArgs fvar namedArgs args expectedType? explicit ellipsis else let f ← elabAppArgs fvar #[] #[Arg.expr f] (expectedType? := none) (explicit := false) (ellipsis := false) diff --git a/tests/lean/run/3031.lean b/tests/lean/run/3031.lean new file mode 100644 index 000000000000..5fbbaddd51ed --- /dev/null +++ b/tests/lean/run/3031.lean @@ -0,0 +1,119 @@ +/-! +# Tests for generalized field notation through aliases and "top-level" dot notation +https://github.com/leanprover/lean4/issues/3031 +-/ + +/-! +Alias dot notation. There used to be a different kind of alias dot notation; +in the following example, it would have looked for an argument of type `Common.String`. +Now it looks for one of type `String`, allowing libraries to add "extension methods" from within their own namespaces. +-/ +def Common.String.a (s : String) : Nat := s.length + +export Common (String.a) + +/-- info: String.a "x" : Nat -/ +#guard_msgs in #check String.a "x" +/-- info: String.a "x" : Nat -/ +#guard_msgs in #check "x".a + +/-! +Declarations take precedence over aliases +-/ +def String.a (s : String) : Nat := s.length + 100 +/-- info: "x".a : Nat -/ +#guard_msgs in #check "x".a +/-- info: 100 -/ +#guard_msgs in #eval "".a + +/-! +Private declarations take precedence over aliases +-/ +private def String.b (s : String) : Nat := 0 +def Common.String.b (s : String) : Nat := 1 +export Common (String.b) +/-- info: 0 -/ +#guard_msgs in #eval "".b + +/-! +Multiple aliases is an error +-/ +def Common.String.c (s : String) : Nat := 0 +def Common'.String.c (s : String) : Nat := 0 +export Common (String.c) +export Common' (String.c) +/-- +error: invalid field notation 'c', the name 'String.c' is ambiguous, possible interpretations: 'Common'.String.c', 'Common.String.c' +-/ +#guard_msgs in #eval "".c + +/-! +Aliases work with inheritance +-/ +namespace Ex1 +structure A +structure B extends A +def Common.A.x (_ : A) : Nat := 0 +export Common (A.x) +/-- info: fun b => A.x b.toA : B → Nat -/ +#guard_msgs in #check fun (b : B) => b.x +end Ex1 + +/-! +`open` also works +-/ +def Common.String.parse (_ : String) : List Nat := [] + +namespace ExOpen1 +/-- +error: invalid field 'parse', the environment does not contain 'String.parse' + "" +has type + String +-/ +#guard_msgs in #check "".parse +section +open Common +/-- info: String.parse "" : List Nat -/ +#guard_msgs in #check "".parse +end +section +open Common (String.parse) +/-- info: String.parse "" : List Nat -/ +#guard_msgs in #check "".parse +end +end ExOpen1 + + +namespace Ex2 +class A (n : Nat) where + x : Nat + +/-! +"Top-level" dot notation. As a last resort, field notation looks for a top-level declaration or alias +and supplies the value as the first explicit argument. +-/ +instance : ToString (A n) where + toString a := s!"A.x is {a.x}" + +/-- info: fun a => toString a : A 2 → String -/ +#guard_msgs in #check fun (a : A 2) => a.toString + +/-! +Incidental fix: `@` for generalized field notation was failing if there were implicit arguments. +True projections were ok. +-/ +def A.x' {n : Nat} (a : A n) := a.x + +/-- info: fun a => a.x' : A 2 → Nat -/ +#guard_msgs in #check fun (a : A 2) => @a.x' +end Ex2 + +namespace Ex3 +variable (f : α → β) (g : β → γ) +/-! +Functions use the "top-level" dot notation rule: they use the first explicit argument, rather than the first function argument. +-/ +/-- info: g ∘ f : α → γ -/ +#guard_msgs in #check g.comp f +end Ex3 diff --git a/tests/lean/run/DVec.lean b/tests/lean/run/DVec.lean index e9f98f8b655c..0006fe7aa637 100644 --- a/tests/lean/run/DVec.lean +++ b/tests/lean/run/DVec.lean @@ -38,6 +38,17 @@ example (v : Vec Nat 1) : Nat := #check @Vec.hd --- works +-- Does not work: Aliases find that `v` could be the `TypeVec` argument since `TypeVec` is an abbrev for `Vec`. +/-- +error: application type mismatch + @Vec.hd ?_ v +argument + v +has type + Vec Nat 1 : Type +but is expected to have type + TypeVec (?_ + 1) : Type (_ + 1) +-/ +#guard_msgs in set_option pp.mvars false in example (v : Vec Nat 1) : Nat := v.hd diff --git a/tests/lean/run/alias.lean b/tests/lean/run/alias.lean index 1174ec65aa63..126c8ab1e49e 100644 --- a/tests/lean/run/alias.lean +++ b/tests/lean/run/alias.lean @@ -3,14 +3,31 @@ def Set (α : Type) := α → Prop def Set.union (s₁ s₂ : Set α) : Set α := fun a => s₁ a ∨ s₂ a -def FinSet (n : Nat) := Fin n → Prop +def FinSet (n : Nat) := Set (Fin n) + +/-! +The type of `x` is unfolded to find `Set.union` +-/ +example (x y : FinSet 10) : FinSet 10 := + x.union y namespace FinSet - export Set (union) +export Set (union) end FinSet +/-! +Since the types are defeq, this alias works: +-/ example (x y : FinSet 10) : FinSet 10 := FinSet.union x y +/-! +However, this dot notation fails since there is no `FinSet` argument. +However, unfolding is the preferred method. +-/ +/-- +error: invalid field notation, function 'FinSet.union' does not have argument with type (FinSet ...) that can be used, it must be explicit or implicit with a unique name +-/ +#guard_msgs in example (x y : FinSet 10) : FinSet 10 := x.union y