Skip to content

Commit

Permalink
chore: refactor Elab.StructInst to use mutual for its structures/…
Browse files Browse the repository at this point in the history
…`inductive`s (leanprover#6174)

Making use of leanprover#6125.
  • Loading branch information
kmill authored Nov 22, 2024
1 parent d3cb812 commit 5145030
Showing 1 changed file with 64 additions and 65 deletions.
129 changes: 64 additions & 65 deletions src/Lean/Elab/StructInst.lean
Original file line number Diff line number Diff line change
Expand Up @@ -302,67 +302,66 @@ instance : ToFormat FieldLHS where
| .fieldIndex _ i => format i
| .modifyOp _ i => "[" ++ i.prettyPrint ++ "]"

/--
`FieldVal StructInstView` is a representation of a field value in the structure instance.
-/
inductive FieldVal (σ : Type) where
/-- A `term` to use for the value of the field. -/
| term (stx : Syntax) : FieldVal σ
/-- A `StructInstView` to use for the value of a subobject field. -/
| nested (s : σ) : FieldVal σ
/-- A field that was not provided and should be synthesized using default values. -/
| default : FieldVal σ
deriving Inhabited
mutual
/--
`FieldVal StructInstView` is a representation of a field value in the structure instance.
-/
inductive FieldVal where
/-- A `term` to use for the value of the field. -/
| term (stx : Syntax) : FieldVal
/-- A `StructInstView` to use for the value of a subobject field. -/
| nested (s : StructInstView) : FieldVal
/-- A field that was not provided and should be synthesized using default values. -/
| default : FieldVal
deriving Inhabited

/--
`Field StructInstView` is a representation of a field in the structure instance.
-/
structure Field (σ : Type) where
/-- The whole field syntax. -/
ref : Syntax
/-- The LHS decomposed into components. -/
lhs : List FieldLHS
/-- The value of the field. -/
val : FieldVal σ
/-- The elaborated field value, filled in at `elabStruct`.
Missing fields use a metavariable for the elaborated value and are later solved for in `DefaultFields.propagate`. -/
expr? : Option Expr := none
deriving Inhabited
/--
`Field StructInstView` is a representation of a field in the structure instance.
-/
structure Field where
/-- The whole field syntax. -/
ref : Syntax
/-- The LHS decomposed into components. -/
lhs : List FieldLHS
/-- The value of the field. -/
val : FieldVal
/-- The elaborated field value, filled in at `elabStruct`.
Missing fields use a metavariable for the elaborated value and are later solved for in `DefaultFields.propagate`. -/
expr? : Option Expr := none
deriving Inhabited

/--
The view for structure instance notation.
-/
structure StructInstView where
/-- The syntax for the whole structure instance. -/
ref : Syntax
/-- The name of the structure for the type of the structure instance. -/
structName : Name
/-- Used for default values, to propagate structure type parameters. It is initially empty, and then set at `elabStruct`. -/
params : Array (Name × Expr)
/-- The fields of the structure instance. -/
fields : List Field
/-- The additional sources for fields for the structure instance. -/
sources : SourcesView
deriving Inhabited
end

/--
Returns if the field has a single component in its LHS.
-/
def Field.isSimple {σ} : Field σ → Bool
def Field.isSimple : Field → Bool
| { lhs := [_], .. } => true
| _ => false

/--
The view for structure instance notation.
-/
structure StructInstView where
/-- The syntax for the whole structure instance. -/
ref : Syntax
/-- The name of the structure for the type of the structure instance. -/
structName : Name
/-- Used for default values, to propagate structure type parameters. It is initially empty, and then set at `elabStruct`. -/
params : Array (Name × Expr)
/-- The fields of the structure instance. -/
fields : List (Field StructInstView)
/-- The additional sources for fields for the structure instance. -/
sources : SourcesView
deriving Inhabited

/-- Abbreviation for the type of `StructInstView.fields`, namely `List (Field StructInstView)`. -/
abbrev Fields := List (Field StructInstView)

/-- `true` iff all fields of the given structure are marked as `default` -/
partial def StructInstView.allDefault (s : StructInstView) : Bool :=
s.fields.all fun { val := val, .. } => match val with
| .term _ => false
| .default => true
| .nested s => allDefault s

def formatField (formatStruct : StructInstView → Format) (field : Field StructInstView) : Format :=
def formatField (formatStruct : StructInstView → Format) (field : Field) : Format :=
Format.joinSep field.lhs " . " ++ " := " ++
match field.val with
| .term v => v.prettyPrint
Expand All @@ -378,11 +377,11 @@ partial def formatStruct : StructInstView → Format
else
"{" ++ format (source.explicit.map (·.stx)) ++ " with " ++ fieldsFmt ++ implicitFmt ++ "}"

instance : ToFormat StructInstView := ⟨formatStruct⟩
instance : ToFormat StructInstView := ⟨formatStruct⟩
instance : ToString StructInstView := ⟨toString ∘ format⟩

instance : ToFormat (Field StructInstView) := ⟨formatField formatStruct⟩
instance : ToString (Field StructInstView) := ⟨toString ∘ format⟩
instance : ToFormat Field := ⟨formatField formatStruct⟩
instance : ToString Field := ⟨toString ∘ format⟩

/--
Converts a `FieldLHS` back into syntax. This assumes the `ref` fields have the correct structure.
Expand All @@ -403,14 +402,14 @@ private def FieldLHS.toSyntax (first : Bool) : FieldLHS → Syntax
/--
Converts a `FieldVal StructInstView` back into syntax. Only supports `.term`, and it assumes the `stx` field has the correct structure.
-/
private def FieldVal.toSyntax : FieldVal Struct → Syntax
private def FieldVal.toSyntax : FieldVal → Syntax
| .term stx => stx
| _ => unreachable!

/--
Converts a `Field StructInstView` back into syntax. Used to construct synthetic structure instance notation for subobjects in `StructInst.expandStruct` processing.
-/
private def Field.toSyntax : Field Struct → Syntax
private def Field.toSyntax : Field → Syntax
| field =>
let stx := field.ref
let stx := stx.setArg 2 field.val.toSyntax
Expand Down Expand Up @@ -452,14 +451,14 @@ private def mkStructView (stx : Syntax) (structName : Name) (sources : SourcesVi
let val := fieldStx[2]
let first ← toFieldLHS fieldStx[0][0]
let rest ← fieldStx[0][1].getArgs.toList.mapM toFieldLHS
return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field StructInstView }
return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field }
return { ref := stx, structName, params := #[], fields, sources }

def StructInstView.modifyFieldsM {m : TypeType} [Monad m] (s : StructInstView) (f : Fields → m Fields) : m StructInstView :=
def StructInstView.modifyFieldsM {m : TypeType} [Monad m] (s : StructInstView) (f : List Field → m (List Field)) : m StructInstView :=
match s with
| { ref, structName, params, fields, sources } => return { ref, structName, params, fields := (← f fields), sources }

def StructInstView.modifyFields (s : StructInstView) (f : Fields → Fields) : StructInstView :=
def StructInstView.modifyFields (s : StructInstView) (f : List Field → List Field) : StructInstView :=
Id.run <| s.modifyFieldsM f

/-- Expands name field LHSs with multi-component names into multi-component LHSs. -/
Expand Down Expand Up @@ -525,14 +524,14 @@ private def expandParentFields (s : StructInstView) : TermElabM StructInstView :
| _ => throwErrorAt ref "failed to access field '{fieldName}' in parent structure"
| _ => return field

private abbrev FieldMap := Std.HashMap Name Fields
private abbrev FieldMap := Std.HashMap Name (List Field)

/--
Creates a hash map collecting all fields with the same first name component.
Throws an error if there are multiple simple fields with the same name.
Used by `StructInst.expandStruct` processing.
-/
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
private def mkFieldMap (fields : List Field) : TermElabM FieldMap :=
fields.foldlM (init := {}) fun fieldMap field =>
match field.lhs with
| .fieldName _ fieldName :: _ =>
Expand All @@ -548,7 +547,7 @@ private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
/--
Given a value of the hash map created by `mkFieldMap`, returns true if the value corresponds to a simple field.
-/
private def isSimpleField? : Fields → Option (Field StructInstView)
private def isSimpleField? : List Field → Option Field
| [field] => if field.isSimple then some field else none
| _ => none

Expand All @@ -566,7 +565,7 @@ def mkProjStx? (s : Syntax) (structName : Name) (fieldName : Name) : TermElabM (
/--
Finds a simple field of the given name.
-/
def findField? (fields : Fields) (fieldName : Name) : Option (Field StructInstView) :=
def findField? (fields : List Field) (fieldName : Name) : Option Field :=
fields.find? fun field =>
match field.lhs with
| [.fieldName _ n] => n == fieldName
Expand Down Expand Up @@ -620,7 +619,7 @@ mutual
match findField? s.fields fieldName with
| some field => return field::fields
| none =>
let addField (val : FieldVal StructInstView) : TermElabM Fields := do
let addField (val : FieldVal) : TermElabM (List Field) := do
return { ref, lhs := [FieldLHS.fieldName ref fieldName], val := val } :: fields
match Lean.isSubobjectField? env s.structName fieldName with
| some substructName =>
Expand Down Expand Up @@ -773,7 +772,7 @@ private partial def elabStructInstView (s : StructInstView) (expectedType? : Opt
trace[Elab.struct] "elabStruct {field}, {type}"
match type with
| .forallE _ d b bi =>
let cont (val : Expr) (field : Field StructInstView) (instMVars := instMVars) : TermElabM (Expr × Expr × Fields × Array MVarId) := do
let cont (val : Expr) (field : Field) (instMVars := instMVars) : TermElabM (Expr × Expr × List Field × Array MVarId) := do
pushInfoTree <| InfoTree.node (children := {}) <| Info.ofFieldInfo {
projName := s.structName.append fieldName, fieldName, lctx := (← getLCtx), val, stx := ref }
let e := mkApp e val
Expand Down Expand Up @@ -879,33 +878,33 @@ partial def getHierarchyDepth (struct : StructInstView) : Nat :=
| _ => max

/-- Returns whether the field is still missing. -/
def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field Struct) : m Bool := do
def isDefaultMissing? [Monad m] [MonadMCtx m] (field : Field) : m Bool := do
if let some expr := field.expr? then
if let some (.mvar mvarId) := defaultMissing? expr then
unless (← mvarId.isAssigned) do
return true
return false

/-- Returns a field that is still missing. -/
partial def findDefaultMissing? [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Option (Field StructInstView)) :=
partial def findDefaultMissing? [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Option Field) :=
struct.fields.findSomeM? fun field => do
match field.val with
| .nested struct => findDefaultMissing? struct
| _ => return if (← isDefaultMissing? field) then field else none

/-- Returns all fields that are still missing. -/
partial def allDefaultMissing [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Array (Field StructInstView)) :=
partial def allDefaultMissing [Monad m] [MonadMCtx m] (struct : StructInstView) : m (Array Field) :=
go struct *> get |>.run' #[]
where
go (struct : StructInstView) : StateT (Array (Field StructInstView)) m Unit :=
go (struct : StructInstView) : StateT (Array Field) m Unit :=
for field in struct.fields do
if let .nested struct := field.val then
go struct
else if (← isDefaultMissing? field) then
modify (·.push field)

/-- Returns the name of the field. Assumes all fields under consideration are simple and named. -/
def getFieldName (field : Field StructInstView) : Name :=
def getFieldName (field : Field) : Name :=
match field.lhs with
| [.fieldName _ fieldName] => fieldName
| _ => unreachable!
Expand Down

0 comments on commit 5145030

Please sign in to comment.