Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added suport to anthropic prompt cache #1413

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
import org.springframework.ai.anthropic.api.AnthropicCacheType;
import org.springframework.ai.anthropic.metadata.AnthropicUsage;
import org.springframework.ai.chat.messages.AbstractMessage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
Expand Down Expand Up @@ -401,7 +403,16 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
.map(message -> {
if (message.getMessageType() == MessageType.USER) {
List<ContentBlock> contents = new ArrayList<>(List.of(new ContentBlock(message.getText())));
AbstractMessage abstractMessage = (AbstractMessage) message;
List<ContentBlock> contents;
if (abstractMessage.getCache() != null) {
AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache());
contents = new ArrayList<>(
List.of(new ContentBlock(message.getContent(), cacheType.cacheControl())));
}
else {
contents = new ArrayList<>(List.of(new ContentBlock(message.getContent())));
}
if (message instanceof UserMessage userMessage) {
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
List<ContentBlock> mediaContent = userMessage.getMedia().stream().map(media -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,9 @@

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl;
import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
Expand All @@ -47,6 +44,11 @@
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;

import com.fasterxml.jackson.annotation.JsonProperty;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/**
* The Anthropic API client.
*
Expand Down Expand Up @@ -464,6 +466,14 @@ public ChatCompletionRequest(String model, List<AnthropicMessage> messages, Stri
this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null);
}

/**
* @param type is the cache type supported by anthropic. <a href=
* "https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#cache-limitations">Doc</a>
*/
@JsonInclude(Include.NON_NULL)
public record CacheControl(String type) {
}

public static ChatCompletionRequestBuilder builder() {
return new ChatCompletionRequestBuilder();
}
Expand Down Expand Up @@ -658,7 +668,10 @@ public record ContentBlock(

// tool_result response only
@JsonProperty("tool_use_id") String toolUseId,
@JsonProperty("content") String content
@JsonProperty("content") String content,

// cache object
@JsonProperty("cache_control") CacheControl cacheControl
) {
// @formatter:on

Expand All @@ -677,23 +690,27 @@ public ContentBlock(String mediaType, String data) {
* @param source The source of the content.
*/
public ContentBlock(Type type, Source source) {
this(type, source, null, null, null, null, null, null, null);
this(type, source, null, null, null, null, null, null, null, null);
}

/**
* Create content block
* @param source The source of the content.
*/
public ContentBlock(Source source) {
this(Type.IMAGE, source, null, null, null, null, null, null, null);
this(Type.IMAGE, source, null, null, null, null, null, null, null, null);
}

/**
* Create content block
* @param text The text of the content.
*/
public ContentBlock(String text) {
this(Type.TEXT, null, text, null, null, null, null, null, null);
this(Type.TEXT, null, text, null, null, null, null, null, null, null);
}

public ContentBlock(String text, CacheControl cache) {
this(Type.TEXT, null, text, null, null, null, null, null, null, cache);
}

// Tool result
Expand All @@ -704,7 +721,7 @@ public ContentBlock(String text) {
* @param content The content of the tool result.
*/
public ContentBlock(Type type, String toolUseId, String content) {
this(type, null, null, null, null, null, null, toolUseId, content);
this(type, null, null, null, null, null, null, toolUseId, content, null);
}

/**
Expand All @@ -715,7 +732,7 @@ public ContentBlock(Type type, String toolUseId, String content) {
* @param index The index of the content block.
*/
public ContentBlock(Type type, Source source, String text, Integer index) {
this(type, source, text, index, null, null, null, null, null);
this(type, source, text, index, null, null, null, null, null, null);
}

// Tool use input JSON delta streaming
Expand All @@ -727,7 +744,7 @@ public ContentBlock(Type type, Source source, String text, Integer index) {
* @param input The input of the tool use.
*/
public ContentBlock(Type type, String id, String name, Map<String, Object> input) {
this(type, null, null, null, id, name, input, null, null);
this(type, null, null, null, id, name, input, null, null, null);
}

/**
Expand Down Expand Up @@ -886,7 +903,9 @@ public record ChatCompletionResponse(
public record Usage(
// @formatter:off
@JsonProperty("input_tokens") Integer inputTokens,
@JsonProperty("output_tokens") Integer outputTokens) {
@JsonProperty("output_tokens") Integer outputTokens,
@JsonProperty("cache_creation_input_tokens") Integer cacheCreationInputTokens,
@JsonProperty("cache_read_input_tokens") Integer cacheReadInputTokens) {
// @formatter:off
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package org.springframework.ai.anthropic.api;

import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl;

import java.util.function.Supplier;

public enum AnthropicCacheType {

EPHEMERAL(() -> new CacheControl("ephemeral"));

private Supplier<CacheControl> value;

AnthropicCacheType(Supplier<CacheControl> value) {
this.value = value;
}

public CacheControl cacheControl() {
return this.value.get();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) {

if (messageDeltaEvent.usage() != null) {
var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(),
messageDeltaEvent.usage().outputTokens());
messageDeltaEvent.usage().outputTokens(),
contentBlockReference.get().usage.cacheCreationInputTokens(),
contentBlockReference.get().usage.cacheReadInputTokens());
contentBlockReference.get().withUsage(totalUsage);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,34 @@ public class AnthropicApiIT {

AnthropicApi anthropicApi = new AnthropicApi(System.getenv("ANTHROPIC_API_KEY"));

@Test
void chatWithPromptCache() {
String userMessageText = "It could be either a contraction of the full title Quenta Silmarillion (\"Tale of the Silmarils\") or also a plain Genitive which "
+ "(as in Ancient Greek) signifies reference. This genitive is translated in English with \"about\" or \"of\" "
+ "constructions; the titles of the chapters in The Silmarillion are examples of this genitive in poetic English "
+ "(Of the Sindar, Of Men, Of the Darkening of Valinor etc), where \"of\" means \"about\" or \"concerning\". "
+ "In the same way, Silmarillion can be taken to mean \"Of/About the Silmarils\"";

AnthropicMessage chatCompletionMessage = new AnthropicMessage(
List.of(new ContentBlock(userMessageText.repeat(20), AnthropicCacheType.EPHEMERAL.cacheControl())),
Role.USER);

ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(
AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue(), List.of(chatCompletionMessage), null, 100, 0.8,
false);
AnthropicApi.Usage createdCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest)
.getBody()
.usage();

assertThat(createdCacheToken.cacheCreationInputTokens()).isGreaterThan(0);
assertThat(createdCacheToken.cacheReadInputTokens()).isEqualTo(0);

AnthropicApi.Usage readCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest).getBody().usage();

assertThat(readCacheToken.cacheCreationInputTokens()).isEqualTo(0);
assertThat(readCacheToken.cacheReadInputTokens()).isGreaterThan(0);
}

@Test
void chatCompletionEntity() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,25 @@ public abstract class AbstractMessage implements Message {
*/
protected final String textContent;

protected String cache;

/**
* Additional options for the message to influence the response, not a generative map.
*/
protected final Map<String, Object> metadata;

protected AbstractMessage(MessageType messageType, String textContent, Map<String, Object> metadata, String cache) {
Assert.notNull(messageType, "Message type must not be null");
if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) {
Assert.notNull(textContent, "Content must not be null for SYSTEM or USER messages");
}
this.messageType = messageType;
this.textContent = textContent;
this.metadata = new HashMap<>(metadata);
this.metadata.put(MESSAGE_TYPE, messageType);
this.cache = cache;
}

/**
* Create a new AbstractMessage with the given message type, text content, and
* metadata.
Expand Down Expand Up @@ -93,6 +107,20 @@ protected AbstractMessage(MessageType messageType, Resource resource, Map<String
this.metadata.put(MESSAGE_TYPE, messageType);
}

protected AbstractMessage(MessageType messageType, Resource resource, Map<String, Object> metadata, String cache) {
Assert.notNull(resource, "Resource must not be null");
try (InputStream inputStream = resource.getInputStream()) {
this.textContent = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
}
catch (IOException ex) {
throw new RuntimeException("Failed to read resource", ex);
}
this.messageType = messageType;
this.metadata = new HashMap<>(metadata);
this.metadata.put(MESSAGE_TYPE, messageType);
this.cache = cache;
}

/**
* Get the content of the message.
* @return the content of the message
Expand Down Expand Up @@ -125,6 +153,10 @@ public MessageType getMessageType() {
return this.messageType;
}

public String getCache() {
return cache;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ public class UserMessage extends AbstractMessage implements MediaContent {

protected final List<Media> media;

public UserMessage(String textContent, String cache) {
this(MessageType.USER, textContent, new ArrayList<>(), Map.of(), cache);
}

public UserMessage(String textContent) {
this(MessageType.USER, textContent, new ArrayList<>(), Map.of());
}
Expand All @@ -45,6 +49,11 @@ public UserMessage(Resource resource) {
this.media = new ArrayList<>();
}

public UserMessage(Resource resource, String cache) {
super(MessageType.USER, resource, Map.of(), cache);
this.media = new ArrayList<>();
}

public UserMessage(String textContent, List<Media> media) {
this(MessageType.USER, textContent, media, Map.of());
}
Expand All @@ -64,6 +73,17 @@ public UserMessage(MessageType messageType, String textContent, Collection<Media
this.media = new ArrayList<>(media);
}

public UserMessage(MessageType messageType, String textContent, Collection<Media> media,
Map<String, Object> metadata, String cache) {
super(messageType, textContent, metadata, cache);
Assert.notNull(media, "media data must not be null");
this.media = new ArrayList<>(media);
}

public List<Media> getMedia(String... dummy) {
return this.media;
}

@Override
public String toString() {
return "UserMessage{" + "content='" + getText() + '\'' + ", properties=" + this.metadata + ", messageType="
Expand All @@ -80,4 +100,9 @@ public String getText() {
return this.textContent;
}

@Override
public String getCache() {
return super.getCache();
}

}
Loading