Skip to content

Commit

Permalink
Don't cache elements across rounds
Browse files Browse the repository at this point in the history
  • Loading branch information
eygraber committed Nov 20, 2024
1 parent 41810f0 commit 2e7720b
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
package me.tatarka.inject.compiler

import com.squareup.kotlinpoet.AnnotationSpec
import com.squareup.kotlinpoet.ClassName
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.FileSpec
import com.squareup.kotlinpoet.FunSpec
import com.squareup.kotlinpoet.KModifier
import me.tatarka.kotlin.ast.AstClass
import me.tatarka.kotlin.ast.AstFunction
import com.squareup.kotlinpoet.TypeName
import me.tatarka.kotlin.ast.AstProvider

data class ComponentClassInfo(
val packageName: String,
val name: String,
val companionClassName: ClassName?,
val typeName: TypeName,
val className: ClassName,
)

data class KmpComponentCreateFunctionInfo(
val name: String,
val annotations: List<AnnotationSpec>,
val visibility: KModifier,
val receiverParameterType: TypeName?,
val parameters: List<Pair<String, TypeName>>,
val parametersTemplate: String,
val parameterNames: List<String>,
)

class KmpComponentCreateGenerator(
private val provider: AstProvider,
private val options: Options,
) {
fun generate(
componentClass: AstClass,
kmpComponentCreateFunctions: List<AstFunction>,
componentClass: ComponentClassInfo,
kmpComponentCreateFunctions: List<KmpComponentCreateFunctionInfo>,
) = with(provider) {
FileSpec.builder(
packageName = componentClass.packageName,
Expand All @@ -24,36 +43,37 @@ class KmpComponentCreateGenerator(
addFunction(
FunSpec
.builder(kmpComponentCreateFunction.name)
.addOriginatingElement(kmpComponentCreateFunction)
// .addOriginatingElement(kmpComponentCreateFunction)
.apply {
kmpComponentCreateFunction.annotations.forEach { annotation ->
addAnnotation(annotation.toAnnotationSpec())
addAnnotation(annotation)
}

addModifiers(
kmpComponentCreateFunction.visibility.toKModifier(),
kmpComponentCreateFunction.visibility,
KModifier.ACTUAL,
)

kmpComponentCreateFunction.receiverParameterType?.toTypeName()?.let(::receiver)
kmpComponentCreateFunction.receiverParameterType?.let(::receiver)

for (param in kmpComponentCreateFunction.parameters) {
addParameter(param.name, param.type.toTypeName())
val (name, typeName) = param
addParameter(name, typeName)
}

val funcParams = kmpComponentCreateFunction.parameters.joinToString { "%N" }
val funcParamsNames = kmpComponentCreateFunction.parameters.map { it.name }.toTypedArray()
val funcParamsNames = kmpComponentCreateFunction.parameterNames.toTypedArray()

val returnTypeCompanion = when {
options.generateCompanionExtensions -> componentClass.companion?.type
options.generateCompanionExtensions -> componentClass.companionClassName
else -> null
}

val returnTypeName = componentClass.type.toTypeName()
val returnTypeName = componentClass.typeName

val (createReceiver, createReceiverClassName) = when (returnTypeCompanion) {
null -> "%T::class" to componentClass.toClassName()
else -> "%T" to returnTypeCompanion.toAstClass().toClassName()
null -> "%T::class" to componentClass.className
else -> "%T" to returnTypeCompanion
}
addCode(
CodeBlock.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,4 +654,4 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
// default values are present.
val args: List<Pair<AstType, String>>,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSFunctionDeclaration
import com.google.devtools.ksp.symbol.KSName
import me.tatarka.inject.compiler.COMPONENT
import me.tatarka.inject.compiler.ComponentClassInfo
import me.tatarka.inject.compiler.InjectGenerator
import me.tatarka.inject.compiler.KMP_COMPONENT_CREATE
import me.tatarka.inject.compiler.KmpComponentCreateFunctionInfo
import me.tatarka.inject.compiler.KmpComponentCreateGenerator
import me.tatarka.inject.compiler.Options
import me.tatarka.kotlin.ast.AstClass
import me.tatarka.kotlin.ast.AstFunction
import me.tatarka.kotlin.ast.KSAstProvider

class InjectProcessor(
Expand All @@ -32,7 +32,8 @@ class InjectProcessor(
private var deferredClassNames: List<KSName> = mutableListOf()
private var deferredFunctionNames: List<KSName> = mutableListOf()

private val kmpComponentCreateFunctionsByComponentType = mutableMapOf<AstClass, MutableList<AstFunction>>()
private val kmpComponentCreateFunctionsByComponentType =
mutableMapOf<ComponentClassInfo, MutableList<KmpComponentCreateFunctionInfo>>()

override fun process(resolver: Resolver): List<KSAnnotated> {
lastResolver = resolver
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@ import com.google.devtools.ksp.processing.CodeGenerator
import com.google.devtools.ksp.symbol.KSFunctionDeclaration
import com.squareup.kotlinpoet.ksp.writeTo
import me.tatarka.inject.compiler.COMPONENT
import me.tatarka.inject.compiler.ComponentClassInfo
import me.tatarka.inject.compiler.KmpComponentCreateFunctionInfo
import me.tatarka.inject.compiler.KmpComponentCreateGenerator
import me.tatarka.kotlin.ast.AstClass
import me.tatarka.kotlin.ast.AstFunction
import me.tatarka.kotlin.ast.KSAstProvider

private typealias KmpComponentCreateFunctionsByComponentType =
MutableMap<ComponentClassInfo, MutableList<KmpComponentCreateFunctionInfo>>

internal fun processKmpComponentCreate(
element: KSFunctionDeclaration,
provider: KSAstProvider,
kmpComponentCreateFunctionsByComponentType: MutableMap<AstClass, MutableList<AstFunction>>
kmpComponentCreateFunctionsByComponentType: KmpComponentCreateFunctionsByComponentType
): Boolean = with(provider) {
val astFunction = element.toAstFunction()
val returnType = astFunction.returnType
Expand All @@ -26,15 +31,33 @@ internal fun processKmpComponentCreate(
val returnTypeClass = returnType.resolvedType().toAstClass()
if (!astFunction.validateReturnType(returnTypeClass, provider)) return true

kmpComponentCreateFunctionsByComponentType.getOrPut(returnTypeClass, ::ArrayList).add(astFunction)
val returnTypeClassInfo = ComponentClassInfo(
packageName = returnTypeClass.packageName,
name = returnTypeClass.name,
companionClassName = returnTypeClass.companion?.type?.toAstClass()?.toClassName(),
typeName = returnTypeClass.type.toTypeName(),
className = returnTypeClass.toClassName(),
)

val functionInfo = KmpComponentCreateFunctionInfo(
name = astFunction.name,
annotations = astFunction.annotations.map { it.toAnnotationSpec() }.toList(),
visibility = astFunction.visibility.toKModifier(),
receiverParameterType = astFunction.receiverParameterType?.toTypeName(),
parameters = astFunction.parameters.map { it.name to it.type.toTypeName() },
parametersTemplate = astFunction.parameters.joinToString { "%N" },
parameterNames = astFunction.parameters.map { it.name },
)

kmpComponentCreateFunctionsByComponentType.getOrPut(returnTypeClassInfo, ::ArrayList).add(functionInfo)

true
}

internal fun generateKmpComponentCreateFiles(
codeGenerator: CodeGenerator,
generator: KmpComponentCreateGenerator,
kmpComponentCreateFunctionsByComponentType: Map<AstClass, List<AstFunction>>
kmpComponentCreateFunctionsByComponentType: Map<ComponentClassInfo, List<KmpComponentCreateFunctionInfo>>
) {
kmpComponentCreateFunctionsByComponentType.forEach { (componentType, kmpComponentCreateFunctions) ->
val file = generator.generate(componentType, kmpComponentCreateFunctions)
Expand Down

0 comments on commit 2e7720b

Please sign in to comment.