diff --git a/pom.xml b/pom.xml index 767db565536..84187c76c60 100644 --- a/pom.xml +++ b/pom.xml @@ -177,6 +177,7 @@ 1.0.0-beta.13 1.1.0 4.31.1 + 1.9.25 2.29.29 diff --git a/spring-ai-core/pom.xml b/spring-ai-core/pom.xml index c1bcb0f3bd4..a082137a9df 100644 --- a/spring-ai-core/pom.xml +++ b/spring-ai-core/pom.xml @@ -144,6 +144,13 @@ true + + org.jetbrains.kotlin + kotlin-reflect + ${kotlin.version} + true + + org.springframework.boot diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/KotlinModule.java b/spring-ai-core/src/main/java/org/springframework/ai/model/KotlinModule.java new file mode 100644 index 00000000000..77938f32337 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/KotlinModule.java @@ -0,0 +1,103 @@ +package org.springframework.ai.model; + +import com.github.victools.jsonschema.generator.*; +import com.github.victools.jsonschema.generator.Module; +import kotlin.jvm.JvmClassMappingKt; +import kotlin.reflect.*; +import kotlin.reflect.full.KClasses; +import kotlin.reflect.jvm.ReflectJvmMapping; +import org.springframework.core.KotlinDetector; + +import java.lang.reflect.Field; +import java.util.HashSet; +import java.util.Set; + +public class KotlinModule implements Module { + + @Override + public void applyToConfigBuilder(SchemaGeneratorConfigBuilder builder) { + SchemaGeneratorConfigPart fieldConfigPart = builder.forFields(); + // SchemaGeneratorConfigPart methodConfigPart = builder.forMethods(); + + this.applyToConfigBuilderPart(fieldConfigPart); + // this.applyToConfigBuilderPart(methodConfigPart); + } + + private void applyToConfigBuilderPart(SchemaGeneratorConfigPart configPart) { + configPart.withNullableCheck(this::isNullable); + configPart.withPropertyNameOverrideResolver(this::getPropertyName); + configPart.withRequiredCheck(this::isRequired); + configPart.withIgnoreCheck(this::shouldIgnore); + } + + private Boolean isNullable(MemberScope member) { + KProperty kotlinProperty = getKotlinProperty(member); + if (kotlinProperty != null) { + return kotlinProperty.getReturnType().isMarkedNullable(); + } + return null; + } + + private String getPropertyName(MemberScope member) { + KProperty kotlinProperty = getKotlinProperty(member); + if (kotlinProperty != null) { + return kotlinProperty.getName(); + } + return null; + } + + private boolean isRequired(MemberScope member) { + KProperty kotlinProperty = getKotlinProperty(member); + if (kotlinProperty != null) { + KType returnType = kotlinProperty.getReturnType(); + boolean isNonNullable = !returnType.isMarkedNullable(); + + Class declaringClass = member.getDeclaringType().getErasedType(); + KClass kotlinClass = JvmClassMappingKt.getKotlinClass(declaringClass); + + Set constructorParamsWithoutDefault = getConstructorParametersWithoutDefault(kotlinClass); + + boolean isInConstructor = constructorParamsWithoutDefault.contains(kotlinProperty.getName()); + + return isNonNullable && isInConstructor; + } + + return false; + } + + private boolean shouldIgnore(MemberScope member) { + return member.getRawMember().isSynthetic(); // Ignore generated properties/methods + } + + private KProperty getKotlinProperty(MemberScope member) { + Class declaringClass = member.getDeclaringType().getErasedType(); + if (KotlinDetector.isKotlinType(declaringClass)) { + KClass kotlinClass = JvmClassMappingKt.getKotlinClass(declaringClass); + for (KProperty prop : KClasses.getMemberProperties(kotlinClass)) { + Field javaField = ReflectJvmMapping.getJavaField(prop); + if (javaField != null && javaField.equals(member.getRawMember())) { + return prop; + } + } + } + return null; + } + + private Set getConstructorParametersWithoutDefault(KClass kotlinClass) { + Set paramsWithoutDefault = new HashSet<>(); + KFunction primaryConstructor = KClasses.getPrimaryConstructor(kotlinClass); + if (primaryConstructor != null) { + primaryConstructor.getParameters().forEach(param -> { + if (param.getKind() != KParameter.Kind.INSTANCE && !param.isOptional()) { + String name = param.getName(); + if (name != null) { + paramsWithoutDefault.add(name); + } + } + }); + } + + return paramsWithoutDefault; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java index 65e1a824f9c..a64e22705f2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java @@ -51,6 +51,7 @@ import org.springframework.ai.util.JacksonUtils; import org.springframework.beans.BeanWrapper; import org.springframework.beans.BeanWrapperImpl; +import org.springframework.core.KotlinDetector; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.ObjectUtils; @@ -369,6 +370,10 @@ public static String getJsonSchema(Class clazz, boolean toUpperCaseTypeValues .with(swaggerModule) .with(jacksonModule); + if (KotlinDetector.isKotlinReflectPresent()) { + configBuilder.with(new KotlinModule()); + } + SchemaGeneratorConfig config = configBuilder.build(); SchemaGenerator generator = new SchemaGenerator(config); SCHEMA_GENERATOR_CACHE.compareAndSet(null, generator); @@ -403,6 +408,10 @@ public static String getJsonSchema(Type inputType, boolean toUpperCaseTypeValues .with(swaggerModule) .with(jacksonModule); + if (KotlinDetector.isKotlinReflectPresent()) { + configBuilder.with(new KotlinModule()); + } + SchemaGeneratorConfig config = configBuilder.build(); SchemaGenerator generator = new SchemaGenerator(config); SCHEMA_GENERATOR_CACHE.compareAndSet(null, generator); diff --git a/spring-ai-core/src/test/kotlin/org/springframework/ai/model/ModelOptionsUtilsTests.kt b/spring-ai-core/src/test/kotlin/org/springframework/ai/model/ModelOptionsUtilsTests.kt new file mode 100644 index 00000000000..9f6cb1b8f77 --- /dev/null +++ b/spring-ai-core/src/test/kotlin/org/springframework/ai/model/ModelOptionsUtilsTests.kt @@ -0,0 +1,72 @@ +package org.springframework.ai.model + +import com.fasterxml.jackson.databind.ObjectMapper +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import java.lang.reflect.Type + +class KotlinModelOptionsUtilsTests { + + private class Foo(val bar: String, val baz: String?) + private class FooWithDefault(val bar: String, val baz: Int = 10) + + private val objectMapper = ObjectMapper() + + @Test + fun `test ModelOptionsUtils with Kotlin data class`() { + val portableOptions = Foo("John", "Doe") + + val optionsMap = ModelOptionsUtils.objectToMap(portableOptions) + assertThat(optionsMap).containsEntry("bar", "John") + assertThat(optionsMap).containsEntry("baz", "Doe") + + val newPortableOptions = ModelOptionsUtils.mapToClass(optionsMap, Foo::class.java) + assertThat(newPortableOptions.bar).isEqualTo("John") + assertThat(newPortableOptions.baz).isEqualTo("Doe") + } + + @Test + fun `test Kotlin data class schema generation using getJsonSchema`() { + val inputType: Type = Foo::class.java + + val schemaJson = ModelOptionsUtils.getJsonSchema(inputType, false) + + val schemaNode = objectMapper.readTree(schemaJson) + + val required = schemaNode["required"] + assertThat(required).isNotNull + assertThat(required.toString()).contains("bar") + assertThat(required.toString()).doesNotContain("baz") + + val properties = schemaNode["properties"] + assertThat(properties["bar"]["type"].asText()).isEqualTo("string") + + val bazTypeNode = properties["baz"]["type"] + if (bazTypeNode.isArray) { + assertThat(bazTypeNode.toString()).contains("string") + assertThat(bazTypeNode.toString()).contains("null") + } else { + assertThat(bazTypeNode.asText()).isEqualTo("string") + } + } + + @Test + fun `test data class with default values`() { + val inputType: Type = FooWithDefault::class.java + + val schemaJson = ModelOptionsUtils.getJsonSchema(inputType, false) + + val schemaNode = objectMapper.readTree(schemaJson) + + val required = schemaNode["required"] + assertThat(required).isNotNull + assertThat(required.toString()).contains("bar") + assertThat(required.toString()).doesNotContain("baz") + + val properties = schemaNode["properties"] + assertThat(properties["bar"]["type"].asText()).isEqualTo("string") + + val bazTypeNode = properties["baz"]["type"] + assertThat(bazTypeNode.asText()).isEqualTo("integer") + } +}