Skip to content

Commit

Permalink
Merge pull request #80 from alan-turing-institute/anthropic-api-test
Browse files Browse the repository at this point in the history
Add Anthropic API tests
  • Loading branch information
rchan26 authored Aug 6, 2024
2 parents 3febf93 + 4f9bfb1 commit d9a50d0
Show file tree
Hide file tree
Showing 27 changed files with 1,858 additions and 278 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ jobs:

- name: Upload coverage report
uses: codecov/[email protected]
with:
token: ${{ secrets.CODECOV_TOKEN }}

docs:
needs: [pre-commit, checks]
Expand Down
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
[![Actions Status][actions-badge]][actions-link]
[![Codecov Status][codecov-badge]][codecov-link]
[![PyPI version][pypi-version]][pypi-link]
[![PyPI platforms][pypi-platforms]][pypi-link]
![GitHub](https://img.shields.io/github/license/alan-turing-institute/prompto)

<!-- [![GitHub Discussion][github-discussions-badge]][github-discussions-link]
[![Gitter][gitter-badge]][gitter-link] -->

<!-- prettier-ignore-start -->
[actions-badge]: https://github.com/alan-turing-institute/prompto/workflows/CI/badge.svg
[actions-link]: https://github.com/alan-turing-institute/prompto/actions
[codecov-badge]: https://codecov.io/gh/alan-turing-institute/prompto/branch/main/graph/badge.svg?token=SU9HZ9NH70
[codecov-link]: https://codecov.io/gh/alan-turing-institute/prompto
[github-discussions-badge]: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github
[github-discussions-link]: https://github.com/alan-turing-institute/prompto/discussions
[pypi-link]: https://pypi.org/project/prompto/
[pypi-platforms]: https://img.shields.io/pypi/pyversions/prompto
[pypi-version]: https://img.shields.io/pypi/v/prompto
<!-- prettier-ignore-end -->


# prompto

`prompto` is a Python library which facilitates processing of experiments of Large Language Models (LLMs) stored as jsonl files. It automates _asynchronous querying of LLM API endpoints_ and logs progress.
Expand Down
64 changes: 36 additions & 28 deletions src/prompto/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ class DependencyWarning(Warning):
from prompto.apis.azure_openai import AzureOpenAIAPI

ASYNC_APIS["azure-openai"] = AzureOpenAIAPI
except ImportError as exc:
except ImportError:
warnings.warn(
message=(
f"Azure OpenAI model ('azure-openai') not available. Perhaps you need to install the Azure OpenAI dependencies: {exc}. "
"Azure OpenAI model ('azure-openai') not available. "
"Perhaps you need to install the Azure OpenAI dependencies: "
"Try `pip install prompto[azure_openai]`"
),
category=DependencyWarning,
Expand All @@ -35,23 +36,40 @@ class DependencyWarning(Warning):
from prompto.apis.openai import OpenAIAPI

ASYNC_APIS["openai"] = OpenAIAPI
except ImportError as exc:
except ImportError:
warnings.warn(
message=(
f"OpenAI model ('openai') not available. Perhaps you need to install the OpenAI dependencies: {exc}. "
"OpenAI model ('openai') not available. "
"Perhaps you need to install the OpenAI dependencies: "
"Try `pip install prompto[openai]`"
),
category=DependencyWarning,
)

try:
from prompto.apis.anthropic import AnthropicAPI

ASYNC_APIS["anthropic"] = AnthropicAPI
except ImportError:
warnings.warn(
message=(
"Anthropic API ('anthropic') not available. "
"Perhaps you need to install the Anthropic dependencies: "
"Try `pip install prompto[anthropic]`"
),
category=DependencyWarning,
)


try:
from prompto.apis.gemini import GeminiAPI

ASYNC_APIS["gemini"] = GeminiAPI
except ImportError as exc:
except ImportError:
warnings.warn(
message=(
f"Gemini API ('gemini') not available. Perhaps you need to install the Gemini dependencies: {exc}. "
"Gemini API ('gemini') not available. "
"Perhaps you need to install the Gemini dependencies: "
"Try `pip install prompto[gemini]`"
),
category=DependencyWarning,
Expand All @@ -61,10 +79,11 @@ class DependencyWarning(Warning):
from prompto.apis.vertexai import VertexAIAPI

ASYNC_APIS["vertexai"] = VertexAIAPI
except ImportError as exc:
except ImportError:
warnings.warn(
message=(
f"Vertex AI API ('vertexai') not available. Perhaps you need to install the Vertex AI dependencies: {exc}. "
"Vertex AI API ('vertexai') not available. "
"Perhaps you need to install the Vertex AI dependencies: "
"Try `pip install prompto[vertexai]`"
),
category=DependencyWarning,
Expand All @@ -74,10 +93,11 @@ class DependencyWarning(Warning):
from prompto.apis.ollama import OllamaAPI

ASYNC_APIS["ollama"] = OllamaAPI
except ImportError as exc:
except ImportError:
warnings.warn(
message=(
f"Ollama API ('ollama') not available. Perhaps you need to install the Ollama dependencies: {exc}. "
"Ollama API ('ollama') not available. "
"Perhaps you need to install the Ollama dependencies: "
"Try `pip install prompto[ollama]`"
),
category=DependencyWarning,
Expand All @@ -87,10 +107,11 @@ class DependencyWarning(Warning):
from prompto.apis.huggingface_tgi import HuggingfaceTGIAPI

ASYNC_APIS["huggingface-tgi"] = HuggingfaceTGIAPI
except ImportError as exc:
except ImportError:
warnings.warn(
message=(
f"Huggingface TGI API ('huggingface-tgi') not available. Perhaps you need to install the Huggingface TGI dependencies: {exc}. "
"Huggingface TGI API ('huggingface-tgi') not available. "
"Perhaps you need to install the Huggingface TGI dependencies: "
"Try `pip install prompto[huggingface_tgi]`"
),
category=DependencyWarning,
Expand All @@ -100,27 +121,14 @@ class DependencyWarning(Warning):
from prompto.apis.quart import QuartAPI

ASYNC_APIS["quart"] = QuartAPI
except ImportError as exc:
except ImportError:
warnings.warn(
message=(
f"Quart API ('quart') not available. Perhaps you need to install the Quart dependencies: {exc}. "
"Quart API ('quart') not available. "
"Perhaps you need to install the Quart dependencies: "
"Try `pip install prompto[quart]`"
),
category=DependencyWarning,
)

try:
from prompto.apis.anthropic import AnthropicAPI

ASYNC_APIS["anthropic"] = AnthropicAPI
except ImportError as exc:
warnings.warn(
message=(
f"Anthropic API ('anthropic') not available. Perhaps you need to install the Anthropic dependencies: {exc}. "
"Try `pip install prompto[anthropic]`"
),
category=DependencyWarning,
)


__all__ = ["ASYNC_APIS"]
91 changes: 48 additions & 43 deletions src/prompto/apis/anthropic/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,16 @@
write_log_message,
)

# set names of environment variables
API_KEY_VAR_NAME = "ANTHROPIC_API_KEY"

TYPE_ERROR = TypeError(
"if api == 'anthropic', then the prompt must be a str, list[str], or "
"list[dict[str,str]] where the dictionary contains the keys 'role' and "
"'content' only, and the values for 'role' must be one of 'user' or 'model', "
"except for the first message in the list of dictionaries can be a "
"system message with the key 'role' set to 'system'."
)


class AnthropicAPI(AsyncAPI):
"""
Expand All @@ -46,7 +53,6 @@ def __init__(
*args: Any,
**kwargs: Any,
):

super().__init__(settings=settings, log_file=log_file, *args, **kwargs)
self.api_type = "anthropic"

Expand Down Expand Up @@ -106,25 +112,26 @@ def check_prompt_dict(prompt_dict: dict) -> list[Exception]:
elif isinstance(prompt_dict["prompt"], list):
if all([isinstance(message, str) for message in prompt_dict["prompt"]]):
pass
if all(
isinstance(message, dict) for message in prompt_dict["prompt"]
) and all(
[
set(d.keys()) == {"role", "content"}
and d["role"] in anthropic_chat_roles
for d in prompt_dict["prompt"]
]
elif (
all(isinstance(message, dict) for message in prompt_dict["prompt"])
and (
set(prompt_dict["prompt"][0].keys()) == {"role", "content"}
and prompt_dict["prompt"][0]["role"]
in list(anthropic_chat_roles) + ["system"]
)
and all(
[
set(d.keys()) == {"role", "content"}
and d["role"] in anthropic_chat_roles
for d in prompt_dict["prompt"][1:]
]
)
):
pass
else:
issues.append(TYPE_ERROR)
else:
issues.append(
TypeError(
"if api == 'anthropic', then the prompt must be a str, list[str], or "
"list[dict[str,str]] where the dictionary contains the keys 'role' and "
"'content' only, and the values for 'role' must be one of 'system', 'user' or "
"'assistant'"
)
)
issues.append(TYPE_ERROR)

# use the model specific environment variables if they exist
model_name = prompt_dict["model_name"]
Expand All @@ -145,7 +152,7 @@ def check_prompt_dict(prompt_dict: dict) -> list[Exception]:

async def _obtain_model_inputs(
self, prompt_dict: dict
) -> tuple[str, str, AsyncAnthropic, dict, str]:
) -> tuple[str, str, AsyncAnthropic, dict]:
"""
Async method to obtain the model inputs from the prompt dictionary.
Expand All @@ -156,7 +163,7 @@ async def _obtain_model_inputs(
Returns
-------
tuple[str, str, AsyncAnthropic, dict, str]
tuple[str, str, AsyncAnthropic, dict]
A tuple containing the prompt, model name, AsyncAnthropic client object,
the generation config, and mode to use for querying the model
"""
Expand Down Expand Up @@ -275,7 +282,6 @@ async def _query_chat(self, prompt_dict: dict, index: int | str) -> dict:

prompt_dict["response"] = response_list
return prompt_dict

except Exception as err:
error_as_string = f"{type(err).__name__} - {err}"
log_message = log_error_response_chat(
Expand Down Expand Up @@ -312,31 +318,32 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict:
prompt_dict
)

# Remove the "system" role from the prompt and add it to the system parameter
# pop the "system" role from the prompt
system = [
message_dict["content"]
for message_dict in prompt
if message_dict["role"] == "system"
]

prompt = [
# remove the system messages from prompt
messages = [
message_dict for message_dict in prompt if message_dict["role"] != "system"
]

# If system message is present, then it must be the only one
# if system message is present, then it must be the only one
if len(system) == 0:
system = None
elif len(system) == 1:
system = system[0]
else:
raise ValueError(
f"There are {len(system)} system messages. Only one system message is supported."
f"There are {len(system)} system messages. Only one system message is supported"
)

try:
response = await client.messages.create(
model=model_name,
messages=prompt,
messages=messages,
system=system,
**generation_config,
)
Expand Down Expand Up @@ -392,37 +399,35 @@ async def query(self, prompt_dict: dict, index: int | str = "NA") -> dict:
Exception
If an error occurs during the querying process
"""
# If prompt is a single string, then use query string method
if isinstance(prompt_dict["prompt"], str):
return await self._query_string(
prompt_dict=prompt_dict,
index=index,
)
elif isinstance(prompt_dict["prompt"], list):
# If prompt is a list of strings, then use query chat method
if all([isinstance(message, str) for message in prompt_dict["prompt"]]):
return await self._query_chat(
prompt_dict=prompt_dict,
index=index,
)
# If prompt is a list of dictionaries, then use query history method
if all(
isinstance(message, dict) for message in prompt_dict["prompt"]
) and all(
[
set(d.keys()) == {"role", "content"}
and d["role"] in anthropic_chat_roles
for d in prompt_dict["prompt"]
]
elif (
all(isinstance(message, dict) for message in prompt_dict["prompt"])
and (
set(prompt_dict["prompt"][0].keys()) == {"role", "content"}
and prompt_dict["prompt"][0]["role"]
in list(anthropic_chat_roles) + ["system"]
)
and all(
[
set(d.keys()) == {"role", "content"}
and d["role"] in anthropic_chat_roles
for d in prompt_dict["prompt"][1:]
]
)
):
return await self._query_history(
prompt_dict=prompt_dict,
index=index,
)

raise TypeError(
"if api == 'anthropic', then the prompt must be a str, list[str], or "
"list[dict[str,str]] where the dictionary contains the keys 'role' and "
"'content' only, and the values for 'role' must be one of 'system', 'user' or "
"'assistant'"
)
raise TYPE_ERROR
2 changes: 1 addition & 1 deletion src/prompto/apis/anthropic/anthropic_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from anthropic.types.message import Message

anthropic_chat_roles = set(["user", "assistant", "system"])
anthropic_chat_roles = set(["user", "assistant"])


def process_response(response: Message) -> str | list[str]:
Expand Down
Loading

0 comments on commit d9a50d0

Please sign in to comment.