Skip to content

Commit d7bd2ad

Browse files
authored
Merge pull request #901 from Mirascope/fix-899
Fix 899
2 parents dedcded + 692ce07 commit d7bd2ad

File tree

3 files changed

+42
-12
lines changed

3 files changed

+42
-12
lines changed

mirascope/llm/_call.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,36 @@
5050
def _get_local_provider_call(
5151
provider: LocalProvider,
5252
client: Any | None, # noqa: ANN401
53+
is_async: bool,
5354
) -> tuple[Callable, Any | None]:
5455
if provider == "ollama":
5556
from ..core.openai import openai_call
5657

5758
if client:
5859
return openai_call, client
59-
from openai import OpenAI
60+
if is_async:
61+
from openai import AsyncOpenAI
6062

61-
client = OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
63+
client = AsyncOpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
64+
else:
65+
from openai import OpenAI
66+
67+
client = OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
6268
return openai_call, client
6369
else: # provider == "vllm"
6470
from ..core.openai import openai_call
6571

6672
if client:
6773
return openai_call, client
68-
from openai import OpenAI
6974

70-
client = OpenAI(api_key="ollama", base_url="http://localhost:8000/v1")
75+
if is_async:
76+
from openai import AsyncOpenAI
77+
78+
client = AsyncOpenAI(api_key="ollama", base_url="http://localhost:8000/v1")
79+
else:
80+
from openai import OpenAI
81+
82+
client = OpenAI(api_key="ollama", base_url="http://localhost:8000/v1")
7183
return openai_call, client
7284

7385

@@ -245,7 +257,9 @@ async def inner_async(
245257

246258
if effective_provider in get_args(LocalProvider):
247259
provider_call, effective_client = _get_local_provider_call(
248-
cast(LocalProvider, effective_provider), effective_client
260+
cast(LocalProvider, effective_provider),
261+
effective_client,
262+
True,
249263
)
250264
effective_call_args["client"] = effective_client
251265
else:
@@ -293,7 +307,9 @@ def inner(
293307

294308
if effective_provider in get_args(LocalProvider):
295309
provider_call, effective_client = _get_local_provider_call(
296-
cast(LocalProvider, effective_provider), effective_client
310+
cast(LocalProvider, effective_provider),
311+
effective_client,
312+
False,
297313
)
298314
effective_call_args["client"] = effective_client
299315
else:

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "mirascope"
3-
version = "1.21.0"
3+
version = "1.21.1"
44
description = "LLM abstractions that aren't obstructions"
55
readme = "README.md"
66
license = { file = "LICENSE" }

tests/llm/test_call.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from unittest.mock import Mock, patch
44

55
import pytest
6-
from openai import OpenAI
6+
from openai import AsyncOpenAI, OpenAI
77
from pydantic import computed_field
88

99
from mirascope.core.base import (
@@ -231,29 +231,43 @@ def test_get_provider_call_xai():
231231

232232
def test_get_local_provider_call_ollama():
233233
with patch("mirascope.core.openai.openai_call", new="openai_ollama_mock"):
234-
func, client = _get_local_provider_call("ollama", None)
234+
func, client = _get_local_provider_call("ollama", None, False)
235235
assert func == "openai_ollama_mock"
236236
assert (
237237
isinstance(client, OpenAI)
238238
and client.api_key == "ollama"
239239
and str(client.base_url) == "http://localhost:11434/v1/"
240240
)
241+
func, client = _get_local_provider_call("ollama", None, True)
242+
assert func == "openai_ollama_mock"
243+
assert (
244+
isinstance(client, AsyncOpenAI)
245+
and client.api_key == "ollama"
246+
and str(client.base_url) == "http://localhost:11434/v1/"
247+
)
241248
mock_client = Mock()
242-
_, client = _get_local_provider_call("ollama", mock_client)
249+
_, client = _get_local_provider_call("ollama", mock_client, False)
243250
assert client == mock_client
244251

245252

246253
def test_get_local_provider_call_vllm():
247254
with patch("mirascope.core.openai.openai_call", new="openai_vllm_mock"):
248-
func, client = _get_local_provider_call("vllm", None)
255+
func, client = _get_local_provider_call("vllm", None, False)
249256
assert func == "openai_vllm_mock"
250257
assert (
251258
isinstance(client, OpenAI)
252259
and client.api_key == "ollama"
253260
and str(client.base_url) == "http://localhost:8000/v1/"
254261
)
262+
func, client = _get_local_provider_call("vllm", None, True)
263+
assert func == "openai_vllm_mock"
264+
assert (
265+
isinstance(client, AsyncOpenAI)
266+
and client.api_key == "ollama"
267+
and str(client.base_url) == "http://localhost:8000/v1/"
268+
)
255269
mock_client = Mock()
256-
_, client = _get_local_provider_call("vllm", mock_client)
270+
_, client = _get_local_provider_call("vllm", mock_client, False)
257271
assert client == mock_client
258272

259273

0 commit comments

Comments
 (0)