Skip to content

Commit

Permalink
Fix: Implement better lambda normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
asr2003 authored Dec 8, 2024
1 parent ce44235 commit dc89772
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ object LightTypeTagRef extends LTTOrdering {
unusedParamsSize < paramRefs.size
}

lazy val normalizedParams: List[NameReference] = makeFakeParams.map(_._2)
lazy val normalizedOutput: AbstractReference = RuntimeAPI.applyLambda(this, makeFakeParams)
lazy val normalizedParams: List[NameReference] = makeNormalizedParams.map(_._2)
lazy val normalizedOutput: AbstractReference = RuntimeAPI.applyLambda(this, makeNormalizedParams)

override def equals(obj: Any): Boolean = {
obj match {
Expand All @@ -107,10 +107,11 @@ object LightTypeTagRef extends LTTOrdering {
}
}

private[this] def makeFakeParams: List[(LambdaParamName, NameReference)] = {
private[this] def makeNormalizedParams: List[(LambdaParamName, NameReference)] = {
input.zipWithIndex.map {
case (p, idx) =>
p -> NameReference(SymName.LambdaParamName(idx, -2, input.size)) // s"!FAKE_$idx"
val relativeDepth = -1 * (input.size - idx)
p -> NameReference(SymName.LambdaParamName(idx, relativeDepth, input.size)) // s"!FAKE_$idx"
}
}
}
Expand Down Expand Up @@ -191,6 +192,64 @@ object LightTypeTagRef extends LTTOrdering {
def maybeIntersection(r: Set[_ <: LightTypeTagRef]): AppliedReference = maybeIntersection(r.iterator)
def maybeUnion(r: Set[_ <: LightTypeTagRef]): AppliedReference = maybeUnion(r.iterator)

private[reflect] def normalizeLambda(lambda: Lambda, parentDepth: Int = 0): Lambda = {
val updatedInput = lambda.input.zipWithIndex.map {
case (param, idx) =>
LambdaParamName(idx, parentDepth - (lambda.input.size - idx), lambda.input.size)
}
val updatedOutput = lambda.output match {
case nestedLambda: Lambda => normalizeLambda(nestedLambda, parentDepth - lambda.input.size)
case other => normalizeReferences(other, updatedInput)
}

Lambda(updatedInput, updatedOutput)
}

private[this] def normalizeReferences(output: AbstractReference, input: List[LambdaParamName]): AbstractReference = {
output match {
case NameReference(ref, boundaries, prefix) =>
ref match {
case param: LambdaParamName =>
val normalizedName = input.find(_.index == param.index).getOrElse(param)
NameReference(normalizedName, boundaries, prefix)
case _ =>
output
}

case FullReference(symName, parameters, prefix) =>
// Recursively normalize type parameters and prefix
FullReference(
symName,
parameters.map(tp => tp.copy(ref = normalizeReferences(tp.ref, input))),
prefix.map(pref => normalizeReferences(pref, input).asInstanceOf[AppliedReference])
)

case IntersectionReference(refs) =>
IntersectionReference(refs.map(ref => normalizeReferences(ref, input).asInstanceOf[AppliedReferenceExceptIntersection]))

case UnionReference(refs) =>
UnionReference(refs.map(ref => normalizeReferences(ref, input).asInstanceOf[AppliedReferenceExceptUnion]))

case Refinement(reference, decls) =>
Refinement(
normalizeReferences(reference, input).asInstanceOf[AppliedReference],
decls.map {
case RefinementDecl.Signature(name, inputs, output) =>
RefinementDecl.Signature(
name,
inputs.map(inp => normalizeReferences(inp, input).asInstanceOf[AppliedReference]),
normalizeReferences(output, input).asInstanceOf[AppliedReference]
)
case RefinementDecl.TypeMember(name, ref) =>
RefinementDecl.TypeMember(name, normalizeReferences(ref, input))
}
)

case other =>
other
}
}

sealed trait AppliedNamedReference extends AppliedReferenceExceptIntersection with AppliedReferenceExceptUnion {
def asName: NameReference
def symName: SymName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,27 @@ class LightTypeTagTest extends SharedLightTypeTagTest {
assertDebugSame(Tag[Option[String] | Nothing].tag, LTT[Option[String]])
}

"normalize nested lambdas with relative depth indices" in {
val innerLambda = LightTypeTagRef.Lambda(
input = List(LightTypeTagRef.SymName.LambdaParamName(0, 0, 1)),
output = LightTypeTagRef.NameReference(LightTypeTagRef.SymName.LambdaParamName(0, 0, 1))
)

val outerLambda = LightTypeTagRef.Lambda(
input = List(LightTypeTagRef.SymName.LambdaParamName(0, 0, 1)),
output = innerLambda
)

val normalizedOuter = LightTypeTagRef.normalizeLambda(outerLambda)

assert(normalizedOuter.input.head.depth == -1)
assert(
normalizedOuter
.output
.asInstanceOf[LightTypeTagRef.Lambda].input.head.depth == -2
)
}

"support top-level abstract types (Scala 3 specific, top level type aliases)" in {
assertChildStrict(LTT[LightTypeTagTestT], LTT[String])
}
Expand Down

0 comments on commit dc89772

Please sign in to comment.