Skip to content

Commit

Permalink
Add Configuration to OpenAIResource (#23260)
Browse files Browse the repository at this point in the history
## Summary & Motivation

When creating an `OpenAIResource`, we might want to pass these
additional arguments to the OpenAI `Client`.

## How I Tested These Changes

I added a test case that validates that these new arguments are passed
from the `OpenAIResource` to the `Client`.
  • Loading branch information
chasleslr authored and clairelin135 committed Aug 13, 2024
1 parent 0056d04 commit 50ac858
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def openai_asset(context: AssetExecutionContext, openai: OpenAIResource):
"""

api_key: str = Field(description=("OpenAI API key. See https://platform.openai.com/api-keys"))
organization: Optional[str] = Field(default=None)
project: Optional[str] = Field(default=None)
base_url: Optional[str] = Field(default=None)

_client: Client = PrivateAttr()

Expand Down Expand Up @@ -212,7 +215,12 @@ def _wrap_with_usage_metadata(

def setup_for_execution(self, context: InitResourceContext) -> None:
# Set up an OpenAI client based on the API key.
self._client = Client(api_key=self.api_key)
self._client = Client(
api_key=self.api_key,
organization=self.organization,
project=self.project,
base_url=self.base_url,
)

@public
@contextmanager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,32 @@ def test_openai_client(mock_client) -> None:

mock_context = MagicMock()
with openai_resource.get_client(mock_context):
mock_client.assert_called_once_with(api_key="xoxp-1234123412341234-12341234-1234")
mock_client.assert_called_once_with(
api_key="xoxp-1234123412341234-12341234-1234",
organization=None,
project=None,
base_url=None,
)


@patch("dagster_openai.resources.Client")
def test_openai_client_with_config(mock_client) -> None:
openai_resource = OpenAIResource(
api_key="xoxp-1234123412341234-12341234-1234",
organization="foo",
project="bar",
base_url="http://foo.bar",
)
openai_resource.setup_for_execution(build_init_resource_context())

mock_context = MagicMock()
with openai_resource.get_client(mock_context):
mock_client.assert_called_once_with(
api_key="xoxp-1234123412341234-12341234-1234",
organization="foo",
project="bar",
base_url="http://foo.bar",
)


@patch("dagster_openai.resources.OpenAIResource._wrap_with_usage_metadata")
Expand Down

0 comments on commit 50ac858

Please sign in to comment.