|
3 | 3 | from unittest.mock import Mock, patch
|
4 | 4 |
|
5 | 5 | import pytest
|
6 |
| -from openai import OpenAI |
| 6 | +from openai import AsyncOpenAI, OpenAI |
7 | 7 | from pydantic import computed_field
|
8 | 8 |
|
9 | 9 | from mirascope.core.base import (
|
@@ -231,29 +231,43 @@ def test_get_provider_call_xai():
|
231 | 231 |
|
232 | 232 | def test_get_local_provider_call_ollama():
|
233 | 233 | with patch("mirascope.core.openai.openai_call", new="openai_ollama_mock"):
|
234 |
| - func, client = _get_local_provider_call("ollama", None) |
| 234 | + func, client = _get_local_provider_call("ollama", None, False) |
235 | 235 | assert func == "openai_ollama_mock"
|
236 | 236 | assert (
|
237 | 237 | isinstance(client, OpenAI)
|
238 | 238 | and client.api_key == "ollama"
|
239 | 239 | and str(client.base_url) == "http://localhost:11434/v1/"
|
240 | 240 | )
|
| 241 | + func, client = _get_local_provider_call("ollama", None, True) |
| 242 | + assert func == "openai_ollama_mock" |
| 243 | + assert ( |
| 244 | + isinstance(client, AsyncOpenAI) |
| 245 | + and client.api_key == "ollama" |
| 246 | + and str(client.base_url) == "http://localhost:11434/v1/" |
| 247 | + ) |
241 | 248 | mock_client = Mock()
|
242 |
| - _, client = _get_local_provider_call("ollama", mock_client) |
| 249 | + _, client = _get_local_provider_call("ollama", mock_client, False) |
243 | 250 | assert client == mock_client
|
244 | 251 |
|
245 | 252 |
|
246 | 253 | def test_get_local_provider_call_vllm():
|
247 | 254 | with patch("mirascope.core.openai.openai_call", new="openai_vllm_mock"):
|
248 |
| - func, client = _get_local_provider_call("vllm", None) |
| 255 | + func, client = _get_local_provider_call("vllm", None, False) |
249 | 256 | assert func == "openai_vllm_mock"
|
250 | 257 | assert (
|
251 | 258 | isinstance(client, OpenAI)
|
252 | 259 | and client.api_key == "ollama"
|
253 | 260 | and str(client.base_url) == "http://localhost:8000/v1/"
|
254 | 261 | )
|
| 262 | + func, client = _get_local_provider_call("vllm", None, True) |
| 263 | + assert func == "openai_vllm_mock" |
| 264 | + assert ( |
| 265 | + isinstance(client, AsyncOpenAI) |
| 266 | + and client.api_key == "ollama" |
| 267 | + and str(client.base_url) == "http://localhost:8000/v1/" |
| 268 | + ) |
255 | 269 | mock_client = Mock()
|
256 |
| - _, client = _get_local_provider_call("vllm", mock_client) |
| 270 | + _, client = _get_local_provider_call("vllm", mock_client, False) |
257 | 271 | assert client == mock_client
|
258 | 272 |
|
259 | 273 |
|
|
0 commit comments