diff --git a/python_modules/libraries/dagster-openai/dagster_openai/resources.py b/python_modules/libraries/dagster-openai/dagster_openai/resources.py index 63939f86d8421..8e6c2e883f174 100644 --- a/python_modules/libraries/dagster-openai/dagster_openai/resources.py +++ b/python_modules/libraries/dagster-openai/dagster_openai/resources.py @@ -180,6 +180,9 @@ def openai_asset(context: AssetExecutionContext, openai: OpenAIResource): """ api_key: str = Field(description=("OpenAI API key. See https://platform.openai.com/api-keys")) + organization: Optional[str] = Field(default=None) + project: Optional[str] = Field(default=None) + base_url: Optional[str] = Field(default=None) _client: Client = PrivateAttr() @@ -212,7 +215,12 @@ def _wrap_with_usage_metadata( def setup_for_execution(self, context: InitResourceContext) -> None: # Set up an OpenAI client based on the API key. - self._client = Client(api_key=self.api_key) + self._client = Client( + api_key=self.api_key, + organization=self.organization, + project=self.project, + base_url=self.base_url, + ) @public @contextmanager diff --git a/python_modules/libraries/dagster-openai/dagster_openai_tests/test_resources.py b/python_modules/libraries/dagster-openai/dagster_openai_tests/test_resources.py index 4f96599d4d3d0..13bf90b25f08f 100644 --- a/python_modules/libraries/dagster-openai/dagster_openai_tests/test_resources.py +++ b/python_modules/libraries/dagster-openai/dagster_openai_tests/test_resources.py @@ -28,7 +28,32 @@ def test_openai_client(mock_client) -> None: mock_context = MagicMock() with openai_resource.get_client(mock_context): - mock_client.assert_called_once_with(api_key="xoxp-1234123412341234-12341234-1234") + mock_client.assert_called_once_with( + api_key="xoxp-1234123412341234-12341234-1234", + organization=None, + project=None, + base_url=None, + ) + + +@patch("dagster_openai.resources.Client") +def test_openai_client_with_config(mock_client) -> None: + openai_resource = OpenAIResource( + api_key="xoxp-1234123412341234-12341234-1234", + organization="foo", + project="bar", + base_url="http://foo.bar", + ) + openai_resource.setup_for_execution(build_init_resource_context()) + + mock_context = MagicMock() + with openai_resource.get_client(mock_context): + mock_client.assert_called_once_with( + api_key="xoxp-1234123412341234-12341234-1234", + organization="foo", + project="bar", + base_url="http://foo.bar", + ) @patch("dagster_openai.resources.OpenAIResource._wrap_with_usage_metadata")