Skip to content

Commit

Permalink
added test to validate bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Claudio-code committed Dec 24, 2024
1 parent fe3db2a commit 893a8df
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import io.micrometer.observation.Observation;
Expand Down Expand Up @@ -280,15 +281,21 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
Flux<ChatCompletionResponse> response = this.anthropicApi.chatCompletionStream(request);

// @formatter:off
AtomicReference<String> toolCallId = new AtomicReference<>("");
Flux<ChatResponse> chatResponseFlux = response.switchMap(chatCompletionResponse -> {
AnthropicApi.Usage usage = chatCompletionResponse.usage();
Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(chatCompletionResponse.usage()) : new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);

if (!isProxyToolCalls(prompt, this.defaultOptions) && this.isToolCall(chatResponse, Set.of("tool_use"))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
var toolCallConversation = prompt.getInstructions();
if (toolCallId.get().equalsIgnoreCase(chatResponse.getMetadata().getId())) {
toolCallConversation = handleToolCalls(prompt, chatResponse);
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
} else {
toolCallId.set(chatResponse.getMetadata().getId());
}
}

return Mono.just(chatResponse);
Expand Down Expand Up @@ -493,7 +500,7 @@ private List<AnthropicApi.Tool> getFunctionTools(Set<String> functionNames) {
}).toList();
}

private ChatOptions buildRequestOptions(AnthropicApi.ChatCompletionRequest request) {
private ChatOptions buildRequestOptions(ChatCompletionRequest request) {
return ChatOptions.builder()
.model(request.model())
.maxTokens(request.maxTokens())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,25 @@
package org.springframework.ai.anthropic;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.stream.Collectors;

import ch.qos.logback.classic.Level;
import ch.qos.logback.classic.LoggerContext;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.Appender;
import ch.qos.logback.core.AppenderBase;
import ch.qos.logback.core.read.ListAppender;
import org.junit.Rule;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.test.system.OutputCaptureExtension;
import org.springframework.boot.test.system.OutputCaptureRule;
import reactor.core.publisher.Flux;

import org.springframework.ai.anthropic.api.AnthropicApi;
Expand Down Expand Up @@ -336,10 +343,17 @@ void streamFunctionCallUsageTest() {

List<Message> messages = new ArrayList<>(List.of(userMessage));

var mockService = new MockWeatherService();

MemoryAppender appender = new MemoryAppender();
appender.setContext((LoggerContext) LoggerFactory.getILoggerFactory());
MockWeatherService.log.addAppender(appender);
appender.start();

var promptOptions = AnthropicChatOptions.builder()
.model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
.functionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.function("getCurrentWeather", mockService)
.description(
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
.inputType(MockWeatherService.Request.class)
Expand All @@ -352,9 +366,12 @@ void streamFunctionCallUsageTest() {

logger.info("Response: {}", chatResponse);
Usage usage = chatResponse.getMetadata().getUsage();
appender.stop();

assertThat(usage).isNotNull();
assertThat(appender.getLoggedEvents().size()).isEqualTo(3);
assertThat(usage.getTotalTokens()).isLessThan(4000).isGreaterThan(1800);

}

@Test
Expand Down Expand Up @@ -417,4 +434,39 @@ public AnthropicChatModel openAiChatModel(AnthropicApi api) {

}

public static class MemoryAppender extends ListAppender<ILoggingEvent> {

public void reset() {
this.list.clear();
}

public boolean contains(String string, Level level) {
return this.list.stream()
.anyMatch(event -> event.toString().contains(string) && event.getLevel().equals(level));
}

public int countEventsForLogger(String loggerName) {
return (int) this.list.stream().filter(event -> event.getLoggerName().contains(loggerName)).count();
}

public List<ILoggingEvent> search(String string) {
return this.list.stream().filter(event -> event.toString().contains(string)).collect(Collectors.toList());
}

public List<ILoggingEvent> search(String string, Level level) {
return this.list.stream()
.filter(event -> event.toString().contains(string) && event.getLevel().equals(level))
.collect(Collectors.toList());
}

public int getSize() {
return this.list.size();
}

public List<ILoggingEvent> getLoggedEvents() {
return Collections.unmodifiableList(this.list);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,24 @@

import java.util.function.Function;

import ch.qos.logback.classic.Logger;
import com.fasterxml.jackson.annotation.JsonClassDescription;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import org.slf4j.LoggerFactory;

/**
* @author Christian Tzolov
*/
public class MockWeatherService implements Function<MockWeatherService.Request, MockWeatherService.Response> {

public static final Logger log = (Logger) LoggerFactory.getLogger(MockWeatherService.class.getName());

@Override
public Response apply(Request request) {

log.info("Weather Request: {}", request.toString());
double temperature = 0;
if (request.location().contains("Paris")) {
temperature = 15;
Expand Down

0 comments on commit 893a8df

Please sign in to comment.