From 0793f9c1f11304feacc99b21c81d4fa1ef45763e Mon Sep 17 00:00:00 2001 From: Gareth Evans Date: Thu, 14 Aug 2025 10:45:53 +0100 Subject: [PATCH] feat(google genai): support sending labels with chat request Signed-off-by: Gareth Evans --- .../ai/google/genai/GoogleGenAiChatModel.java | 6 +++ .../google/genai/GoogleGenAiChatOptions.java | 26 +++++++++++-- .../genai/GoogleGenAiChatOptionsTest.java | 39 ++++++++++++++++++- 3 files changed, 67 insertions(+), 4 deletions(-) diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java index 6630b4227f9..d668fb11809 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java @@ -478,6 +478,8 @@ Prompt buildRequestPrompt(Prompt prompt) { runtimeOptions.getGoogleSearchRetrieval(), this.defaultOptions.getGoogleSearchRetrieval())); requestOptions.setSafetySettings(ModelOptionsUtils.mergeOption(runtimeOptions.getSafetySettings(), this.defaultOptions.getSafetySettings())); + requestOptions + .setLabels(ModelOptionsUtils.mergeOption(runtimeOptions.getLabels(), this.defaultOptions.getLabels())); } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); @@ -487,6 +489,7 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setGoogleSearchRetrieval(this.defaultOptions.getGoogleSearchRetrieval()); requestOptions.setSafetySettings(this.defaultOptions.getSafetySettings()); + requestOptions.setLabels(this.defaultOptions.getLabels()); } ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); @@ -677,6 +680,9 @@ GeminiRequest createGeminiRequest(Prompt prompt) { configBuilder .thinkingConfig(ThinkingConfig.builder().thinkingBudget(requestOptions.getThinkingBudget()).build()); } + if (requestOptions.getLabels() != null && !requestOptions.getLabels().isEmpty()) { + configBuilder.labels(requestOptions.getLabels()); + } // Add safety settings if (!CollectionUtils.isEmpty(requestOptions.getSafetySettings())) { diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java index 9254cbec4b6..0f69562262a 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java @@ -144,6 +144,9 @@ public class GoogleGenAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private List safetySettings = new ArrayList<>(); + + @JsonIgnore + private Map labels = new HashMap<>(); // @formatter:on public static Builder builder() { @@ -170,6 +173,7 @@ public static GoogleGenAiChatOptions fromOptions(GoogleGenAiChatOptions fromOpti options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()); options.setToolContext(fromOptions.getToolContext()); options.setThinkingBudget(fromOptions.getThinkingBudget()); + options.setLabels(fromOptions.getLabels()); return options; } @@ -332,6 +336,15 @@ public void setSafetySettings(List safetySettings) { this.safetySettings = safetySettings; } + public Map getLabels() { + return this.labels; + } + + public void setLabels(Map labels) { + Assert.notNull(labels, "labels must not be null"); + this.labels = labels; + } + @Override public Map getToolContext() { return this.toolContext; @@ -363,7 +376,7 @@ public boolean equals(Object o) { && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.safetySettings, that.safetySettings) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) - && Objects.equals(this.toolContext, that.toolContext); + && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.labels, that.labels); } @Override @@ -371,7 +384,7 @@ public int hashCode() { return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, this.frequencyPenalty, this.presencePenalty, this.thinkingBudget, this.maxOutputTokens, this.model, this.responseMimeType, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, - this.safetySettings, this.internalToolExecutionEnabled, this.toolContext); + this.safetySettings, this.internalToolExecutionEnabled, this.toolContext, this.labels); } @Override @@ -382,7 +395,8 @@ public String toString() { + ", candidateCount=" + this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\'' + ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks + ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" - + this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + '}'; + + this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + ", labels=" + this.labels + + '}'; } @Override @@ -510,6 +524,12 @@ public Builder thinkingBudget(Integer thinkingBudget) { return this; } + public Builder labels(Map labels) { + Assert.notNull(labels, "labels must not be null"); + this.options.labels = labels; + return this; + } + public GoogleGenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java index d5051f8ec39..0636bff2bf6 100644 --- a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java +++ b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java @@ -18,6 +18,8 @@ import org.junit.jupiter.api.Test; +import java.util.Map; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -104,6 +106,29 @@ public void testEqualsAndHashCodeWithThinkingBudget() { assertThat(options1.hashCode()).isNotEqualTo(options3.hashCode()); } + @Test + public void testEqualsAndHashCodeWithLabels() { + GoogleGenAiChatOptions options1 = GoogleGenAiChatOptions.builder() + .model("test-model") + .labels(Map.of("org", "my-org")) + .build(); + + GoogleGenAiChatOptions options2 = GoogleGenAiChatOptions.builder() + .model("test-model") + .labels(Map.of("org", "my-org")) + .build(); + + GoogleGenAiChatOptions options3 = GoogleGenAiChatOptions.builder() + .model("test-model") + .labels(Map.of("org", "other-org")) + .build(); + + assertThat(options1).isEqualTo(options2); + assertThat(options1.hashCode()).isEqualTo(options2.hashCode()); + assertThat(options1).isNotEqualTo(options3); + assertThat(options1.hashCode()).isNotEqualTo(options3.hashCode()); + } + @Test public void testToStringWithThinkingBudget() { GoogleGenAiChatOptions options = GoogleGenAiChatOptions.builder() @@ -116,4 +141,16 @@ public void testToStringWithThinkingBudget() { assertThat(toString).contains("test-model"); } -} \ No newline at end of file + @Test + public void testToStringWithLabels() { + GoogleGenAiChatOptions options = GoogleGenAiChatOptions.builder() + .model("test-model") + .labels(Map.of("org", "my-org")) + .build(); + + String toString = options.toString(); + assertThat(toString).contains("labels={org=my-org}"); + assertThat(toString).contains("test-model"); + } + +}