From dc8977296408049bde3b95cc582d81e9729da5b6 Mon Sep 17 00:00:00 2001 From: asr2003 <162500856+asr2003@users.noreply.github.com> Date: Mon, 9 Dec 2024 00:38:11 +0530 Subject: [PATCH] Fix: Implement better lambda normalization --- .../reflect/macrortti/LightTypeTagRef.scala | 67 +++++++++++++++++-- .../izumi/reflect/test/LightTypeTagTest.scala | 21 ++++++ 2 files changed, 84 insertions(+), 4 deletions(-) diff --git a/izumi-reflect/izumi-reflect/src/main/scala/izumi/reflect/macrortti/LightTypeTagRef.scala b/izumi-reflect/izumi-reflect/src/main/scala/izumi/reflect/macrortti/LightTypeTagRef.scala index 75272a76..ddc83cae 100644 --- a/izumi-reflect/izumi-reflect/src/main/scala/izumi/reflect/macrortti/LightTypeTagRef.scala +++ b/izumi-reflect/izumi-reflect/src/main/scala/izumi/reflect/macrortti/LightTypeTagRef.scala @@ -93,8 +93,8 @@ object LightTypeTagRef extends LTTOrdering { unusedParamsSize < paramRefs.size } - lazy val normalizedParams: List[NameReference] = makeFakeParams.map(_._2) - lazy val normalizedOutput: AbstractReference = RuntimeAPI.applyLambda(this, makeFakeParams) + lazy val normalizedParams: List[NameReference] = makeNormalizedParams.map(_._2) + lazy val normalizedOutput: AbstractReference = RuntimeAPI.applyLambda(this, makeNormalizedParams) override def equals(obj: Any): Boolean = { obj match { @@ -107,10 +107,11 @@ object LightTypeTagRef extends LTTOrdering { } } - private[this] def makeFakeParams: List[(LambdaParamName, NameReference)] = { + private[this] def makeNormalizedParams: List[(LambdaParamName, NameReference)] = { input.zipWithIndex.map { case (p, idx) => - p -> NameReference(SymName.LambdaParamName(idx, -2, input.size)) // s"!FAKE_$idx" + val relativeDepth = -1 * (input.size - idx) + p -> NameReference(SymName.LambdaParamName(idx, relativeDepth, input.size)) // s"!FAKE_$idx" } } } @@ -191,6 +192,64 @@ object LightTypeTagRef extends LTTOrdering { def maybeIntersection(r: Set[_ <: LightTypeTagRef]): AppliedReference = maybeIntersection(r.iterator) def maybeUnion(r: Set[_ <: LightTypeTagRef]): AppliedReference = maybeUnion(r.iterator) + private[reflect] def normalizeLambda(lambda: Lambda, parentDepth: Int = 0): Lambda = { + val updatedInput = lambda.input.zipWithIndex.map { + case (param, idx) => + LambdaParamName(idx, parentDepth - (lambda.input.size - idx), lambda.input.size) + } + val updatedOutput = lambda.output match { + case nestedLambda: Lambda => normalizeLambda(nestedLambda, parentDepth - lambda.input.size) + case other => normalizeReferences(other, updatedInput) + } + + Lambda(updatedInput, updatedOutput) + } + + private[this] def normalizeReferences(output: AbstractReference, input: List[LambdaParamName]): AbstractReference = { + output match { + case NameReference(ref, boundaries, prefix) => + ref match { + case param: LambdaParamName => + val normalizedName = input.find(_.index == param.index).getOrElse(param) + NameReference(normalizedName, boundaries, prefix) + case _ => + output + } + + case FullReference(symName, parameters, prefix) => + // Recursively normalize type parameters and prefix + FullReference( + symName, + parameters.map(tp => tp.copy(ref = normalizeReferences(tp.ref, input))), + prefix.map(pref => normalizeReferences(pref, input).asInstanceOf[AppliedReference]) + ) + + case IntersectionReference(refs) => + IntersectionReference(refs.map(ref => normalizeReferences(ref, input).asInstanceOf[AppliedReferenceExceptIntersection])) + + case UnionReference(refs) => + UnionReference(refs.map(ref => normalizeReferences(ref, input).asInstanceOf[AppliedReferenceExceptUnion])) + + case Refinement(reference, decls) => + Refinement( + normalizeReferences(reference, input).asInstanceOf[AppliedReference], + decls.map { + case RefinementDecl.Signature(name, inputs, output) => + RefinementDecl.Signature( + name, + inputs.map(inp => normalizeReferences(inp, input).asInstanceOf[AppliedReference]), + normalizeReferences(output, input).asInstanceOf[AppliedReference] + ) + case RefinementDecl.TypeMember(name, ref) => + RefinementDecl.TypeMember(name, normalizeReferences(ref, input)) + } + ) + + case other => + other + } + } + sealed trait AppliedNamedReference extends AppliedReferenceExceptIntersection with AppliedReferenceExceptUnion { def asName: NameReference def symName: SymName diff --git a/izumi-reflect/izumi-reflect/src/test/scala-3/izumi/reflect/test/LightTypeTagTest.scala b/izumi-reflect/izumi-reflect/src/test/scala-3/izumi/reflect/test/LightTypeTagTest.scala index 166aefd8..a934e919 100644 --- a/izumi-reflect/izumi-reflect/src/test/scala-3/izumi/reflect/test/LightTypeTagTest.scala +++ b/izumi-reflect/izumi-reflect/src/test/scala-3/izumi/reflect/test/LightTypeTagTest.scala @@ -75,6 +75,27 @@ class LightTypeTagTest extends SharedLightTypeTagTest { assertDebugSame(Tag[Option[String] | Nothing].tag, LTT[Option[String]]) } + "normalize nested lambdas with relative depth indices" in { + val innerLambda = LightTypeTagRef.Lambda( + input = List(LightTypeTagRef.SymName.LambdaParamName(0, 0, 1)), + output = LightTypeTagRef.NameReference(LightTypeTagRef.SymName.LambdaParamName(0, 0, 1)) + ) + + val outerLambda = LightTypeTagRef.Lambda( + input = List(LightTypeTagRef.SymName.LambdaParamName(0, 0, 1)), + output = innerLambda + ) + + val normalizedOuter = LightTypeTagRef.normalizeLambda(outerLambda) + + assert(normalizedOuter.input.head.depth == -1) + assert( + normalizedOuter + .output + .asInstanceOf[LightTypeTagRef.Lambda].input.head.depth == -2 + ) + } + "support top-level abstract types (Scala 3 specific, top level type aliases)" in { assertChildStrict(LTT[LightTypeTagTestT], LTT[String]) }