Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
jxnl committed May 23, 2024
1 parent b46654c commit cc5444a
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 31 deletions.
10 changes: 5 additions & 5 deletions instructor/cli/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ def calculate_cost(

def group_and_sum_by_date_and_snapshot(usage_data: list[dict[str, Any]]) -> Table:
"""Group and sum the usage data by date and snapshot, including costs."""
summary: defaultdict[str, defaultdict[str, dict[str, Union[int, float]]]] = (
defaultdict(
lambda: defaultdict(
lambda: {"total_requests": 0, "total_tokens": 0, "total_cost": 0.0}
)
summary: defaultdict[
str, defaultdict[str, dict[str, Union[int, float]]]
] = defaultdict(
lambda: defaultdict(
lambda: {"total_requests": 0, "total_tokens": 0, "total_cost": 0.0}
)
)

Expand Down
27 changes: 18 additions & 9 deletions instructor/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def create(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> Awaitable[T]: ...
) -> Awaitable[T]:
...

@overload
def create(
Expand All @@ -74,7 +75,8 @@ def create(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> T: ...
) -> T:
...

# TODO: we should overload a case where response_model is None
def create(
Expand Down Expand Up @@ -106,7 +108,8 @@ def create_partial(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> AsyncGenerator[T, None]: ...
) -> AsyncGenerator[T, None]:
...

@overload
def create_partial(
Expand All @@ -117,7 +120,8 @@ def create_partial(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> Generator[T, None, None]: ...
) -> Generator[T, None, None]:
...

def create_partial(
self,
Expand Down Expand Up @@ -151,7 +155,8 @@ def create_iterable(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> AsyncGenerator[T, None]: ...
) -> AsyncGenerator[T, None]:
...

@overload
def create_iterable(
Expand All @@ -162,7 +167,8 @@ def create_iterable(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> Generator[T, None, None]: ...
) -> Generator[T, None, None]:
...

def create_iterable(
self,
Expand Down Expand Up @@ -197,7 +203,8 @@ def create_with_completion(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> Awaitable[tuple[T, Any]]: ...
) -> Awaitable[tuple[T, Any]]:
...

@overload
def create_with_completion(
Expand All @@ -208,7 +215,8 @@ def create_with_completion(
validation_context: dict[str, Any] | None = None,
strict: bool = True,
**kwargs: Any,
) -> tuple[T, Any]: ...
) -> tuple[T, Any]:
...

def create_with_completion(
self,
Expand Down Expand Up @@ -424,7 +432,8 @@ def from_litellm(
completion: Callable[..., Any],
mode: instructor.Mode = instructor.Mode.TOOLS,
**kwargs: Any,
) -> Instructor: ...
) -> Instructor:
...


@overload
Expand Down
6 changes: 4 additions & 2 deletions instructor/client_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def from_anthropic(
),
mode: instructor.Mode = instructor.Mode.ANTHROPIC_JSON,
**kwargs: Any,
) -> instructor.Instructor: ...
) -> instructor.Instructor:
...


@overload
Expand All @@ -25,7 +26,8 @@ def from_anthropic(
),
mode: instructor.Mode = instructor.Mode.ANTHROPIC_JSON,
**kwargs: Any,
) -> instructor.AsyncInstructor: ...
) -> instructor.AsyncInstructor:
...


def from_anthropic(
Expand Down
6 changes: 4 additions & 2 deletions instructor/client_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ def from_cohere(
client: cohere.Client,
mode: instructor.Mode = instructor.Mode.COHERE_TOOLS,
**kwargs: Any,
) -> instructor.Instructor: ...
) -> instructor.Instructor:
...


@overload
def from_cohere(
client: cohere.AsyncClient,
mode: instructor.Mode = instructor.Mode.COHERE_TOOLS,
**kwargs: Any,
) -> instructor.AsyncInstructor: ...
) -> instructor.AsyncInstructor:
...


def from_cohere(
Expand Down
6 changes: 4 additions & 2 deletions instructor/client_groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ def from_groq(
client: groq.Groq,
mode: instructor.Mode = instructor.Mode.TOOLS,
**kwargs: Any,
) -> instructor.Instructor: ...
) -> instructor.Instructor:
...


@overload
def from_groq(
client: groq.AsyncGroq,
mode: instructor.Mode = instructor.Mode.TOOLS,
**kwargs: Any,
) -> instructor.Instructor: ...
) -> instructor.Instructor:
...


def from_groq(
Expand Down
6 changes: 4 additions & 2 deletions instructor/client_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@ def from_mistral(
client: mistralai.client.MistralClient,
mode: instructor.Mode = instructor.Mode.MISTRAL_TOOLS,
**kwargs: Any,
) -> instructor.Instructor: ...
) -> instructor.Instructor:
...


@overload
def from_mistral(
client: mistralaiasynccli.MistralAsyncClient,
mode: instructor.Mode = instructor.Mode.MISTRAL_TOOLS,
**kwargs: Any,
) -> instructor.AsyncInstructor: ...
) -> instructor.AsyncInstructor:
...


def from_mistral(
Expand Down
18 changes: 12 additions & 6 deletions instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __call__(
max_retries: int = 1,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model: ...
) -> T_Model:
...


class AsyncInstructorChatCompletionCreate(Protocol):
Expand All @@ -46,35 +47,40 @@ async def __call__(
max_retries: int = 1,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model: ...
) -> T_Model:
...


@overload
def patch(
client: OpenAI,
mode: Mode = Mode.TOOLS,
) -> OpenAI: ...
) -> OpenAI:
...


@overload
def patch(
client: AsyncOpenAI,
mode: Mode = Mode.TOOLS,
) -> AsyncOpenAI: ...
) -> AsyncOpenAI:
...


@overload
def patch(
create: Callable[T_ParamSpec, T_Retval],
mode: Mode = Mode.TOOLS,
) -> InstructorChatCompletionCreate: ...
) -> InstructorChatCompletionCreate:
...


@overload
def patch(
create: Awaitable[T_Retval],
mode: Mode = Mode.TOOLS,
) -> InstructorChatCompletionCreate: ...
) -> InstructorChatCompletionCreate:
...


def patch(
Expand Down
4 changes: 1 addition & 3 deletions instructor/process_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,7 @@ def handle_response_model(
+ "\n\n".join(openai_system_messages)
)

new_kwargs[
"system"
] += f"""
new_kwargs["system"] += f"""
You must only response in JSON format that adheres to the following schema:
<JSON_SCHEMA>
Expand Down

0 comments on commit cc5444a

Please sign in to comment.