Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add mistral generator #1135

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions docs/source/garak.generators.mistral.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
garak.generators.mistral

.. automodule:: garak.generators.mistral
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/generators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ For a detailed oversight into how a generator operates, see :ref:`garak.generato
garak.generators.langchain
garak.generators.langchain_serve
garak.generators.litellm
garak.generators.mistral
garak.generators.octo
garak.generators.ollama
garak.generators.openai
Expand Down
56 changes: 56 additions & 0 deletions garak/generators/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
DEFAULT_CLASS = "MistralGenerator"
import os
import backoff
from garak.generators.base import Generator
import garak._config as _config
from mistralai import Mistral, models
from garak import exception


class MistralGenerator(Generator):
"""
Interface for public endpoints of models hosted in Mistral La Plateforme (console.mistral.ai).
Expects API key in MISTRAL_API_TOKEN environment variable.
"""

generator_family_name = "mistral"
fullname = "Mistral AI"
supports_multiple_generations = False
ENV_VAR = "MISTRAL_API_KEY"
DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"name": "mistral-large-latest",
}

# avoid attempt to pickle the client attribute
def __getstate__(self) -> object:
self._clear_client()
return dict(self.__dict__)

# restore the client attribute
def __setstate__(self, d) -> object:
self.__dict__.update(d)
self._load_client()

def _load_client(self):
self.client = Mistral(api_key=self.api_key)

def _clear_client(self):
self.client = None

def __init__(self, name="", config_root=_config):
super().__init__(name, config_root)
self._load_client()

@backoff.on_exception(backoff.fibo, models.SDKError, max_value=70)
def _call_model(self, prompt, generations_this_call=1):
print(self.name)
chat_response = self.client.chat.complete(
model=self.name,
messages=[
{
"role": "user",
"content": prompt,
},
],
)
return [chat_response.choices[0].message.content]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ dependencies = [
"xdg-base-dirs>=6.0.1",
"wn==0.9.5",
"ollama>=0.4.7",
"tiktoken>=0.7.0"
"tiktoken>=0.7.0",
"mistralai==1.5.2"
]

[project.optional-dependencies]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ xdg-base-dirs>=6.0.1
wn==0.9.5
ollama>=0.4.7
tiktoken>=0.7.0
mistralai==1.5.2
# tests
pytest>=8.0
pytest-mock>=3.14.0
Expand Down
6 changes: 6 additions & 0 deletions tests/generators/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,9 @@ def watsonx_compat_mocks():
"""Mock responses for watsonx.ai based endpoints"""
with open(pathlib.Path(__file__).parents[0] / "watsonx.json") as mock_watsonx:
return json.load(mock_watsonx)

@pytest.fixture
def mistral_compat_mocks():
"""Mock responses for OpenAI compatible endpoints"""
with open(pathlib.Path(__file__).parents[0] / "mistral.json") as mock_mistral:
return json.load(mock_mistral)
26 changes: 26 additions & 0 deletions tests/generators/mistral.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"mistralai_generation": {
"code": 200,
"json": {
"id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
"object": "chat.completion",
"model": "mistral-small-latest",
"created": 1742909709,
"choices": [
{
"message": {
"role": "assistant",
"content": "Ceci est une génération de test. :)"
},
"finish_reason": "stop",
"index": 0
}
],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 32,
"total_tokens": 37
}
}
}
}
45 changes: 45 additions & 0 deletions tests/generators/test_mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import pytest
import httpx
from unittest.mock import patch
from garak.generators.mistral import MistralGenerator

DEFAULT_DEPLOYMENT_NAME = "mistral-small-latest"

@pytest.fixture
def set_fake_env(request) -> None:
stored_env = os.getenv(MistralGenerator.ENV_VAR, None)

def restore_env():
if stored_env is not None:
os.environ[MistralGenerator.ENV_VAR] = stored_env
else:
del os.environ[MistralGenerator.ENV_VAR]

os.environ[MistralGenerator.ENV_VAR] = os.path.abspath(__file__)
request.addfinalizer(restore_env)

@pytest.mark.usefixtures("set_fake_env")
@pytest.mark.respx(base_url="https://api.mistral.ai/v1")
def test_mistral_generator(respx_mock, mistral_compat_mocks):
mock_response = mistral_compat_mocks["mistralai_generation"]
extended_request = "chat/completions"
respx_mock.post(extended_request).mock(
return_value=httpx.Response(mock_response["code"], json=mock_response["json"])
)
generator = MistralGenerator(name=DEFAULT_DEPLOYMENT_NAME)
assert generator.name == DEFAULT_DEPLOYMENT_NAME
output = generator.generate("Hello Mistral!")
assert len(output) == 1 # expect 1 generation by default
print("test passed!")

@pytest.mark.skipif(
os.getenv(MistralGenerator.ENV_VAR, None) is None,
reason=f"Mistral API key is not set in {MistralGenerator.ENV_VAR}",
)
def test_mistral_chat():
generator = MistralGenerator(name=DEFAULT_DEPLOYMENT_NAME)
assert generator.name == DEFAULT_DEPLOYMENT_NAME
output = generator.generate("Hello Mistral!")
assert len(output) == 1 # expect 1 generation by default
print("test passed!")