Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Anthropic API tests #80

Merged
merged 24 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
0fac0b9
testing anthopic
EdwinB12 Jul 31, 2024
7c66d47
test get_environment_variable function
rchan26 Aug 1, 2024
7f2a9eb
test check_max_queries_dict
rchan26 Aug 1, 2024
4995585
update get_environment_variable log
rchan26 Aug 1, 2024
bf76d20
anthropic tests with mocking
rchan26 Aug 1, 2024
a76a10e
fix small errors in anthropic
rchan26 Aug 1, 2024
e547729
refactor anthropic tests
rchan26 Aug 1, 2024
bae6efd
fix the type error checks for prompts
rchan26 Aug 1, 2024
c178145
continue anthropic api tests
rchan26 Aug 1, 2024
38d3cb7
enforce system message being the first for gemini, vertexai and anthr…
rchan26 Aug 1, 2024
d8a355b
history query anthropic tests and system message enforce
rchan26 Aug 1, 2024
5557486
tidy up query if/else
rchan26 Aug 1, 2024
5eaa4e8
convert gemini_chat_roles to list first before adding
rchan26 Aug 1, 2024
6ca835b
anthropic chat input tests
rchan26 Aug 2, 2024
e59e220
remove the mimimal example when dev
rchan26 Aug 2, 2024
a685151
add some general query tests for anthropic
rchan26 Aug 2, 2024
9d491e3
get anthropic coverage to 100
rchan26 Aug 2, 2024
80f769f
tidy up
rchan26 Aug 2, 2024
f06a60a
add tests for creating ASYNC_APIS and importing
rchan26 Aug 2, 2024
24c4b5a
uncover error test
rchan26 Aug 2, 2024
66f3590
Merge branch 'anthropic-api-test' into mock-api-imports
rchan26 Aug 2, 2024
ea175ee
add badges to readme
rchan26 Aug 2, 2024
43eb5e7
configure codecov
rchan26 Aug 5, 2024
4f9bfb1
Merge pull request #81 from alan-turing-institute/mock-api-imports
EdwinB12 Aug 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ jobs:

- name: Upload coverage report
uses: codecov/[email protected]
with:
token: ${{ secrets.CODECOV_TOKEN }}

docs:
needs: [pre-commit, checks]
Expand Down
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
[![Actions Status][actions-badge]][actions-link]
[![Codecov Status][codecov-badge]][codecov-link]
[![PyPI version][pypi-version]][pypi-link]
[![PyPI platforms][pypi-platforms]][pypi-link]
![GitHub](https://img.shields.io/github/license/alan-turing-institute/prompto)

<!-- [![GitHub Discussion][github-discussions-badge]][github-discussions-link]
[![Gitter][gitter-badge]][gitter-link] -->

<!-- prettier-ignore-start -->
[actions-badge]: https://github.com/alan-turing-institute/prompto/workflows/CI/badge.svg
[actions-link]: https://github.com/alan-turing-institute/prompto/actions
[codecov-badge]: https://codecov.io/gh/alan-turing-institute/prompto/branch/main/graph/badge.svg?token=SU9HZ9NH70
[codecov-link]: https://codecov.io/gh/alan-turing-institute/prompto
[github-discussions-badge]: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github
[github-discussions-link]: https://github.com/alan-turing-institute/prompto/discussions
[pypi-link]: https://pypi.org/project/prompto/
[pypi-platforms]: https://img.shields.io/pypi/pyversions/prompto
[pypi-version]: https://img.shields.io/pypi/v/prompto
<!-- prettier-ignore-end -->


# prompto

`prompto` is a Python library which facilitates processing of experiments of Large Language Models (LLMs) stored as jsonl files. It automates _asynchronous querying of LLM API endpoints_ and logs progress.
Expand Down
64 changes: 36 additions & 28 deletions src/prompto/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ class DependencyWarning(Warning):
from prompto.apis.azure_openai import AzureOpenAIAPI

ASYNC_APIS["azure-openai"] = AzureOpenAIAPI
except ImportError as exc:
except ImportError:
warnings.warn(
message=(
f"Azure OpenAI model ('azure-openai') not available. Perhaps you need to install the Azure OpenAI dependencies: {exc}. "
"Azure OpenAI model ('azure-openai') not available. "
"Perhaps you need to install the Azure OpenAI dependencies: "
"Try `pip install prompto[azure_openai]`"
),
category=DependencyWarning,
Expand All @@ -35,23 +36,40 @@ class DependencyWarning(Warning):
from prompto.apis.openai import OpenAIAPI

ASYNC_APIS["openai"] = OpenAIAPI
except ImportError as exc:
except ImportError:
warnings.warn(
message=(
f"OpenAI model ('openai') not available. Perhaps you need to install the OpenAI dependencies: {exc}. "
"OpenAI model ('openai') not available. "
"Perhaps you need to install the OpenAI dependencies: "
"Try `pip install prompto[openai]`"
),
category=DependencyWarning,
)

try:
from prompto.apis.anthropic import AnthropicAPI

ASYNC_APIS["anthropic"] = AnthropicAPI
except ImportError:
warnings.warn(
message=(
"Anthropic API ('anthropic') not available. "
"Perhaps you need to install the Anthropic dependencies: "
"Try `pip install prompto[anthropic]`"
),
category=DependencyWarning,
)


try:
from prompto.apis.gemini import GeminiAPI

ASYNC_APIS["gemini"] = GeminiAPI
except ImportError as exc:
except ImportError:
warnings.warn(
message=(
f"Gemini API ('gemini') not available. Perhaps you need to install the Gemini dependencies: {exc}. "
"Gemini API ('gemini') not available. "
"Perhaps you need to install the Gemini dependencies: "
"Try `pip install prompto[gemini]`"
),
category=DependencyWarning,
Expand All @@ -61,10 +79,11 @@ class DependencyWarning(Warning):
from prompto.apis.vertexai import VertexAIAPI

ASYNC_APIS["vertexai"] = VertexAIAPI
except ImportError as exc:
except ImportError:
warnings.warn(
message=(
f"Vertex AI API ('vertexai') not available. Perhaps you need to install the Vertex AI dependencies: {exc}. "
"Vertex AI API ('vertexai') not available. "
"Perhaps you need to install the Vertex AI dependencies: "
"Try `pip install prompto[vertexai]`"
),
category=DependencyWarning,
Expand All @@ -74,10 +93,11 @@ class DependencyWarning(Warning):
from prompto.apis.ollama import OllamaAPI

ASYNC_APIS["ollama"] = OllamaAPI
except ImportError as exc:
except ImportError:
warnings.warn(
message=(
f"Ollama API ('ollama') not available. Perhaps you need to install the Ollama dependencies: {exc}. "
"Ollama API ('ollama') not available. "
"Perhaps you need to install the Ollama dependencies: "
"Try `pip install prompto[ollama]`"
),
category=DependencyWarning,
Expand All @@ -87,10 +107,11 @@ class DependencyWarning(Warning):
from prompto.apis.huggingface_tgi import HuggingfaceTGIAPI

ASYNC_APIS["huggingface-tgi"] = HuggingfaceTGIAPI
except ImportError as exc:
except ImportError:
warnings.warn(
message=(
f"Huggingface TGI API ('huggingface-tgi') not available. Perhaps you need to install the Huggingface TGI dependencies: {exc}. "
"Huggingface TGI API ('huggingface-tgi') not available. "
"Perhaps you need to install the Huggingface TGI dependencies: "
"Try `pip install prompto[huggingface_tgi]`"
),
category=DependencyWarning,
Expand All @@ -100,27 +121,14 @@ class DependencyWarning(Warning):
from prompto.apis.quart import QuartAPI

ASYNC_APIS["quart"] = QuartAPI
except ImportError as exc:
except ImportError:
warnings.warn(
message=(
f"Quart API ('quart') not available. Perhaps you need to install the Quart dependencies: {exc}. "
"Quart API ('quart') not available. "
"Perhaps you need to install the Quart dependencies: "
"Try `pip install prompto[quart]`"
),
category=DependencyWarning,
)

try:
from prompto.apis.anthropic import AnthropicAPI

ASYNC_APIS["anthropic"] = AnthropicAPI
except ImportError as exc:
warnings.warn(
message=(
f"Anthropic API ('anthropic') not available. Perhaps you need to install the Anthropic dependencies: {exc}. "
"Try `pip install prompto[anthropic]`"
),
category=DependencyWarning,
)


__all__ = ["ASYNC_APIS"]
91 changes: 48 additions & 43 deletions src/prompto/apis/anthropic/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,16 @@
write_log_message,
)

# set names of environment variables
API_KEY_VAR_NAME = "ANTHROPIC_API_KEY"

TYPE_ERROR = TypeError(
"if api == 'anthropic', then the prompt must be a str, list[str], or "
"list[dict[str,str]] where the dictionary contains the keys 'role' and "
"'content' only, and the values for 'role' must be one of 'user' or 'model', "
"except for the first message in the list of dictionaries can be a "
"system message with the key 'role' set to 'system'."
)


class AnthropicAPI(AsyncAPI):
"""
Expand All @@ -46,7 +53,6 @@ def __init__(
*args: Any,
**kwargs: Any,
):

super().__init__(settings=settings, log_file=log_file, *args, **kwargs)
self.api_type = "anthropic"

Expand Down Expand Up @@ -106,25 +112,26 @@ def check_prompt_dict(prompt_dict: dict) -> list[Exception]:
elif isinstance(prompt_dict["prompt"], list):
if all([isinstance(message, str) for message in prompt_dict["prompt"]]):
pass
if all(
isinstance(message, dict) for message in prompt_dict["prompt"]
) and all(
[
set(d.keys()) == {"role", "content"}
and d["role"] in anthropic_chat_roles
for d in prompt_dict["prompt"]
]
elif (
all(isinstance(message, dict) for message in prompt_dict["prompt"])
and (
set(prompt_dict["prompt"][0].keys()) == {"role", "content"}
and prompt_dict["prompt"][0]["role"]
in list(anthropic_chat_roles) + ["system"]
)
and all(
[
set(d.keys()) == {"role", "content"}
and d["role"] in anthropic_chat_roles
for d in prompt_dict["prompt"][1:]
]
)
):
pass
else:
issues.append(TYPE_ERROR)
else:
issues.append(
TypeError(
"if api == 'anthropic', then the prompt must be a str, list[str], or "
"list[dict[str,str]] where the dictionary contains the keys 'role' and "
"'content' only, and the values for 'role' must be one of 'system', 'user' or "
"'assistant'"
)
)
issues.append(TYPE_ERROR)

# use the model specific environment variables if they exist
model_name = prompt_dict["model_name"]
Expand All @@ -145,7 +152,7 @@ def check_prompt_dict(prompt_dict: dict) -> list[Exception]:

async def _obtain_model_inputs(
self, prompt_dict: dict
) -> tuple[str, str, AsyncAnthropic, dict, str]:
) -> tuple[str, str, AsyncAnthropic, dict]:
"""
Async method to obtain the model inputs from the prompt dictionary.

Expand All @@ -156,7 +163,7 @@ async def _obtain_model_inputs(

Returns
-------
tuple[str, str, AsyncAnthropic, dict, str]
tuple[str, str, AsyncAnthropic, dict]
A tuple containing the prompt, model name, AsyncAnthropic client object,
the generation config, and mode to use for querying the model
"""
Expand Down Expand Up @@ -275,7 +282,6 @@ async def _query_chat(self, prompt_dict: dict, index: int | str) -> dict:

prompt_dict["response"] = response_list
return prompt_dict

except Exception as err:
error_as_string = f"{type(err).__name__} - {err}"
log_message = log_error_response_chat(
Expand Down Expand Up @@ -312,31 +318,32 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict:
prompt_dict
)

# Remove the "system" role from the prompt and add it to the system parameter
# pop the "system" role from the prompt
system = [
message_dict["content"]
for message_dict in prompt
if message_dict["role"] == "system"
]

prompt = [
# remove the system messages from prompt
messages = [
message_dict for message_dict in prompt if message_dict["role"] != "system"
]

# If system message is present, then it must be the only one
# if system message is present, then it must be the only one
if len(system) == 0:
system = None
elif len(system) == 1:
system = system[0]
else:
raise ValueError(
f"There are {len(system)} system messages. Only one system message is supported."
f"There are {len(system)} system messages. Only one system message is supported"
)

try:
response = await client.messages.create(
model=model_name,
messages=prompt,
messages=messages,
system=system,
**generation_config,
)
Expand Down Expand Up @@ -392,37 +399,35 @@ async def query(self, prompt_dict: dict, index: int | str = "NA") -> dict:
Exception
If an error occurs during the querying process
"""
# If prompt is a single string, then use query string method
if isinstance(prompt_dict["prompt"], str):
return await self._query_string(
prompt_dict=prompt_dict,
index=index,
)
elif isinstance(prompt_dict["prompt"], list):
# If prompt is a list of strings, then use query chat method
if all([isinstance(message, str) for message in prompt_dict["prompt"]]):
return await self._query_chat(
prompt_dict=prompt_dict,
index=index,
)
# If prompt is a list of dictionaries, then use query history method
if all(
isinstance(message, dict) for message in prompt_dict["prompt"]
) and all(
[
set(d.keys()) == {"role", "content"}
and d["role"] in anthropic_chat_roles
for d in prompt_dict["prompt"]
]
elif (
all(isinstance(message, dict) for message in prompt_dict["prompt"])
and (
set(prompt_dict["prompt"][0].keys()) == {"role", "content"}
and prompt_dict["prompt"][0]["role"]
in list(anthropic_chat_roles) + ["system"]
)
and all(
[
set(d.keys()) == {"role", "content"}
and d["role"] in anthropic_chat_roles
for d in prompt_dict["prompt"][1:]
]
)
):
return await self._query_history(
prompt_dict=prompt_dict,
index=index,
)

raise TypeError(
"if api == 'anthropic', then the prompt must be a str, list[str], or "
"list[dict[str,str]] where the dictionary contains the keys 'role' and "
"'content' only, and the values for 'role' must be one of 'system', 'user' or "
"'assistant'"
)
raise TYPE_ERROR
2 changes: 1 addition & 1 deletion src/prompto/apis/anthropic/anthropic_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from anthropic.types.message import Message

anthropic_chat_roles = set(["user", "assistant", "system"])
anthropic_chat_roles = set(["user", "assistant"])


def process_response(response: Message) -> str | list[str]:
Expand Down
Loading
Loading