-
Notifications
You must be signed in to change notification settings - Fork 679
use hatch instead of poetry #445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
overload, | ||
) | ||
|
||
import anyio | ||
from pydantic.json_schema import JsonSchemaValue | ||
|
||
from ollama._utils import convert_function_to_tool | ||
|
@@ -75,6 +76,7 @@ def __init__( | |
self, | ||
client, | ||
host: Optional[str] = None, | ||
*, | ||
follow_redirects: bool = True, | ||
timeout: Any = None, | ||
headers: Optional[Mapping[str, str]] = None, | ||
|
@@ -253,7 +255,7 @@ def generate( | |
stream=stream, | ||
raw=raw, | ||
format=format, | ||
images=[image for image in _copy_images(images)] if images else None, | ||
images=list(_copy_images(images)) if images else None, | ||
options=options, | ||
keep_alive=keep_alive, | ||
).model_dump(exclude_none=True), | ||
|
@@ -336,8 +338,8 @@ def add_two_numbers(a: int, b: int) -> int: | |
'/api/chat', | ||
json=ChatRequest( | ||
model=model, | ||
messages=[message for message in _copy_messages(messages)], | ||
tools=[tool for tool in _copy_tools(tools)], | ||
messages=list(_copy_messages(messages)), | ||
tools=list(_copy_tools(tools)), | ||
stream=stream, | ||
format=format, | ||
options=options, | ||
|
@@ -756,7 +758,7 @@ async def generate( | |
stream=stream, | ||
raw=raw, | ||
format=format, | ||
images=[image for image in _copy_images(images)] if images else None, | ||
images=list(_copy_images(images)) if images else None, | ||
options=options, | ||
keep_alive=keep_alive, | ||
).model_dump(exclude_none=True), | ||
|
@@ -840,8 +842,8 @@ def add_two_numbers(a: int, b: int) -> int: | |
'/api/chat', | ||
json=ChatRequest( | ||
model=model, | ||
messages=[message for message in _copy_messages(messages)], | ||
tools=[tool for tool in _copy_tools(tools)], | ||
messages=list(_copy_messages(messages)), | ||
tools=list(_copy_tools(tools)), | ||
stream=stream, | ||
format=format, | ||
options=options, | ||
|
@@ -991,7 +993,7 @@ async def create( | |
parameters: Optional[Union[Mapping[str, Any], Options]] = None, | ||
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, | ||
*, | ||
stream: Literal[True] = True, | ||
stream: Literal[False] = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This overload was incorrect since stream=True is also overloaded below |
||
) -> ProgressResponse: ... | ||
|
||
@overload | ||
|
@@ -1054,19 +1056,19 @@ async def create( | |
|
||
async def create_blob(self, path: Union[str, Path]) -> str: | ||
sha256sum = sha256() | ||
with open(path, 'rb') as r: | ||
async with await anyio.open_file(path, 'rb') as r: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should not use a blocking open in an async function |
||
while True: | ||
chunk = r.read(32 * 1024) | ||
chunk = await r.read(32 * 1024) | ||
if not chunk: | ||
break | ||
sha256sum.update(chunk) | ||
|
||
digest = f'sha256:{sha256sum.hexdigest()}' | ||
|
||
async def upload_bytes(): | ||
with open(path, 'rb') as r: | ||
async with await anyio.open_file(path, 'rb') as r: | ||
while True: | ||
chunk = r.read(32 * 1024) | ||
chunk = await r.read(32 * 1024) | ||
if not chunk: | ||
break | ||
yield chunk | ||
|
@@ -1133,7 +1135,7 @@ def _copy_images(images: Optional[Sequence[Union[Image, Any]]]) -> Iterator[Imag | |
def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]]]) -> Iterator[Message]: | ||
for message in messages or []: | ||
yield Message.model_validate( | ||
{k: [image for image in _copy_images(v)] if k == 'images' else v for k, v in dict(message).items() if v}, | ||
{k: list(_copy_images(v)) if k == 'images' else v for k, v in dict(message).items() if v}, | ||
) | ||
|
||
|
||
|
@@ -1143,7 +1145,7 @@ def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable | |
|
||
|
||
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]: | ||
if isinstance(s, str) or isinstance(s, Path): | ||
if isinstance(s, (str, Path)): | ||
try: | ||
if (p := Path(s)).exists(): | ||
return p | ||
|
@@ -1225,7 +1227,7 @@ def _parse_host(host: Optional[str]) -> str: | |
elif scheme == 'https': | ||
port = 443 | ||
|
||
split = urllib.parse.urlsplit('://'.join([scheme, hostport])) | ||
split = urllib.parse.urlsplit(f'{scheme}://{hostport}') | ||
host = split.hostname or '127.0.0.1' | ||
port = split.port or port | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import contextlib | ||
import json | ||
from base64 import b64decode, b64encode | ||
from datetime import datetime | ||
|
@@ -78,8 +79,8 @@ def __contains__(self, key: str) -> bool: | |
if key in self.model_fields_set: | ||
return True | ||
|
||
if key in self.model_fields: | ||
return self.model_fields[key].default is not None | ||
if value := self.model_fields.get(key): | ||
return value.default is not None | ||
|
||
return False | ||
|
||
|
@@ -97,7 +98,7 @@ def get(self, key: str, default: Any = None) -> Any: | |
>>> msg.get('tool_calls')[0]['function']['name'] | ||
'foo' | ||
""" | ||
return self[key] if key in self else default | ||
return getattr(self, key) if hasattr(self, key) else default | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. linter suggests |
||
|
||
|
||
class Options(SubscriptableBaseModel): | ||
|
@@ -332,7 +333,7 @@ class ChatRequest(BaseGenerateRequest): | |
@model_serializer(mode='wrap') | ||
def serialize_model(self, nxt): | ||
output = nxt(self) | ||
if 'tools' in output and output['tools']: | ||
if output.get('tools'): | ||
for tool in output['tools']: | ||
if 'function' in tool and 'parameters' in tool['function'] and 'defs' in tool['function']['parameters']: | ||
tool['function']['parameters']['$defs'] = tool['function']['parameters'].pop('defs') | ||
|
@@ -536,12 +537,10 @@ class ResponseError(Exception): | |
""" | ||
|
||
def __init__(self, error: str, status_code: int = -1): | ||
try: | ||
# try to parse content as JSON and extract 'error' | ||
# fallback to raw content if JSON parsing fails | ||
# try to parse content as JSON and extract 'error' | ||
# fallback to raw content if JSON parsing fails | ||
with contextlib.suppress(json.JSONDecodeError): | ||
error = json.loads(error).get('error', error) | ||
except json.JSONDecodeError: | ||
... | ||
|
||
super().__init__(error) | ||
self.error = error | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,12 +15,12 @@ def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]: | |
if not doc_string: | ||
return parsed_docstring | ||
|
||
key = hash(doc_string) | ||
key = str(hash(doc_string)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
for line in doc_string.splitlines(): | ||
lowered_line = line.lower().strip() | ||
if lowered_line.startswith('args:'): | ||
key = 'args' | ||
elif lowered_line.startswith('returns:') or lowered_line.startswith('yields:') or lowered_line.startswith('raises:'): | ||
elif lowered_line.startswith(('returns:', 'yields:', 'raises:')): | ||
key = '_' | ||
|
||
else: | ||
|
@@ -54,7 +54,7 @@ def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]: | |
|
||
|
||
def convert_function_to_tool(func: Callable) -> Tool: | ||
doc_string_hash = hash(inspect.getdoc(func)) | ||
doc_string_hash = str(hash(inspect.getdoc(func))) | ||
parsed_docstring = _parse_docstring(inspect.getdoc(func)) | ||
schema = type( | ||
func.__name__, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason for removing the matrix tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is replaced by
hatch test --all
which tests all configured python versions. I'm still debating whether this is better than matrix testsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we can setup a matrix if needed in the
pyproject.toml
if needed: https://hatch.pypa.io/latest/tutorials/testing/overview/#all-environmentsI think it makes sense to just define the versions we want tested in the toml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's already implemented and runs with
--all
option tohatch test
however it is slightly slower since the Python versions are sequential rather than parallel, ~1m vs. ~35sThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would parallelizing the tests themselves help? https://hatch.pypa.io/1.13/tutorials/testing/overview/#parallelize-test-execution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No that option, which is also set, parallelizes tests in a pytest run, not between python versions