Skip to content

Commit

Permalink
groq-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
igorbenav committed Nov 9, 2024
1 parent 04d449f commit 717f278
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 167 deletions.
37 changes: 25 additions & 12 deletions clientai/groq/_typing.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,35 @@
from __future__ import annotations

from collections.abc import Iterator
from dataclasses import dataclass
from typing import (
Any,
List,
Literal,
Optional,
Protocol,
TypedDict,
Union,
)

from .._common_types import GenericResponse


class Message(TypedDict):
@dataclass
class Message:
role: Literal["system", "user", "assistant"]
content: str


class GroqChoice(TypedDict):
@dataclass
class GroqChoice:
index: int
message: Message
logprobs: Optional[Any]
finish_reason: Optional[str]


class GroqUsage(TypedDict):
@dataclass
class GroqUsage:
queue_time: float
prompt_tokens: int
prompt_time: float
Expand All @@ -36,11 +39,13 @@ class GroqUsage(TypedDict):
total_time: float


class GroqMetadata(TypedDict):
@dataclass
class GroqMetadata:
id: str


class GroqResponse(TypedDict):
@dataclass
class GroqResponse:
id: str
object: str
created: int
Expand All @@ -51,18 +56,21 @@ class GroqResponse(TypedDict):
x_groq: GroqMetadata


class GroqStreamDelta(TypedDict):
role: Optional[Literal["system", "user", "assistant"]]
content: Optional[str]
@dataclass
class GroqStreamDelta:
role: Optional[Literal["system", "user", "assistant"]] = None
content: Optional[str] = None


class GroqStreamChoice(TypedDict):
@dataclass
class GroqStreamChoice:
index: int
delta: GroqStreamDelta
finish_reason: Optional[str]


class GroqStreamResponse(TypedDict):
@dataclass
class GroqStreamResponse:
id: str
object: str
created: int
Expand All @@ -74,7 +82,12 @@ class GroqStreamResponse(TypedDict):

class GroqChatCompletionProtocol(Protocol):
def create(
self, **kwargs: Any
self,
*,
messages: List[dict[str, str]],
model: str,
stream: bool = False,
**kwargs: Any,
) -> Union[GroqResponse, Iterator[GroqStreamResponse]]: ...


Expand Down
10 changes: 5 additions & 5 deletions clientai/groq/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _stream_response(
if return_full_response:
yield chunk
else:
content = chunk["choices"][0]["delta"].get("content")
content = chunk.choices[0].delta.content
if content:
yield content

Expand All @@ -113,7 +113,7 @@ def _map_exception_to_clientai_error(self, e: Exception) -> ClientAIError:
"""
error_message = str(e)

if isinstance(e, (GroqAuthenticationError, PermissionDeniedError)): # noqa: UP038
if isinstance(e, (GroqAuthenticationError, PermissionDeniedError)): # noqa: UP038
return AuthenticationError(
error_message,
status_code=getattr(e, "status_code", 401),
Expand All @@ -125,7 +125,7 @@ def _map_exception_to_clientai_error(self, e: Exception) -> ClientAIError:
)
elif isinstance(e, NotFoundError):
return ModelError(error_message, status_code=404, original_error=e)
elif isinstance( # noqa: UP038
elif isinstance( # noqa: UP038
e, (BadRequestError, UnprocessableEntityError, ConflictError)
):
return InvalidRequestError(
Expand Down Expand Up @@ -219,7 +219,7 @@ def generate_text(
if return_full_response:
return response
else:
return response["choices"][0]["message"]["content"]
return response.choices[0].message.content

except Exception as e:
raise self._map_exception_to_clientai_error(e)
Expand Down Expand Up @@ -284,7 +284,7 @@ def chat(
if return_full_response:
return response
else:
return response["choices"][0]["message"]["content"]
return response.choices[0].message.content

except Exception as e:
raise self._map_exception_to_clientai_error(e)
Loading

0 comments on commit 717f278

Please sign in to comment.