Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
leijendary authored Jan 1, 2025
2 parents 253dd30 + 34ac319 commit 845ec9d
Show file tree
Hide file tree
Showing 680 changed files with 21,686 additions and 7,188 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Source Code Format
name: PR Check

on:
pull_request:
Expand All @@ -20,6 +20,6 @@ jobs:
distribution: 'temurin'
cache: 'maven'

- name: Source code formatting check
- name: Run tests
run: |
./mvnw spring-javaformat:validate
./mvnw test
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Please refer to the [Getting Started Guide](https://docs.spring.io/spring-ai/ref
* [Issues](https://github.com/spring-projects/spring-ai/issues)
<!-- * [Discussions](https://github.com/spring-projects/spring-ai/discussions) - Go here if you have a question, suggestion, or feedback! -->
* [Awesome Spring AI](https://github.com/danvega/awesome-spring-ai) - A curated list of awesome resources, tools, tutorials, and projects for building generative AI applications using Spring AI
* [Spring AI Examples](https://github.com/spring-projects/spring-ai-examples) contains example projects that explain specific features in more detail.

## Breaking changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ void classpathRead() {
.withNumberOfTopTextLinesToDelete(0)
.withNumberOfBottomTextLinesToDelete(3)
.withNumberOfTopPagesToSkipBeforeDelete(0)
.overrideLineSeparator("\n")
.build())
.withPagesPerDocument(1)
.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
Expand Down Expand Up @@ -78,6 +81,7 @@
* @author Mariusz Bernacki
* @author Thomas Vitale
* @author Claudio Silva Junior
* @author Alexandros Pappas
* @since 1.0.0
*/
public class AnthropicChatModel extends AbstractToolCallSupport implements ChatModel {
Expand Down Expand Up @@ -124,9 +128,9 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM
public AnthropicChatModel(AnthropicApi anthropicApi) {
this(anthropicApi,
AnthropicChatOptions.builder()
.withModel(DEFAULT_MODEL_NAME)
.withMaxTokens(DEFAULT_MAX_TOKENS)
.withTemperature(DEFAULT_TEMPERATURE)
.model(DEFAULT_MODEL_NAME)
.maxTokens(DEFAULT_MAX_TOKENS)
.temperature(DEFAULT_TEMPERATURE)
.build());
}

Expand Down Expand Up @@ -210,6 +214,10 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul

@Override
public ChatResponse call(Prompt prompt) {
return this.internalCall(prompt, null);
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
ChatCompletionRequest request = createRequest(prompt, false);

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
Expand All @@ -226,8 +234,14 @@ public ChatResponse call(Prompt prompt) {
ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));

ChatResponse chatResponse = toChatResponse(completionEntity.getBody());
AnthropicApi.ChatCompletionResponse completionResponse = completionEntity.getBody();
AnthropicApi.Usage usage = completionResponse.usage();

Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(completionResponse.usage())
: new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);

ChatResponse chatResponse = toChatResponse(completionEntity.getBody(), accumulatedUsage);
observationContext.setResponse(chatResponse);

return chatResponse;
Expand All @@ -236,14 +250,18 @@ public ChatResponse call(Prompt prompt) {
if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null
&& this.isToolCall(response, Set.of("tool_use"))) {
var toolCallConversation = handleToolCalls(prompt, response);
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
}

return response;
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return this.internalStream(prompt, null);
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);

Expand All @@ -263,11 +281,14 @@ public Flux<ChatResponse> stream(Prompt prompt) {

// @formatter:off
Flux<ChatResponse> chatResponseFlux = response.switchMap(chatCompletionResponse -> {
ChatResponse chatResponse = toChatResponse(chatCompletionResponse);
AnthropicApi.Usage usage = chatCompletionResponse.usage();
Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(chatCompletionResponse.usage()) : new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);

if (!isProxyToolCalls(prompt, this.defaultOptions) && this.isToolCall(chatResponse, Set.of("tool_use"))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
}

return Mono.just(chatResponse);
Expand All @@ -281,7 +302,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
});
}

private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage usage) {

if (chatCompletion == null) {
logger.warn("Null chat completion returned");
Expand Down Expand Up @@ -327,19 +348,22 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
allGenerations.add(toolCallGeneration);
}

return new ChatResponse(allGenerations, this.from(chatCompletion));
return new ChatResponse(allGenerations, this.from(chatCompletion, usage));
}

private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
return from(result, AnthropicUsage.from(result.usage()));
}

private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result, Usage usage) {
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
AnthropicUsage usage = AnthropicUsage.from(result.usage());
return ChatResponseMetadata.builder()
.withId(result.id())
.withModel(result.model())
.withUsage(usage)
.withKeyValue("stop-reason", result.stopReason())
.withKeyValue("stop-sequence", result.stopSequence())
.withKeyValue("type", result.type())
.id(result.id())
.model(result.model())
.usage(usage)
.keyValue("stop-reason", result.stopReason())
.keyValue("stop-sequence", result.stopSequence())
.keyValue("type", result.type())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
*
* @author Christian Tzolov
* @author Thomas Vitale
* @author Alexandros Pappas
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
Expand Down Expand Up @@ -89,17 +90,17 @@ public static Builder builder() {
}

public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) {
return builder().withModel(fromOptions.getModel())
.withMaxTokens(fromOptions.getMaxTokens())
.withMetadata(fromOptions.getMetadata())
.withStopSequences(fromOptions.getStopSequences())
.withTemperature(fromOptions.getTemperature())
.withTopP(fromOptions.getTopP())
.withTopK(fromOptions.getTopK())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withFunctions(fromOptions.getFunctions())
.withProxyToolCalls(fromOptions.getProxyToolCalls())
.withToolContext(fromOptions.getToolContext())
return builder().model(fromOptions.getModel())
.maxTokens(fromOptions.getMaxTokens())
.metadata(fromOptions.getMetadata())
.stopSequences(fromOptions.getStopSequences())
.temperature(fromOptions.getTemperature())
.topP(fromOptions.getTopP())
.topK(fromOptions.getTopK())
.functionCallbacks(fromOptions.getFunctionCallbacks())
.functions(fromOptions.getFunctions())
.proxyToolCalls(fromOptions.getProxyToolCalls())
.toolContext(fromOptions.getToolContext())
.build();
}

Expand Down Expand Up @@ -227,69 +228,69 @@ public static class Builder {

private final AnthropicChatOptions options = new AnthropicChatOptions();

public Builder withModel(String model) {
public Builder model(String model) {
this.options.model = model;
return this;
}

public Builder withModel(AnthropicApi.ChatModel model) {
public Builder model(AnthropicApi.ChatModel model) {
this.options.model = model.getValue();
return this;
}

public Builder withMaxTokens(Integer maxTokens) {
public Builder maxTokens(Integer maxTokens) {
this.options.maxTokens = maxTokens;
return this;
}

public Builder withMetadata(ChatCompletionRequest.Metadata metadata) {
public Builder metadata(ChatCompletionRequest.Metadata metadata) {
this.options.metadata = metadata;
return this;
}

public Builder withStopSequences(List<String> stopSequences) {
public Builder stopSequences(List<String> stopSequences) {
this.options.stopSequences = stopSequences;
return this;
}

public Builder withTemperature(Double temperature) {
public Builder temperature(Double temperature) {
this.options.temperature = temperature;
return this;
}

public Builder withTopP(Double topP) {
public Builder topP(Double topP) {
this.options.topP = topP;
return this;
}

public Builder withTopK(Integer topK) {
public Builder topK(Integer topK) {
this.options.topK = topK;
return this;
}

public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
this.options.functionCallbacks = functionCallbacks;
return this;
}

public Builder withFunctions(Set<String> functionNames) {
public Builder functions(Set<String> functionNames) {
Assert.notNull(functionNames, "Function names must not be null");
this.options.functions = functionNames;
return this;
}

public Builder withFunction(String functionName) {
public Builder function(String functionName) {
Assert.hasText(functionName, "Function name must not be empty");
this.options.functions.add(functionName);
return this;
}

public Builder withProxyToolCalls(Boolean proxyToolCalls) {
public Builder proxyToolCalls(Boolean proxyToolCalls) {
this.options.proxyToolCalls = proxyToolCalls;
return this;
}

public Builder withToolContext(Map<String, Object> toolContext) {
public Builder toolContext(Map<String, Object> toolContext) {
if (this.options.toolContext == null) {
this.options.toolContext = toolContext;
}
Expand All @@ -299,6 +300,110 @@ public Builder withToolContext(Map<String, Object> toolContext) {
return this;
}

/**
* @deprecated use {@link #model(String)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M5")
public Builder withModel(String model) {
return model(model);
}

/**
* @deprecated use {@link #model(AnthropicApi.ChatModel)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M5")
public Builder withModel(AnthropicApi.ChatModel model) {
return model(model);
}

/**
* @deprecated use {@link #maxTokens(Integer)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M5")
public Builder withMaxTokens(Integer maxTokens) {
return maxTokens(maxTokens);
}

/**
* @deprecated use {@link #metadata(ChatCompletionRequest.Metadata)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M5")
public Builder withMetadata(ChatCompletionRequest.Metadata metadata) {
return metadata(metadata);
}

/**
* @deprecated use {@link #stopSequences(List)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M5")
public Builder withStopSequences(List<String> stopSequences) {
return stopSequences(stopSequences);
}

/**
* @deprecated use {@link #temperature(Double)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M5")
public Builder withTemperature(Double temperature) {
return temperature(temperature);
}

/**
* @deprecated use {@link #topP(Double)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M5")
public Builder withTopP(Double topP) {
return topP(topP);
}

/**
* @deprecated use {@link #topK(Integer)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M5")
public Builder withTopK(Integer topK) {
return topK(topK);
}

/**
* @deprecated use {@link #functionCallbacks(List)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M5")
public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
return functionCallbacks(functionCallbacks);
}

/**
* @deprecated use {@link #functions(Set)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M5")
public Builder withFunctions(Set<String> functionNames) {
return functions(functionNames);
}

/**
* @deprecated use {@link #function(String)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M5")
public Builder withFunction(String functionName) {
return function(functionName);
}

/**
* @deprecated use {@link #proxyToolCalls(Boolean)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M5")
public Builder withProxyToolCalls(Boolean proxyToolCalls) {
return proxyToolCalls(proxyToolCalls);
}

/**
* @deprecated use {@link #toolContext(Map)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M5")
public Builder withToolContext(Map<String, Object> toolContext) {
return toolContext(toolContext);
}

public AnthropicChatOptions build() {
return this.options;
}
Expand Down
Loading

0 comments on commit 845ec9d

Please sign in to comment.