Skip to content

use hatch instead of poetry #445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@ jobs:
contents: write
steps:
- uses: actions/checkout@v4
- run: pipx install poetry
- uses: actions/setup-python@v5
- uses: astral-sh/setup-uv@v5
with:
python-version: '3.x'
cache: poetry
- run: |
poetry version ${GITHUB_REF_NAME#v}
poetry build
enable-cache: true
- run: uv build
- uses: pypa/gh-action-pypi-publish@release/v1
- run: gh release upload $GITHUB_REF_NAME dist/*
env:
Expand Down
34 changes: 12 additions & 22 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,29 @@ on:

jobs:
test:
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13']
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- run: pipx install poetry
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for removing the matrix tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is replaced by hatch test --all which tests all configured python versions. I'm still debating whether this is better than matrix tests

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we can setup a matrix if needed in the pyproject.toml if needed: https://hatch.pypa.io/latest/tutorials/testing/overview/#all-environments

I think it makes sense to just define the versions we want tested in the toml

Copy link
Collaborator Author

@mxyng mxyng Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's already implemented and runs with --all option to hatch test however it is slightly slower since the Python versions are sequential rather than parallel, ~1m vs. ~35s

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No that option, which is also set, parallelizes tests in a pytest run, not between python versions

cache: poetry
- run: poetry install --with=dev
- run: poetry run pytest . --junitxml=junit/test-results-${{ matrix.python-version }}.xml --cov=ollama --cov-report=xml --cov-report=html
- uses: actions/upload-artifact@v4
- uses: astral-sh/setup-uv@v5
with:
name: pytest-results-${{ matrix.python-version }}
path: junit/test-results-${{ matrix.python-version }}.xml
enable-cache: true
- run: uvx hatch test -acp
if: ${{ always() }}
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- run: pipx install poetry
- uses: actions/setup-python@v5
- uses: astral-sh/setup-uv@v5
with:
python-version: "3.13"
cache: poetry
- run: poetry install --with=dev
- run: poetry run ruff check --output-format=github .
- run: poetry run ruff format --check .
- name: check poetry.lock is up-to-date
run: poetry check --lock
enable-cache: true
- name: check formatting
run: uvx hatch fmt --check -f
- name: check linting
run: uvx hatch fmt --check -l --output-format=github
- name: check uv.lock is up-to-date
run: uv lock --check
- name: check requirements.txt is up-to-date
run: |
poetry export >requirements.txt
uv export >requirements.txt
git diff --exit-code requirements.txt
5 changes: 1 addition & 4 deletions examples/chat-with-history.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@
user_input = input('Chat with history: ')
response = chat(
'llama3.2',
messages=messages
+ [
{'role': 'user', 'content': user_input},
],
messages=[*messages, {'role': 'user', 'content': user_input}],
)

# Add the response to the messages to maintain the history
Expand Down
5 changes: 1 addition & 4 deletions examples/multimodal-generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
latest = httpx.get('https://xkcd.com/info.0.json')
latest.raise_for_status()

if len(sys.argv) > 1:
num = int(sys.argv[1])
else:
num = random.randint(1, latest.json().get('num'))
num = int(sys.argv[1]) if len(sys.argv) > 1 else random.randint(1, latest.json().get('num'))

comic = httpx.get(f'https://xkcd.com/{num}/info.0.json')
comic.raise_for_status()
Expand Down
30 changes: 16 additions & 14 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
overload,
)

import anyio
from pydantic.json_schema import JsonSchemaValue

from ollama._utils import convert_function_to_tool
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
self,
client,
host: Optional[str] = None,
*,
follow_redirects: bool = True,
timeout: Any = None,
headers: Optional[Mapping[str, str]] = None,
Expand Down Expand Up @@ -253,7 +255,7 @@ def generate(
stream=stream,
raw=raw,
format=format,
images=[image for image in _copy_images(images)] if images else None,
images=list(_copy_images(images)) if images else None,
options=options,
keep_alive=keep_alive,
).model_dump(exclude_none=True),
Expand Down Expand Up @@ -336,8 +338,8 @@ def add_two_numbers(a: int, b: int) -> int:
'/api/chat',
json=ChatRequest(
model=model,
messages=[message for message in _copy_messages(messages)],
tools=[tool for tool in _copy_tools(tools)],
messages=list(_copy_messages(messages)),
tools=list(_copy_tools(tools)),
stream=stream,
format=format,
options=options,
Expand Down Expand Up @@ -756,7 +758,7 @@ async def generate(
stream=stream,
raw=raw,
format=format,
images=[image for image in _copy_images(images)] if images else None,
images=list(_copy_images(images)) if images else None,
options=options,
keep_alive=keep_alive,
).model_dump(exclude_none=True),
Expand Down Expand Up @@ -840,8 +842,8 @@ def add_two_numbers(a: int, b: int) -> int:
'/api/chat',
json=ChatRequest(
model=model,
messages=[message for message in _copy_messages(messages)],
tools=[tool for tool in _copy_tools(tools)],
messages=list(_copy_messages(messages)),
tools=list(_copy_tools(tools)),
stream=stream,
format=format,
options=options,
Expand Down Expand Up @@ -991,7 +993,7 @@ async def create(
parameters: Optional[Union[Mapping[str, Any], Options]] = None,
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
stream: Literal[True] = True,
stream: Literal[False] = False,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This overload was incorrect since stream=True is also overloaded below

) -> ProgressResponse: ...

@overload
Expand Down Expand Up @@ -1054,19 +1056,19 @@ async def create(

async def create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
async with await anyio.open_file(path, 'rb') as r:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not use a blocking open in an async function

while True:
chunk = r.read(32 * 1024)
chunk = await r.read(32 * 1024)
if not chunk:
break
sha256sum.update(chunk)

digest = f'sha256:{sha256sum.hexdigest()}'

async def upload_bytes():
with open(path, 'rb') as r:
async with await anyio.open_file(path, 'rb') as r:
while True:
chunk = r.read(32 * 1024)
chunk = await r.read(32 * 1024)
if not chunk:
break
yield chunk
Expand Down Expand Up @@ -1133,7 +1135,7 @@ def _copy_images(images: Optional[Sequence[Union[Image, Any]]]) -> Iterator[Imag
def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]]]) -> Iterator[Message]:
for message in messages or []:
yield Message.model_validate(
{k: [image for image in _copy_images(v)] if k == 'images' else v for k, v in dict(message).items() if v},
{k: list(_copy_images(v)) if k == 'images' else v for k, v in dict(message).items() if v},
)


Expand All @@ -1143,7 +1145,7 @@ def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable


def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
if isinstance(s, str) or isinstance(s, Path):
if isinstance(s, (str, Path)):
try:
if (p := Path(s)).exists():
return p
Expand Down Expand Up @@ -1225,7 +1227,7 @@ def _parse_host(host: Optional[str]) -> str:
elif scheme == 'https':
port = 443

split = urllib.parse.urlsplit('://'.join([scheme, hostport]))
split = urllib.parse.urlsplit(f'{scheme}://{hostport}')
host = split.hostname or '127.0.0.1'
port = split.port or port

Expand Down
17 changes: 8 additions & 9 deletions ollama/_types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import json
from base64 import b64decode, b64encode
from datetime import datetime
Expand Down Expand Up @@ -78,8 +79,8 @@ def __contains__(self, key: str) -> bool:
if key in self.model_fields_set:
return True

if key in self.model_fields:
return self.model_fields[key].default is not None
if value := self.model_fields.get(key):
return value.default is not None

return False

Expand All @@ -97,7 +98,7 @@ def get(self, key: str, default: Any = None) -> Any:
>>> msg.get('tool_calls')[0]['function']['name']
'foo'
"""
return self[key] if key in self else default
return getattr(self, key) if hasattr(self, key) else default
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linter suggests self.get(k, default) but that'll call itself



class Options(SubscriptableBaseModel):
Expand Down Expand Up @@ -332,7 +333,7 @@ class ChatRequest(BaseGenerateRequest):
@model_serializer(mode='wrap')
def serialize_model(self, nxt):
output = nxt(self)
if 'tools' in output and output['tools']:
if output.get('tools'):
for tool in output['tools']:
if 'function' in tool and 'parameters' in tool['function'] and 'defs' in tool['function']['parameters']:
tool['function']['parameters']['$defs'] = tool['function']['parameters'].pop('defs')
Expand Down Expand Up @@ -536,12 +537,10 @@ class ResponseError(Exception):
"""

def __init__(self, error: str, status_code: int = -1):
try:
# try to parse content as JSON and extract 'error'
# fallback to raw content if JSON parsing fails
# try to parse content as JSON and extract 'error'
# fallback to raw content if JSON parsing fails
with contextlib.suppress(json.JSONDecodeError):
error = json.loads(error).get('error', error)
except json.JSONDecodeError:
...

super().__init__(error)
self.error = error
Expand Down
6 changes: 3 additions & 3 deletions ollama/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
if not doc_string:
return parsed_docstring

key = hash(doc_string)
key = str(hash(doc_string))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy complains about a type mismatch for key since hash(...) returns an int but may be set to a 'args', i.e. string, later

for line in doc_string.splitlines():
lowered_line = line.lower().strip()
if lowered_line.startswith('args:'):
key = 'args'
elif lowered_line.startswith('returns:') or lowered_line.startswith('yields:') or lowered_line.startswith('raises:'):
elif lowered_line.startswith(('returns:', 'yields:', 'raises:')):
key = '_'

else:
Expand Down Expand Up @@ -54,7 +54,7 @@ def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:


def convert_function_to_tool(func: Callable) -> Tool:
doc_string_hash = hash(inspect.getdoc(func))
doc_string_hash = str(hash(inspect.getdoc(func)))
parsed_docstring = _parse_docstring(inspect.getdoc(func))
schema = type(
func.__name__,
Expand Down
Loading
Loading