Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nested generics with partial #1207

Merged
merged 3 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

from jiter import from_json
from pydantic import BaseModel, create_model
from typing import Union
import types
import sys
from pydantic.fields import FieldInfo
from typing import (
Any,
Expand All @@ -29,6 +32,12 @@

T_Model = TypeVar("T_Model", bound=BaseModel)

if sys.version_info >= (3, 10):
# types.UnionType is only available in Python 3.10 and above
UNION_ORIGINS = (Union, types.UnionType)
else:
UNION_ORIGINS = (Union,)


class MakeFieldsOptional:
pass
Expand All @@ -38,6 +47,37 @@
pass


def _process_generic_arg(
arg: Any,
make_fields_optional: bool = False,
) -> Any:
arg_origin = get_origin(arg)
if arg_origin is not None:
# Handle any nested generic type (Union, List, Dict, etc.)
nested_args = get_args(arg)
modified_nested_args = tuple(
_process_generic_arg(
t,
make_fields_optional=make_fields_optional,
)
for t in nested_args
)
# Special handling for Union types (types.UnionType isn't subscriptable)
if arg_origin in UNION_ORIGINS:
return Union[modified_nested_args] # type: ignore

return arg_origin[modified_nested_args]
else:
if isinstance(arg, type) and issubclass(arg, BaseModel):
return (
Partial[arg, MakeFieldsOptional] # type: ignore[valid-type]
if make_fields_optional
else Partial[arg]
)
else:
return arg


def _make_field_optional(
field: FieldInfo,
) -> tuple[Any, FieldInfo]:
Expand All @@ -51,14 +91,8 @@
generic_base = get_origin(annotation)
generic_args = get_args(annotation)

# Recursively apply Partial to each of the generic arguments
modified_args = tuple(
(
Partial[arg, MakeFieldsOptional] # type: ignore[valid-type]
if isinstance(arg, type) and issubclass(arg, BaseModel)
else arg
)
for arg in generic_args
_process_generic_arg(arg, make_fields_optional=True) for arg in generic_args
)

# Reconstruct the generic type with modified arguments
Expand All @@ -72,10 +106,10 @@
tmp_field.annotation = Optional[Partial[annotation, MakeFieldsOptional]] # type: ignore[assignment, valid-type]
tmp_field.default = {}
else:
tmp_field.annotation = Optional[field.annotation] # type: ignore[assignment]
tmp_field.annotation = Optional[field.annotation]

Check failure on line 109 in instructor/dsl/partial.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Expected class but received "type[Any] | None"   "None" is not a class (reportGeneralTypeIssues)
tmp_field.default = None

return tmp_field.annotation, tmp_field # type: ignore
return tmp_field.annotation, tmp_field

Check failure on line 112 in instructor/dsl/partial.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Type of "annotation" is partially unknown   Type of "annotation" is "type[Any] | type[None] | type[Partial[BaseModel]] | Unknown | None" (reportUnknownMemberType)

Check failure on line 112 in instructor/dsl/partial.py

View workflow job for this annotation

GitHub Actions / Pyright (ubuntu-latest, 3.11)

Return type, "tuple[type[Any] | type[None] | type[Partial[BaseModel]] | Unknown | None, FieldInfo]", is partially unknown (reportUnknownVariableType)


class PartialBase(Generic[T_Model]):
Expand Down Expand Up @@ -360,15 +394,7 @@
generic_base = get_origin(annotation)
generic_args = get_args(annotation)

# Recursively apply Partial to each of the generic arguments
modified_args = tuple(
(
Partial[arg]
if isinstance(arg, type) and issubclass(arg, BaseModel)
else arg
)
for arg in generic_args
)
modified_args = tuple(_process_generic_arg(arg) for arg in generic_args)

# Reconstruct the generic type with modified arguments
tmp_field.annotation = (
Expand Down
26 changes: 26 additions & 0 deletions tests/dsl/test_partial.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# type: ignore[all]
from pydantic import BaseModel, Field
from typing import Optional, Union
from instructor.dsl.partial import Partial, PartialLiteralMixin
import pytest
import instructor
Expand All @@ -20,6 +21,24 @@ class SamplePartial(BaseModel):
b: SampleNestedPartial


class NestedA(BaseModel):
a: str
b: Optional[str]


class NestedB(BaseModel):
c: str
d: str
e: list[Union[str, int]]
f: str


class UnionWithNested(BaseModel):
a: list[Union[NestedA, NestedB]]
b: list[NestedA]
c: NestedB


def test_partial():
partial = Partial[SamplePartial]
assert partial.model_json_schema() == {
Expand Down Expand Up @@ -166,3 +185,10 @@ class Summary(BaseModel, PartialLiteralMixin):
previous_summary = extraction.summary

assert updates == 1


def test_union_with_nested():
partial = Partial[UnionWithNested]
partial.get_partial_model().model_validate_json(
'{"a": [{"b": "b"}, {"d": "d"}], "b": [{"b": "b"}], "c": {"d": "d"}, "e": [1, "a"]}'
)
Loading