diff --git a/openai-java-core/src/main/kotlin/com/openai/models/embeddings/Embedding.kt b/openai-java-core/src/main/kotlin/com/openai/models/embeddings/Embedding.kt index 8192bc19..8ded8a53 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/embeddings/Embedding.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/embeddings/Embedding.kt @@ -22,6 +22,7 @@ import kotlin.jvm.optionals.getOrNull class Embedding private constructor( private val embedding: JsonField>, + private val embeddingValue: JsonField?, private val index: JsonField, private val object_: JsonValue, private val additionalProperties: MutableMap, @@ -31,19 +32,52 @@ private constructor( private constructor( @JsonProperty("embedding") @ExcludeMissing - embedding: JsonField> = JsonMissing.of(), + embedding: JsonField = JsonMissing.of(), @JsonProperty("index") @ExcludeMissing index: JsonField = JsonMissing.of(), @JsonProperty("object") @ExcludeMissing object_: JsonValue = JsonMissing.of(), - ) : this(embedding, index, object_, mutableMapOf()) + ) : this( + JsonMissing.of(), // Legacy embedding field will be populated from embeddingValue + embedding, + index, + object_, + mutableMapOf(), + ) /** * The embedding vector, which is a list of floats. The length of vector depends on the model as * listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings). * + * Important: When Base64 data is received, it is automatically decoded and returned as + * List + * + * @throws OpenAIInvalidDataException if the JSON field has an unexpected type or is + * unexpectedly missing or null (e.g. if the server responded with an unexpected value). + */ + fun embedding(): List = + when { + embeddingValue != null -> + embeddingValue + .getRequired("embedding") + .asFloatList() // Base64→Float auto conversion + !embedding.isMissing() -> + embedding.getRequired("embedding") // Original Float format data + else -> throw OpenAIInvalidDataException("Embedding data is missing") + } + + /** + * The embedding data in its original format (either float list or base64 string). This method + * provides efficient access to the embedding data without unnecessary conversions. + * + * @return EmbeddingValue containing the embedding data in its original format * @throws OpenAIInvalidDataException if the JSON field has an unexpected type or is * unexpectedly missing or null (e.g. if the server responded with an unexpected value). */ - fun embedding(): List = embedding.getRequired("embedding") + fun embeddingValue(): EmbeddingValue = + when { + embeddingValue != null -> embeddingValue.getRequired("embedding") + !embedding.isMissing() -> EmbeddingValue.ofFloatList(embedding.getRequired("embedding")) + else -> throw OpenAIInvalidDataException("Embedding data is missing") + } /** * The index of the embedding in the list of embeddings. @@ -71,7 +105,15 @@ private constructor( * * Unlike [embedding], this method doesn't throw if the JSON field has an unexpected type. */ - @JsonProperty("embedding") @ExcludeMissing fun _embedding(): JsonField> = embedding + @JsonProperty("embedding") + @ExcludeMissing + fun _embedding(): JsonField = + when { + embeddingValue != null -> embeddingValue + !embedding.isMissing() -> + JsonField.of(EmbeddingValue.ofFloatList(embedding.getRequired("embedding"))) + else -> JsonMissing.of() + } /** * Returns the raw JSON value of [index]. @@ -116,7 +158,12 @@ private constructor( @JvmSynthetic internal fun from(embedding: Embedding) = apply { - this.embedding = embedding.embedding.map { it.toMutableList() } + try { + this.embedding = JsonField.of(embedding.embedding().toMutableList()) + } catch (e: Exception) { + // Fallback to field-level copying if embedding() method fails + this.embedding = embedding.embedding.map { it.toMutableList() } + } index = embedding.index object_ = embedding.object_ additionalProperties = embedding.additionalProperties.toMutableMap() @@ -212,6 +259,7 @@ private constructor( fun build(): Embedding = Embedding( checkRequired("embedding", embedding).map { it.toImmutable() }, + null, // embeddingValue - will be null for builder-created instances checkRequired("index", index), object_, additionalProperties.toMutableMap(), @@ -225,7 +273,7 @@ private constructor( return@apply } - embedding() + embedding() // This will call the method that returns List index() _object_().let { if (it != JsonValue.from("embedding")) { @@ -250,7 +298,11 @@ private constructor( */ @JvmSynthetic internal fun validity(): Int = - (embedding.asKnown().getOrNull()?.size ?: 0) + + when { + embeddingValue != null -> embeddingValue.asKnown().getOrNull()?.validity() ?: 0 + !embedding.isMissing() -> embedding.asKnown().getOrNull()?.size ?: 0 + else -> 0 + } + (if (index.asKnown().isPresent) 1 else 0) + object_.let { if (it == JsonValue.from("embedding")) 1 else 0 } @@ -259,15 +311,43 @@ private constructor( return true } - return /* spotless:off */ other is Embedding && embedding == other.embedding && index == other.index && object_ == other.object_ && additionalProperties == other.additionalProperties /* spotless:on */ + if (other !is Embedding) { + return false + } + + return try { + embedding() == other.embedding() && + index == other.index && + object_ == other.object_ && + additionalProperties == other.additionalProperties + } catch (e: Exception) { + // Fallback to field-level comparison if embedding() methods fail + embedding == other.embedding && + embeddingValue == other.embeddingValue && + index == other.index && + object_ == other.object_ && + additionalProperties == other.additionalProperties + } } /* spotless:off */ - private val hashCode: Int by lazy { Objects.hash(embedding, index, object_, additionalProperties) } + private val hashCode: Int by lazy { + try { + Objects.hash(embedding(), index, object_, additionalProperties) + } catch (e: Exception) { + // Fallback to field-level hashing if embedding() method fails + Objects.hash(embedding, embeddingValue, index, object_, additionalProperties) + } + } /* spotless:on */ override fun hashCode(): Int = hashCode override fun toString() = - "Embedding{embedding=$embedding, index=$index, object_=$object_, additionalProperties=$additionalProperties}" + when { + embeddingValue != null -> + "Embedding{embedding=${try { embedding() } catch (e: Exception) { "[]" }}, index=$index, object_=$object_, additionalProperties=$additionalProperties}" + else -> + "Embedding{embedding=${embedding.asKnown().getOrNull() ?: emptyList()}, index=$index, object_=$object_, additionalProperties=$additionalProperties}" + } } diff --git a/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingCreateParams.kt b/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingCreateParams.kt index 6ea30e45..40b42492 100644 --- a/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingCreateParams.kt +++ b/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingCreateParams.kt @@ -78,6 +78,9 @@ private constructor( * The format to return the embeddings in. Can be either `float` or * [`base64`](https://pypi.org/project/pybase64/). * + * Returns the encoding format that was set (either explicitly or via default) when this + * EmbeddingCreateParams instance was built. + * * @throws OpenAIInvalidDataException if the JSON field has an unexpected type (e.g. if the * server responded with an unexpected value). */ @@ -418,12 +421,18 @@ private constructor( * * @throws IllegalStateException if any required field is unset. */ - fun build(): EmbeddingCreateParams = - EmbeddingCreateParams( + fun build(): EmbeddingCreateParams { + // Apply default encoding format if not explicitly set + if (body._encodingFormat().isMissing()) { + body.encodingFormat(EmbeddingDefaults.defaultEncodingFormat) + } + + return EmbeddingCreateParams( body.build(), additionalHeaders.build(), additionalQueryParams.build(), ) + } } fun _body(): Body = body @@ -724,6 +733,12 @@ private constructor( keys.forEach(::removeAdditionalProperty) } + /** + * Internal method to check if encodingFormat has been set. Used by the main Builder to + * determine if default should be applied. + */ + internal fun _encodingFormat(): JsonField = encodingFormat + /** * Returns an immutable instance of [Body]. * diff --git a/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingDefaults.kt b/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingDefaults.kt new file mode 100644 index 00000000..251e4f8a --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingDefaults.kt @@ -0,0 +1,57 @@ +// File generated from our OpenAPI spec by Stainless. + +package com.openai.models.embeddings + +import com.openai.models.embeddings.EmbeddingCreateParams.EncodingFormat + +/** + * Configuration object for default embedding behavior. This allows users to change the default + * encoding format globally. + * + * By default, Base64 encoding is used for optimal performance and reduced network bandwidth. Users + * can explicitly choose float encoding when direct float access is needed. + */ +object EmbeddingDefaults { + + @JvmStatic + @get:JvmName("getDefaultEncodingFormat") + @set:JvmName("setDefaultEncodingFormat") + var defaultEncodingFormat: EncodingFormat = EncodingFormat.BASE64 // Default is Base64 + private set + + /** + * Set the default encoding format for embeddings. This will be applied when no explicit format + * is specified in EmbeddingCreateParams. + * + * @param format the encoding format to use as default + */ + @JvmStatic + fun setDefaultEncodingFormat(format: EncodingFormat) { + defaultEncodingFormat = format + } + + /** + * Reset the default encoding format to Base64 (the recommended default). Base64 encoding + * provides better performance and reduced network bandwidth usage. + */ + @JvmStatic + fun resetToDefaults() { + defaultEncodingFormat = EncodingFormat.BASE64 + } + + /** + * Configure the system to use float encoding as default. This is primarily for backward + * compatibility scenarios. Note: Float encoding uses more network bandwidth and may impact + * performance. For most use cases, the default base64 encoding is recommended. + */ + @JvmStatic + fun enableLegacyFloatDefaults() { + defaultEncodingFormat = EncodingFormat.FLOAT + } + + /** Returns true if the current default encoding format is BASE64. */ + @JvmStatic fun isUsingBase64Defaults(): Boolean = defaultEncodingFormat == EncodingFormat.BASE64 + + /** Returns true if the current default encoding format is FLOAT. */ + @JvmStatic fun isUsingFloatDefaults(): Boolean = defaultEncodingFormat == EncodingFormat.FLOAT +} diff --git a/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingValue.kt b/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingValue.kt new file mode 100644 index 00000000..acb4e3c9 --- /dev/null +++ b/openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingValue.kt @@ -0,0 +1,265 @@ +// File generated from our OpenAPI spec by Stainless. + +package com.openai.models.embeddings + +import com.fasterxml.jackson.core.JsonGenerator +import com.fasterxml.jackson.core.ObjectCodec +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.SerializerProvider +import com.fasterxml.jackson.databind.annotation.JsonDeserialize +import com.fasterxml.jackson.databind.annotation.JsonSerialize +import com.fasterxml.jackson.module.kotlin.jacksonTypeRef +import com.openai.core.BaseDeserializer +import com.openai.core.BaseSerializer +import com.openai.core.JsonValue +import com.openai.core.allMaxBy +import com.openai.errors.OpenAIInvalidDataException +import java.nio.ByteBuffer +import java.util.Base64 +import java.util.Objects + +/** + * Represents embedding data that can be either a list of floats or base64-encoded string. This + * union type allows for efficient handling of both formats. + * + * This class is immutable - all instances are thread-safe and cannot be modified after creation. + */ +@JsonDeserialize(using = EmbeddingValue.Deserializer::class) +@JsonSerialize(using = EmbeddingValue.Serializer::class) +class EmbeddingValue +private constructor( + private val floatList: List? = null, + private val base64String: String? = null, + private val _json: JsonValue? = null, +) { + + /** Returns the embedding as a list of floats, or null if this value represents base64 data. */ + fun floatList(): List? = floatList + + /** + * Returns the embedding as a base64-encoded string, or null if this value represents float + * data. + */ + fun base64String(): String? = base64String + + /** Returns true if this value contains float list data. */ + fun isFloatList(): Boolean = floatList != null + + /** Returns true if this value contains base64 string data. */ + fun isBase64String(): Boolean = base64String != null + + /** + * Returns the embedding data as a list of floats. + * + * **Important feature: Automatic Base64 decoding** This method is the core part of backward + * compatibility. When data is stored in Base64 format, it automatically decodes and returns + * List, so existing user code requires no changes. + * + * Processing flow: + * - Float format data → Return as-is + * - Base64 format data → Automatically decode and return as List + * + * @return Decoded embedding data in List format + */ + fun asFloatList(): List = + when { + floatList != null -> floatList + base64String != null -> + decodeBase64ToFloatList(base64String) // Automatic Base64 decoding + else -> throw IllegalStateException("No valid embedding data") + } + + /** + * Returns the embedding data as a base64-encoded string. If the data is a float list, it will + * be encoded automatically. + */ + fun asBase64String(): String = + when { + base64String != null -> base64String + floatList != null -> encodeFloatListToBase64(floatList) + else -> throw IllegalStateException("No valid embedding data") + } + + /** Returns the raw JSON value for debugging purposes. */ + fun _json(): JsonValue? = _json + + /** Accepts a visitor that can handle both float list and base64 string cases. */ + fun accept(visitor: Visitor): T = + when { + floatList != null -> visitor.visitFloatList(floatList) + base64String != null -> visitor.visitBase64String(base64String) + else -> visitor.unknown(_json) + } + + /** + * Validates the embedding data and returns a new validated instance. This method is immutable - + * it returns a new instance if validation is successful, or throws an exception if validation + * fails. + * + * @return this instance if validation succeeds + * @throws OpenAIInvalidDataException if validation fails + */ + fun validate(): EmbeddingValue { + accept( + object : Visitor { + override fun visitFloatList(floatList: List) { + // Validate that float list is not empty and contains valid values + if (floatList.isEmpty()) { + throw OpenAIInvalidDataException("Float list cannot be empty") + } + floatList.forEach { value -> + if (!value.isFinite()) { + throw OpenAIInvalidDataException("Float values must be finite") + } + } + } + + override fun visitBase64String(base64String: String) { + // Validate base64 format + try { + Base64.getDecoder().decode(base64String) + } catch (e: IllegalArgumentException) { + throw OpenAIInvalidDataException("Invalid base64 string", e) + } + } + } + ) + return this // Return this instance if validation succeeds + } + + fun isValid(): Boolean = + try { + validate() + true + } catch (e: OpenAIInvalidDataException) { + false + } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + return other is EmbeddingValue && + floatList == other.floatList && + base64String == other.base64String + } + + override fun hashCode(): Int = Objects.hash(floatList, base64String) + + override fun toString(): String = + when { + floatList != null -> "EmbeddingValue{floatList=$floatList}" + base64String != null -> "EmbeddingValue{base64String=$base64String}" + _json != null -> "EmbeddingValue{_unknown=$_json}" + else -> throw IllegalStateException("Invalid EmbeddingValue") + } + + companion object { + /** + * Creates an EmbeddingValue from a list of floats. The input list is defensively copied to + * ensure immutability. + * + * @param floatList the list of float values (will be copied) + * @return a new immutable EmbeddingValue instance + * @throws OpenAIInvalidDataException if validation fails + */ + @JvmStatic + fun ofFloatList(floatList: List): EmbeddingValue { + // Defensive copy to ensure immutability + val immutableList = floatList.toList() + val instance = EmbeddingValue(floatList = immutableList) + return instance.validate() // Validate upon creation + } + + /** + * Creates an EmbeddingValue from a base64-encoded string. + * + * @param base64String the base64-encoded string + * @return a new immutable EmbeddingValue instance + * @throws OpenAIInvalidDataException if validation fails + */ + @JvmStatic + fun ofBase64String(base64String: String): EmbeddingValue { + val instance = EmbeddingValue(base64String = base64String) + return instance.validate() // Validate upon creation + } + + /** + * Decodes a base64 string to a list of floats. Assumes the base64 string represents an + * array of 32-bit IEEE 754 floats in little-endian format. + */ + private fun decodeBase64ToFloatList(base64String: String): List { + val bytes = Base64.getDecoder().decode(base64String) + val buffer = ByteBuffer.wrap(bytes).asFloatBuffer() + return (0 until buffer.remaining()).map { buffer.get() } + } + + /** + * Encodes a list of floats to a base64 string. Encodes the floats as an array of 32-bit + * IEEE 754 floats in little-endian format. + */ + private fun encodeFloatListToBase64(floatList: List): String { + val buffer = ByteBuffer.allocate(floatList.size * 4) + floatList.forEach { buffer.putFloat(it) } + return Base64.getEncoder().encodeToString(buffer.array()) + } + } + + /** Visitor interface for handling different types of embedding data. */ + interface Visitor { + fun visitFloatList(floatList: List): T + + fun visitBase64String(base64String: String): T + + fun unknown(json: JsonValue?): T { + throw OpenAIInvalidDataException("Unknown EmbeddingValue: $json") + } + } + + internal class Deserializer : BaseDeserializer(EmbeddingValue::class) { + override fun ObjectCodec.deserialize(node: JsonNode): EmbeddingValue { + val json = JsonValue.fromJsonNode(node) + + val bestMatches = + sequenceOf( + tryDeserialize(node, jacksonTypeRef>())?.let { + EmbeddingValue(floatList = it, _json = json) + }, + tryDeserialize(node, jacksonTypeRef())?.let { + EmbeddingValue(base64String = it, _json = json) + }, + ) + .filterNotNull() + .allMaxBy { it.validity() } + .toList() + + return when (bestMatches.size) { + 0 -> EmbeddingValue(_json = json) + 1 -> bestMatches.single() + else -> bestMatches.firstOrNull { it.isValid() } ?: bestMatches.first() + } + } + } + + internal class Serializer : BaseSerializer(EmbeddingValue::class) { + override fun serialize( + value: EmbeddingValue, + generator: JsonGenerator, + provider: SerializerProvider, + ) { + when { + value.floatList != null -> generator.writeObject(value.floatList) + value.base64String != null -> generator.writeObject(value.base64String) + value._json != null -> generator.writeObject(value._json) + else -> throw IllegalStateException("Invalid EmbeddingValue") + } + } + } + + /** Returns a score indicating how many valid values are contained in this object. */ + @JvmSynthetic + internal fun validity(): Int = + when { + floatList != null -> floatList.size + base64String != null -> 1 + else -> 0 + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingDebugTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingDebugTest.kt new file mode 100644 index 00000000..5f3d1945 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingDebugTest.kt @@ -0,0 +1,62 @@ +// Simple debug test for Embedding +package com.openai.models.embeddings + +import com.fasterxml.jackson.module.kotlin.jacksonTypeRef +import com.openai.core.jsonMapper +import org.junit.jupiter.api.Test + +class EmbeddingDebugTest { + + @Test + fun debugEmbeddingCreation() { + println("=== Debug: Creating embedding with builder ===") + + val builder = Embedding.builder() + println("Builder created: $builder") + + builder.addEmbedding(0.0f) + println("After addEmbedding(0.0f): $builder") + + builder.index(0L) + println("After index(0L): $builder") + + val embedding = builder.build() + println("After build(): $embedding") + + try { + val embeddingList = embedding.embedding() + println("embedding.embedding(): $embeddingList") + println("embedding.embedding().size: ${embeddingList.size}") + } catch (e: Exception) { + println("Error calling embedding(): ${e.message}") + e.printStackTrace() + } + + try { + val index = embedding.index() + println("embedding.index(): $index") + } catch (e: Exception) { + println("Error calling index(): ${e.message}") + } + + // Test JSON serialization/deserialization + try { + val jsonMapper = jsonMapper() + val jsonString = jsonMapper.writeValueAsString(embedding) + println("JSON: $jsonString") + + val roundtrippedEmbedding = + jsonMapper.readValue(jsonString, jacksonTypeRef()) + println("Roundtripped: $roundtrippedEmbedding") + + val roundtrippedList = roundtrippedEmbedding.embedding() + println("Roundtripped embedding(): $roundtrippedList") + println("Roundtripped size: ${roundtrippedList.size}") + + println("Original equals roundtripped: ${embedding == roundtrippedEmbedding}") + } catch (e: Exception) { + println("Error in JSON roundtrip: ${e.message}") + e.printStackTrace() + } + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingDefaultsManualTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingDefaultsManualTest.kt new file mode 100644 index 00000000..3bb6b2a3 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingDefaultsManualTest.kt @@ -0,0 +1,49 @@ +package com.openai.models.embeddings + +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test + +/** Manual test for EmbeddingDefaults behavior */ +@DisplayName("EmbeddingDefaults Manual Test") +class EmbeddingDefaultsManualTest { + + @Test + @DisplayName("Manual test for global defaults") + fun testGlobalDefaultsManually() { + println("=== Manual Test ===") + + // Step 1: Check original default + val originalDefault = EmbeddingDefaults.defaultEncodingFormat + println("Original default: $originalDefault") + + // Step 2: Change to FLOAT + EmbeddingDefaults.setDefaultEncodingFormat(EmbeddingCreateParams.EncodingFormat.FLOAT) + val changedDefault = EmbeddingDefaults.defaultEncodingFormat + println("Changed default: $changedDefault") + + // Step 3: Build params without explicit encoding + val params = + EmbeddingCreateParams.builder() + .input("test input") + .model("text-embedding-ada-002") + .build() + + println("Params encoding format: ${params.encodingFormat()}") + println("Is present: ${params.encodingFormat().isPresent}") + if (params.encodingFormat().isPresent) { + println("Value: ${params.encodingFormat().get()}") + } + + // Step 4: Reset to defaults + EmbeddingDefaults.resetToDefaults() + val resetDefault = EmbeddingDefaults.defaultEncodingFormat + println("Reset default: $resetDefault") + + // Assertions for verification + assertThat(changedDefault).isEqualTo(EmbeddingCreateParams.EncodingFormat.FLOAT) + assertThat(params.encodingFormat().get()) + .isEqualTo(EmbeddingCreateParams.EncodingFormat.FLOAT) + assertThat(resetDefault).isEqualTo(EmbeddingCreateParams.EncodingFormat.BASE64) + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingStepTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingStepTest.kt new file mode 100644 index 00000000..b3f5c40c --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingStepTest.kt @@ -0,0 +1,58 @@ +package com.openai.models.embeddings + +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test + +/** Step-by-step trace test */ +@DisplayName("Step Test") +class EmbeddingStepTest { + + @Test + @DisplayName("Step 1: Check initial state") + fun step1_checkInitialState() { + println("===== Step 1: Check initial state =====") + val defaultFormat = EmbeddingDefaults.defaultEncodingFormat + println("EmbeddingDefaults.defaultEncodingFormat = $defaultFormat") + println("EncodingFormat.BASE64 = ${EmbeddingCreateParams.EncodingFormat.BASE64}") + println("EncodingFormat.FLOAT = ${EmbeddingCreateParams.EncodingFormat.FLOAT}") + println("Are they equal? ${defaultFormat == EmbeddingCreateParams.EncodingFormat.BASE64}") + } + + @Test + @DisplayName("Step 2: Check builder creation") + fun step2_checkBuilder() { + println("===== Step 2: Check builder creation =====") + val builder = EmbeddingCreateParams.builder().input("test").model("text-embedding-ada-002") + println("Builder created") + + // Check state before build + println("About to build...") + val params = builder.build() + println("Build completed") + + val encodingFormat = params.encodingFormat() + println("encodingFormat() result: $encodingFormat") + println("isPresent: ${encodingFormat.isPresent}") + if (encodingFormat.isPresent) { + println("Value: ${encodingFormat.get()}") + } + } + + @Test + @DisplayName("Step 3: Explicit Base64 setting") + fun step3_explicitBase64() { + println("===== Step 3: Explicit Base64 setting =====") + val params = + EmbeddingCreateParams.builder() + .input("test") + .model("text-embedding-ada-002") + .encodingFormat(EmbeddingCreateParams.EncodingFormat.BASE64) + .build() + + val encodingFormat = params.encodingFormat() + println("After explicit Base64 setting: $encodingFormat") + if (encodingFormat.isPresent) { + println("Value: ${encodingFormat.get()}") + } + } +} diff --git a/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingValueIntegrationTest.kt b/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingValueIntegrationTest.kt new file mode 100644 index 00000000..58902458 --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/models/embeddings/EmbeddingValueIntegrationTest.kt @@ -0,0 +1,263 @@ +package com.openai.models.embeddings + +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test + +/** + * Integration test to verify Base64 default functionality and backward compatibility of + * EmbeddingValue. Ensures that both existing List usage and new Base64 format work + * correctly. + */ +@DisplayName("EmbeddingValue Integration Test") +class EmbeddingValueIntegrationTest { + + private var originalDefault: EmbeddingCreateParams.EncodingFormat? = null + + @BeforeEach + fun setUp() { + // Save default settings before test + originalDefault = EmbeddingDefaults.defaultEncodingFormat + } + + @AfterEach + fun tearDown() { + // Restore default settings after test + originalDefault?.let { EmbeddingDefaults.setDefaultEncodingFormat(it) } + } + + /** + * Test to confirm that the default encoding format is Base64. In the new implementation, Base64 + * becomes the default for performance improvements. + */ + @Test + @DisplayName("Confirm that default encoding format is Base64") + fun testDefaultEncodingFormatIsBase64() { + assertThat(EmbeddingDefaults.defaultEncodingFormat) + .describedAs("Default encoding format must be Base64") + .isEqualTo(EmbeddingCreateParams.EncodingFormat.BASE64) + } + + /** + * Test EmbeddingValue creation and format conversion functionality. + * - Creating EmbeddingValue from Float array + * - Converting to Base64 string + * - Creating EmbeddingValue from Base64 string + * - Auto-decode functionality (Base64 → List) + */ + @Test + @DisplayName("Test EmbeddingValue creation and format conversion") + fun testEmbeddingValueCreationAndConversion() { + val floatList = listOf(1.0f, 2.0f, 3.0f, 4.0f) + + // Create EmbeddingValue from Float array + val embeddingFromFloat = EmbeddingValue.ofFloatList(floatList) + assertThat(embeddingFromFloat.isFloatList()) + .describedAs("EmbeddingValue created from Float array must be in Float format") + .isTrue() + assertThat(embeddingFromFloat.asFloatList()) + .describedAs("Float array contents must match") + .isEqualTo(floatList) + + // Test conversion to Base64 + val base64String = embeddingFromFloat.asBase64String() + assertThat(base64String).describedAs("Base64 string must not be empty").isNotEmpty() + + // Create EmbeddingValue from Base64 string + val embeddingFromBase64 = EmbeddingValue.ofBase64String(base64String) + assertThat(embeddingFromBase64.isBase64String()) + .describedAs("EmbeddingValue created from Base64 string must be in Base64 format") + .isTrue() + assertThat(embeddingFromBase64.base64String()) + .describedAs("Base64 string contents must match") + .isEqualTo(base64String) + + // Test auto-decode: Base64 → List + val decodedFloatList = embeddingFromBase64.asFloatList() + assertThat(decodedFloatList) + .describedAs("Decoded Float array must match the original array") + .isEqualTo(floatList) + } + + /** + * Test explicit Base64 encoding specification in EmbeddingCreateParams.Builder. Confirm that + * Base64 format can be explicitly specified using the encodingFormat() method. + */ + @Test + @DisplayName("Test explicit Base64 encoding specification in EmbeddingCreateParams") + fun testEmbeddingCreateParamsBuilderWithBase64Encoding() { + val params = + EmbeddingCreateParams.builder() + .input("test input") + .model("text-embedding-ada-002") + .encodingFormat(encodingFormat = EmbeddingCreateParams.EncodingFormat.BASE64) + .build() + + assertThat(params.encodingFormat()).describedAs("Encoding format must be set").isPresent() + assertThat(params.encodingFormat().get()) + .describedAs("Explicitly specified encoding format must be Base64") + .isEqualTo(EmbeddingCreateParams.EncodingFormat.BASE64) + } + + /** + * Test default behavior of EmbeddingCreateParams. Confirm that Base64 is used by default when + * encoding format is not explicitly specified. + */ + @Test + @DisplayName("Test EmbeddingCreateParams default behavior") + fun testEmbeddingCreateParamsDefaultBehavior() { + val params = + EmbeddingCreateParams.builder() + .input("test input") + .model("text-embedding-ada-002") + .build() // Do not explicitly specify encoding format + + assertThat(params.encodingFormat()) + .describedAs("Encoding format must be set by default") + .isPresent() + assertThat(params.encodingFormat().get()) + .describedAs("Default encoding format must be Base64") + .isEqualTo(EmbeddingCreateParams.EncodingFormat.BASE64) + } + + /** + * Test explicit Float format specification for backward compatibility. Confirm that the + * traditional Float format can be explicitly specified using the encodingFormat() method. + */ + @Test + @DisplayName("Test explicit Float format specification for backward compatibility") + fun testEmbeddingCreateParamsFloatCompatibility() { + val params = + EmbeddingCreateParams.builder() + .input("test input") + .model("text-embedding-ada-002") + .encodingFormat(encodingFormat = EmbeddingCreateParams.EncodingFormat.FLOAT) + .build() + + assertThat(params.encodingFormat()).describedAs("Encoding format must be set").isPresent() + assertThat(params.encodingFormat().get()) + .describedAs( + "Explicitly specified encoding format for backward compatibility must be Float" + ) + .isEqualTo(EmbeddingCreateParams.EncodingFormat.FLOAT) + } + + /** + * Test EmbeddingDefaults global configuration change functionality. + * - Change default setting to Float + * - Confirm that new default setting is applied + * - Confirm that settings can be reset + */ + @Test + @DisplayName("Test EmbeddingDefaults global configuration change") + fun testEmbeddingDefaultsCanBeChanged() { + val originalDefault = EmbeddingDefaults.defaultEncodingFormat + + try { + // Change default to FLOAT + EmbeddingDefaults.setDefaultEncodingFormat(EmbeddingCreateParams.EncodingFormat.FLOAT) + assertThat(EmbeddingDefaults.defaultEncodingFormat) + .describedAs("Default setting must be changed to FLOAT") + .isEqualTo(EmbeddingCreateParams.EncodingFormat.FLOAT) + + // Test that new instances use the new default + val params = + EmbeddingCreateParams.builder() + .input("test input") + .model("text-embedding-ada-002") + .build() + + // Debug information + println( + "EmbeddingDefaults.defaultEncodingFormat = ${EmbeddingDefaults.defaultEncodingFormat}" + ) + println("params.encodingFormat() = ${params.encodingFormat()}") + println("params.encodingFormat().isPresent = ${params.encodingFormat().isPresent}") + if (params.encodingFormat().isPresent) { + println("params.encodingFormat().get() = ${params.encodingFormat().get()}") + } + + assertThat(params.encodingFormat().get()) + .describedAs("New instances must use the changed default setting") + .isEqualTo(EmbeddingCreateParams.EncodingFormat.FLOAT) + + // Test default reset functionality + EmbeddingDefaults.resetToDefaults() + assertThat(EmbeddingDefaults.defaultEncodingFormat) + .describedAs("After reset, Base64 must be returned as default") + .isEqualTo(EmbeddingCreateParams.EncodingFormat.BASE64) + } finally { + // Restore original default setting + EmbeddingDefaults.setDefaultEncodingFormat(originalDefault) + } + } + + /** + * Test EmbeddingValue validation functionality. + * - Validation failure with empty Float array + * - Validation failure with invalid Base64 string + */ + @Test + @DisplayName("Test EmbeddingValue validation functionality") + fun testEmbeddingValueValidation() { + // Test validation success with valid data + val validFloatList = listOf(1.0f, 2.0f, 3.0f) + val validEmbedding = EmbeddingValue.ofFloatList(validFloatList) + + assertThat(validEmbedding.validate()) + .describedAs("Validation with valid data must succeed") + .isNotNull() + .isEqualTo(validEmbedding) + } + + /** + * Test EmbeddingValue visitor pattern implementation. + * - Visitor call for Float array case + * - Visitor call for Base64 string case + */ + @Test + @DisplayName("Test EmbeddingValue visitor pattern") + fun testEmbeddingValueVisitorPattern() { + val floatList = listOf(1.0f, 2.0f, 3.0f) + val embeddingFromFloat = EmbeddingValue.ofFloatList(floatList) + + // Visitor for Float array case + val floatResult = + embeddingFromFloat.accept( + object : EmbeddingValue.Visitor { + override fun visitFloatList(floatList: List): String = "float_visited" + + override fun visitBase64String(base64String: String): String = "base64_visited" + + override fun unknown(json: com.openai.core.JsonValue?): String = + "unknown_visited" + } + ) + + assertThat(floatResult) + .describedAs("For Float array case, visitFloatList must be called") + .isEqualTo("float_visited") + + // Visitor for Base64 case + val base64String = embeddingFromFloat.asBase64String() + val embeddingFromBase64 = EmbeddingValue.ofBase64String(base64String) + + val base64Result = + embeddingFromBase64.accept( + object : EmbeddingValue.Visitor { + override fun visitFloatList(floatList: List): String = "float_visited" + + override fun visitBase64String(base64String: String): String = "base64_visited" + + override fun unknown(json: com.openai.core.JsonValue?): String = + "unknown_visited" + } + ) + + assertThat(base64Result) + .describedAs("For Base64 string case, visitBase64String must be called") + .isEqualTo("base64_visited") + } +}