diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index adcb897ad89..f2db871a300 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -34,8 +34,10 @@ import reactor.core.publisher.Mono; import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder; +import org.springframework.ai.model.ApiKey; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.HttpHeaders; @@ -96,6 +98,8 @@ public static Builder builder() { private final WebClient webClient; + private final ApiKey apiKey; + /** * Create a new client api. * @param baseUrl api base URL. @@ -107,18 +111,18 @@ public static Builder builder() { * @param responseErrorHandler Response error handler. * @param anthropicBetaFeatures Anthropic beta features. */ - private AnthropicApi(String baseUrl, String completionsPath, String anthropicApiKey, String anthropicVersion, + private AnthropicApi(String baseUrl, String completionsPath, ApiKey anthropicApiKey, String anthropicVersion, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler, String anthropicBetaFeatures) { Consumer jsonContentHeaders = headers -> { - headers.add(HEADER_X_API_KEY, anthropicApiKey); headers.add(HEADER_ANTHROPIC_VERSION, anthropicVersion); headers.add(HEADER_ANTHROPIC_BETA, anthropicBetaFeatures); headers.setContentType(MediaType.APPLICATION_JSON); }; this.completionsPath = completionsPath; + this.apiKey = anthropicApiKey; this.restClient = restClientBuilder.clone() .baseUrl(baseUrl) @@ -160,12 +164,17 @@ public ResponseEntity chatCompletionEntity(ChatCompletio Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null."); + // @formatter:off return this.restClient.post() .uri(this.completionsPath) - .headers(headers -> headers.addAll(additionalHttpHeader)) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) .body(chatRequest) .retrieve() .toEntity(ChatCompletionResponse.class); + // @formatter:on } /** @@ -196,9 +205,13 @@ public Flux chatCompletionStream(ChatCompletionRequest c AtomicReference chatCompletionReference = new AtomicReference<>(); + // @formatter:off return this.webClient.post() .uri(this.completionsPath) - .headers(headers -> headers.addAll(additionalHttpHeader)) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) // @formatter:off .body(Mono.just(chatRequest), ChatCompletionRequest.class) .retrieve() .bodyToFlux(String.class) @@ -232,6 +245,15 @@ public Flux chatCompletionStream(ChatCompletionRequest c .filter(chatCompletionResponse -> chatCompletionResponse.type() != null); } + private void addDefaultHeadersIfMissing(HttpHeaders headers) { + if (!headers.containsKey(HEADER_X_API_KEY)) { + String apiKeyValue = this.apiKey.getValue(); + if (StringUtils.hasText(apiKeyValue)) { + headers.add(HEADER_X_API_KEY, apiKeyValue); + } + } + } + /** * Check the Models * overview and AnthropicApi.builder().build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("apiKey must be set"); + } + + @Test + void testInvalidBaseUrl() { + assertThatThrownBy(() -> AnthropicApi.builder().baseUrl("").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + + assertThatThrownBy(() -> AnthropicApi.builder().baseUrl(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + } + + @Test + void testInvalidCompletionsPath() { + assertThatThrownBy(() -> AnthropicApi.builder().completionsPath("").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("completionsPath cannot be null or empty"); + + assertThatThrownBy(() -> AnthropicApi.builder().completionsPath(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("completionsPath cannot be null or empty"); + } + + @Test + void testInvalidRestClientBuilder() { + assertThatThrownBy(() -> AnthropicApi.builder().restClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("restClientBuilder cannot be null"); + } + + @Test + void testInvalidWebClientBuilder() { + assertThatThrownBy(() -> AnthropicApi.builder().webClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("webClientBuilder cannot be null"); + } + + @Test + void testInvalidResponseErrorHandler() { + assertThatThrownBy(() -> AnthropicApi.builder().responseErrorHandler(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("responseErrorHandler cannot be null"); + } + + @Nested + class MockRequests { + + MockWebServer mockWebServer; + + @BeforeEach + void setUp() throws IOException { + mockWebServer = new MockWebServer(); + mockWebServer.start(); + } + + @AfterEach + void tearDown() throws IOException { + mockWebServer.shutdown(); + } + + @Test + void dynamicApiKeyRestClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + AnthropicApi api = AnthropicApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-opus-3-latest", + "stop_reason": null, + "stop_sequence": null, + "usage": { + "input_tokens": 25, + "output_tokens": 1 + } + } + """); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + AnthropicApi.AnthropicMessage chatCompletionMessage = new AnthropicApi.AnthropicMessage( + List.of(new AnthropicApi.ContentBlock("Hello world")), AnthropicApi.Role.USER); + AnthropicApi.ChatCompletionRequest request = AnthropicApi.ChatCompletionRequest.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_OPUS) + .temperature(0.8) + .messages(List.of(chatCompletionMessage)) + .build(); + ResponseEntity response = api.chatCompletionEntity(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(recordedRequest.getHeader("x-api-key")).isEqualTo("key1"); + + response = api.chatCompletionEntity(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(recordedRequest.getHeader("x-api-key")).isEqualTo("key2"); + } + + @Test + void dynamicApiKeyRestClientWithAdditionalApiKeyHeader() throws InterruptedException { + AnthropicApi api = AnthropicApi.builder().apiKey(() -> { + throw new AssertionFailedError("Should not be called, API key is provided in headers"); + }).baseUrl(mockWebServer.url("/").toString()).build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-opus-3-latest", + "stop_reason": null, + "stop_sequence": null, + "usage": { + "input_tokens": 25, + "output_tokens": 1 + } + } + """); + mockWebServer.enqueue(mockResponse); + + AnthropicApi.AnthropicMessage chatCompletionMessage = new AnthropicApi.AnthropicMessage( + List.of(new AnthropicApi.ContentBlock("Hello world")), AnthropicApi.Role.USER); + AnthropicApi.ChatCompletionRequest request = AnthropicApi.ChatCompletionRequest.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_OPUS) + .temperature(0.8) + .messages(List.of(chatCompletionMessage)) + .build(); + MultiValueMap additionalHeaders = new LinkedMultiValueMap<>(); + additionalHeaders.add("x-api-key", "additional-key"); + ResponseEntity response = api.chatCompletionEntity(request, + additionalHeaders); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(recordedRequest.getHeader("x-api-key")).isEqualTo("additional-key"); + } + + @Test + void dynamicApiKeyWebClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + AnthropicApi api = AnthropicApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) + .setBody(""" + { + "type": "message_start", + "message": { + "id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-opus-4-20250514", + "stop_reason": null, + "stop_sequence": null, + "usage": { + "input_tokens": 25, + "output_tokens": 1 + } + } + } + """.replace("\n", "")); + mockWebServer.enqueue(mockResponse); + mockWebServer.enqueue(mockResponse); + + AnthropicApi.AnthropicMessage chatCompletionMessage = new AnthropicApi.AnthropicMessage( + List.of(new AnthropicApi.ContentBlock("Hello world")), AnthropicApi.Role.USER); + AnthropicApi.ChatCompletionRequest request = AnthropicApi.ChatCompletionRequest.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_OPUS) + .temperature(0.8) + .messages(List.of(chatCompletionMessage)) + .stream(true) + .build(); + api.chatCompletionStream(request).collectList().block(); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(recordedRequest.getHeader("x-api-key")).isEqualTo("key1"); + + api.chatCompletionStream(request).collectList().block(); + + recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(recordedRequest.getHeader("x-api-key")).isEqualTo("key2"); + } + + @Test + void dynamicApiKeyWebClientWithAdditionalApiKey() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + AnthropicApi api = AnthropicApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_EVENT_STREAM_VALUE) + .setBody(""" + { + "type": "message_start", + "message": { + "id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", + "type": "message", + "role": "assistant", + "content": [], + "model": "claude-opus-4-20250514", + "stop_reason": null, + "stop_sequence": null, + "usage": { + "input_tokens": 25, + "output_tokens": 1 + } + } + } + """.replace("\n", "")); + mockWebServer.enqueue(mockResponse); + + AnthropicApi.AnthropicMessage chatCompletionMessage = new AnthropicApi.AnthropicMessage( + List.of(new AnthropicApi.ContentBlock("Hello world")), AnthropicApi.Role.USER); + AnthropicApi.ChatCompletionRequest request = AnthropicApi.ChatCompletionRequest.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_OPUS) + .temperature(0.8) + .messages(List.of(chatCompletionMessage)) + .stream(true) + .build(); + MultiValueMap additionalHeaders = new LinkedMultiValueMap<>(); + additionalHeaders.add("x-api-key", "additional-key"); + + api.chatCompletionStream(request, additionalHeaders).collectList().block(); + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(recordedRequest.getHeader("x-api-key")).isEqualTo("additional-key"); + } + + } + +}