From e3f5487df46523c31e7820f7e7c171b12d2cb7df Mon Sep 17 00:00:00 2001 From: D Gardner Date: Fri, 2 May 2025 15:24:11 +0100 Subject: [PATCH 1/5] structured-outputs: updates and more unit tests. --- openai-java-core/build.gradle.kts | 2 + .../com/openai/core/StructuredOutputs.kt | 67 + .../completions/ChatCompletionCreateParams.kt | 8 + .../completions/StructuredChatCompletion.kt | 169 ++ .../StructuredChatCompletionCreateParams.kt | 744 +++++++++ .../StructuredChatCompletionMessage.kt | 92 ++ .../blocking/chat/ChatCompletionService.kt | 12 + .../openai/core/JsonSchemaValidatorTest.kt | 1403 +++++++++++++++++ .../ChatCompletionCreateParamsTest.kt | 32 + ...tructuredChatCompletionCreateParamsTest.kt | 499 ++++++ .../StructuredChatCompletionMessageTest.kt | 141 ++ .../StructuredChatCompletionTest.kt | 405 +++++ .../StructuredOutputsClassExample.java | 73 + 13 files changed, 3647 insertions(+) create mode 100644 openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt create mode 100644 openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt create mode 100644 openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt create mode 100644 openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt create mode 100644 openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt create mode 100644 openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt create mode 100644 openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt create mode 100644 openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt create mode 100644 openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java diff --git a/openai-java-core/build.gradle.kts b/openai-java-core/build.gradle.kts index 08a91e0d..894f0e23 100644 --- a/openai-java-core/build.gradle.kts +++ b/openai-java-core/build.gradle.kts @@ -27,6 +27,8 @@ dependencies { implementation("com.fasterxml.jackson.module:jackson-module-kotlin:2.18.2") implementation("org.apache.httpcomponents.core5:httpcore5:5.2.4") implementation("org.apache.httpcomponents.client5:httpclient5:5.3.1") + implementation("com.github.victools:jsonschema-generator:4.38.0") + implementation("com.github.victools:jsonschema-module-jackson:4.38.0") testImplementation(kotlin("test")) testImplementation(project(":openai-java-client-okhttp")) diff --git a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt new file mode 100644 index 00000000..7f18d237 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt @@ -0,0 +1,67 @@ +package com.openai.core + +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.json.JsonMapper +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule +import com.fasterxml.jackson.module.kotlin.kotlinModule +import com.github.victools.jsonschema.generator.Option +import com.github.victools.jsonschema.generator.OptionPreset +import com.github.victools.jsonschema.generator.SchemaGenerator +import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder +import com.github.victools.jsonschema.module.jackson.JacksonModule +import com.openai.errors.OpenAIInvalidDataException +import com.openai.models.ResponseFormatJsonSchema + +// The SDK `ObjectMappers.jsonMapper()` requires that all fields of classes be marked with +// `@JsonProperty`, which is not desirable in this context, as it impedes usability. Therefore, a +// custom JSON mapper configuration is required. +private val MAPPER = + JsonMapper.builder() + .addModule(kotlinModule()) + .addModule(Jdk8Module()) + .addModule(JavaTimeModule()) + .build() + +fun fromClass(type: Class) = + ResponseFormatJsonSchema.builder() + .jsonSchema( + ResponseFormatJsonSchema.JsonSchema.builder() + .name("json-schema-from-${type.simpleName}") + .schema(JsonValue.from(extractSchema(type))) + .build() + ) + .build() + +internal fun extractSchema(type: Class): JsonNode { + val configBuilder = + SchemaGeneratorConfigBuilder( + com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12, + OptionPreset.PLAIN_JSON, + ) + // Add `"additionalProperties" : false` to all object schemas (see OpenAI). + .with(Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT) + // Use `JacksonModule` to support the use of Jackson annotations to set property and + // class names and descriptions and to mark fields with `@JsonIgnore`. + .with(JacksonModule()) + + configBuilder + .forFields() + // For OpenAI schemas, _all_ properties _must_ be required. Override the interpretation of + // the Jackson `required` parameter to the `@JsonProperty` annotation: it will always be + // assumed to be `true`, even if explicitly `false` and even if there is no `@JsonProperty` + // annotation present. + .withRequiredCheck { true } + + return SchemaGenerator(configBuilder.build()).generateSchema(type) +} + +fun fromJson(json: String, type: Class): T = + try { + MAPPER.readValue(json, type) + } catch (e: Exception) { + // The JSON document is included in the exception message to aid diagnosis of the problem. + // It is the responsibility of the SDK user to ensure that exceptions that may contain + // sensitive data are not exposed in logs. + throw OpenAIInvalidDataException("Error parsing JSON: $json", e) + } diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt index a3281dc6..cb3459fe 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt @@ -1297,6 +1297,14 @@ private constructor( body.responseFormat(jsonObject) } + /** + * Sets the class that defines the structured outputs response format. This changes the + * builder to a type-safe [StructuredChatCompletionCreateParams.Builder] that will build a + * [StructuredChatCompletionCreateParams] instance when `build()` is called. + */ + fun responseFormat(responseFormat: Class) = + StructuredChatCompletionCreateParams.builder().wrap(responseFormat, this) + /** * This feature is in Beta. If specified, our system will make a best effort to sample * deterministically, such that repeated requests with the same `seed` and parameters should diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt new file mode 100644 index 00000000..6ca931a5 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt @@ -0,0 +1,169 @@ +package com.openai.models.chat.completions + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.errors.OpenAIInvalidDataException +import com.openai.models.chat.completions.ChatCompletion.Choice.FinishReason +import com.openai.models.chat.completions.ChatCompletion.Choice.Logprobs +import com.openai.models.chat.completions.ChatCompletion.ServiceTier +import com.openai.models.completions.CompletionUsage +import java.util.Objects +import java.util.Optional + +class StructuredChatCompletion( + val responseFormat: Class, + val chatCompletion: ChatCompletion, +) { + /** @see ChatCompletion.id */ + fun id(): String = chatCompletion.id() + + private val choices by lazy { + chatCompletion._choices().map { choices -> choices.map { Choice(responseFormat, it) } } + } + + /** @see ChatCompletion.choices */ + fun choices(): List> = choices.getRequired("choices") + + /** @see ChatCompletion.created */ + fun created(): Long = chatCompletion.created() + + /** @see ChatCompletion.model */ + fun model(): String = chatCompletion.model() + + /** @see ChatCompletion._object_ */ + fun _object_(): JsonValue = chatCompletion._object_() + + /** @see ChatCompletion.serviceTier */ + fun serviceTier(): Optional = chatCompletion.serviceTier() + + /** @see ChatCompletion.systemFingerprint */ + fun systemFingerprint(): Optional = chatCompletion.systemFingerprint() + + /** @see ChatCompletion.usage */ + fun usage(): Optional = chatCompletion.usage() + + /** @see ChatCompletion._id */ + fun _id(): JsonField = chatCompletion._id() + + /** @see ChatCompletion._choices */ + fun _choices(): JsonField>> = choices + + /** @see ChatCompletion._created */ + fun _created(): JsonField = chatCompletion._created() + + /** @see ChatCompletion._model */ + fun _model(): JsonField = chatCompletion._model() + + /** @see ChatCompletion._serviceTier */ + fun _serviceTier(): JsonField = chatCompletion._serviceTier() + + /** @see ChatCompletion._systemFingerprint */ + fun _systemFingerprint(): JsonField = chatCompletion._systemFingerprint() + + /** @see ChatCompletion._usage */ + fun _usage(): JsonField = chatCompletion._usage() + + /** @see ChatCompletion._additionalProperties */ + fun _additionalProperties(): Map = chatCompletion._additionalProperties() + + class Choice + internal constructor( + internal val responseFormat: Class, + internal val choice: ChatCompletion.Choice, + ) { + /** @see ChatCompletion.Choice.finishReason */ + fun finishReason(): FinishReason = choice.finishReason() + + /** @see ChatCompletion.Choice.index */ + fun index(): Long = choice.index() + + /** @see ChatCompletion.Choice.logprobs */ + fun logprobs(): Optional = choice.logprobs() + + /** @see ChatCompletion.Choice._finishReason */ + fun _finishReason(): JsonField = choice._finishReason() + + private val message by lazy { + choice._message().map { StructuredChatCompletionMessage(responseFormat, it) } + } + + /** @see ChatCompletion.Choice.message */ + fun message(): StructuredChatCompletionMessage = message.getRequired("message") + + /** @see ChatCompletion.Choice._index */ + fun _index(): JsonField = choice._index() + + /** @see ChatCompletion.Choice._logprobs */ + fun _logprobs(): JsonField = choice._logprobs() + + /** @see ChatCompletion.Choice._message */ + fun _message(): JsonField> = message + + /** @see ChatCompletion.Choice._additionalProperties */ + fun _additionalProperties(): Map = choice._additionalProperties() + + /** @see ChatCompletion.Choice.validate */ + fun validate(): Choice = apply { + message().validate() + choice.validate() + } + + /** @see ChatCompletion.Choice.isValid */ + fun isValid(): Boolean = + try { + validate() + true + } catch (_: OpenAIInvalidDataException) { + false + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is Choice<*> && + responseFormat == other.responseFormat && + choice == other.choice + } + + private val hashCode: Int by lazy { Objects.hash(responseFormat, choice) } + + override fun hashCode(): Int = hashCode + + override fun toString() = + "${javaClass.simpleName}{responseFormat=$responseFormat, choice=$choice}" + } + + /** @see ChatCompletion.validate */ + fun validate() = apply { + choices().forEach { it.validate() } + chatCompletion.validate() + } + + /** @see ChatCompletion.isValid */ + fun isValid(): Boolean = + try { + validate() + true + } catch (_: OpenAIInvalidDataException) { + false + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredChatCompletion<*> && + responseFormat == other.responseFormat && + chatCompletion == other.chatCompletion + } + + private val hashCode: Int by lazy { Objects.hash(responseFormat, chatCompletion) } + + override fun hashCode(): Int = hashCode + + override fun toString() = + "${javaClass.simpleName}{responseFormat=$responseFormat, chatCompletion=$chatCompletion}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt new file mode 100644 index 00000000..ae1ea1be --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt @@ -0,0 +1,744 @@ +package com.openai.models.chat.completions + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.core.checkRequired +import com.openai.core.fromClass +import com.openai.core.http.Headers +import com.openai.core.http.QueryParams +import com.openai.models.ChatModel +import com.openai.models.ReasoningEffort +import java.util.Objects +import java.util.Optional + +class StructuredChatCompletionCreateParams +internal constructor( + val responseFormat: Class, + /** + * The raw, underlying chat completion create parameters wrapped by this structured instance of + * the parameters. + */ + @get:JvmName("rawParams") val rawParams: ChatCompletionCreateParams, +) { + + companion object { + @JvmStatic fun builder() = Builder() + } + + class Builder internal constructor() { + private var responseFormat: Class? = null + private var paramsBuilder = ChatCompletionCreateParams.builder() + + @JvmSynthetic + internal fun wrap( + responseFormat: Class, + paramsBuilder: ChatCompletionCreateParams.Builder, + ) = apply { + this.responseFormat = responseFormat + this.paramsBuilder = paramsBuilder + // Convert the class to a JSON schema and apply it to the delegate `Builder`. + responseFormat(responseFormat) + } + + /** Injects a given `ChatCompletionCreateParams.Builder`. For use only when testing. */ + @JvmSynthetic + internal fun inject(paramsBuilder: ChatCompletionCreateParams.Builder) = apply { + this.paramsBuilder = paramsBuilder + } + + /** @see ChatCompletionCreateParams.Builder.body */ + fun body(body: ChatCompletionCreateParams.Body) = apply { paramsBuilder.body(body) } + + /** @see ChatCompletionCreateParams.Builder.messages */ + fun messages(messages: List) = apply { + paramsBuilder.messages(messages) + } + + /** @see ChatCompletionCreateParams.Builder.messages */ + fun messages(messages: JsonField>) = apply { + paramsBuilder.messages(messages) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(message: ChatCompletionMessageParam) = apply { + paramsBuilder.addMessage(message) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(developer: ChatCompletionDeveloperMessageParam) = apply { + paramsBuilder.addMessage(developer) + } + + /** @see ChatCompletionCreateParams.Builder.addDeveloperMessage */ + fun addDeveloperMessage(content: ChatCompletionDeveloperMessageParam.Content) = apply { + paramsBuilder.addDeveloperMessage(content) + } + + /** @see ChatCompletionCreateParams.Builder.addDeveloperMessage */ + fun addDeveloperMessage(text: String) = apply { paramsBuilder.addDeveloperMessage(text) } + + /** @see ChatCompletionCreateParams.Builder.addDeveloperMessageOfArrayOfContentParts */ + fun addDeveloperMessageOfArrayOfContentParts( + arrayOfContentParts: List + ) = apply { paramsBuilder.addDeveloperMessageOfArrayOfContentParts(arrayOfContentParts) } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(system: ChatCompletionSystemMessageParam) = apply { + paramsBuilder.addMessage(system) + } + + /** @see ChatCompletionCreateParams.Builder.addSystemMessage */ + fun addSystemMessage(content: ChatCompletionSystemMessageParam.Content) = apply { + paramsBuilder.addSystemMessage(content) + } + + /** @see ChatCompletionCreateParams.Builder.addSystemMessage */ + fun addSystemMessage(text: String) = apply { paramsBuilder.addSystemMessage(text) } + + /** @see ChatCompletionCreateParams.Builder.addSystemMessageOfArrayOfContentParts */ + fun addSystemMessageOfArrayOfContentParts( + arrayOfContentParts: List + ) = apply { paramsBuilder.addSystemMessageOfArrayOfContentParts(arrayOfContentParts) } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(user: ChatCompletionUserMessageParam) = apply { + paramsBuilder.addMessage(user) + } + + /** @see ChatCompletionCreateParams.Builder.addUserMessage */ + fun addUserMessage(content: ChatCompletionUserMessageParam.Content) = apply { + paramsBuilder.addUserMessage(content) + } + + /** @see ChatCompletionCreateParams.Builder.addUserMessage */ + fun addUserMessage(text: String) = apply { paramsBuilder.addUserMessage(text) } + + /** @see ChatCompletionCreateParams.Builder.addUserMessageOfArrayOfContentParts */ + fun addUserMessageOfArrayOfContentParts( + arrayOfContentParts: List + ) = apply { paramsBuilder.addUserMessageOfArrayOfContentParts(arrayOfContentParts) } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(assistant: ChatCompletionAssistantMessageParam) = apply { + paramsBuilder.addMessage(assistant) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(assistant: ChatCompletionMessage) = apply { + paramsBuilder.addMessage(assistant) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + fun addMessage(tool: ChatCompletionToolMessageParam) = apply { + paramsBuilder.addMessage(tool) + } + + /** @see ChatCompletionCreateParams.Builder.addMessage */ + @Deprecated("deprecated") + fun addMessage(function: ChatCompletionFunctionMessageParam) = apply { + paramsBuilder.addMessage(function) + } + + /** @see ChatCompletionCreateParams.Builder.model */ + fun model(model: ChatModel) = apply { paramsBuilder.model(model) } + + /** @see ChatCompletionCreateParams.Builder.model */ + fun model(model: JsonField) = apply { paramsBuilder.model(model) } + + /** @see ChatCompletionCreateParams.Builder.model */ + fun model(value: String) = apply { paramsBuilder.model(value) } + + /** @see ChatCompletionCreateParams.Builder.audio */ + fun audio(audio: ChatCompletionAudioParam?) = apply { paramsBuilder.audio(audio) } + + /** @see ChatCompletionCreateParams.Builder.audio */ + fun audio(audio: Optional) = apply { paramsBuilder.audio(audio) } + + /** @see ChatCompletionCreateParams.Builder.audio */ + fun audio(audio: JsonField) = apply { paramsBuilder.audio(audio) } + + /** @see ChatCompletionCreateParams.Builder.frequencyPenalty */ + fun frequencyPenalty(frequencyPenalty: Double?) = apply { + paramsBuilder.frequencyPenalty(frequencyPenalty) + } + + /** @see ChatCompletionCreateParams.Builder.frequencyPenalty */ + fun frequencyPenalty(frequencyPenalty: Double) = apply { + paramsBuilder.frequencyPenalty(frequencyPenalty) + } + + /** @see ChatCompletionCreateParams.Builder.frequencyPenalty */ + fun frequencyPenalty(frequencyPenalty: Optional) = apply { + paramsBuilder.frequencyPenalty(frequencyPenalty) + } + + /** @see ChatCompletionCreateParams.Builder.frequencyPenalty */ + fun frequencyPenalty(frequencyPenalty: JsonField) = apply { + paramsBuilder.frequencyPenalty(frequencyPenalty) + } + + /** @see ChatCompletionCreateParams.Builder.functionCall */ + @Deprecated("deprecated") + fun functionCall(functionCall: ChatCompletionCreateParams.FunctionCall) = apply { + paramsBuilder.functionCall(functionCall) + } + + /** @see ChatCompletionCreateParams.Builder.functionCall */ + @Deprecated("deprecated") + fun functionCall(functionCall: JsonField) = apply { + paramsBuilder.functionCall(functionCall) + } + + /** @see ChatCompletionCreateParams.Builder.functionCall */ + @Deprecated("deprecated") + fun functionCall(mode: ChatCompletionCreateParams.FunctionCall.FunctionCallMode) = apply { + paramsBuilder.functionCall(mode) + } + + /** @see ChatCompletionCreateParams.Builder.functionCall */ + @Deprecated("deprecated") + fun functionCall(functionCallOption: ChatCompletionFunctionCallOption) = apply { + paramsBuilder.functionCall(functionCallOption) + } + + /** @see ChatCompletionCreateParams.Builder.functions */ + @Deprecated("deprecated") + fun functions(functions: List) = apply { + paramsBuilder.functions(functions) + } + + /** @see ChatCompletionCreateParams.Builder.functions */ + @Deprecated("deprecated") + fun functions(functions: JsonField>) = apply { + paramsBuilder.functions(functions) + } + + /** @see ChatCompletionCreateParams.Builder.addFunction */ + @Deprecated("deprecated") + fun addFunction(function: ChatCompletionCreateParams.Function) = apply { + paramsBuilder.addFunction(function) + } + + /** @see ChatCompletionCreateParams.Builder.logitBias */ + fun logitBias(logitBias: ChatCompletionCreateParams.LogitBias?) = apply { + paramsBuilder.logitBias(logitBias) + } + + /** @see ChatCompletionCreateParams.Builder.logitBias */ + fun logitBias(logitBias: Optional) = apply { + paramsBuilder.logitBias(logitBias) + } + + /** @see ChatCompletionCreateParams.Builder.logitBias */ + fun logitBias(logitBias: JsonField) = apply { + paramsBuilder.logitBias(logitBias) + } + + /** @see ChatCompletionCreateParams.Builder.logprobs */ + fun logprobs(logprobs: Boolean?) = apply { paramsBuilder.logprobs(logprobs) } + + /** @see ChatCompletionCreateParams.Builder.logprobs */ + fun logprobs(logprobs: Boolean) = apply { paramsBuilder.logprobs(logprobs) } + + /** @see ChatCompletionCreateParams.Builder.logprobs */ + fun logprobs(logprobs: Optional) = apply { paramsBuilder.logprobs(logprobs) } + + /** @see ChatCompletionCreateParams.Builder.logprobs */ + fun logprobs(logprobs: JsonField) = apply { paramsBuilder.logprobs(logprobs) } + + /** @see ChatCompletionCreateParams.Builder.maxCompletionTokens */ + fun maxCompletionTokens(maxCompletionTokens: Long?) = apply { + paramsBuilder.maxCompletionTokens(maxCompletionTokens) + } + + /** @see ChatCompletionCreateParams.Builder.maxCompletionTokens */ + fun maxCompletionTokens(maxCompletionTokens: Long) = apply { + paramsBuilder.maxCompletionTokens(maxCompletionTokens) + } + + /** @see ChatCompletionCreateParams.Builder.maxCompletionTokens */ + fun maxCompletionTokens(maxCompletionTokens: Optional) = apply { + paramsBuilder.maxCompletionTokens(maxCompletionTokens) + } + + /** @see ChatCompletionCreateParams.Builder.maxCompletionTokens */ + fun maxCompletionTokens(maxCompletionTokens: JsonField) = apply { + paramsBuilder.maxCompletionTokens(maxCompletionTokens) + } + + /** @see ChatCompletionCreateParams.Builder.maxTokens */ + @Deprecated("deprecated") + fun maxTokens(maxTokens: Long?) = apply { paramsBuilder.maxTokens(maxTokens) } + + /** @see ChatCompletionCreateParams.Builder.maxTokens */ + @Deprecated("deprecated") + fun maxTokens(maxTokens: Long) = apply { paramsBuilder.maxTokens(maxTokens) } + + /** @see ChatCompletionCreateParams.Builder.maxTokens */ + @Deprecated("deprecated") + fun maxTokens(maxTokens: Optional) = apply { paramsBuilder.maxTokens(maxTokens) } + + /** @see ChatCompletionCreateParams.Builder.maxTokens */ + @Deprecated("deprecated") + fun maxTokens(maxTokens: JsonField) = apply { paramsBuilder.maxTokens(maxTokens) } + + /** @see ChatCompletionCreateParams.Builder.metadata */ + fun metadata(metadata: ChatCompletionCreateParams.Metadata?) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ChatCompletionCreateParams.Builder.metadata */ + fun metadata(metadata: Optional) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ChatCompletionCreateParams.Builder.metadata */ + fun metadata(metadata: JsonField) = apply { + paramsBuilder.metadata(metadata) + } + + /** @see ChatCompletionCreateParams.Builder.modalities */ + fun modalities(modalities: List?) = apply { + paramsBuilder.modalities(modalities) + } + + /** @see ChatCompletionCreateParams.Builder.modalities */ + fun modalities(modalities: Optional>) = apply { + paramsBuilder.modalities(modalities) + } + + /** @see ChatCompletionCreateParams.Builder.modalities */ + fun modalities(modalities: JsonField>) = apply { + paramsBuilder.modalities(modalities) + } + + /** @see ChatCompletionCreateParams.Builder.addModality */ + fun addModality(modality: ChatCompletionCreateParams.Modality) = apply { + paramsBuilder.addModality(modality) + } + + /** @see ChatCompletionCreateParams.Builder.n */ + fun n(n: Long?) = apply { paramsBuilder.n(n) } + + /** @see ChatCompletionCreateParams.Builder.n */ + fun n(n: Long) = apply { paramsBuilder.n(n) } + + /** @see ChatCompletionCreateParams.Builder.n */ + fun n(n: Optional) = apply { paramsBuilder.n(n) } + + /** @see ChatCompletionCreateParams.Builder.n */ + fun n(n: JsonField) = apply { paramsBuilder.n(n) } + + /** @see ChatCompletionCreateParams.Builder.parallelToolCalls */ + fun parallelToolCalls(parallelToolCalls: Boolean) = apply { + paramsBuilder.parallelToolCalls(parallelToolCalls) + } + + /** @see ChatCompletionCreateParams.Builder.parallelToolCalls */ + fun parallelToolCalls(parallelToolCalls: JsonField) = apply { + paramsBuilder.parallelToolCalls(parallelToolCalls) + } + + /** @see ChatCompletionCreateParams.Builder.prediction */ + fun prediction(prediction: ChatCompletionPredictionContent?) = apply { + paramsBuilder.prediction(prediction) + } + + /** @see ChatCompletionCreateParams.Builder.prediction */ + fun prediction(prediction: Optional) = apply { + paramsBuilder.prediction(prediction) + } + + /** @see ChatCompletionCreateParams.Builder.prediction */ + fun prediction(prediction: JsonField) = apply { + paramsBuilder.prediction(prediction) + } + + /** @see ChatCompletionCreateParams.Builder.presencePenalty */ + fun presencePenalty(presencePenalty: Double?) = apply { + paramsBuilder.presencePenalty(presencePenalty) + } + + /** @see ChatCompletionCreateParams.Builder.presencePenalty */ + fun presencePenalty(presencePenalty: Double) = apply { + paramsBuilder.presencePenalty(presencePenalty) + } + + /** @see ChatCompletionCreateParams.Builder.presencePenalty */ + fun presencePenalty(presencePenalty: Optional) = apply { + paramsBuilder.presencePenalty(presencePenalty) + } + + /** @see ChatCompletionCreateParams.Builder.presencePenalty */ + fun presencePenalty(presencePenalty: JsonField) = apply { + paramsBuilder.presencePenalty(presencePenalty) + } + + /** @see ChatCompletionCreateParams.Builder.reasoningEffort */ + fun reasoningEffort(reasoningEffort: ReasoningEffort?) = apply { + paramsBuilder.reasoningEffort(reasoningEffort) + } + + /** @see ChatCompletionCreateParams.Builder.reasoningEffort */ + fun reasoningEffort(reasoningEffort: Optional) = apply { + paramsBuilder.reasoningEffort(reasoningEffort) + } + + /** @see ChatCompletionCreateParams.Builder.reasoningEffort */ + fun reasoningEffort(reasoningEffort: JsonField) = apply { + paramsBuilder.reasoningEffort(reasoningEffort) + } + + /** Sets the response format to a JSON schema derived from the given class. */ + fun responseFormat(responseFormat: Class) = apply { + this.responseFormat = responseFormat + paramsBuilder.responseFormat(fromClass(responseFormat)) + } + + /** @see ChatCompletionCreateParams.Builder.seed */ + fun seed(seed: Long?) = apply { paramsBuilder.seed(seed) } + + /** @see ChatCompletionCreateParams.Builder.seed */ + fun seed(seed: Long) = apply { paramsBuilder.seed(seed) } + + /** @see ChatCompletionCreateParams.Builder.seed */ + fun seed(seed: Optional) = apply { paramsBuilder.seed(seed) } + + /** @see ChatCompletionCreateParams.Builder.seed */ + fun seed(seed: JsonField) = apply { paramsBuilder.seed(seed) } + + /** @see ChatCompletionCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: ChatCompletionCreateParams.ServiceTier?) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ChatCompletionCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: Optional) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ChatCompletionCreateParams.Builder.serviceTier */ + fun serviceTier(serviceTier: JsonField) = apply { + paramsBuilder.serviceTier(serviceTier) + } + + /** @see ChatCompletionCreateParams.Builder.stop */ + fun stop(stop: ChatCompletionCreateParams.Stop?) = apply { paramsBuilder.stop(stop) } + + /** @see ChatCompletionCreateParams.Builder.stop */ + fun stop(stop: Optional) = apply { + paramsBuilder.stop(stop) + } + + /** @see ChatCompletionCreateParams.Builder.stop */ + fun stop(stop: JsonField) = apply { + paramsBuilder.stop(stop) + } + + /** @see ChatCompletionCreateParams.Builder.stop */ + fun stop(string: String) = apply { paramsBuilder.stop(string) } + + /** @see ChatCompletionCreateParams.Builder.stopOfStrings */ + fun stopOfStrings(strings: List) = apply { paramsBuilder.stopOfStrings(strings) } + + /** @see ChatCompletionCreateParams.Builder.store */ + fun store(store: Boolean?) = apply { paramsBuilder.store(store) } + + /** @see ChatCompletionCreateParams.Builder.store */ + fun store(store: Boolean) = apply { paramsBuilder.store(store) } + + /** @see ChatCompletionCreateParams.Builder.store */ + fun store(store: Optional) = apply { paramsBuilder.store(store) } + + /** @see ChatCompletionCreateParams.Builder.store */ + fun store(store: JsonField) = apply { paramsBuilder.store(store) } + + /** @see ChatCompletionCreateParams.Builder.streamOptions */ + fun streamOptions(streamOptions: ChatCompletionStreamOptions?) = apply { + paramsBuilder.streamOptions(streamOptions) + } + + /** @see ChatCompletionCreateParams.Builder.streamOptions */ + fun streamOptions(streamOptions: Optional) = apply { + paramsBuilder.streamOptions(streamOptions) + } + + /** @see ChatCompletionCreateParams.Builder.streamOptions */ + fun streamOptions(streamOptions: JsonField) = apply { + paramsBuilder.streamOptions(streamOptions) + } + + /** @see ChatCompletionCreateParams.Builder.temperature */ + fun temperature(temperature: Double?) = apply { paramsBuilder.temperature(temperature) } + + /** @see ChatCompletionCreateParams.Builder.temperature */ + fun temperature(temperature: Double) = apply { paramsBuilder.temperature(temperature) } + + /** @see ChatCompletionCreateParams.Builder.temperature */ + fun temperature(temperature: Optional) = apply { + paramsBuilder.temperature(temperature) + } + + /** @see ChatCompletionCreateParams.Builder.temperature */ + fun temperature(temperature: JsonField) = apply { + paramsBuilder.temperature(temperature) + } + + /** @see ChatCompletionCreateParams.Builder.toolChoice */ + fun toolChoice(toolChoice: ChatCompletionToolChoiceOption) = apply { + paramsBuilder.toolChoice(toolChoice) + } + + /** @see ChatCompletionCreateParams.Builder.toolChoice */ + fun toolChoice(toolChoice: JsonField) = apply { + paramsBuilder.toolChoice(toolChoice) + } + + /** @see ChatCompletionCreateParams.Builder.toolChoice */ + fun toolChoice(auto: ChatCompletionToolChoiceOption.Auto) = apply { + paramsBuilder.toolChoice(auto) + } + + /** @see ChatCompletionCreateParams.Builder.toolChoice */ + fun toolChoice(namedToolChoice: ChatCompletionNamedToolChoice) = apply { + paramsBuilder.toolChoice(namedToolChoice) + } + + /** @see ChatCompletionCreateParams.Builder.tools */ + fun tools(tools: List) = apply { paramsBuilder.tools(tools) } + + /** @see ChatCompletionCreateParams.Builder.tools */ + fun tools(tools: JsonField>) = apply { paramsBuilder.tools(tools) } + + /** @see ChatCompletionCreateParams.Builder.addTool */ + fun addTool(tool: ChatCompletionTool) = apply { paramsBuilder.addTool(tool) } + + /** @see ChatCompletionCreateParams.Builder.topLogprobs */ + fun topLogprobs(topLogprobs: Long?) = apply { paramsBuilder.topLogprobs(topLogprobs) } + + /** @see ChatCompletionCreateParams.Builder.topLogprobs */ + fun topLogprobs(topLogprobs: Long) = apply { paramsBuilder.topLogprobs(topLogprobs) } + + /** @see ChatCompletionCreateParams.Builder.topLogprobs */ + fun topLogprobs(topLogprobs: Optional) = apply { + paramsBuilder.topLogprobs(topLogprobs) + } + + /** @see ChatCompletionCreateParams.Builder.topLogprobs */ + fun topLogprobs(topLogprobs: JsonField) = apply { + paramsBuilder.topLogprobs(topLogprobs) + } + + /** @see ChatCompletionCreateParams.Builder.topP */ + fun topP(topP: Double?) = apply { paramsBuilder.topP(topP) } + + /** @see ChatCompletionCreateParams.Builder.topP */ + fun topP(topP: Double) = apply { paramsBuilder.topP(topP) } + + /** @see ChatCompletionCreateParams.Builder.topP */ + fun topP(topP: Optional) = apply { paramsBuilder.topP(topP) } + + /** @see ChatCompletionCreateParams.Builder.topP */ + fun topP(topP: JsonField) = apply { paramsBuilder.topP(topP) } + + /** @see ChatCompletionCreateParams.Builder.user */ + fun user(user: String) = apply { paramsBuilder.user(user) } + + /** @see ChatCompletionCreateParams.Builder.user */ + fun user(user: JsonField) = apply { paramsBuilder.user(user) } + + /** @see ChatCompletionCreateParams.Builder.webSearchOptions */ + fun webSearchOptions(webSearchOptions: ChatCompletionCreateParams.WebSearchOptions) = + apply { + paramsBuilder.webSearchOptions(webSearchOptions) + } + + /** @see ChatCompletionCreateParams.Builder.webSearchOptions */ + fun webSearchOptions( + webSearchOptions: JsonField + ) = apply { paramsBuilder.webSearchOptions(webSearchOptions) } + + /** @see ChatCompletionCreateParams.Builder.additionalBodyProperties */ + fun additionalBodyProperties(additionalBodyProperties: Map) = apply { + paramsBuilder.additionalBodyProperties(additionalBodyProperties) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalBodyProperty */ + fun putAdditionalBodyProperty(key: String, value: JsonValue) = apply { + paramsBuilder.putAdditionalBodyProperty(key, value) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalBodyProperties */ + fun putAllAdditionalBodyProperties(additionalBodyProperties: Map) = + apply { + paramsBuilder.putAllAdditionalBodyProperties(additionalBodyProperties) + } + + /** @see ChatCompletionCreateParams.Builder.removeAdditionalBodyProperty */ + fun removeAdditionalBodyProperty(key: String) = apply { + paramsBuilder.removeAdditionalBodyProperty(key) + } + + /** @see ChatCompletionCreateParams.Builder.removeAllAdditionalBodyProperties */ + fun removeAllAdditionalBodyProperties(keys: Set) = apply { + paramsBuilder.removeAllAdditionalBodyProperties(keys) + } + + /** @see ChatCompletionCreateParams.Builder.additionalHeaders */ + fun additionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.additionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.additionalHeaders */ + fun additionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.additionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalHeader */ + fun putAdditionalHeader(name: String, value: String) = apply { + paramsBuilder.putAdditionalHeader(name, value) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalHeaders */ + fun putAdditionalHeaders(name: String, values: Iterable) = apply { + paramsBuilder.putAdditionalHeaders(name, values) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalHeaders */ + fun putAllAdditionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.putAllAdditionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalHeaders */ + fun putAllAdditionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.putAllAdditionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAdditionalHeaders */ + fun replaceAdditionalHeaders(name: String, value: String) = apply { + paramsBuilder.replaceAdditionalHeaders(name, value) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAdditionalHeaders */ + fun replaceAdditionalHeaders(name: String, values: Iterable) = apply { + paramsBuilder.replaceAdditionalHeaders(name, values) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAllAdditionalHeaders */ + fun replaceAllAdditionalHeaders(additionalHeaders: Headers) = apply { + paramsBuilder.replaceAllAdditionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAllAdditionalHeaders */ + fun replaceAllAdditionalHeaders(additionalHeaders: Map>) = apply { + paramsBuilder.replaceAllAdditionalHeaders(additionalHeaders) + } + + /** @see ChatCompletionCreateParams.Builder.removeAdditionalHeaders */ + fun removeAdditionalHeaders(name: String) = apply { + paramsBuilder.removeAdditionalHeaders(name) + } + + /** @see ChatCompletionCreateParams.Builder.removeAllAdditionalHeaders */ + fun removeAllAdditionalHeaders(names: Set) = apply { + paramsBuilder.removeAllAdditionalHeaders(names) + } + + /** @see ChatCompletionCreateParams.Builder.additionalQueryParams */ + fun additionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.additionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.additionalQueryParams */ + fun additionalQueryParams(additionalQueryParams: Map>) = apply { + paramsBuilder.additionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalQueryParam */ + fun putAdditionalQueryParam(key: String, value: String) = apply { + paramsBuilder.putAdditionalQueryParam(key, value) + } + + /** @see ChatCompletionCreateParams.Builder.putAdditionalQueryParams */ + fun putAdditionalQueryParams(key: String, values: Iterable) = apply { + paramsBuilder.putAdditionalQueryParams(key, values) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalQueryParams */ + fun putAllAdditionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.putAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.putAllAdditionalQueryParams */ + fun putAllAdditionalQueryParams(additionalQueryParams: Map>) = + apply { + paramsBuilder.putAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAdditionalQueryParams */ + fun replaceAdditionalQueryParams(key: String, value: String) = apply { + paramsBuilder.replaceAdditionalQueryParams(key, value) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAdditionalQueryParams */ + fun replaceAdditionalQueryParams(key: String, values: Iterable) = apply { + paramsBuilder.replaceAdditionalQueryParams(key, values) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAllAdditionalQueryParams */ + fun replaceAllAdditionalQueryParams(additionalQueryParams: QueryParams) = apply { + paramsBuilder.replaceAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.replaceAllAdditionalQueryParams */ + fun replaceAllAdditionalQueryParams(additionalQueryParams: Map>) = + apply { + paramsBuilder.replaceAllAdditionalQueryParams(additionalQueryParams) + } + + /** @see ChatCompletionCreateParams.Builder.removeAdditionalQueryParams */ + fun removeAdditionalQueryParams(key: String) = apply { + paramsBuilder.removeAdditionalQueryParams(key) + } + + /** @see ChatCompletionCreateParams.Builder.removeAllAdditionalQueryParams */ + fun removeAllAdditionalQueryParams(keys: Set) = apply { + paramsBuilder.removeAllAdditionalQueryParams(keys) + } + + /** + * Returns an immutable instance of [StructuredChatCompletionCreateParams]. + * + * Further updates to this [Builder] will not mutate the returned instance. + * + * The following fields are required: + * ```java + * .messages() + * .model() + * .responseFormat() + * ``` + * + * @throws IllegalStateException If any required field is unset. + */ + fun build() = + StructuredChatCompletionCreateParams( + checkRequired("responseFormat", responseFormat), + paramsBuilder.build(), + ) + } + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredChatCompletionCreateParams<*> && + responseFormat == other.responseFormat && + rawParams == other.rawParams + } + + override fun hashCode(): Int = Objects.hash(responseFormat, rawParams) + + override fun toString() = + "${javaClass.simpleName}{responseFormat=$responseFormat, params=$rawParams}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt new file mode 100644 index 00000000..519596ef --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt @@ -0,0 +1,92 @@ +package com.openai.models.chat.completions + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.core.fromJson +import com.openai.models.chat.completions.ChatCompletionMessage.FunctionCall +import java.util.Objects +import java.util.Optional + +class StructuredChatCompletionMessage +internal constructor( + val responseFormat: Class, + val chatCompletionMessage: ChatCompletionMessage, +) { + + private val content: JsonField by lazy { + chatCompletionMessage._content().map { fromJson(it, responseFormat) } + } + + /** @see ChatCompletionMessage.content */ + fun content(): Optional = content.getOptional("content") + + /** @see ChatCompletionMessage.refusal */ + fun refusal(): Optional = chatCompletionMessage.refusal() + + /** @see ChatCompletionMessage._role */ + fun _role(): JsonValue = chatCompletionMessage._role() + + /** @see ChatCompletionMessage.annotations */ + fun annotations(): Optional> = + chatCompletionMessage.annotations() + + /** @see ChatCompletionMessage.audio */ + fun audio(): Optional = chatCompletionMessage.audio() + + /** @see ChatCompletionMessage.functionCall */ + @Deprecated("deprecated") + fun functionCall(): Optional = chatCompletionMessage.functionCall() + + /** @see ChatCompletionMessage.toolCalls */ + fun toolCalls(): Optional> = + chatCompletionMessage.toolCalls() + + /** @see ChatCompletionMessage._content */ + fun _content(): JsonField = content + + /** @see ChatCompletionMessage._refusal */ + fun _refusal(): JsonField = chatCompletionMessage._refusal() + + /** @see ChatCompletionMessage._annotations */ + fun _annotations(): JsonField> = + chatCompletionMessage._annotations() + + /** @see ChatCompletionMessage._audio */ + fun _audio(): JsonField = chatCompletionMessage._audio() + + /** @see ChatCompletionMessage._functionCall */ + @Deprecated("deprecated") + fun _functionCall(): JsonField = chatCompletionMessage._functionCall() + + /** @see ChatCompletionMessage._toolCalls */ + fun _toolCalls(): JsonField> = + chatCompletionMessage._toolCalls() + + /** @see ChatCompletionMessage._additionalProperties */ + fun _additionalProperties(): Map = + chatCompletionMessage._additionalProperties() + + /** @see ChatCompletionMessage.validate */ + // `content()` is not included in the validation by the delegate method, so just call it. + fun validate(): ChatCompletionMessage = chatCompletionMessage.validate() + + /** @see ChatCompletionMessage.isValid */ + fun isValid(): Boolean = chatCompletionMessage.isValid() + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is StructuredChatCompletionMessage<*> && + responseFormat == other.responseFormat && + chatCompletionMessage == other.chatCompletionMessage + } + + private val hashCode: Int by lazy { Objects.hash(responseFormat, chatCompletionMessage) } + + override fun hashCode(): Int = hashCode + + override fun toString() = + "${javaClass.simpleName}{responseFormat=$responseFormat, chatCompletionMessage=$chatCompletionMessage}" +} diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt index 3f46a970..0060c54f 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt @@ -15,6 +15,8 @@ import com.openai.models.chat.completions.ChatCompletionListPage import com.openai.models.chat.completions.ChatCompletionListParams import com.openai.models.chat.completions.ChatCompletionRetrieveParams import com.openai.models.chat.completions.ChatCompletionUpdateParams +import com.openai.models.chat.completions.StructuredChatCompletion +import com.openai.models.chat.completions.StructuredChatCompletionCreateParams import com.openai.services.blocking.chat.completions.MessageService interface ChatCompletionService { @@ -53,6 +55,16 @@ interface ChatCompletionService { requestOptions: RequestOptions = RequestOptions.none(), ): ChatCompletion + /** @see create */ + fun create( + params: StructuredChatCompletionCreateParams + ): StructuredChatCompletion = + StructuredChatCompletion( + params.responseFormat, + // Normal, non-generic create method call via `ChatCompletionCreateParams`. + create(params.rawParams), + ) + /** * **Starting a new project?** We recommend trying * [Responses](https://platform.openai.com/docs/api-reference/responses) to take advantage of diff --git a/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt new file mode 100644 index 00000000..31768c04 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt @@ -0,0 +1,1403 @@ +package com.openai.core + +import com.fasterxml.jackson.annotation.JsonClassDescription +import com.fasterxml.jackson.annotation.JsonIgnore +import com.fasterxml.jackson.annotation.JsonProperty +import com.fasterxml.jackson.annotation.JsonPropertyDescription +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.node.ObjectNode +import java.util.Optional +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.AfterTestExecutionCallback +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.api.extension.RegisterExtension + +/** Tests the [JsonSchemaValidator] and, in passing, tests the [extractSchema] function. */ +internal class JsonSchemaValidatorTest { + companion object { + private const val SCHEMA = "\$schema" + private const val SCHEMA_VER = "https://json-schema.org/draft/2020-12/schema" + private const val DEFS = "\$defs" + private const val REF = "\$ref" + + /** + * `true` to print the schema and validation errors for all executed tests, or `false` to + * print them only for failed tests. + */ + private const val VERBOSE_MODE = false + } + + /** + * A validator that can be used by each unit test. A new validation instance is created for each + * test, as each test is run from its own instance of the test class. If a test fails, any + * validation errors are automatically printed to standard output to aid diagnosis. + */ + val validator = JsonSchemaValidator.create() + + /** + * The schema that was created by the unit test. This may be printed out after a test fails to + * aid in diagnosing the cause of the failure. In that case, this property must be set, or an + * error will occur. However, it will only be printed if the failed test method has the name + * prefix `schemaTest_`, so only test methods with that naming pattern need to set this field. + */ + lateinit var schema: JsonNode + + /** + * An extension to JUnit that prints the [schema] and the validation status (including any + * errors) when a test fails. This applies only to test methods whose names are prefixed with + * `schemaTest_`. An error will occur if [schema] was not set, but this can be avoided by only + * using the method name prefix for test methods that set [schema]. This reporting is intended + * as an aid to diagnosing test failures. + */ + @Suppress("unused") + @RegisterExtension + val printValidationErrorsOnFailure: AfterTestExecutionCallback = + object : AfterTestExecutionCallback { + @Throws(Exception::class) + override fun afterTestExecution(context: ExtensionContext) { + if ( + context.displayName.startsWith("schemaTest_") && + (VERBOSE_MODE || context.executionException.isPresent) + ) { + // Test failed. + println("Schema: ${schema.toPrettyString()}\n") + println("$validator\n") + } + } + } + + // NOTE: In most of these tests, it is assumed that the schema is generated as expected; it is + // not examined in fine detail if the validator succeeds or fails with the expected errors. + + @Test + fun schemaTest_minimalSchema() { + class X() + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_minimalListSchema() { + val s: List = listOf() + + schema = extractSchema(s.javaClass) + validator.validate(schema) + + // FIXME: Currently, the generated schema looks like this: + // { + // "$schema" : "https://json-schema.org/draft/2020-12/schema", + // "type" : "array", + // "items" : { } + // } + // That causes an error, as the `"items"` object is empty when it should be a valid + // sub-schema. Something like this is what is expected: + // { + // "$schema" : "https://json-schema.org/draft/2020-12/schema", + // "type" : "array", + // "items" : { + // "type" : "string" + // } + // } + // It might be presumed that type erasure is the cause of the missing field. However, the + // `schemaTest_listFieldSchema` method (below) seems to be able to produce the expected + // `"items"` object when it is defined as a class property, so, well ... huh? + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_listFieldSchema() { + @Suppress("unused") class X(val s: List) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + // This gives a root schema with `"type" : "string"` and `"const" : "HELLO"` + // Unfortunately, an "enum class" cannot be defined within a function or within a class within + // a function. + @Suppress("unused") + enum class MinimalEnum1 { + HELLO + } + + @Test + fun schemaTest_minimalEnumSchema1() { + schema = extractSchema(MinimalEnum1::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + // This gives a root schema with `"type" : "string"` and `"enum" : [ "HELLO", "WORLD" ]` + @Suppress("unused") + enum class MinimalEnum2 { + HELLO, + WORLD, + } + + @Test + fun schemaTest_minimalEnumSchema2() { + schema = extractSchema(MinimalEnum2::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_nonStringEnum() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "integer", + "enum" : [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchema() { + @Suppress("unused") class X(val s: String) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchemaFromOptionalString() { + // Using an `Optional` will result in this JSON: `"type" : [ "string", "null" ]`. + // That is supported by the OpenAI Structured Outputs API spec, as long as the field is also + // marked as required. Though required, it is still allowed for the field to be explicitly + // set to `"null"`. + @Suppress("unused") class X(val s: Optional) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchemaFromOptionalBoolean() { + @Suppress("unused") class X(val b: Optional) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchemaFromOptionalInteger() { + @Suppress("unused") class X(val i: Optional) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinySchemaFromOptionalNumber() { + @Suppress("unused") class X(val n: Optional) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_arraySchemaFromOptional() { + @Suppress("unused") class X(val s: Optional>) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_arrayTypeMissingItems() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "array" + } + """ + ) + validator.validate(schema) + + // Check once here that "validator.isValid()" returns "false" when there is an error. In + // the other tests, there is no need to repeat this assertion, as it would be redundant. + assertThat(validator.isValid()).isFalse + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'items' field is missing or is not an object.") + } + + @Test + fun schemaTest_arrayTypeWithWrongItemsType() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "array", + "items" : [ "should_not_be_an_array" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'items' field is missing or is not an object.") + } + + @Test + @Suppress("unused") + fun schemaTest_objectSubSchemaFromOptional() { + class X(val s: Optional) + class Y(val x: Optional) + + schema = extractSchema(Y::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_badOptionalTypeNotArray() { + // Testing more for code coverage than for anything expected to go wrong in practice. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : { "type" : "string" } + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'type' field is not a type name or array of type names.") + } + + @Test + fun schemaTest_badOptionalTypeNoNull1() { + // Testing more for code coverage than for anything expected to go wrong in practice. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "string" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected exactly two types, both strings.") + } + + @Test + fun schemaTest_badOptionalTypeNoNull2() { + // If "type" is an array, one of the two "type" values must be "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "string", "number" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected one type name and one \"null\".") + } + + @Test + fun schemaTest_badOptionalTypeNoNull3() { + // If "type" is an array, there must be two type values only, one of them "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "string", "number", "null" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected exactly two types, both strings.") + } + + @Test + fun schemaTest_badOptionalTypeNoStringTypeNames() { + // If "type" is an array, there must be two type values only, one of them "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "string", null ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected exactly two types, both strings.") + } + + @Test + fun schemaTest_badOptionalTypeAllNull() { + // If "type" is an array, there must be two type values only, and only one of them "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "null", "null" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/type: Expected one type name and one \"null\".") + } + + @Test + fun schemaTest_badOptionalTypeUnknown() { + // If "type" is an array, there must be two type values only, and only one of them "null". + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "unknown", "null" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).isEqualTo("#/type: Unsupported 'type' value: 'unknown'.") + } + + @Test + fun schemaTest_goodOptionalTypeNullFirst() { + // The validator should be lenient about the order of the null/not-null types in the array. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : [ "null", "string" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_tinyRecursiveSchema() { + @Suppress("unused") class X(val s: String, val x: X) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_unsupportedKeywords() { + // OpenAI lists a set of keywords that are not allowed, but the set is not exhaustive. Check + // that everything named in that set is identified as not allowed, as that is the minimum + // level of validation expected. Check at the root schema and a sub-schema. There is no need + // to match the keywords to their expected schema types or be concerned about the values of + // the keyword fields, which makes testing easier. + val keywordsNotAllowed = + listOf( + "minLength", + "maxLength", + "pattern", + "format", + "minimum", + "maximum", + "multipleOf", + "patternProperties", + "unevaluatedProperties", + "propertyNames", + "minProperties", + "maxProperties", + "unevaluatedItems", + "contains", + "minContains", + "maxContains", + "minItems", + "maxItems", + "uniqueItems", + ) + val notAllowedUses = keywordsNotAllowed.joinToString(", ") { "\"$it\" : \"\"" } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "x" : { + "type" : "string", + $notAllowedUses + } + }, + $notAllowedUses, + "additionalProperties" : false, + "required" : [ "x" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(keywordsNotAllowed.size * 2) + keywordsNotAllowed.forEachIndexed { index, keyword -> + assertThat(validator.errors()[index]) + .isEqualTo("#: Use of '$keyword' is not supported here.") + assertThat(validator.errors()[index + keywordsNotAllowed.size]) + .isEqualTo("#/properties/x: Use of '$keyword' is not supported here.") + } + } + + @Test + fun schemaTest_propertyNotMarkedRequired() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false, + "required" : [ ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/required: 'properties' field 'name' is not listed as 'required'.") + } + + @Test + fun schemaTest_requiredArrayNull() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false, + "required" : null + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/required: 'properties' field 'name' is not listed as 'required'.") + } + + @Test + fun schemaTest_requiredArrayMissing() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/required: 'properties' field 'name' is not listed as 'required'.") + } + + @Test + fun schemaTest_additionalPropertiesMissing() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "required" : [ "name" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'additionalProperties' field is missing or is not set to 'false'.") + } + + @Test + fun schemaTest_additionalPropertiesTrue() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : true, + "required" : [ "name" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'additionalProperties' field is missing or is not set to 'false'.") + } + + @Test + fun schemaTest_objectPropertiesMissing() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "additionalProperties" : false, + "required" : [ ] + } + """ + ) + validator.validate(schema) + + // TODO: Decide if this is the expected behavior, i.e., that it is OK for an "object" schema + // to have no "properties". + assertThat(validator.isValid()).isTrue() + } + + @Test + fun schemaTest_objectPropertiesNotObject() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : [ "name", "age" ], + "additionalProperties" : false, + "required" : [ "name", "age" ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'properties' field is not a non-empty object.") + } + + @Test + fun schemaTest_objectPropertiesEmpty() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { }, + "additionalProperties" : false, + "required" : [ ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: 'properties' field is not a non-empty object.") + } + + @Test + fun schemaTest_anyOfInRootSchema() { + // OpenAI does not allow `"anyOf"` to appear at the root level of a schema. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "anyOf" : [ { + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false, + "required" : ["name"] + }, { + "type" : "array", + "items" : { + "type" : "object", + "properties" : { "name" : { "type" : "string" } }, + "additionalProperties" : false, + "required" : ["name"] + } + } ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).isEqualTo("#: Root schema contains 'anyOf' field.") + } + + @Test + fun schemaTest_anyOfNotArray() { + // Unlikely that this can occur in a generated schema, so this is more about code coverage. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "name" : { + "anyOf" : { + "type" : "string" + } + } + }, + "additionalProperties" : false, + "required" : ["name"] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/properties/name: 'anyOf' field is not a non-empty array.") + } + + @Test + fun schemaTest_anyOfIsEmptyArray() { + // Unlikely that this can occur in a generated schema, so this is more about code coverage. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "name" : { + "anyOf" : [ ] + } + }, + "additionalProperties" : false, + "required" : ["name"] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/properties/name: 'anyOf' field is not a non-empty array.") + } + + @Test + fun schemaTest_anyOfInSubSchemaArray() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "value" : { + "anyOf" : [ + { "type" : "string" }, + { "type" : "number" } + ] + } + }, + "additionalProperties" : false, + "required" : ["value"] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_noSchemaFieldRootSchema() { + @Suppress("unused") class X(val s: String) + + schema = extractSchema(X::class.java) + (schema as ObjectNode).remove(SCHEMA) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).isEqualTo("#: Root schema missing '$SCHEMA' field.") + } + + @Test + @Suppress("unused") + fun schemaTest_deepNestingAtLimit() { + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + + schema = extractSchema(Y::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + @Suppress("unused") + fun schemaTest_deepNestingBeyondLimit() { + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + schema = extractSchema(Z::class.java) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).contains("Current nesting depth is 6, but maximum is 5.") + } + + @Test + fun schemaTest_stringEnum250ValueOverSizeLimit() { + // OpenAI specification: "For a single enum property with string values, the total string + // length of all enum values cannot exceed 7,500 characters when there are more than 250 + // enum values." + + // This test creates an enum with exactly 250 string values with more than 7,500 characters + // in total (31 characters per value for a total of 7,750 characters). No error is expected. + val values = (1..250).joinToString(", ") { "\"%s%03d\"".format("x".repeat(28), it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "string", + "enum" : [ $values ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_stringEnum251ValueUnderSizeLimit() { + // This test creates an enum with exactly 251 string values with fewer than 7,500 characters + // in total (29 characters per value for a total of 7,279 characters). No error is expected. + val values = (1..251).joinToString(", ") { "\"%s%03d\"".format("x".repeat(26), it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "string", + "enum" : [ $values ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_stringEnum251ValueOverSizeLimit() { + // This test creates an enum with exactly 251 string values with fewer than 7,500 characters + // in total (30 characters per value for a total of 7,530 characters). An error is expected. + val values = (1..251).joinToString(", ") { "\"%s%03d\"".format("x".repeat(27), it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "string", + "enum" : [ $values ] + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo( + "#/enum: Total string length (7530) of values of an enum " + + "with 251 values exceeds limit of 7500." + ) + } + + @Test + fun schemaTest_totalEnumValuesAtLimit() { + // OpenAI specification: "A schema may have up to 500 enum values across all enum + // properties." + + // This test creates two enums with a total of 500 values. The total string length of the + // values is well within the limits (2,000 characters). + val valuesA = (1..250).joinToString(", ") { "\"a%03d\"".format(it) } + val valuesB = (1..250).joinToString(", ") { "\"b%03d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "a" : { + "type" : "string", + "enum" : [ $valuesA ] + }, + "b" : { + "type" : "string", + "enum" : [ $valuesB ] + } + }, + "required" : [ "a", "b" ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_totalEnumValuesOverLimit() { + // This test creates two enums with a total of 501 values. The total string length of the + // values is well within the limits (2,004 characters). + val valuesA = (1..250).joinToString(", ") { "\"a%03d\"".format(it) } + val valuesB = (1..251).joinToString(", ") { "\"b%03d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "a" : { + "type" : "string", + "enum" : [ $valuesA ] + }, + "b" : { + "type" : "string", + "enum" : [ $valuesB ] + } + }, + "required" : [ "a", "b" ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: Total number of enum values (501) exceeds limit of 500.") + } + + @Test + fun schemaTest_maxObjectPropertiesAtLimit() { + // This test creates two object schemas with a total of 100 object properties. OpenAI does + // not support more than 100 properties total in the whole schema. Two objects are used to + // ensure that counting is not done per object, but across all objects. Note that each + // object schema is itself a property, so there are two properties at the top level and 49 + // properties each at the next level. No error is expected, as the limit is not exceeded. + val propUses = + (1..49).joinToString(", ") { "\"x%02d\" : { \"type\" : \"string\" }".format(it) } + val propNames = (1..49).joinToString(", ") { "\"x%02d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "a" : { + "type" : "object", + "properties" : { + $propUses + }, + "required" : [ $propNames ], + "additionalProperties" : false + }, + "b" : { + "type" : "object", + "properties" : { + $propUses + }, + "required" : [ $propNames ], + "additionalProperties" : false + } + }, + "required" : [ "a", "b" ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_maxObjectPropertiesOverLimit() { + // This test creates two object schemas with a total of 101 object properties. OpenAI does + // not support more than 100 properties total in the whole schema. Expect an error. + val propUses = + (1..49).joinToString(", ") { "\"x_%02d\" : { \"type\" : \"string\" }".format(it) } + val propNames = (1..49).joinToString(", ") { "\"x_%02d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "type" : "object", + "properties" : { + "a" : { + "type" : "object", + "properties" : { + $propUses + }, + "required" : [ $propNames ], + "additionalProperties" : false + }, + "b" : { + "type" : "object", + "properties" : { + $propUses, + "property_101" : { "type" : "string" } + }, + "required" : [ $propNames, "property_101" ], + "additionalProperties" : false + } + }, + "required" : [ "a", "b" ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: Total number of object properties (101) exceeds limit of 100.") + } + + @Test + fun schemaTest_maxStringLengthAtLimit() { + // OpenAI specification: "In a schema, total string length of all property names, definition + // names, enum values, and const values cannot exceed 15,000 characters." + // + // This test creates a schema with many property names, definition names, enum values, and + // const values calculated to have a total string length of 15,000 characters. No error is + // expected. + // + // The test creates a schema that looks like the following, with the numbers adjusted to + // achieve a total of 15,000 characters for the relevant elements. + // + // { + // "$schema" : "...", + // "$defs" : { + // "d_001" : { + // "type" : "string", + // "const" : "c_001" + // }, + // ..., + // "d_nnn" : { + // "type" : "string", + // "const" : "c_nnn" + // } + // }, + // "type" : "object", + // "properties" : { + // "p_001" : { + // "type" : "string", + // "enum" : [ "eeeee..._001", ..., "eeeee..._nnn" ] + // }, + // ..., + // "p_nnn" : { + // "type" : "string", + // "enum" : [ "eeeee..._001", ..., "eeeee..._nnn" ] + // } + // }, + // "required" : [ "p_001", ..., "p_nnn" ], + // "additionalProperties" : false + // } + + val numDefs = 65 // Each also has one "const" value. + val numProps = 70 // Each also has "numEnumValues" enum values. + val nameLen = 5 // Length of names of definitions, properties and const values. + val numEnumValues = 5 // numProps * numEnumValues <= 500 limit (OpenAI) + val enumValueLen = 40 // Length of enum values. + val expectedTotalStringLength = + nameLen * (numProps + numDefs * 2) + numProps * enumValueLen * numEnumValues + + val enumValues = + (1..numEnumValues).joinToString(", ") { "\"%s_%03d\"".format("e".repeat(36), it) } + val defs = + (1..numDefs).joinToString(", ") { + "\"d_%03d\" : { \"type\" : \"string\", \"const\" : \"c_%03d\" }".format(it, it) + } + val props = + (1..numProps).joinToString(", ") { + "\"p_%03d\" : { \"type\" : \"string\", \"enum\" : [ $enumValues ] }".format(it) + } + val propNames = (1..numProps).joinToString(", ") { "\"p_%03d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { $defs }, + "type" : "object", + "properties" : { $props }, + "required" : [ $propNames ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(expectedTotalStringLength).isEqualTo(15_000) // Exactly on the limit. + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_maxStringLengthOverLimit() { + // OpenAI specification: "In a schema, total string length of all property names, definition + // names, enum values, and const values cannot exceed 15,000 characters." + // + // This test creates a schema with many property names, definition names, enum values, and + // const values calculated to have a total string length of just over 15,000 characters. An + // error is expected. + + val numDefs = 66 // Each also has one "const" value. + val numProps = 70 // Each also has "numEnumValues" enum values. + val numEnumValues = 5 // numProps * numEnumValues <= 500 limit (OpenAI) + val nameLen = 5 // Length of names of definitions, properties and const values. + val enumValueLen = 40 // Length of enum values. + val expectedTotalStringLength = + nameLen * (numProps + numDefs * 2) + numProps * enumValueLen * numEnumValues + + val enumValues = + (1..numEnumValues).joinToString(", ") { "\"%s_%03d\"".format("e".repeat(36), it) } + val defs = + (1..numDefs).joinToString(", ") { + "\"d_%03d\" : { \"type\" : \"string\", \"const\" : \"c_%03d\" }".format(it, it) + } + val props = + (1..numProps).joinToString(", ") { + "\"p_%03d\" : { \"type\" : \"string\", \"enum\" : [ $enumValues ] }".format(it) + } + val propNames = (1..numProps).joinToString(", ") { "\"p_%03d\"".format(it) } + + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { $defs }, + "type" : "object", + "properties" : { $props }, + "required" : [ $propNames ], + "additionalProperties" : false + } + """ + ) + validator.validate(schema) + + assertThat(expectedTotalStringLength).isGreaterThan(15_000) + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#: Total string length of all values (15010) exceeds limit of 15000.") + } + + @Test + fun schemaTest_annotatedWithJsonClassDescription() { + // Add a "description" to the root schema using an annotation. + @JsonClassDescription("A simple schema.") class X() + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val desc = schema.get("description") + + assertThat(validator.isValid()).isTrue + assertThat(desc).isNotNull + assertThat(desc.isTextual).isTrue + assertThat(desc.asText()).isEqualTo("A simple schema.") + } + + @Test + fun schemaTest_annotatedWithJsonPropertyDescription() { + // Add a "description" to the property using an annotation. + @Suppress("unused") class X(@get:JsonPropertyDescription("A string value.") val s: String) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val properties = schema.get("properties") + val stringProperty = properties.get("s") + val desc = stringProperty.get("description") + + assertThat(validator.isValid()).isTrue + assertThat(desc).isNotNull + assertThat(desc.isTextual).isTrue + assertThat(desc.asText()).isEqualTo("A string value.") + } + + @Test + fun schemaTest_annotatedWithJsonProperty() { + // Override the default name of the property using the annotation. + @Suppress("unused") class X(@get:JsonProperty("a_string") val s: String) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val properties = schema.get("properties") + val stringProperty = properties.get("a_string") + + assertThat(validator.isValid()).isTrue + assertThat(stringProperty).isNotNull + } + + @Test + fun schemaTest_annotatedWithJsonPropertyRejectDefaultValue() { + // Set a default value for the property. It should be ignored when the schema is generated, + // as default property values are not supported in OpenAI JSON schemas. (The Victools docs + // have examples of how to add support for this default values via annotations or initial + // values, should support for default values be needed in the future.) + // + // Lack of support is not mentioned in the specification, but see the evidence at: + // https://engineering.fractional.ai/openai-structured-output-fixes + @Suppress("unused") + class X( + @get:JsonProperty(defaultValue = "default_value_1") val s: String = "default_value_2" + ) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val properties = schema.get("properties") + val stringProperty = properties.get("s") + + assertThat(validator.isValid()).isTrue + assertThat(stringProperty).isNotNull + assertThat(stringProperty.get("default")).isNull() + } + + @Test + fun schemaTest_annotatedWithJsonIgnore() { + // Override the default name of the property using the annotation. + @Suppress("unused") class X(@get:JsonIgnore val s1: String, val s2: String) + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Assume that the schema is well-formed. + val properties = schema.get("properties") + val s1Property = properties.get("s1") + val s2Property = properties.get("s2") + + assertThat(validator.isValid()).isTrue + assertThat(s1Property).isNull() + assertThat(s2Property).isNotNull + } + + @Test + fun schemaTest_emptyDefinitions() { + // Be lenient about empty definitions. + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { }, + "type" : "string" + } + """ + ) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + @Test + fun schemaTest_referenceMissingReferent() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { }, + "$REF" : "#/$DEFS/Person" + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]) + .isEqualTo("#/$REF: Invalid or unsupported reference: '#/$DEFS/Person'.") + } + + @Test + fun schemaTest_referenceFieldIsNotTextual() { + schema = + parseJson( + """ + { + "$SCHEMA" : "$SCHEMA_VER", + "$DEFS" : { }, + "$REF" : 42 + } + """ + ) + validator.validate(schema) + + assertThat(validator.errors()).hasSize(1) + assertThat(validator.errors()[0]).isEqualTo("#/$REF: '$REF' field is not a text value.") + } + + @Test + fun validatorBeforeValidation() { + assertThat(validator.errors()).isEmpty() + assertThat(validator.isValid()).isFalse + } + + @Test + fun validatorReused() { + class X() + + schema = extractSchema(X::class.java) + validator.validate(schema) + + // Should fail if an attempt is made to reuse the validator. + assertThatThrownBy { validator.validate(schema) } + .isExactlyInstanceOf(IllegalStateException::class.java) + .hasMessageContaining("Validation already complete.") + } + + @Test + @Suppress("unused") + fun schemaTest_largeLaureatesSchema() { + // This covers many cases: large and complex "$defs", resolution of references, recursive + // references, etc. The output is assumed to be good (it has been checked by eye) and the + // test just shows that the validator can handle the complexity without crashing or emitting + // spurious errors. + class Name(val givenName: String, val familyName: String) + + class Person( + @get:JsonPropertyDescription("The name of the person.") val name: Name, + @get:JsonProperty(value = "date_of_birth", defaultValue = "unknown_1") + @get:JsonPropertyDescription("The date of birth of the person.") + var dateOfBirth: String, + @get:JsonPropertyDescription("The country of citizenship of the person.") + var nationality: String, + // A child being a `Person` results in a recursive schema. + @get:JsonPropertyDescription("The children (if any) of the person.") + val children: List, + ) { + @get:JsonPropertyDescription("The other name of the person.") + var otherName: Name = Name("Bob", "Smith") + } + + class Laureate( + val laureate: Person, + val majorContribution: String, + val yearOfWinning: String, + @get:JsonIgnore val favoriteColor: String, + ) + + class Laureates( + // Two lists results in a `Laureate` definition that is referenced in the schema. + var laureates1901to1950: List, + var laureates1951to2025: List, + ) + + schema = extractSchema(Laureates::class.java) + validator.validate(schema) + + assertThat(validator.isValid()).isTrue + } + + private fun parseJson(schemaString: String) = ObjectMapper().readTree(schemaString) +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt index e47aebd6..fb52ffc6 100644 --- a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParamsTest.kt @@ -347,4 +347,36 @@ internal class ChatCompletionCreateParamsTest { ) assertThat(body.model()).isEqualTo(ChatModel.GPT_4_1) } + + @Test + fun structuredOutputsBuilder() { + class X(val s: String) + + // Only interested in a few things: + // - Does the `Builder` type change when `responseFormat(Class)` is called? + // - Are values already set on the "old" `Builder` preserved in the change-over? + // - Can new values be set on the "new" `Builder` alongside the "old" values? + val params = + ChatCompletionCreateParams.builder() + .addDeveloperMessage("dev message") + .model(ChatModel.GPT_4_1) + .responseFormat(X::class.java) // Creates and return a new builder. + .addSystemMessage("sys message") + .build() + + val body = params.rawParams._body() + + assertThat(params).isInstanceOf(StructuredChatCompletionCreateParams::class.java) + assertThat(params.responseFormat).isEqualTo(X::class.java) + assertThat(body.messages()) + .containsExactly( + ChatCompletionMessageParam.ofDeveloper( + ChatCompletionDeveloperMessageParam.builder().content("dev message").build() + ), + ChatCompletionMessageParam.ofSystem( + ChatCompletionSystemMessageParam.builder().content("sys message").build() + ), + ) + assertThat(body.model()).isEqualTo(ChatModel.GPT_4_1) + } } diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt new file mode 100644 index 00000000..4abd66b6 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParamsTest.kt @@ -0,0 +1,499 @@ +package com.openai.models.chat.completions + +import com.openai.core.fromClass +import com.openai.core.http.Headers +import com.openai.core.http.QueryParams +import com.openai.models.ChatModel +import com.openai.models.FunctionDefinition +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.JSON_FIELD +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.JSON_VALUE +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.MESSAGE +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.OPTIONAL +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.STRING +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.X +import java.lang.reflect.Method +import kotlin.collections.plus +import kotlin.reflect.full.declaredFunctions +import kotlin.reflect.jvm.javaMethod +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.fail +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredChatCompletionCreateParams] class (delegator) and its delegation of + * most functions to a wrapped [ChatCompletionCreateParams] (delegate). It is the `Builder` class of + * each main class that is involved in the delegation. The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredChatCompletionCreateParamsTest { + companion object { + private fun checkOneDelegationWrite( + delegator: Any, + mockDelegate: Any, + testCase: DelegationWriteTestCase, + ) { + invokeMethod(findDelegationMethod(delegator, testCase), delegator, testCase) + + // Verify that the corresponding method on the mock delegate was called exactly once. + verify(mockDelegate, times(1)).apply { + invokeMethod(findDelegationMethod(mockDelegate, testCase), mockDelegate, testCase) + } + verifyNoMoreInteractions(mockDelegate) + } + + private fun invokeMethod(method: Method, target: Any, testCase: DelegationWriteTestCase) { + val numParams = testCase.inputValues.size + val inputValue1 = testCase.inputValues[0] + val inputValue2 = testCase.inputValues.getOrNull(1) + + when (numParams) { + 1 -> method.invoke(target, inputValue1) + 2 -> method.invoke(target, inputValue1, inputValue2) + else -> fail { "Unexpected number of function parameters ($numParams)." } + } + } + + /** + * Finds the java method matching the test case's function name and parameter types in the + * delegator or delegate `target`. + */ + private fun findDelegationMethod(target: Any, testCase: DelegationWriteTestCase): Method { + val numParams = testCase.inputValues.size + val inputValue1: Any? = testCase.inputValues[0] + val inputValue2 = if (numParams > 1) testCase.inputValues[1] else null + + val method = + when (numParams) { + 1 -> + if (inputValue1 != null) { + findJavaMethod( + target.javaClass, + testCase.functionName, + toJavaType(inputValue1.javaClass), + ) + } else { + // Only the first parameter may be nullable and only if it is the only + // parameter. If the first parameter is nullable, it will be the only + // function of the same name with a nullable first parameter. To handle + // the potentially nullable first parameter, Kotlin reflection is + // needed. This allows a function `f(Boolean)` to be distinguished from + // `f(Boolean?)`. For the tests, if the parameter type is nullable, the + // parameter value will always be `null` (if not, the function with the + // nullable parameter would not be matched). + // + // Using Kotlin reflection, the first parameter (zero index) is `this` + // object, so start matching from the second parameter onwards. + target::class + .declaredFunctions + .find { + it.name == testCase.functionName && + it.parameters[1].type.isMarkedNullable + } + ?.javaMethod + } + 2 -> + if (inputValue1 != null && inputValue2 != null) { + findJavaMethod( + target.javaClass, + testCase.functionName, + toJavaType(inputValue1.javaClass), + toJavaType(inputValue2.javaClass), + ) + } else { + // There are no instances where there are two parameters and one of them + // is nullable. + fail { "Function $testCase second parameter must not be null." } + } + else -> fail { "Function $testCase has unsupported number of parameters." } + } + + // Using `if` and `fail`, so the compiler knows the code will not continue and can infer + // that `delegationMethod` is not null. It cannot do this for `assertThat...isNotNull`. + if (method == null) { + fail { "Function $testCase cannot be found in $target." } + } + + return method + } + + private fun findJavaMethod( + clazz: Class<*>, + methodName: String, + vararg parameterTypes: Class<*>, + ): Method? = + clazz.declaredMethods.firstOrNull { method -> + method.name == methodName && + method.parameterTypes.size == parameterTypes.size && + method.parameterTypes.indices.all { index -> + (parameterTypes[index].isPrimitive && + method.parameterTypes[index] == parameterTypes[index]) || + method.parameterTypes[index].isAssignableFrom(parameterTypes[index]) + } + } + + /** + * Returns the Java type to use when matching type parameters for a Java method. The type is + * the type of the input value that will be used when the method is invoked. For most types, + * the given type is returned. However, if the type represents a Kotlin primitive, it will + * be converted to a Java primitive. This allows matching of methods with parameter types + * that are non-nullable Kotlin primitives. If not translated, methods with parameter types + * that are nullable Kotlin primitives would always be matched instead. + */ + private fun toJavaType(type: Class<*>) = + when (type) { + // This only needs to cover the types used in the test cases. + java.lang.Long::class.java -> java.lang.Long.TYPE + java.lang.Boolean::class.java -> java.lang.Boolean.TYPE + java.lang.Double::class.java -> java.lang.Double.TYPE + else -> type + } + + private val NULLABLE = null + private const val BOOLEAN: Boolean = true + private val NULLABLE_BOOLEAN: Boolean? = null + private const val LONG: Long = 42L + private val NULLABLE_LONG: Long? = null + private const val DOUBLE: Double = 42.0 + private val NULLABLE_DOUBLE: Double? = null + private val LIST = listOf(STRING) + private val SET = setOf(STRING) + private val MAP = mapOf(STRING to STRING) + + private val CHAT_MODEL = ChatModel.GPT_4 + + private val USER_MESSAGE_PARAM = + ChatCompletionUserMessageParam.builder().content(STRING).build() + private val DEV_MESSAGE_PARAM = + ChatCompletionDeveloperMessageParam.builder().content(STRING).build() + private val SYS_MESSAGE_PARAM = + ChatCompletionSystemMessageParam.builder().content(STRING).build() + private val ASSIST_MESSAGE_PARAM = + ChatCompletionAssistantMessageParam.builder().content(STRING).build() + private val TOOL_MESSAGE_PARAM = + ChatCompletionToolMessageParam.builder().content(STRING).toolCallId(STRING).build() + private val FUNC_MESSAGE_PARAM = + ChatCompletionFunctionMessageParam.builder().content(STRING).name(STRING).build() + private val MESSAGE_PARAM = ChatCompletionMessageParam.ofUser(USER_MESSAGE_PARAM) + + private val DEV_MESSAGE_PARAM_CONTENT = + ChatCompletionDeveloperMessageParam.Content.ofText(STRING) + private val SYS_MESSAGE_PARAM_CONTENT = + ChatCompletionSystemMessageParam.Content.ofText(STRING) + private val USER_MESSAGE_PARAM_CONTENT = + ChatCompletionUserMessageParam.Content.ofText(STRING) + + private val PARAMS_BODY = + ChatCompletionCreateParams.Body.builder() + .messages(listOf(MESSAGE_PARAM)) + .model(CHAT_MODEL) + .build() + private val WEB_SEARCH_OPTIONS = + ChatCompletionCreateParams.WebSearchOptions.builder().build() + + private val FUNCTION_CALL_MODE = + ChatCompletionCreateParams.FunctionCall.FunctionCallMode.AUTO + private val FUNCTION_CALL_OPTION = + ChatCompletionFunctionCallOption.builder().name(STRING).build() + private val FUNCTION_CALL = + ChatCompletionCreateParams.FunctionCall.ofFunctionCallOption(FUNCTION_CALL_OPTION) + + private val FUNCTION = ChatCompletionCreateParams.Function.builder().name(STRING).build() + private val METADATA = ChatCompletionCreateParams.Metadata.builder().build() + private val MODALITY = ChatCompletionCreateParams.Modality.TEXT + private val FUNCTION_DEFINITION = FunctionDefinition.builder().name(STRING).build() + private val TOOL = ChatCompletionTool.builder().function(FUNCTION_DEFINITION).build() + + private val NAMED_TOOL_CHOICE_FUNCTION = + ChatCompletionNamedToolChoice.Function.builder().name(STRING).build() + private val NAMED_TOOL_CHOICE = + ChatCompletionNamedToolChoice.builder().function(NAMED_TOOL_CHOICE_FUNCTION).build() + private val TOOL_CHOICE_OPTION_AUTO = ChatCompletionToolChoiceOption.Auto.AUTO + private val TOOL_CHOICE_OPTION = + ChatCompletionToolChoiceOption.ofAuto(TOOL_CHOICE_OPTION_AUTO) + + private val HEADERS = Headers.builder().build() + private val QUERY_PARAMS = QueryParams.builder().build() + + // Want `vararg`, so cannot use `data class`. Need a custom `toString`, anyway. + class DelegationWriteTestCase(val functionName: String, vararg val inputValues: Any?) { + /** + * Gets the string representation that identifies the test function when running JUnit. + */ + override fun toString(): String = + "$functionName(${inputValues.joinToString(", ") { + it?.javaClass?.simpleName ?: "null" + }})" + } + + // The list order follows the declaration order in `ChatCompletionCreateParams.Builder` for + // easier maintenance. + @JvmStatic + fun builderDelegationTestCases() = + listOf( + DelegationWriteTestCase("body", PARAMS_BODY), + DelegationWriteTestCase("messages", LIST), + DelegationWriteTestCase("messages", JSON_FIELD), + DelegationWriteTestCase("addMessage", MESSAGE_PARAM), + DelegationWriteTestCase("addMessage", DEV_MESSAGE_PARAM), + DelegationWriteTestCase("addDeveloperMessage", DEV_MESSAGE_PARAM_CONTENT), + DelegationWriteTestCase("addDeveloperMessage", STRING), + DelegationWriteTestCase("addDeveloperMessageOfArrayOfContentParts", LIST), + DelegationWriteTestCase("addMessage", SYS_MESSAGE_PARAM), + DelegationWriteTestCase("addSystemMessage", SYS_MESSAGE_PARAM_CONTENT), + DelegationWriteTestCase("addSystemMessage", STRING), + DelegationWriteTestCase("addSystemMessageOfArrayOfContentParts", LIST), + DelegationWriteTestCase("addMessage", USER_MESSAGE_PARAM), + DelegationWriteTestCase("addUserMessage", USER_MESSAGE_PARAM_CONTENT), + DelegationWriteTestCase("addUserMessage", STRING), + DelegationWriteTestCase("addUserMessageOfArrayOfContentParts", LIST), + DelegationWriteTestCase("addMessage", ASSIST_MESSAGE_PARAM), + DelegationWriteTestCase("addMessage", MESSAGE), + DelegationWriteTestCase("addMessage", TOOL_MESSAGE_PARAM), + DelegationWriteTestCase("addMessage", FUNC_MESSAGE_PARAM), + DelegationWriteTestCase("model", CHAT_MODEL), + DelegationWriteTestCase("model", JSON_FIELD), + DelegationWriteTestCase("model", STRING), + DelegationWriteTestCase("audio", NULLABLE), + DelegationWriteTestCase("audio", OPTIONAL), + DelegationWriteTestCase("audio", JSON_FIELD), + DelegationWriteTestCase("frequencyPenalty", NULLABLE_DOUBLE), + DelegationWriteTestCase("frequencyPenalty", DOUBLE), + DelegationWriteTestCase("frequencyPenalty", OPTIONAL), + DelegationWriteTestCase("frequencyPenalty", JSON_FIELD), + DelegationWriteTestCase("functionCall", FUNCTION_CALL), + DelegationWriteTestCase("functionCall", JSON_FIELD), + DelegationWriteTestCase("functionCall", FUNCTION_CALL_MODE), + DelegationWriteTestCase("functionCall", FUNCTION_CALL_OPTION), + DelegationWriteTestCase("functions", LIST), + DelegationWriteTestCase("functions", JSON_FIELD), + DelegationWriteTestCase("addFunction", FUNCTION), + DelegationWriteTestCase("logitBias", NULLABLE), + DelegationWriteTestCase("logitBias", OPTIONAL), + DelegationWriteTestCase("logitBias", JSON_FIELD), + DelegationWriteTestCase("logprobs", NULLABLE_BOOLEAN), + DelegationWriteTestCase("logprobs", BOOLEAN), + DelegationWriteTestCase("logprobs", OPTIONAL), + DelegationWriteTestCase("logprobs", JSON_FIELD), + DelegationWriteTestCase("maxCompletionTokens", NULLABLE_LONG), + DelegationWriteTestCase("maxCompletionTokens", LONG), + DelegationWriteTestCase("maxCompletionTokens", OPTIONAL), + DelegationWriteTestCase("maxCompletionTokens", JSON_FIELD), + DelegationWriteTestCase("maxTokens", NULLABLE_LONG), + DelegationWriteTestCase("maxTokens", LONG), + DelegationWriteTestCase("maxTokens", OPTIONAL), + DelegationWriteTestCase("maxTokens", JSON_FIELD), + DelegationWriteTestCase("metadata", METADATA), + DelegationWriteTestCase("metadata", OPTIONAL), + DelegationWriteTestCase("metadata", JSON_FIELD), + DelegationWriteTestCase("modalities", LIST), + DelegationWriteTestCase("modalities", OPTIONAL), + DelegationWriteTestCase("modalities", JSON_FIELD), + DelegationWriteTestCase("addModality", MODALITY), + DelegationWriteTestCase("n", NULLABLE_LONG), + DelegationWriteTestCase("n", LONG), + DelegationWriteTestCase("n", OPTIONAL), + DelegationWriteTestCase("n", JSON_FIELD), + DelegationWriteTestCase("parallelToolCalls", BOOLEAN), + DelegationWriteTestCase("parallelToolCalls", JSON_FIELD), + DelegationWriteTestCase("prediction", NULLABLE), + DelegationWriteTestCase("prediction", OPTIONAL), + DelegationWriteTestCase("prediction", JSON_FIELD), + DelegationWriteTestCase("presencePenalty", NULLABLE_DOUBLE), + DelegationWriteTestCase("presencePenalty", DOUBLE), + DelegationWriteTestCase("presencePenalty", OPTIONAL), + DelegationWriteTestCase("presencePenalty", JSON_FIELD), + DelegationWriteTestCase("reasoningEffort", NULLABLE), + DelegationWriteTestCase("reasoningEffort", OPTIONAL), + DelegationWriteTestCase("reasoningEffort", JSON_FIELD), + // `responseFormat()` is a special case and has its own unit test. + DelegationWriteTestCase("seed", NULLABLE_LONG), + DelegationWriteTestCase("seed", LONG), + DelegationWriteTestCase("seed", OPTIONAL), + DelegationWriteTestCase("seed", JSON_FIELD), + DelegationWriteTestCase("serviceTier", NULLABLE), + DelegationWriteTestCase("serviceTier", OPTIONAL), + DelegationWriteTestCase("serviceTier", JSON_FIELD), + DelegationWriteTestCase("stop", NULLABLE), + DelegationWriteTestCase("stop", OPTIONAL), + DelegationWriteTestCase("stop", JSON_FIELD), + DelegationWriteTestCase("stop", STRING), + DelegationWriteTestCase("stopOfStrings", LIST), + DelegationWriteTestCase("store", NULLABLE_BOOLEAN), + DelegationWriteTestCase("store", BOOLEAN), + DelegationWriteTestCase("store", OPTIONAL), + DelegationWriteTestCase("store", JSON_FIELD), + DelegationWriteTestCase("streamOptions", NULLABLE), + DelegationWriteTestCase("streamOptions", OPTIONAL), + DelegationWriteTestCase("streamOptions", JSON_FIELD), + DelegationWriteTestCase("temperature", NULLABLE_DOUBLE), + DelegationWriteTestCase("temperature", DOUBLE), + DelegationWriteTestCase("temperature", OPTIONAL), + DelegationWriteTestCase("temperature", JSON_FIELD), + DelegationWriteTestCase("toolChoice", TOOL_CHOICE_OPTION), + DelegationWriteTestCase("toolChoice", JSON_FIELD), + DelegationWriteTestCase("toolChoice", TOOL_CHOICE_OPTION_AUTO), + DelegationWriteTestCase("toolChoice", NAMED_TOOL_CHOICE), + DelegationWriteTestCase("tools", LIST), + DelegationWriteTestCase("tools", JSON_FIELD), + DelegationWriteTestCase("addTool", TOOL), + DelegationWriteTestCase("topLogprobs", NULLABLE_LONG), + DelegationWriteTestCase("topLogprobs", LONG), + DelegationWriteTestCase("topLogprobs", OPTIONAL), + DelegationWriteTestCase("topLogprobs", JSON_FIELD), + DelegationWriteTestCase("topP", NULLABLE_DOUBLE), + DelegationWriteTestCase("topP", DOUBLE), + DelegationWriteTestCase("topP", OPTIONAL), + DelegationWriteTestCase("topP", JSON_FIELD), + DelegationWriteTestCase("user", STRING), + DelegationWriteTestCase("user", JSON_FIELD), + DelegationWriteTestCase("webSearchOptions", WEB_SEARCH_OPTIONS), + DelegationWriteTestCase("webSearchOptions", JSON_FIELD), + DelegationWriteTestCase("additionalBodyProperties", MAP), + DelegationWriteTestCase("putAdditionalBodyProperty", STRING, JSON_VALUE), + DelegationWriteTestCase("putAllAdditionalBodyProperties", MAP), + DelegationWriteTestCase("removeAdditionalBodyProperty", STRING), + DelegationWriteTestCase("removeAllAdditionalBodyProperties", SET), + DelegationWriteTestCase("additionalHeaders", HEADERS), + DelegationWriteTestCase("additionalHeaders", MAP), + DelegationWriteTestCase("putAdditionalHeader", STRING, STRING), + DelegationWriteTestCase("putAdditionalHeaders", STRING, LIST), + DelegationWriteTestCase("putAllAdditionalHeaders", HEADERS), + DelegationWriteTestCase("putAllAdditionalHeaders", MAP), + DelegationWriteTestCase("replaceAdditionalHeaders", STRING, STRING), + DelegationWriteTestCase("replaceAdditionalHeaders", STRING, LIST), + DelegationWriteTestCase("replaceAllAdditionalHeaders", HEADERS), + DelegationWriteTestCase("replaceAllAdditionalHeaders", MAP), + DelegationWriteTestCase("removeAdditionalHeaders", STRING), + DelegationWriteTestCase("removeAllAdditionalHeaders", SET), + DelegationWriteTestCase("additionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("additionalQueryParams", MAP), + DelegationWriteTestCase("putAdditionalQueryParam", STRING, STRING), + DelegationWriteTestCase("putAdditionalQueryParams", STRING, LIST), + DelegationWriteTestCase("putAllAdditionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("putAllAdditionalQueryParams", MAP), + DelegationWriteTestCase("replaceAdditionalQueryParams", STRING, STRING), + DelegationWriteTestCase("replaceAdditionalQueryParams", STRING, LIST), + DelegationWriteTestCase("replaceAllAdditionalQueryParams", QUERY_PARAMS), + DelegationWriteTestCase("replaceAllAdditionalQueryParams", MAP), + DelegationWriteTestCase("removeAdditionalQueryParams", STRING), + DelegationWriteTestCase("removeAllAdditionalQueryParams", SET), + ) + } + + // New instances of the `mockBuilderDelegate` and `builderDelegator` are required for each test + // case (each test case runs in its own instance of the test class). + val mockBuilderDelegate: ChatCompletionCreateParams.Builder = + mock(ChatCompletionCreateParams.Builder::class.java) + val builderDelegator = + StructuredChatCompletionCreateParams.builder().inject(mockBuilderDelegate) + + @Test + fun allBuilderDelegateFunctionsExistInDelegator() { + // The delegator class does not implement the various `responseFormat` functions of the + // delegate class. + StructuredChatCompletionTest.checkAllDelegation( + ChatCompletionCreateParams.Builder::class, + StructuredChatCompletionCreateParams.Builder::class, + "responseFormat", + ) + } + + @Test + fun allBuilderDelegatorFunctionsExistInDelegate() { + // The delegator implements a different `responseFormat` function from those overloads in + // the delegate class. + StructuredChatCompletionTest.checkAllDelegation( + StructuredChatCompletionCreateParams.Builder::class, + ChatCompletionCreateParams.Builder::class, + "responseFormat", + ) + } + + @Test + fun allBuilderDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. There are many overloaded functions, so the + // approach here is to build a list (_not_ a set) of all function names and then "subtract" + // those for which tests are defined and see what remains. For example, there are (at this + // time) eight `addMessage` functions, so there must be eight tests defined for functions + // named `addMessage` that will be subtracted from the list of functions matching that name. + // Parameter types are not checked, as that is awkward and probably overkill. Therefore, + // this scheme is not reliable if a function is tested more than once. + val exceptionalTestedFns = listOf("responseFormat") + val testedFns = + (builderDelegationTestCases().map { it.functionName } + exceptionalTestedFns) + .toMutableList() + val nonDelegatingFns = listOf("build", "wrap", "inject") + + val delegatorFns = + StructuredChatCompletionCreateParams.Builder::class.declaredFunctions.toMutableList() + + // Making concurrent modifications to the list, so using an `Iterator`. + val i = delegatorFns.iterator() + + while (i.hasNext()) { + val functionName = i.next().name + + if (functionName in testedFns) { + testedFns.remove(functionName) + i.remove() + } + if (functionName in nonDelegatingFns) { + i.remove() + } + } + + // If there are function names remaining in `delegatorFns`, then there are tests missing. + // Only report the names of the functions not tested: parameters are not matched, so any + // signatures could be misleading. + assertThat(delegatorFns) + .describedAs { + "Delegation is not tested for functions ${delegatorFns.map { it.name }}." + } + .isEmpty() + + // If there are function names remaining in `testedFns`, then there are more tests than + // there should be. Functions might be tested twice, or there may be tests for functions + // that have since been removed from the delegate (though those tests probably failed). + assertThat(testedFns) + .describedAs { "Unexpected or redundant tests for functions $testedFns." } + .isEmpty() + } + + @ParameterizedTest + @MethodSource("builderDelegationTestCases") + fun `delegation of Builder write functions`(testCase: DelegationWriteTestCase) { + checkOneDelegationWrite(builderDelegator, mockBuilderDelegate, testCase) + } + + @Test + fun `delegation of responseFormat`() { + // Special unit test case as the delegator method signature does not match that of the + // delegate method. + val delegatorTestCase = DelegationWriteTestCase("responseFormat", X::class.java) + val delegatorMethod = findDelegationMethod(builderDelegator, delegatorTestCase) + val mockDelegateTestCase = + DelegationWriteTestCase("responseFormat", fromClass(X::class.java)) + val mockDelegateMethod = findDelegationMethod(mockBuilderDelegate, mockDelegateTestCase) + + delegatorMethod.invoke(builderDelegator, delegatorTestCase.inputValues[0]) + + // Verify that the corresponding method on the mock delegate was called exactly once. + verify(mockBuilderDelegate, times(1)).apply { + mockDelegateMethod.invoke(mockBuilderDelegate, mockDelegateTestCase.inputValues[0]) + } + verifyNoMoreInteractions(mockBuilderDelegate) + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt new file mode 100644 index 00000000..347788a3 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessageTest.kt @@ -0,0 +1,141 @@ +package com.openai.models.chat.completions + +import com.openai.core.JsonField +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.DelegationReadTestCase +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.JSON_FIELD +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.JSON_VALUE +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.MESSAGE +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.OPTIONAL +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.X +import com.openai.models.chat.completions.StructuredChatCompletionTest.Companion.checkOneDelegationRead +import java.util.Optional +import kotlin.reflect.full.declaredFunctions +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.Mockito.`when` +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredChatCompletionMessage] class (delegator) and its delegation of most + * functions to a wrapped [ChatCompletionMessage] (delegate). The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredChatCompletionMessageTest { + companion object { + // The list order follows the declaration order in `StructuredChatCompletionMessage` for + // easier maintenance. See `StructuredChatCompletionTest` for details on the values used. + @JvmStatic + fun delegationTestCases() = + listOf( + // `content()` is a special case and has its own test function. + DelegationReadTestCase("refusal", OPTIONAL), + DelegationReadTestCase("_role", JSON_VALUE), + DelegationReadTestCase("annotations", OPTIONAL), + DelegationReadTestCase("audio", OPTIONAL), + DelegationReadTestCase("functionCall", OPTIONAL), + DelegationReadTestCase("toolCalls", OPTIONAL), + // `_content()` is a special case and has its own test function. + DelegationReadTestCase("_refusal", JSON_FIELD), + DelegationReadTestCase("_annotations", JSON_FIELD), + DelegationReadTestCase("_audio", JSON_FIELD), + DelegationReadTestCase("_functionCall", JSON_FIELD), + DelegationReadTestCase("_toolCalls", JSON_FIELD), + DelegationReadTestCase("_additionalProperties", mapOf("key" to JSON_VALUE)), + DelegationReadTestCase("validate", MESSAGE), + // For this boolean function, call with both possible values to ensure that any + // hard-coding or default value will not result in a false positive test. + DelegationReadTestCase("isValid", true), + DelegationReadTestCase("isValid", false), + ) + } + + // New instances of the `mockDelegate` and `delegator` are required for each test case (each + // test case runs in its own instance of the test class). + val mockDelegate: ChatCompletionMessage = mock(ChatCompletionMessage::class.java) + val delegator = StructuredChatCompletionMessage(X::class.java, mockDelegate) + + @Test + fun allDelegateFunctionsExistInDelegator() { + StructuredChatCompletionTest.checkAllDelegation( + ChatCompletionMessage::class, + StructuredChatCompletionMessage::class, + "toBuilder", + "toParam", + ) + } + + @Test + fun allDelegatorFunctionsExistInDelegate() { + StructuredChatCompletionTest.checkAllDelegation( + StructuredChatCompletionMessage::class, + ChatCompletionMessage::class, + ) + } + + @Test + fun allDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. + val exceptionalTestedFns = setOf("content", "_content") + val testedFns = delegationTestCases().map { it.functionName }.toSet() + exceptionalTestedFns + // A few delegator functions do not delegate, so no test function is necessary. + val nonDelegatingFns = listOf("equals", "hashCode", "toString") + + val delegatorFunctions = StructuredChatCompletionMessage::class.declaredFunctions + + for (delegatorFunction in delegatorFunctions) { + assertThat( + delegatorFunction.name in testedFns || + delegatorFunction.name in nonDelegatingFns + ) + .describedAs("Delegation is not tested for function '${delegatorFunction.name}.") + .isTrue + } + } + + @ParameterizedTest + @MethodSource("delegationTestCases") + fun `delegation of functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(delegator, mockDelegate, testCase) + } + + @Test + fun `delegation of content`() { + // Input and output are different types, so this test is an exceptional case. + // `content()` (without an underscore) delegates to `_content()` (with an underscore) + // indirectly via the `content` field initializer. + val input = JsonField.of("{\"s\" : \"hello\"}") + `when`(mockDelegate._content()).thenReturn(input) + val output = delegator.content() // Without an underscore. + + verify(mockDelegate, times(1))._content() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isEqualTo(Optional.of(X("hello"))) + } + + @Test + fun `delegation of _content`() { + // Input and output are different types, so this test is an exceptional case. + // `_content()` delegates to `_content()` indirectly via the `content` field initializer. + val input = JsonField.of("{\"s\" : \"hello\"}") + `when`(mockDelegate._content()).thenReturn(input) + val output = delegator._content() // With an underscore. + + verify(mockDelegate, times(1))._content() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isEqualTo(JsonField.of(X("hello"))) + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt new file mode 100644 index 00000000..af380bbf --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/chat/completions/StructuredChatCompletionTest.kt @@ -0,0 +1,405 @@ +package com.openai.models.chat.completions + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.errors.OpenAIInvalidDataException +import java.util.Optional +import kotlin.reflect.KClass +import kotlin.reflect.KVisibility +import kotlin.reflect.full.declaredFunctions +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito.mock +import org.mockito.Mockito.verifyNoMoreInteractions +import org.mockito.Mockito.`when` +import org.mockito.kotlin.times +import org.mockito.kotlin.verify + +/** + * Unit tests for the [StructuredChatCompletion] class (delegator) and its delegation of most + * functions to a wrapped [ChatCompletion] (delegate). The tests include confirmation of the + * following: + * - All functions in the delegator correspond to a function in the delegate and _vice versa_. + * - All functions in the delegator call their corresponding function in the delegate and only that + * function. + * - A unit test exists for all functions. + * + * There are some exceptions to the above that are handled differently. + */ +internal class StructuredChatCompletionTest { + companion object { + internal fun checkAllDelegation( + delegateClass: KClass<*>, + delegatorClass: KClass<*>, + vararg exceptFunctionNames: String, + ) { + assertThat(delegateClass != delegatorClass) + .describedAs { "Delegate and delegator classes should not be the same." } + .isTrue + + val delegateFunctions = delegateClass.declaredFunctions + + for (delegateFunction in delegateFunctions) { + if (delegateFunction.visibility != KVisibility.PUBLIC) { + // Non-public methods are just implementation details of each class. + continue + } + + if (delegateFunction.name in exceptFunctionNames) { + // Ignore functions that are known exceptions (e.g., `toBuilder`). + continue + } + + // Drop the first parameter from each function, as it is the implicit "this" object + // and has the type of the class declaring the function, which will never match. + val delegatorFunction = + delegatorClass.declaredFunctions.find { + it.name == delegateFunction.name && + it.parameters.drop(1).map { it.type } == + delegateFunction.parameters.drop(1).map { it.type } + } + + assertThat(delegatorFunction != null) + .describedAs { + "Function $delegateFunction is not found in ${delegatorClass.simpleName}." + } + .isTrue + } + } + + internal fun checkOneDelegationRead( + delegator: Any, + mockDelegate: Any, + testCase: DelegationReadTestCase, + ) { + // Stub the method in the mock delegate using reflection + val delegateMethod = mockDelegate::class.java.getMethod(testCase.functionName) + `when`(delegateMethod.invoke(mockDelegate)).thenReturn(testCase.expectedValue) + + // Call the corresponding method on the delegator using reflection + val delegatorMethod = delegator::class.java.getMethod(testCase.functionName) + val result = delegatorMethod.invoke(delegator) + + // Verify that the corresponding method on the mock delegate was called exactly once + verify(mockDelegate, times(1)).apply { delegateMethod.invoke(mockDelegate) } + verifyNoMoreInteractions(mockDelegate) + + // Assert that the result matches the expected value + assertThat(result).isEqualTo(testCase.expectedValue) + } + + // Where a function returns `Optional`, `JsonField` or `JsonValue` There is no need to + // provide a value that matches the type ``, a simple `String` value of `"a-string"` will + // work OK with the test. Constants have been provided for this purpose. + internal const val STRING = "a-string" + + internal val OPTIONAL = Optional.of(STRING) + internal val JSON_FIELD = JsonField.of(STRING) + internal val JSON_VALUE = JsonValue.from(STRING) + internal val MESSAGE = + ChatCompletionMessage.builder().content(STRING).refusal(STRING).build() + private val FINISH_REASON = ChatCompletion.Choice.FinishReason.STOP + private val CHOICE = + ChatCompletion.Choice.builder() + .message(MESSAGE) + .index(0L) + .finishReason(FINISH_REASON) + .logprobs( + ChatCompletion.Choice.Logprobs.builder().content(null).refusal(null).build() + ) + .build() + + data class DelegationReadTestCase(val functionName: String, val expectedValue: Any) + + // The list order follows the declaration order in `StructuredChatCompletionMessage` for + // easier maintenance. + @JvmStatic + fun delegationTestCases() = + listOf( + DelegationReadTestCase("id", STRING), + // `choices()` is a special case and has its own test function. + DelegationReadTestCase("created", 123L), + DelegationReadTestCase("model", STRING), + DelegationReadTestCase("_object_", JSON_VALUE), + DelegationReadTestCase("serviceTier", OPTIONAL), + DelegationReadTestCase("systemFingerprint", OPTIONAL), + DelegationReadTestCase("usage", OPTIONAL), + DelegationReadTestCase("_id", JSON_FIELD), + // `_choices()` is a special case and has its own test function. + DelegationReadTestCase("_created", JSON_FIELD), + DelegationReadTestCase("_model", JSON_FIELD), + DelegationReadTestCase("_serviceTier", JSON_FIELD), + DelegationReadTestCase("_systemFingerprint", JSON_FIELD), + DelegationReadTestCase("_usage", JSON_FIELD), + DelegationReadTestCase("_additionalProperties", mapOf("key" to JSON_VALUE)), + // `validate()` and `isValid()` (which calls `validate()`) are tested separately, + // as they require special handling. + ) + + @JvmStatic + fun choiceDelegationTestCases() = + listOf( + DelegationReadTestCase("finishReason", FINISH_REASON), + DelegationReadTestCase("index", 123L), + DelegationReadTestCase("logprobs", OPTIONAL), + DelegationReadTestCase("_finishReason", JSON_FIELD), + // `message()` is a special case and has its own test function. + DelegationReadTestCase("_index", JSON_FIELD), + DelegationReadTestCase("_logprobs", JSON_FIELD), + // `_message()` is a special case and has its own test function. + DelegationReadTestCase("_additionalProperties", mapOf("key" to JSON_VALUE)), + // `validate()` and `isValid()` (which calls `validate()`) are tested separately, + // as they require special handling. + ) + + /** A basic class used as the generic type when testing. */ + internal class X(val s: String) { + override fun equals(other: Any?) = other is X && other.s == s + + override fun hashCode() = s.hashCode() + } + } + + // New instances of the `mockDelegate` and `delegator` are required for each test case (each + // test case runs in its own instance of the test class). + val mockDelegate: ChatCompletion = mock(ChatCompletion::class.java) + val delegator = StructuredChatCompletion(X::class.java, mockDelegate) + + val mockChoiceDelegate: ChatCompletion.Choice = mock(ChatCompletion.Choice::class.java) + val choiceDelegator = StructuredChatCompletion.Choice(X::class.java, mockChoiceDelegate) + + @Test + fun allChatCompletionDelegateFunctionsExistInDelegator() { + checkAllDelegation(ChatCompletion::class, StructuredChatCompletion::class, "toBuilder") + } + + @Test + fun allChatCompletionDelegatorFunctionsExistInDelegate() { + checkAllDelegation(StructuredChatCompletion::class, ChatCompletion::class) + } + + @Test + fun allChoiceDelegateFunctionsExistInDelegator() { + checkAllDelegation( + ChatCompletion.Choice::class, + StructuredChatCompletion.Choice::class, + "toBuilder", + ) + } + + @Test + fun allChoiceDelegatorFunctionsExistInDelegate() { + checkAllDelegation(StructuredChatCompletion.Choice::class, ChatCompletion.Choice::class) + } + + @Test + fun allDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. + val exceptionalTestedFns = setOf("choices", "_choices", "validate", "isValid") + val testedFns = delegationTestCases().map { it.functionName }.toSet() + exceptionalTestedFns + // A few delegator functions do not delegate, so no test function is necessary. + val nonDelegatingFns = listOf("equals", "hashCode", "toString") + + val delegatorFunctions = StructuredChatCompletion::class.declaredFunctions + + for (delegatorFunction in delegatorFunctions) { + assertThat( + delegatorFunction.name in testedFns || + delegatorFunction.name in nonDelegatingFns + ) + .describedAs("Delegation is not tested for function '${delegatorFunction.name}.") + .isTrue + } + } + + @Test + fun allChoiceDelegatorFunctionsAreTested() { + // There are exceptional test cases for some functions. Most other functions are part of the + // list of those using the parameterized test. + val exceptionalTestedFns = setOf("message", "_message", "validate", "isValid") + val testedFns = + choiceDelegationTestCases().map { it.functionName }.toSet() + exceptionalTestedFns + // A few delegator functions do not delegate, so no test function is necessary. + val nonDelegatingFns = listOf("equals", "hashCode", "toString") + + val delegatorFunctions = StructuredChatCompletion.Choice::class.declaredFunctions + + for (delegatorFunction in delegatorFunctions) { + assertThat( + delegatorFunction.name in testedFns || + delegatorFunction.name in nonDelegatingFns + ) + .describedAs( + "Delegation is not tested for function 'Choice.${delegatorFunction.name}." + ) + .isTrue + } + } + + @ParameterizedTest + @MethodSource("delegationTestCases") + fun `delegation of functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(delegator, mockDelegate, testCase) + } + + @ParameterizedTest + @MethodSource("choiceDelegationTestCases") + fun `delegation of Choice functions in general`(testCase: DelegationReadTestCase) { + checkOneDelegationRead(choiceDelegator, mockChoiceDelegate, testCase) + } + + @Test + fun `delegation of choices`() { + // Input and output are different types, so this test is an exceptional case. + // `choices()` (without an underscore) delegates to `_choices()` (with an underscore) + // indirectly via the `choices` field initializer. + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + val output = delegator.choices() // Without an underscore. + + verify(mockDelegate, times(1))._choices() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output[0].choice).isEqualTo(CHOICE) + } + + @Test + fun `delegation of _choices`() { + // Input and output are different types, so this test is an exceptional case. + // `_choices()` delegates to `_choices()` indirectly via the `choices` field initializer. + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + val output = delegator._choices() // With an underscore. + + verify(mockDelegate, times(1))._choices() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output.getRequired("_choices")[0].choice).isEqualTo(CHOICE) + } + + @Test + fun `delegation of validate`() { + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + val output = delegator.validate() + + // `validate()` calls `choices()` on the delegator which triggers the lazy initializer which + // calls `_choices()` on the delegate before `validate()` also calls `validate()` on the + // delegate. + verify(mockDelegate, times(1))._choices() + verify(mockDelegate, times(1)).validate() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isSameAs(delegator) + } + + @Test + fun `delegation of isValid when true`() { + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + val output = delegator.isValid() + + // `isValid()` calls `validate()`, which has side effects explained in its test function. + verify(mockDelegate, times(1))._choices() + verify(mockDelegate, times(1)).validate() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isTrue + } + + @Test + fun `delegation of isValid when false`() { + // Try with a `false` value to make sure `isValid()` is not just hard-coded to `true`. Do + // this by making `validate()` on the delegate throw an exception. + val input = JsonField.of(listOf(CHOICE)) + `when`(mockDelegate._choices()).thenReturn(input) + `when`(mockDelegate.validate()).thenThrow(OpenAIInvalidDataException("test")) + val output = delegator.isValid() + + // `isValid()` calls `validate()`, which has side effects explained in its test function. + verify(mockDelegate, times(1))._choices() + verify(mockDelegate, times(1)).validate() + verifyNoMoreInteractions(mockDelegate) + + assertThat(output).isFalse + } + + @Test + fun `delegation of Choice-message`() { + // Input and output are different types, so this test is an exceptional case. + // `message()` (without an underscore) delegates to `_message()` (with an underscore) + // indirectly via the `message` field initializer. + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + val output = choiceDelegator.message() // Without an underscore. + + verify(mockChoiceDelegate, times(1))._message() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output.chatCompletionMessage).isEqualTo(MESSAGE) + } + + @Test + fun `delegation of Choice-_message`() { + // Input and output are different types, so this test is an exceptional case. + // `_message()` delegates to `_message()` indirectly via the `message` field initializer. + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + val output = choiceDelegator._message() // With an underscore. + + verify(mockChoiceDelegate, times(1))._message() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output.getRequired("_message").chatCompletionMessage).isEqualTo(MESSAGE) + } + + @Test + fun `delegation of Choice-validate`() { + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + val output = choiceDelegator.validate() + + // `validate()` calls `message()` on the delegator which triggers the lazy initializer which + // calls `_message()` on the delegate before `validate()` also calls `validate()` on the + // delegate. + verify(mockChoiceDelegate, times(1))._message() + verify(mockChoiceDelegate, times(1)).validate() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output).isSameAs(choiceDelegator) + } + + @Test + fun `delegation of Choice-isValid when true`() { + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + val output = choiceDelegator.isValid() + + // `isValid()` calls `validate()`, which has side effects explained in its test function. + verify(mockChoiceDelegate, times(1))._message() + verify(mockChoiceDelegate, times(1)).validate() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output).isTrue + } + + @Test + fun `delegation of Choice-isValid when false`() { + // Try with a `false` value to make sure `isValid()` is not just hard-coded to `true`. Do + // this by making `validate()` on the delegate throw an exception. + val input = JsonField.of(MESSAGE) + `when`(mockChoiceDelegate._message()).thenReturn(input) + `when`(mockChoiceDelegate.validate()).thenThrow(OpenAIInvalidDataException("test")) + val output = choiceDelegator.isValid() + + // `isValid()` calls `validate()`, which has side effects explained in its test function. + verify(mockChoiceDelegate, times(1))._message() + verify(mockChoiceDelegate, times(1)).validate() + verifyNoMoreInteractions(mockChoiceDelegate) + + assertThat(output).isFalse + } +} diff --git a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java new file mode 100644 index 00000000..bcc46a80 --- /dev/null +++ b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java @@ -0,0 +1,73 @@ +package com.openai.example; + +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.ChatModel; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; +import java.util.List; + +public final class StructuredOutputsClassExample { + + public static class Person { + public String firstName; + public String surname; + + @JsonPropertyDescription("The date of birth of the person.") + public String dateOfBirth; + + @Override + public String toString() { + return "Person{firstName=" + firstName + ", surname=" + surname + ", dateOfBirth=" + dateOfBirth + '}'; + } + } + + public static class Laureate { + public Person person; + public String majorAchievement; + public int yearWon; + + @JsonPropertyDescription("The share of the prize money won by the Nobel Laureate.") + public double prizeMoney; + + @Override + public String toString() { + return "Laureate{person=" + + person + ", majorAchievement=" + + majorAchievement + ", yearWon=" + + yearWon + ", prizeMoney=" + + prizeMoney + '}'; + } + } + + public static class Laureates { + @JsonPropertyDescription("A list of winners of a Nobel Prize.") + public List laureates; + + @Override + public String toString() { + return "Laureates{laureates=" + laureates + '}'; + } + } + + private StructuredOutputsClassExample() {} + + public static void main(String[] args) { + // Configures using one of: + // - The `OPENAI_API_KEY` environment variable + // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables + OpenAIClient client = OpenAIOkHttpClient.fromEnv(); + + StructuredChatCompletionCreateParams createParams = ChatCompletionCreateParams.builder() + .model(ChatModel.GPT_4O_MINI) + .maxCompletionTokens(2048) + .responseFormat(Laureates.class) + .addUserMessage("List some winners of the Nobel Prize in Physics since 2000.") + .build(); + + client.chat().completions().create(createParams).choices().stream() + .flatMap(choice -> choice.message().content().stream()) + .forEach(System.out::println); + } +} From 984fe7d2728e5e6651fbf18417c5b5e18d246f7c Mon Sep 17 00:00:00 2001 From: D Gardner Date: Fri, 2 May 2025 16:40:52 +0100 Subject: [PATCH 2/5] structured-outputs: repair after bad merge. --- .../com/openai/core/JsonSchemaValidator.kt | 670 ++++++++++++++++++ .../openai/core/JsonSchemaValidatorTest.kt | 3 +- 2 files changed, 672 insertions(+), 1 deletion(-) create mode 100644 openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt diff --git a/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt new file mode 100644 index 00000000..6af40929 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt @@ -0,0 +1,670 @@ +package com.openai.core + +import com.fasterxml.jackson.databind.JsonNode +import com.openai.core.JsonSchemaValidator.Companion.MAX_ENUM_TOTAL_STRING_LENGTH +import com.openai.core.JsonSchemaValidator.Companion.UNRESTRICTED_ENUM_VALUES_LIMIT + +/** + * A validator that ensures that a JSON schema complies with the rules and restrictions imposed by + * the OpenAI API specification for the input schemas used to define structured outputs. Only a + * subset of the JSON Schema language is supported. The purpose of this validator is to perform a + * quick check of a schema so that it can be determined to be likely to be accepted when passed in + * the request for an AI inference. + * + * This validator assumes that the JSON schema represents the structure of Java/Kotlin classes; it + * is not a general-purpose JSON schema validator. Assumptions are also made that the generator will + * be well-behaved, so the validation is not a check for strict conformance to the JSON Schema + * specification, but to the OpenAI API specification's restrictions on JSON schemas. + */ +internal class JsonSchemaValidator private constructor() { + + companion object { + // The names of the supported schema keywords. All other keywords will be rejected. + private const val SCHEMA = "\$schema" + private const val ID = "\$id" + private const val DEFS = "\$defs" + private const val REF = "\$ref" + private const val PROPS = "properties" + private const val ANY_OF = "anyOf" + private const val TYPE = "type" + private const val REQUIRED = "required" + private const val DESC = "description" + private const val TITLE = "title" + private const val ITEMS = "items" + private const val CONST = "const" + private const val ENUM = "enum" + private const val ADDITIONAL_PROPS = "additionalProperties" + + // The names of the supported schema data types. + // + // JSON Schema does not define an "integer" type, only a "number" type, but it allows any + // schema to define its own "vocabulary" of type names. "integer" is supported by OpenAI. + private const val TYPE_ARRAY = "array" + private const val TYPE_OBJECT = "object" + private const val TYPE_BOOLEAN = "boolean" + private const val TYPE_STRING = "string" + private const val TYPE_NUMBER = "number" + private const val TYPE_INTEGER = "integer" + private const val TYPE_NULL = "null" + + // The validator checks that unsupported type-specific keywords are not present in a + // property node. The OpenAI API specification states: + // + // "Notable keywords not supported include: + // + // - For strings: `minLength`, `maxLength`, `pattern`, `format` + // - For numbers: `minimum`, `maximum`, `multipleOf` + // - For objects: `patternProperties`, `unevaluatedProperties`, `propertyNames`, + // `minProperties`, `maxProperties` + // - For arrays: `unevaluatedItems`, `contains`, `minContains`, `maxContains`, `minItems`, + // `maxItems`, `uniqueItems`" + // + // As that list is not exhaustive, and no keywords are explicitly named as supported, this + // validation allows _no_ type-specific keywords. The following sets define the allowed + // keywords in different contexts and all others are rejected. + + /** + * The set of allowed keywords in the root schema only, not including the keywords that are + * also allowed in a sub-schema. + */ + private val ALLOWED_KEYWORDS_ROOT_SCHEMA_ONLY = setOf(SCHEMA, ID, DEFS) + + /** + * The set of allowed keywords when defining sub-schemas when the `"anyOf"` field is + * present. OpenAI allows the `"anyOf"` field in sub-schemas, but not in the root schema. + */ + private val ALLOWED_KEYWORDS_ANY_OF_SUB_SCHEMA = setOf(ANY_OF, TITLE, DESC) + + /** + * The set of allowed keywords when defining sub-schemas when the `"$ref"` field is present. + */ + private val ALLOWED_KEYWORDS_REF_SUB_SCHEMA = setOf(REF, TITLE, DESC) + + /** + * The set of allowed keywords when defining sub-schemas when the `"type"` field is set to + * `"object"`. + */ + private val ALLOWED_KEYWORDS_OBJECT_SUB_SCHEMA = + setOf(TYPE, REQUIRED, ADDITIONAL_PROPS, TITLE, DESC, PROPS) + + /** + * The set of allowed keywords when defining sub-schemas when the `"type"` field is set to + * `"array"`. + */ + private val ALLOWED_KEYWORDS_ARRAY_SUB_SCHEMA = setOf(TYPE, TITLE, DESC, ITEMS) + + /** + * The set of allowed keywords when defining sub-schemas when the `"type"` field is set to + * `"boolean"`, `"integer"`, `"number"`, or `"string"`. + */ + private val ALLOWED_KEYWORDS_SIMPLE_SUB_SCHEMA = setOf(TYPE, TITLE, DESC, ENUM, CONST) + + /** + * The maximum total length of all strings used in the schema for property names, definition + * names, enum values and const values. The OpenAI specification states: + * > In a schema, total string length of all property names, definition names, enum values, + * > and const values cannot exceed 15,000 characters. + */ + private const val MAX_TOTAL_STRING_LENGTH = 15_000 + + /** The maximum number of object properties allowed in a schema. */ + private const val MAX_OBJECT_PROPERTIES = 100 + + /** The maximum number of enum values across all enums in the schema. */ + private const val MAX_ENUM_VALUES = 500 + + /** + * The number of enum values in any one enum with string values beyond which a limit of + * [MAX_ENUM_TOTAL_STRING_LENGTH] is imposed on the total length of all the string values of + * that one enum. + */ + private const val UNRESTRICTED_ENUM_VALUES_LIMIT = 250 + + /** + * The maximum total length of all string values of a single enum where the number of values + * exceeds [UNRESTRICTED_ENUM_VALUES_LIMIT]. + */ + private const val MAX_ENUM_TOTAL_STRING_LENGTH = 7_500 + + /** The maximum depth (number of levels) of nesting allowed in a schema. */ + private const val MAX_NESTING_DEPTH = 5 + + /** The depth value that corresponds to the root level of the schema. */ + private const val ROOT_DEPTH = 0 + + /** + * The path string that identifies the root node in the schema when appearing in error + * messages or references. + */ + private const val ROOT_PATH = "#" + + /** + * Creates a new [JsonSchemaValidator]. After calling [validate], the validator instance + * holds information about the errors that occurred during validation (if any). A validator + * instance can be used only once to validate a schema; to validate another schema, create + * another validator. + */ + fun create() = JsonSchemaValidator() + } + + /** + * The total length of all strings used in the schema for property names, definition names, enum + * values and const values. + */ + private var totalStringLength: Int = 0 + + /** The total number of values across all enums in the schema. */ + private var totalEnumValues: Int = 0 + + /** The total number of object properties found in the schema, including in definitions. */ + private var totalObjectProperties: Int = 0 + + /** + * The set of valid references that may appear in the schema. This set includes the root schema + * and any definitions within the root schema. This is used to verify that references elsewhere + * in the schema are valid. This will always contain the root schema, but that may be the only + * member. + */ + private var validReferences: MutableSet = mutableSetOf(ROOT_PATH) + + /** The list of error messages accumulated during the validation process. */ + private val errors: MutableList = mutableListOf() + + /** + * Indicates if this validator has validated a schema or not. If a schema has been validated, + * this validator cannot be used again. + */ + private var isValidationComplete = false + + /** + * Gets the list of errors that were recorded during the validation pass. + * + * @return The list of errors. The list may be empty if no errors were recorded. In that case, + * the schema was found to be valid, or has not yet been validated by calling [validate]. + */ + fun errors(): List = errors.toImmutable() + + /** + * Indicates if a validated schema is valid or not. + * + * @return `true` if a schema has been validated by calling [validate] and no errors were + * reported; or `false` if errors were reported or if a schema has not yet been validated. + */ + fun isValid(): Boolean = isValidationComplete && errors.isEmpty() + + /** + * Validates a schema with respect to the OpenAI API specifications. + * + * @param rootSchema The root node of the tree representing the JSON schema definition. + * @return This schema validator for convenience, such as when chaining calls. + * @throws IllegalStateException If called a second time. Create a new validator to validate + * each new schema. + */ + fun validate(rootSchema: JsonNode): JsonSchemaValidator { + if (isValidationComplete) { + throw IllegalStateException("Validation already complete.") + } + isValidationComplete = true + + validateSchema(rootSchema, ROOT_PATH, ROOT_DEPTH) + + // Verify total counts/lengths. These are not localized to a specific element in the schema, + // as no one element is the cause of the error; it is the combination of all elements that + // exceed the limits. Therefore, the root path is used in the error messages. + verify(totalEnumValues <= MAX_ENUM_VALUES, ROOT_PATH) { + "Total number of enum values ($totalEnumValues) exceeds limit of $MAX_ENUM_VALUES." + } + verify(totalStringLength <= MAX_TOTAL_STRING_LENGTH, ROOT_PATH) { + "Total string length of all values ($totalStringLength) exceeds " + + "limit of $MAX_TOTAL_STRING_LENGTH." + } + verify(totalObjectProperties <= MAX_OBJECT_PROPERTIES, ROOT_PATH) { + "Total number of object properties ($totalObjectProperties) exceeds " + + "limit of $MAX_OBJECT_PROPERTIES." + } + + return this + } + + /** + * Validates a schema. This may be the root schema or a sub-schema. Some validations are + * specific to the root schema, which is identified by the [depth] being equal to zero. + * + * This method is recursive: it will validate the given schema and any sub-schemas that it + * contains at any depth. References to other schemas (either the root schema or definition + * sub-schemas) do not increase the depth of nesting, as those references are not followed + * recursively, only checked to be valid internal schema references. + * + * @param schema The schema to be validated. This may be the root schema or any sub-schema. + * @param path The path that identifies the location of this schema within the JSON schema. For + * example, for the root schema, this will be `"#"`; for a definition sub-schema of a `Person` + * object, this will be `"#/$defs/Person"`. + * @param depth The current depth of nesting. The OpenAI API specification places a maximum + * limit on the depth of nesting, which will result in an error if it is exceeded. The nesting + * depth increases with each recursion into a nested sub-schema. For the root schema, the + * nesting depth is zero; all other sub-schemas will have a nesting depth greater than zero. + */ + private fun validateSchema(schema: JsonNode, path: String, depth: Int) { + verify(depth <= MAX_NESTING_DEPTH, path) { + "Current nesting depth is $depth, but maximum is $MAX_NESTING_DEPTH." + } + + verify(schema.isObject, path, { "Schema or sub-schema is not an object." }) { + // If the schema is not an object, perform no further validations. + return + } + + verify(!schema.isEmpty, path) { "Schema or sub-schema is empty." } + + if (depth == ROOT_DEPTH) { + // Sanity check for the presence of the "$schema" field, as this makes it more likely + // that the schema with `depth == 0` is actually the root node of a JSON schema, not + // just a generic JSON document that is being validated in error. + verify(schema.get(SCHEMA) != null, path) { "Root schema missing '$SCHEMA' field." } + } + + // Before sub-schemas can be validated, the list of definitions must be recorded to ensure + // that "$ref" references can be checked for validity. Definitions are optional and only + // appear in the root schema. + validateDefinitions(schema.get(DEFS), "$path/$DEFS", depth) + + val anyOf = schema.get(ANY_OF) + val type = schema.get(TYPE) + val ref = schema.get(REF) + + verify( + (anyOf != null).xor(type != null).xor(ref != null), + path, + { "Expected exactly one of '$TYPE' or '$ANY_OF' or '$REF'." }, + ) { + // Validation cannot continue if none are set, or if more than one is set. + return + } + + validateAnyOfSchema(schema, path, depth) + validateTypeSchema(schema, path, depth) + validateRefSchema(schema, path, depth) + } + + /** + * Validates a schema if it has an `"anyOf"` field. OpenAI does not support the use of `"anyOf"` + * at the root of a JSON schema. The value is the field is expected to be an array of valid + * sub-schemas. If the schema has no `"anyOf"` field, no action is taken. + */ + private fun validateAnyOfSchema(schema: JsonNode, path: String, depth: Int) { + val anyOf = schema.get(ANY_OF) + + if (anyOf == null) return + + validateKeywords(schema, ALLOWED_KEYWORDS_ANY_OF_SUB_SCHEMA, path, depth) + + verify( + anyOf.isArray && !anyOf.isEmpty, + path, + { "'$ANY_OF' field is not a non-empty array." }, + ) { + return + } + + // Validates that the root schema does not contain an `anyOf` field. This is a restriction + // imposed by the OpenAI API specification. `anyOf` fields _can_ appear at other depths. + verify(depth != ROOT_DEPTH, path) { "Root schema contains '$ANY_OF' field." } + + // Each entry must be a valid sub-schema. + anyOf.forEachIndexed { index, subSchema -> + validateSchema(subSchema, "$path/$ANY_OF[$index]", depth + 1) + } + } + + /** + * Validates a schema if it has a `"$ref"` field. The reference is checked to ensure it + * corresponds to a valid definition, or is a reference to the root schema. Recursive references + * are allowed. If no `"$ref"` field is found in the schema, no action is taken. + */ + private fun validateRefSchema(schema: JsonNode, path: String, depth: Int) { + val ref = schema.get(REF) + + if (ref == null) return + + validateKeywords(schema, ALLOWED_KEYWORDS_REF_SUB_SCHEMA, path, depth) + + val refPath = "$path/$REF" + + verify(ref.isTextual, refPath, { "'$REF' field is not a text value." }) { + // No point checking the reference has a referent if it is definitely malformed. + return + } + verify(ref.asText() in validReferences, refPath) { + "Invalid or unsupported reference: '${ref.asText()}'." + } + } + + /** + * Validates a schema if it has a `"type"` field. This includes most sub-schemas, except those + * that have a `"$ref"` or `"anyOf"` field instead. The `"type"` field may be set to a text + * value that is the name of the type (e.g., `"object"`, `"array"`, `"number"`), or it may be + * set to an array that contains two text values: the name of the type and `"null"`. The OpenAI + * API specification explains that this is how a property can be both required (i.e., it must + * appear in the JSON document), but its value can be optional (i.e., it can be set explicitly + * to `"null"`). If the schema has no `"type"` field, no action is taken. + */ + private fun validateTypeSchema(schema: JsonNode, path: String, depth: Int) { + val type = schema.get(TYPE) + + if (type == null) return + + val typeName = + if (type.isTextual) { + // Type will be something like `"type" : "string"` + type.asText() + } else if (type.isArray) { + // Type will be something like `"type" : [ "string", "null" ]`. This corresponds to + // the use of "Optional" in Java/Kotlin. + getTypeNameFromTypeArray(type, "$path/$TYPE") + } else { + error(path) { "'$TYPE' field is not a type name or array of type names." } + return + } + + when (typeName) { + TYPE_ARRAY -> validateArraySchema(schema, path, depth) + TYPE_OBJECT -> validateObjectSchema(schema, path, depth) + + TYPE_BOOLEAN, + TYPE_INTEGER, + TYPE_NUMBER, + TYPE_STRING -> validateSimpleSchema(schema, typeName, path, depth) + + // The type name could not be determined from a type name array. An error will already + // have been logged by `getTypeNameFromTypeArray`, so no need to do anything more here. + null -> return + + else -> error("$path/$TYPE") { "Unsupported '$TYPE' value: '$typeName'." } + } + } + + /** + * Validates a schema whose `"type"` is `"object"`. It is the responsibility of the caller to + * ensure that [schema] contains that type definition. If no type, or a different type is + * present, the behavior is not defined. + */ + private fun validateObjectSchema(schema: JsonNode, path: String, depth: Int) { + validateKeywords(schema, ALLOWED_KEYWORDS_OBJECT_SUB_SCHEMA, path, depth) + + // The schema must declare that additional properties are not allowed. For this check, it + // does not matter if there are no "properties" in the schema. + verify( + schema.get(ADDITIONAL_PROPS) != null && + schema.get(ADDITIONAL_PROPS).asBoolean() == false, + path, + ) { + "'$ADDITIONAL_PROPS' field is missing or is not set to 'false'." + } + + val properties = schema.get(PROPS) + + // The "properties" field may be missing (there may be no properties to declare), but if it + // is present, it must be a non-empty object, or validation cannot continue. + // TODO: Decide if a missing or empty "properties" field is OK or not. + verify( + properties == null || (properties.isObject && !properties.isEmpty), + path, + { "'$PROPS' field is not a non-empty object." }, + ) { + return + } + + if (properties != null) { // Must be an object. + // If a "properties" field is present, there must also be a "required" field. All + // properties must be named in the list of required properties. + validatePropertiesRequired( + properties.fieldNames().asSequence().toSet(), + schema.get(REQUIRED), + "$path/$REQUIRED", + ) + validateProperties(properties, "$path/$PROPS", depth) + } + } + + /** + * Validates a schema whose `"type"` is `"array"`. It is the responsibility of the caller to + * ensure that [schema] contains that type definition. If no type, or a different type is + * present, the behavior is not defined. + * + * An array schema must have an `"items"` field whose value is an object representing a valid + * sub-schema. + */ + private fun validateArraySchema(schema: JsonNode, path: String, depth: Int) { + validateKeywords(schema, ALLOWED_KEYWORDS_ARRAY_SUB_SCHEMA, path, depth) + + val items = schema.get(ITEMS) + + verify( + items != null && items.isObject, + path, + { "'$ITEMS' field is missing or is not an object." }, + ) { + return + } + + validateSchema(items, "$path/$ITEMS", depth + 1) + } + + /** + * Validates a schema whose `"type"` is one of the supported simple type names other than + * `"object"` and `"array"`. It is the responsibility of the caller to ensure that [schema] + * contains the correct type definition. If no type, or a different type is present, the + * behavior is not defined. + * + * @param typeName The name of the specific type of the schema. Where the field value is + * optional and the type is defined as an array of a type name and a `"null"`, this is the + * value of the non-`"null"` type name. For example `"string"`, or `"number"`. + */ + private fun validateSimpleSchema(schema: JsonNode, typeName: String, path: String, depth: Int) { + validateKeywords(schema, ALLOWED_KEYWORDS_SIMPLE_SUB_SCHEMA, path, depth) + + val enumField = schema.get(ENUM) + + // OpenAI API specification: "For a single enum property with string values, the total + // string length of all enum values cannot exceed 7,500 characters when there are more than + // 250 enum values." + val isString = typeName == TYPE_STRING + var numEnumValues = 0 + var stringLength = 0 + + enumField?.forEach { value -> + // OpenAI places limits on the total string length of all enum values across all enums + // without being specific about the type of those enums (unlike for enums with string + // values, which have their own restrictions noted above). The specification does not + // indicate how to count the string length for boolean or number values. Here it is + // assumed that their simple string representations should be counted. + val length = value.asText().length + + totalStringLength += length + totalEnumValues++ + + if (isString) { + numEnumValues++ + stringLength += length + } + } + + verify( + !isString || + numEnumValues <= UNRESTRICTED_ENUM_VALUES_LIMIT || + stringLength <= MAX_ENUM_TOTAL_STRING_LENGTH, + "$path/$ENUM", + ) { + "Total string length ($stringLength) of values of an enum with $numEnumValues " + + "values exceeds limit of $MAX_ENUM_TOTAL_STRING_LENGTH." + } + + schema.get(CONST)?.let { constValue -> totalStringLength += constValue.asText().length } + } + + /** + * Validates that the definitions (if present) contain fields that each define a valid schema. + * Records the names of any definitions to construct the set of possible valid references to + * those definitions. This set will be used to validate any references from within definition + * sub-schemas, or any other sub-schemas validated at a later time. + * + * @param defs The node containing the definitions. Definitions are optional, so this node may + * be `null`. Definitions may appear in the root schema, but will not appear in any + * sub-schemas. If no definitions are present, the list of valid references will not be + * changed and no errors will be recorded. + * @param path The path that identifies the location within the schema of the `"$defs"` node. + * @param depth The current depth of nesting. If definitions are present, this will be zero, as + * that is the depth of the root schema. + */ + private fun validateDefinitions(defs: JsonNode?, path: String, depth: Int) { + // Definitions are optional. If present, expect an object whose fields are named from the + // classes the definitions were extracted from. If not present, do not continue. + verify(defs == null || defs.isObject, path, { "'$DEFS' must be an object." }) { + return + } + + // First, record the valid references to definitions, as any definition sub-schema may + // contain a reference to any other definitions sub-schema (including itself) and those + // references need to be validated. + defs?.fieldNames()?.asSequence()?.forEach { defName -> + val reference = "$path/$defName" + + // Consider that there might be duplicate definition names if two different classes + // (from different packages) have the same simple name. That would be an error, but + // there is no need to stop the validations. + // TODO: How should duplicate names be handled? Will the generator use longer names? + verify(reference !in validReferences, path) { "Duplicate definition of '$defName'." } + validReferences += reference + } + + // Second, recursively validate the definition sub-schemas. + defs?.fieldNames()?.asSequence()?.forEach { defName -> + totalStringLength += defName.length + validateSchema(defs.get(defName), "$path/$DEFS/$defName", depth + 1) + } + } + + /** + * Validates that every property in a collection of property names appears in the array of + * property names in a `"required"` field. + * + * @param propertyNames The collection of property names to check in the array of required + * properties. This collection will not be empty. + * @param required The `"required"` field. This is expected to be a non-`null` array field with + * a set of property names. + * @param path The path identifying the location of the `"required"` field within the schema. + */ + private fun validatePropertiesRequired( + propertyNames: Collection, + required: JsonNode?, + path: String, + ) { + val requiredNames = required?.map { it.asText() }?.toSet() ?: emptySet() + + propertyNames.forEach { propertyName -> + verify(propertyName in requiredNames, path) { + "'$PROPS' field '$propertyName' is not listed as '$REQUIRED'." + } + } + } + + /** + * Validates that each named entry in the `"properties"` field of an object schema has a value + * that is a valid sub-schema. + * + * @param properties The `"properties"` field node of an object schema. + * @param path The path identifying the location of the `"properties"` field within the schema. + */ + private fun validateProperties(properties: JsonNode, path: String, depth: Int) { + val propertyNames = properties.fieldNames().asSequence().toList() + + propertyNames.forEach { propertyName -> + totalObjectProperties++ + totalStringLength += propertyName.length + validateSchema(properties.get(propertyName), "$path/$propertyName", depth + 1) + } + } + + /** + * Validates that the names of all fields in the given schema node are present in a collection + * of allowed keywords. + * + * @param depth The nesting depth of the [schema] node. If this depth is zero, an additional set + * of allowed keywords will be included automatically for keywords that are allowed to appear + * only at the root level of the schema (e.g., `"$schema"`, `"$defs"`). + */ + private fun validateKeywords( + schema: JsonNode, + allowedKeywords: Collection, + path: String, + depth: Int, + ) { + schema.fieldNames().forEach { keyword -> + verify( + keyword in allowedKeywords || + (depth == ROOT_DEPTH && keyword in ALLOWED_KEYWORDS_ROOT_SCHEMA_ONLY), + path, + ) { + "Use of '$keyword' is not supported here." + } + } + } + + /** + * Gets the name of a type from the given `"type"` field, where the field is an array that + * contains exactly two string values: a type name and a `"null"` (in any order). + * + * @param type The type node. This must be a field with an array value. If this is not an array + * field, the behavior is undefined. It is the responsibility of the caller to ensure that + * this function is only called for array fields. + * @return The type name in the array that is not the `"null"` type; or `null` if no such type + * name was found, or if the array does not contain exactly two expected values: the type name + * and a `"null"` type. If `null`, one or more validation errors will be recorded. + */ + private fun getTypeNameFromTypeArray(type: JsonNode, path: String): String? { + val types = type.asSequence().toList() + + if (types.size == 2 && types.all { it.isTextual }) { + // Allow one type name and one "null". Be lenient about the order. Assume that there are + // no oddities like type names that are empty strings, etc., as the schemas are expected + // to be generated. + if (types[1].asText() == TYPE_NULL && types[0].asText() != TYPE_NULL) { + return types[0].asText() + } else if (types[0].asText() == TYPE_NULL && types[1].asText() != TYPE_NULL) { + return types[1].asText() + } else { + error(path) { "Expected one type name and one \"$TYPE_NULL\"." } + } + } else { + error(path) { "Expected exactly two types, both strings." } + } + + return null + } + + private inline fun verify(value: Boolean, path: String, lazyMessage: () -> Any) { + verify(value, path, lazyMessage) {} + } + + private inline fun verify( + value: Boolean, + path: String, + lazyMessage: () -> Any, + onFalse: () -> Unit, + ) { + if (!value) { + error(path, lazyMessage) + onFalse() + } + } + + private inline fun error(path: String, lazyMessage: () -> Any) { + errors.add("$path: ${lazyMessage()}") + } + + override fun toString(): String = + "${javaClass.simpleName}{isValidationComplete=$isValidationComplete, " + + "totalStringLength=$totalStringLength, " + + "totalObjectProperties=$totalObjectProperties, " + + "totalEnumValues=$totalEnumValues, errors=$errors}" +} diff --git a/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt index 31768c04..ccbc3926 100644 --- a/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt @@ -82,7 +82,8 @@ internal class JsonSchemaValidatorTest { assertThat(validator.isValid()).isTrue } - @Test + // FIXME: Disabled test until issues (noted below) are resolved. + // @Test fun schemaTest_minimalListSchema() { val s: List = listOf() From 60b708dca2622db51652b785c4d0275c35f06a8d Mon Sep 17 00:00:00 2001 From: D Gardner Date: Mon, 5 May 2025 16:27:46 +0100 Subject: [PATCH 3/5] structured-outputs: local validation, unit tests and documentation --- README.md | 166 +++++++++++++++++- .../com/openai/core/StructuredOutputs.kt | 28 ++- .../completions/ChatCompletionCreateParams.kt | 23 ++- .../StructuredChatCompletionCreateParams.kt | 14 +- ...idatorTest.kt => StructuredOutputsTest.kt} | 103 ++++++++++- .../StructuredOutputsClassExample.java | 10 +- 6 files changed, 314 insertions(+), 30 deletions(-) rename openai-java-core/src/test/kotlin/com/openai/core/{JsonSchemaValidatorTest.kt => StructuredOutputsTest.kt} (92%) diff --git a/README.md b/README.md index e6b82333..eceb8074 100644 --- a/README.md +++ b/README.md @@ -286,7 +286,7 @@ OpenAIClient client = OpenAIOkHttpClient.builder() The SDK provides conveniences for streamed chat completions. A [`ChatCompletionAccumulator`](openai-java-core/src/main/kotlin/com/openai/helpers/ChatCompletionAccumulator.kt) -can record the stream of chat completion chunks in the response as they are processed and accumulate +can record the stream of chat completion chunks in the response as they are processed and accumulate a [`ChatCompletion`](openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletion.kt) object similar to that which would have been returned by the non-streaming API. @@ -334,6 +334,166 @@ client.chat() ChatCompletion chatCompletion = chatCompletionAccumulator.chatCompletion(); ``` +## Structured outputs with JSON schemas + +Open AI [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat) +is a feature that ensures that the model will always generate responses that adhere to a supplied +[JSON schema](https://json-schema.org/overview/what-is-jsonschema). + +A JSON schema can be defined by creating a +[`ResponseFormatJsonSchema`](openai-java-core/src/main/kotlin/com/openai/models/ResponseFormatJsonSchema.kt) +and setting it on the input parameters. However, for greater convenience, a JSON schema can instead +be derived automatically from the structure of an arbitrary Java class. The response will then +automatically convert the generated JSON content to an instance of that Java class. + +Java classes can contain fields declared to be instances of other classes and can use collections: + +```java +class Person { + public String name; + public int yearOfBirth; +} + +class Book { + public String title; + public Person author; + public int yearPublished; +} + +class BookList { + public List books; +} +``` + +Pass the top-level class—`BookList` in this example—to `responseFormat(Class)` when building the +parameters and then access an instance of `BookList` from the generated message content in the +response: + +```java +import com.openai.models.ChatModel; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; + +StructuredChatCompletionCreateParams params = ChatCompletionCreateParams.builder() + .addUserMessage("List six famous nineteenth century novels.") + .model(ChatModel.GPT_4_1) + .responseFormat(BookList.class) + .build(); + +client.chat().completions().create(params).choices().stream() + .flatMap(choice -> choice.message().content().stream()) + .flatMap(bookList -> bookList.books.stream()) + .forEach(book -> System.out.println(book.title + " by " + book.author.name)); +``` + +You can start building the parameters with an instance of +[`ChatCompletionCreateParams.Builder`](openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt) +or +[`StructuredChatCompletionCreateParams.Builder`](openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt). +If you start with the former (which allows for more compact code) the builder type will change to +the latter when `ChatCompletionCreateParams.Builder.responseFormat(Class)` is called. + +If a field in a class is optional and does not require a defined value, you can represent this using +the [`java.util.Optional`](https://docs.oracle.com/javase/8/docs/api/java/util/Optional.html) class. +It is up to the AI model to decide whether to provide a value for that field or leave it empty. + +```java +import java.util.Optional; + +class Book { + public String title; + public Person author; + public int yearPublished; + public Optional isbn; +} +``` + +If an error occurs while converting a JSON response to an instance of a Java class, the error +message will include the JSON response to assist in diagnosis. For instance, if the response is +truncated, the JSON data will be incomplete and cannot be converted to a class instance. If your +JSON response may contain sensitive information, avoid logging it directly, or ensure that you +redact any sensitive details from the error message. + +### Local JSON schema validation + +Structured Outputs supports a +[subset](https://platform.openai.com/docs/guides/structured-outputs#supported-schemas) of the JSON +Schema language. Schemas are generated automatically from classes to align with this subset. +However, due to the inherent structure of the classes, the generated schema may still violate +certain OpenAI schema restrictions, such as exceeding the maximum nesting depth or utilizing +unsupported data types. + +To facilitate compliance, the method `responseFormat(Class)` performs a validation check on the +schema derived from the specified class. This validation ensures that all restrictions are adhered +to. If any issues are detected, an exception will be thrown, providing a detailed message outlining +the reasons for the validation failure. + +- **Local Validation**: The validation process occurs locally, meaning no requests are sent to the +remote AI model. If the schema passes local validation, it is likely to pass remote validation as +well. +- **Remote Validation**: The remote AI model will conduct its own validation upon receiving the JSON +schema in the request. +- **Version Compatibility**: There may be instances where local validation fails while remote +validation succeeds. This can occur if the SDK version is outdated compared to the restrictions +enforced by the remote model. +- **Disabling Local Validation**: If you encounter compatibility issues and wish to bypass local +validation, you can disable it by passing `false` to the `responseFormat(Class, boolean)` method +when building the parameters. (The default value for this parameter is `true`.) + +```java +import com.openai.models.ChatModel; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; + +StructuredChatCompletionCreateParams params = ChatCompletionCreateParams.builder() + .addUserMessage("List six famous nineteenth century novels.") + .model(ChatModel.GPT_4_1) + .responseFormat(BookList.class, false) // Disable local validation. + .build(); +``` + +By following these guidelines, you can ensure that your structured outputs conform to the necessary +schema requirements and minimize the risk of remote validation errors. + +### Annotating classes and JSON schemas + +You can use annotations to add further information to the JSON schema derived from your Java +classes, or to exclude individual fields from the schema. Details from annotations captured in the +JSON schema may be used by the AI model to improve its response. The SDK supports the use of +[Jackson Databind](https://github.com/FasterXML/jackson-databind) annotations. + +```java +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +class Person { + @JsonPropertyDescription("The first name and surname of the person") + public String name; + public int yearOfBirth; +} + +@JsonClassDescription("The details of one published book") +class Book { + public String title; + public Person author; + public int yearPublished; + @JsonIgnore public String genre; +} + +class BookList { + public List books; +} +``` + +- Use `@JsonClassDescription` to add a detailed description to a class. +- Use `@JsonPropertyDescription` to add a detailed description to a field of a class. +- Use `@JsonIgnore` to omit a field of a class from the generated JSON schema. + +If you use `@JsonProperty(required = false)`, the `false` value will be ignored. OpenAI JSON schemas +must mark all properties as _required_, so the schema generated from your Java classes will respect +that restriction and ignore any annotation that would violate it. + ## File uploads The SDK defines methods that accept files. @@ -607,7 +767,7 @@ If the SDK threw an exception, but you're _certain_ the version is compatible, t ## Microsoft Azure -To use this library with [Azure OpenAI](https://learn.microsoft.com/azure/ai-services/openai/overview), use the same +To use this library with [Azure OpenAI](https://learn.microsoft.com/azure/ai-services/openai/overview), use the same OpenAI client builder but with the Azure-specific configuration. ```java @@ -620,7 +780,7 @@ OpenAIClient client = OpenAIOkHttpClient.builder() .build(); ``` -See the complete Azure OpenAI example in the [`openai-java-example`](openai-java-example/src/main/java/com/openai/example/AzureEntraIdExample.java) directory. The other examples in the directory also work with Azure as long as the client is configured to use it. +See the complete Azure OpenAI example in the [`openai-java-example`](openai-java-example/src/main/java/com/openai/example/AzureEntraIdExample.java) directory. The other examples in the directory also work with Azure as long as the client is configured to use it. ## Network options diff --git a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt index 7f18d237..ba828d9a 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt @@ -23,17 +23,37 @@ private val MAPPER = .addModule(JavaTimeModule()) .build() -fun fromClass(type: Class) = - ResponseFormatJsonSchema.builder() +internal fun fromClass( + type: Class, + localValidation: Boolean = true, +): ResponseFormatJsonSchema { + val schema = extractSchema(type) + + if (localValidation) { + val validator = JsonSchemaValidator.create().validate(schema) + + if (!validator.isValid()) { + throw IllegalArgumentException( + "Local validation failed for JSON schema derived from $type:\n" + + validator.errors().joinToString("\n") { " - $it" } + ) + } + } + + return ResponseFormatJsonSchema.builder() .jsonSchema( ResponseFormatJsonSchema.JsonSchema.builder() .name("json-schema-from-${type.simpleName}") - .schema(JsonValue.from(extractSchema(type))) + .schema(JsonValue.from(schema)) .build() ) .build() +} internal fun extractSchema(type: Class): JsonNode { + // Validation is not performed by this function, as it allows extraction of the schema and + // validation of the schema to be controlled more easily when unit testing, as no exceptions + // will be thrown and any recorded validation errors can be inspected at leisure by the tests. val configBuilder = SchemaGeneratorConfigBuilder( com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12, @@ -56,7 +76,7 @@ internal fun extractSchema(type: Class): JsonNode { return SchemaGenerator(configBuilder.build()).generateSchema(type) } -fun fromJson(json: String, type: Class): T = +internal fun fromJson(json: String, type: Class): T = try { MAPPER.readValue(json, type) } catch (e: Exception) { diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt index cb3459fe..0fe22b2b 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt @@ -1298,12 +1298,23 @@ private constructor( } /** - * Sets the class that defines the structured outputs response format. This changes the - * builder to a type-safe [StructuredChatCompletionCreateParams.Builder] that will build a - * [StructuredChatCompletionCreateParams] instance when `build()` is called. - */ - fun responseFormat(responseFormat: Class) = - StructuredChatCompletionCreateParams.builder().wrap(responseFormat, this) + * Sets response format to a JSON schema derived from the structure of the given class. This + * changes the builder to a type-safe [StructuredChatCompletionCreateParams.Builder] that + * will build a [StructuredChatCompletionCreateParams] instance when `build()` is called. + * + * @param responseFormat A class from which a JSON schema will be derived to define the + * response format. + * @param localValidation `true` (the default) to validate the JSON schema locally when it + * is generated by this method to confirm that it adheres to the requirements and + * restrictions on JSON schemas imposed by the OpenAI specification; or `false` to disable + * local validation. See the SDK documentation for more details. + * @throws IllegalArgumentException If local validation is enabled, but it fails because a + * valid JSON schema cannot be derived from the given class. + */ + @JvmOverloads + fun responseFormat(responseFormat: Class, localValidation: Boolean = true) = + StructuredChatCompletionCreateParams.builder() + .wrap(responseFormat, this, localValidation) /** * This feature is in Beta. If specified, our system will make a best effort to sample diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt index ae1ea1be..14194ac9 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt @@ -33,11 +33,12 @@ internal constructor( internal fun wrap( responseFormat: Class, paramsBuilder: ChatCompletionCreateParams.Builder, + localValidation: Boolean, ) = apply { this.responseFormat = responseFormat this.paramsBuilder = paramsBuilder // Convert the class to a JSON schema and apply it to the delegate `Builder`. - responseFormat(responseFormat) + responseFormat(responseFormat, localValidation) } /** Injects a given `ChatCompletionCreateParams.Builder`. For use only when testing. */ @@ -389,10 +390,15 @@ internal constructor( paramsBuilder.reasoningEffort(reasoningEffort) } - /** Sets the response format to a JSON schema derived from the given class. */ - fun responseFormat(responseFormat: Class) = apply { + /** + * Sets the response format to a JSON schema derived from the structure of the given class. + * + * @see ChatCompletionCreateParams.Builder.responseFormat + */ + @JvmOverloads + fun responseFormat(responseFormat: Class, localValidation: Boolean = true) = apply { this.responseFormat = responseFormat - paramsBuilder.responseFormat(fromClass(responseFormat)) + paramsBuilder.responseFormat(fromClass(responseFormat, localValidation)) } /** @see ChatCompletionCreateParams.Builder.seed */ diff --git a/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt similarity index 92% rename from openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt rename to openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt index ccbc3926..88696e55 100644 --- a/openai-java-core/src/test/kotlin/com/openai/core/JsonSchemaValidatorTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt @@ -7,16 +7,18 @@ import com.fasterxml.jackson.annotation.JsonPropertyDescription import com.fasterxml.jackson.databind.JsonNode import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.node.ObjectNode +import com.openai.errors.OpenAIInvalidDataException import java.util.Optional import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatNoException import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.AfterTestExecutionCallback import org.junit.jupiter.api.extension.ExtensionContext import org.junit.jupiter.api.extension.RegisterExtension -/** Tests the [JsonSchemaValidator] and, in passing, tests the [extractSchema] function. */ -internal class JsonSchemaValidatorTest { +/** Tests for the `StructuredOutputs` functions and the [JsonSchemaValidator]. */ +internal class StructuredOutputsTest { companion object { private const val SCHEMA = "\$schema" private const val SCHEMA_VER = "https://json-schema.org/draft/2020-12/schema" @@ -28,6 +30,8 @@ internal class JsonSchemaValidatorTest { * print them only for failed tests. */ private const val VERBOSE_MODE = false + + private fun parseJson(schemaString: String) = ObjectMapper().readTree(schemaString) } /** @@ -82,7 +86,7 @@ internal class JsonSchemaValidatorTest { assertThat(validator.isValid()).isTrue } - // FIXME: Disabled test until issues (noted below) are resolved. + // TODO: Disabled test until issues (noted below) are resolved. // @Test fun schemaTest_minimalListSchema() { val s: List = listOf() @@ -90,7 +94,7 @@ internal class JsonSchemaValidatorTest { schema = extractSchema(s.javaClass) validator.validate(schema) - // FIXME: Currently, the generated schema looks like this: + // TODO: Currently, the generated schema looks like this: // { // "$schema" : "https://json-schema.org/draft/2020-12/schema", // "type" : "array", @@ -1400,5 +1404,94 @@ internal class JsonSchemaValidatorTest { assertThat(validator.isValid()).isTrue } - private fun parseJson(schemaString: String) = ObjectMapper().readTree(schemaString) + @Test + fun fromJsonSuccess() { + @Suppress("unused") class X(val s: String) + + val x = fromJson("{\"s\" : \"hello\"}", X::class.java) + + assertThat(x.s).isEqualTo("hello") + } + + @Test + fun fromJsonFailure1() { + @Suppress("unused") class X(val s: String) + + // Well-formed JSON, but it does not match the schema of class `X`. + assertThatThrownBy { fromJson("{\"wrong\" : \"hello\"}", X::class.java) } + .isExactlyInstanceOf(OpenAIInvalidDataException::class.java) + .hasMessage("Error parsing JSON: {\"wrong\" : \"hello\"}") + } + + @Test + fun fromJsonFailure2() { + @Suppress("unused") class X(val s: String) + + // Malformed JSON. + assertThatThrownBy { fromJson("{\"truncated", X::class.java) } + .isExactlyInstanceOf(OpenAIInvalidDataException::class.java) + .hasMessage("Error parsing JSON: {\"truncated") + } + + @Test + @Suppress("unused") + fun fromClassSuccessWithoutValidation() { + // Exceed the maximum nesting depth, but do not enable validation. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + assertThatNoException().isThrownBy { fromClass(X::class.java, false) } + } + + @Test + fun fromClassSuccessWithValidation() { + @Suppress("unused") class X(val s: String) + + assertThatNoException().isThrownBy { fromClass(X::class.java, true) } + } + + @Test + @Suppress("unused") + fun fromClassFailureWithValidation() { + // Exceed the maximum nesting depth and enable validation. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + assertThatThrownBy { fromClass(Z::class.java, true) } + .isExactlyInstanceOf(IllegalArgumentException::class.java) + .hasMessage( + "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + + " - #/properties/y/properties/x/properties/w/properties/v/properties/u" + + "/properties/s: Current nesting depth is 6, but maximum is 5." + ) + } + + @Test + @Suppress("unused") + fun fromClassFailureWithValidationDefault() { + // Confirm that the default value of the `localValidation` argument is `true` by expecting + // a validation error when that argument is not given an explicit value. + class U(val s: String) + class V(val u: U) + class W(val v: V) + class X(val w: W) + class Y(val x: X) + class Z(val y: Y) + + assertThatThrownBy { fromClass(Z::class.java) } // Use default for `localValidation` flag. + .isExactlyInstanceOf(IllegalArgumentException::class.java) + .hasMessage( + "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + + " - #/properties/y/properties/x/properties/w/properties/v/properties/u" + + "/properties/s: Current nesting depth is 6, but maximum is 5." + ) + } } diff --git a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java index bcc46a80..30188fa5 100644 --- a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java +++ b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java @@ -13,8 +13,6 @@ public final class StructuredOutputsClassExample { public static class Person { public String firstName; public String surname; - - @JsonPropertyDescription("The date of birth of the person.") public String dateOfBirth; @Override @@ -44,11 +42,6 @@ public String toString() { public static class Laureates { @JsonPropertyDescription("A list of winners of a Nobel Prize.") public List laureates; - - @Override - public String toString() { - return "Laureates{laureates=" + laureates + '}'; - } } private StructuredOutputsClassExample() {} @@ -63,11 +56,12 @@ public static void main(String[] args) { .model(ChatModel.GPT_4O_MINI) .maxCompletionTokens(2048) .responseFormat(Laureates.class) - .addUserMessage("List some winners of the Nobel Prize in Physics since 2000.") + .addUserMessage("List five winners of the Nobel Prize in Physics.") .build(); client.chat().completions().create(createParams).choices().stream() .flatMap(choice -> choice.message().content().stream()) + .flatMap(laureates -> laureates.laureates.stream()) .forEach(System.out::println); } } From d4435b2f892312941235577116188b1ac10b2225 Mon Sep 17 00:00:00 2001 From: D Gardner Date: Tue, 6 May 2025 18:31:44 +0100 Subject: [PATCH 4/5] structured-outputs: changes from code review --- README.md | 48 ++++++++++++----- .../openai/core/JsonSchemaLocalValidation.kt | 19 +++++++ .../com/openai/core/JsonSchemaValidator.kt | 4 +- .../com/openai/core/StructuredOutputs.kt | 17 ++++--- .../completions/ChatCompletionCreateParams.kt | 16 ++++-- .../completions/StructuredChatCompletion.kt | 15 ++++-- .../StructuredChatCompletionCreateParams.kt | 19 +++++-- .../StructuredChatCompletionMessage.kt | 12 ++++- .../blocking/chat/ChatCompletionService.kt | 12 +++-- .../com/openai/core/StructuredOutputsTest.kt | 35 ++++++++----- .../StructuredOutputsClassExample.java | 51 ++++++++++--------- 11 files changed, 170 insertions(+), 78 deletions(-) create mode 100644 openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaLocalValidation.kt diff --git a/README.md b/README.md index eceb8074..bb0bfa13 100644 --- a/README.md +++ b/README.md @@ -343,21 +343,23 @@ is a feature that ensures that the model will always generate responses that adh A JSON schema can be defined by creating a [`ResponseFormatJsonSchema`](openai-java-core/src/main/kotlin/com/openai/models/ResponseFormatJsonSchema.kt) and setting it on the input parameters. However, for greater convenience, a JSON schema can instead -be derived automatically from the structure of an arbitrary Java class. The response will then -automatically convert the generated JSON content to an instance of that Java class. +be derived automatically from the structure of an arbitrary Java class. The JSON content from the +response will then be converted automatically to an instance of that Java class. A full, working +example of the use of Structured Outputs with arbitrary Java classes can be seen in +[`StructuredOutputsClassExample`](openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java). Java classes can contain fields declared to be instances of other classes and can use collections: ```java class Person { public String name; - public int yearOfBirth; + public int birthYear; } class Book { public String title; public Person author; - public int yearPublished; + public int publicationYear; } class BookList { @@ -375,7 +377,7 @@ import com.openai.models.chat.completions.ChatCompletionCreateParams; import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; StructuredChatCompletionCreateParams params = ChatCompletionCreateParams.builder() - .addUserMessage("List six famous nineteenth century novels.") + .addUserMessage("List some famous late twentieth century novels.") .model(ChatModel.GPT_4_1) .responseFormat(BookList.class) .build(); @@ -403,11 +405,25 @@ import java.util.Optional; class Book { public String title; public Person author; - public int yearPublished; + public int publicationYear; public Optional isbn; } ``` +Generic type information for fields is retained in the class's metadata, but _generic type erasure_ +applies in other scopes. While, for example, a JSON schema defining an array of strings can be +derived from the `BoolList.books` field with type `List`, a valid JSON schema cannot be +derived from a local variable of that same type, so the following will _not_ work: + +```java +List books = new ArrayList<>(); + +StructuredChatCompletionCreateParams params = ChatCompletionCreateParams.builder() + .responseFormat(books.class) + // ... + .build(); +``` + If an error occurs while converting a JSON response to an instance of a Java class, the error message will include the JSON response to assist in diagnosis. For instance, if the response is truncated, the JSON data will be incomplete and cannot be converted to a class instance. If your @@ -435,20 +451,23 @@ well. schema in the request. - **Version Compatibility**: There may be instances where local validation fails while remote validation succeeds. This can occur if the SDK version is outdated compared to the restrictions -enforced by the remote model. +enforced by the remote AI model. - **Disabling Local Validation**: If you encounter compatibility issues and wish to bypass local -validation, you can disable it by passing `false` to the `responseFormat(Class, boolean)` method -when building the parameters. (The default value for this parameter is `true`.) +validation, you can disable it by passing +[`JsonSchemaLocalValidation.NO`](openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaLocalValidation.kt) +to the `responseFormat(Class, JsonSchemaLocalValidation)` method when building the parameters. +(The default value for this parameter is `JsonSchemaLocalValidation.YES`.) ```java +import com.openai.core.JsonSchemaLocalValidation; import com.openai.models.ChatModel; import com.openai.models.chat.completions.ChatCompletionCreateParams; import com.openai.models.chat.completions.StructuredChatCompletionCreateParams; StructuredChatCompletionCreateParams params = ChatCompletionCreateParams.builder() - .addUserMessage("List six famous nineteenth century novels.") + .addUserMessage("List some famous late twentieth century novels.") .model(ChatModel.GPT_4_1) - .responseFormat(BookList.class, false) // Disable local validation. + .responseFormat(BookList.class, JsonSchemaLocalValidation.NO) .build(); ``` @@ -470,14 +489,17 @@ import com.fasterxml.jackson.annotation.JsonPropertyDescription; class Person { @JsonPropertyDescription("The first name and surname of the person") public String name; - public int yearOfBirth; + public int birthYear; + @JsonPropertyDescription("The year the person died, or 'present' if the person is living.") + public String deathYear; } @JsonClassDescription("The details of one published book") class Book { public String title; public Person author; - public int yearPublished; + @JsonPropertyDescription("The year in which the book was first published.") + public int publicationYear; @JsonIgnore public String genre; } diff --git a/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaLocalValidation.kt b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaLocalValidation.kt new file mode 100644 index 00000000..9a3ae799 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaLocalValidation.kt @@ -0,0 +1,19 @@ +package com.openai.core + +/** + * Options for local validation of JSON schemas derived from arbitrary classes before a request is + * executed. + */ +enum class JsonSchemaLocalValidation { + /** + * Validate the JSON schema locally before the request is executed. The remote AI model will + * also validate the JSON schema. + */ + YES, + + /** + * Do not validate the JSON schema locally before the request is executed. The remote AI model + * will always validate the JSON schema. + */ + NO, +} diff --git a/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt index 6af40929..85c20b43 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/JsonSchemaValidator.kt @@ -201,9 +201,7 @@ internal class JsonSchemaValidator private constructor() { * each new schema. */ fun validate(rootSchema: JsonNode): JsonSchemaValidator { - if (isValidationComplete) { - throw IllegalStateException("Validation already complete.") - } + check(!isValidationComplete) { "Validation already complete." } isValidationComplete = true validateSchema(rootSchema, ROOT_PATH, ROOT_DEPTH) diff --git a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt index ba828d9a..3c6e7dec 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt @@ -23,20 +23,19 @@ private val MAPPER = .addModule(JavaTimeModule()) .build() +@JvmSynthetic internal fun fromClass( type: Class, - localValidation: Boolean = true, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, ): ResponseFormatJsonSchema { val schema = extractSchema(type) - if (localValidation) { + if (localValidation == JsonSchemaLocalValidation.YES) { val validator = JsonSchemaValidator.create().validate(schema) - if (!validator.isValid()) { - throw IllegalArgumentException( - "Local validation failed for JSON schema derived from $type:\n" + - validator.errors().joinToString("\n") { " - $it" } - ) + require(validator.isValid()) { + "Local validation failed for JSON schema derived from $type:\n" + + validator.errors().joinToString("\n") { " - $it" } } } @@ -44,12 +43,13 @@ internal fun fromClass( .jsonSchema( ResponseFormatJsonSchema.JsonSchema.builder() .name("json-schema-from-${type.simpleName}") - .schema(JsonValue.from(schema)) + .schema(JsonValue.fromJsonNode(schema)) .build() ) .build() } +@JvmSynthetic internal fun extractSchema(type: Class): JsonNode { // Validation is not performed by this function, as it allows extraction of the schema and // validation of the schema to be controlled more easily when unit testing, as no exceptions @@ -76,6 +76,7 @@ internal fun extractSchema(type: Class): JsonNode { return SchemaGenerator(configBuilder.build()).generateSchema(type) } +@JvmSynthetic internal fun fromJson(json: String, type: Class): T = try { MAPPER.readValue(json, type) diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt index 0fe22b2b..d012ff94 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/ChatCompletionCreateParams.kt @@ -19,6 +19,7 @@ import com.openai.core.Enum import com.openai.core.ExcludeMissing import com.openai.core.JsonField import com.openai.core.JsonMissing +import com.openai.core.JsonSchemaLocalValidation import com.openai.core.JsonValue import com.openai.core.Params import com.openai.core.allMaxBy @@ -1304,15 +1305,20 @@ private constructor( * * @param responseFormat A class from which a JSON schema will be derived to define the * response format. - * @param localValidation `true` (the default) to validate the JSON schema locally when it - * is generated by this method to confirm that it adheres to the requirements and - * restrictions on JSON schemas imposed by the OpenAI specification; or `false` to disable - * local validation. See the SDK documentation for more details. + * @param localValidation [com.openai.core.JsonSchemaLocalValidation.YES] (the default) to + * validate the JSON schema locally when it is generated by this method to confirm that it + * adheres to the requirements and restrictions on JSON schemas imposed by the OpenAI + * specification; or [com.openai.core.JsonSchemaLocalValidation.NO] to skip local + * validation and rely only on remote validation. See the SDK documentation for more + * details. * @throws IllegalArgumentException If local validation is enabled, but it fails because a * valid JSON schema cannot be derived from the given class. */ @JvmOverloads - fun responseFormat(responseFormat: Class, localValidation: Boolean = true) = + fun responseFormat( + responseFormat: Class, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, + ) = StructuredChatCompletionCreateParams.builder() .wrap(responseFormat, this, localValidation) diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt index 6ca931a5..62cde6f5 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletion.kt @@ -10,9 +10,16 @@ import com.openai.models.completions.CompletionUsage import java.util.Objects import java.util.Optional +/** + * A wrapper for [ChatCompletion] that provides type-safe access to the [choices] when using the + * _Structured Outputs_ feature to deserialize a JSON response to an instance of an arbitrary class. + * See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class to which the JSON data in the response will be deserialized. + */ class StructuredChatCompletion( - val responseFormat: Class, - val chatCompletion: ChatCompletion, + @get:JvmName("responseFormat") val responseFormat: Class, + @get:JvmName("chatCompletion") val chatCompletion: ChatCompletion, ) { /** @see ChatCompletion.id */ fun id(): String = chatCompletion.id() @@ -68,8 +75,8 @@ class StructuredChatCompletion( class Choice internal constructor( - internal val responseFormat: Class, - internal val choice: ChatCompletion.Choice, + @get:JvmName("responseFormat") val responseFormat: Class, + @get:JvmName("choice") val choice: ChatCompletion.Choice, ) { /** @see ChatCompletion.Choice.finishReason */ fun finishReason(): FinishReason = choice.finishReason() diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt index 14194ac9..4f6a3a63 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionCreateParams.kt @@ -1,6 +1,7 @@ package com.openai.models.chat.completions import com.openai.core.JsonField +import com.openai.core.JsonSchemaLocalValidation import com.openai.core.JsonValue import com.openai.core.checkRequired import com.openai.core.fromClass @@ -11,9 +12,18 @@ import com.openai.models.ReasoningEffort import java.util.Objects import java.util.Optional +/** + * A wrapper for [ChatCompletionCreateParams] that provides a type-safe [Builder] that can record + * the type of the [responseFormat] used to derive a JSON schema from an arbitrary class when using + * the _Structured Outputs_ feature. When a JSON response is received, it is deserialized to am + * instance of that type. See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class that will be used to derive the JSON schema in the request and to + * which the JSON response will be deserialized. + */ class StructuredChatCompletionCreateParams internal constructor( - val responseFormat: Class, + @get:JvmName("responseFormat") val responseFormat: Class, /** * The raw, underlying chat completion create parameters wrapped by this structured instance of * the parameters. @@ -33,7 +43,7 @@ internal constructor( internal fun wrap( responseFormat: Class, paramsBuilder: ChatCompletionCreateParams.Builder, - localValidation: Boolean, + localValidation: JsonSchemaLocalValidation, ) = apply { this.responseFormat = responseFormat this.paramsBuilder = paramsBuilder @@ -396,7 +406,10 @@ internal constructor( * @see ChatCompletionCreateParams.Builder.responseFormat */ @JvmOverloads - fun responseFormat(responseFormat: Class, localValidation: Boolean = true) = apply { + fun responseFormat( + responseFormat: Class, + localValidation: JsonSchemaLocalValidation = JsonSchemaLocalValidation.YES, + ) = apply { this.responseFormat = responseFormat paramsBuilder.responseFormat(fromClass(responseFormat, localValidation)) } diff --git a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt index 519596ef..b833dd47 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/chat/completions/StructuredChatCompletionMessage.kt @@ -7,10 +7,18 @@ import com.openai.models.chat.completions.ChatCompletionMessage.FunctionCall import java.util.Objects import java.util.Optional +/** + * A wrapper for [ChatCompletionMessage] that provides type-safe access to the [content] when using + * the _Structured Outputs_ feature to deserialize a JSON response to an instance of an arbitrary + * class. See the SDK documentation for more details on _Structured Outputs_. + * + * @param T The type of the class to which the JSON data in the content will be deserialized when + * [content] is called. + */ class StructuredChatCompletionMessage internal constructor( - val responseFormat: Class, - val chatCompletionMessage: ChatCompletionMessage, + @get:JvmName("responseFormat") val responseFormat: Class, + @get:JvmName("chatCompletionMessage") val chatCompletionMessage: ChatCompletionMessage, ) { private val content: JsonField by lazy { diff --git a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt index 0060c54f..33207a8c 100644 --- a/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt +++ b/openai-java-core/src/main/kotlin/com/openai/services/blocking/chat/ChatCompletionService.kt @@ -58,12 +58,14 @@ interface ChatCompletionService { /** @see create */ fun create( params: StructuredChatCompletionCreateParams + ): StructuredChatCompletion = create(params, RequestOptions.none()) + + /** @see create */ + fun create( + params: StructuredChatCompletionCreateParams, + requestOptions: RequestOptions = RequestOptions.none(), ): StructuredChatCompletion = - StructuredChatCompletion( - params.responseFormat, - // Normal, non-generic create method call via `ChatCompletionCreateParams`. - create(params.rawParams), - ) + StructuredChatCompletion(params.responseFormat, create(params.rawParams, requestOptions)) /** * **Starting a new project?** We recommend trying diff --git a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt index 88696e55..6ea7454b 100644 --- a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt @@ -86,22 +86,24 @@ internal class StructuredOutputsTest { assertThat(validator.isValid()).isTrue } - // TODO: Disabled test until issues (noted below) are resolved. - // @Test + @Test fun schemaTest_minimalListSchema() { val s: List = listOf() schema = extractSchema(s.javaClass) validator.validate(schema) - // TODO: Currently, the generated schema looks like this: + // Currently, the generated schema looks like this: + // // { // "$schema" : "https://json-schema.org/draft/2020-12/schema", // "type" : "array", // "items" : { } // } - // That causes an error, as the `"items"` object is empty when it should be a valid - // sub-schema. Something like this is what is expected: + // + // That causes an error, as the `"items"` object is empty when it should be a valid + // sub-schema. Something like this is what would be valid: + // // { // "$schema" : "https://json-schema.org/draft/2020-12/schema", // "type" : "array", @@ -109,10 +111,15 @@ internal class StructuredOutputsTest { // "type" : "string" // } // } - // It might be presumed that type erasure is the cause of the missing field. However, the - // `schemaTest_listFieldSchema` method (below) seems to be able to produce the expected - // `"items"` object when it is defined as a class property, so, well ... huh? - assertThat(validator.isValid()).isTrue + // + // The reason for the failure is that generic type information is erased for scopes like + // local variables, but generic type information for fields is retained as part of the class + // metadata. This is the expected behavior in Java, so this test expects an invalid schema. + assertThat(validator.isValid()).isFalse + assertThat(validator.errors()).hasSize(2) + assertThat(validator.errors()[0]).isEqualTo("#/items: Schema or sub-schema is empty.") + assertThat(validator.errors()[1]) + .isEqualTo("#/items: Expected exactly one of 'type' or 'anyOf' or '$REF'.") } @Test @@ -1444,14 +1451,18 @@ internal class StructuredOutputsTest { class Y(val x: X) class Z(val y: Y) - assertThatNoException().isThrownBy { fromClass(X::class.java, false) } + assertThatNoException().isThrownBy { + fromClass(X::class.java, JsonSchemaLocalValidation.NO) + } } @Test fun fromClassSuccessWithValidation() { @Suppress("unused") class X(val s: String) - assertThatNoException().isThrownBy { fromClass(X::class.java, true) } + assertThatNoException().isThrownBy { + fromClass(X::class.java, JsonSchemaLocalValidation.YES) + } } @Test @@ -1465,7 +1476,7 @@ internal class StructuredOutputsTest { class Y(val x: X) class Z(val y: Y) - assertThatThrownBy { fromClass(Z::class.java, true) } + assertThatThrownBy { fromClass(Z::class.java, JsonSchemaLocalValidation.YES) } .isExactlyInstanceOf(IllegalArgumentException::class.java) .hasMessage( "Local validation failed for JSON schema derived from ${Z::class.java}:\n" + diff --git a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java index 30188fa5..3f65a991 100644 --- a/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java +++ b/openai-java-example/src/main/java/com/openai/example/StructuredOutputsClassExample.java @@ -1,5 +1,6 @@ package com.openai.example; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonPropertyDescription; import com.openai.client.OpenAIClient; import com.openai.client.okhttp.OpenAIOkHttpClient; @@ -11,37 +12,41 @@ public final class StructuredOutputsClassExample { public static class Person { - public String firstName; - public String surname; - public String dateOfBirth; + @JsonPropertyDescription("The first name and surname of the person.") + public String name; + + public int birthYear; + + @JsonPropertyDescription("The year the person died, or 'present' if the person is living.") + public String deathYear; @Override public String toString() { - return "Person{firstName=" + firstName + ", surname=" + surname + ", dateOfBirth=" + dateOfBirth + '}'; + return name + " (" + birthYear + '-' + deathYear + ')'; } } - public static class Laureate { - public Person person; - public String majorAchievement; - public int yearWon; + public static class Book { + public String title; + + public Person author; + + @JsonPropertyDescription("The year in which the book was first published.") + public int publicationYear; + + public String genre; - @JsonPropertyDescription("The share of the prize money won by the Nobel Laureate.") - public double prizeMoney; + @JsonIgnore + public String isbn; @Override public String toString() { - return "Laureate{person=" - + person + ", majorAchievement=" - + majorAchievement + ", yearWon=" - + yearWon + ", prizeMoney=" - + prizeMoney + '}'; + return '"' + title + "\" (" + publicationYear + ") [" + genre + "] by " + author; } } - public static class Laureates { - @JsonPropertyDescription("A list of winners of a Nobel Prize.") - public List laureates; + public static class BookList { + public List books; } private StructuredOutputsClassExample() {} @@ -52,16 +57,16 @@ public static void main(String[] args) { // - The `OPENAI_BASE_URL` and `AZURE_OPENAI_KEY` environment variables OpenAIClient client = OpenAIOkHttpClient.fromEnv(); - StructuredChatCompletionCreateParams createParams = ChatCompletionCreateParams.builder() + StructuredChatCompletionCreateParams createParams = ChatCompletionCreateParams.builder() .model(ChatModel.GPT_4O_MINI) .maxCompletionTokens(2048) - .responseFormat(Laureates.class) - .addUserMessage("List five winners of the Nobel Prize in Physics.") + .responseFormat(BookList.class) + .addUserMessage("List some famous late twentieth century novels.") .build(); client.chat().completions().create(createParams).choices().stream() .flatMap(choice -> choice.message().content().stream()) - .flatMap(laureates -> laureates.laureates.stream()) - .forEach(System.out::println); + .flatMap(bookList -> bookList.books.stream()) + .forEach(book -> System.out.println(" - " + book)); } } From 67381ffc3bf5e82ffdff8c92c870bfc62175d0dc Mon Sep 17 00:00:00 2001 From: D Gardner Date: Wed, 7 May 2025 11:30:56 +0100 Subject: [PATCH 5/5] structured-outputs: added 'strict' flag --- .../kotlin/com/openai/core/StructuredOutputs.kt | 3 +++ .../com/openai/core/StructuredOutputsTest.kt | 14 +++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt index 3c6e7dec..6b1889ff 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/StructuredOutputs.kt @@ -44,6 +44,9 @@ internal fun fromClass( ResponseFormatJsonSchema.JsonSchema.builder() .name("json-schema-from-${type.simpleName}") .schema(JsonValue.fromJsonNode(schema)) + // Ensure the model's output strictly adheres to this JSON schema. This is the + // essential "ON switch" for Structured Outputs. + .strict(true) .build() ) .build() diff --git a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt index 6ea7454b..2c1eb885 100644 --- a/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt +++ b/openai-java-core/src/test/kotlin/com/openai/core/StructuredOutputsTest.kt @@ -1440,6 +1440,18 @@ internal class StructuredOutputsTest { .hasMessage("Error parsing JSON: {\"truncated") } + @Test + fun fromClassEnablesStrictAdherenceToSchema() { + @Suppress("unused") class X(val s: String) + + val jsonSchema = fromClass(X::class.java) + + // The "strict" flag _must_ be set to ensure that the model's output will _always_ conform + // to the JSON schema. + assertThat(jsonSchema.jsonSchema().strict()).isPresent + assertThat(jsonSchema.jsonSchema().strict().get()).isTrue + } + @Test @Suppress("unused") fun fromClassSuccessWithoutValidation() { @@ -1452,7 +1464,7 @@ internal class StructuredOutputsTest { class Z(val y: Y) assertThatNoException().isThrownBy { - fromClass(X::class.java, JsonSchemaLocalValidation.NO) + fromClass(Z::class.java, JsonSchemaLocalValidation.NO) } }