From 7c192351c007c6d15b08cbf451fea43337717b80 Mon Sep 17 00:00:00 2001 From: AlaeddineAbdessalem Date: Wed, 3 Jul 2024 14:37:55 +0200 Subject: [PATCH] Add support for Azure authentication using Azure AD Tokens (#165) * support Azure AD token * support Azure AD token * support Azure AD token * support Azure AD token * support Azure AD token * linting * apply suggestion * add copyrights * fix client * increase version * fix mypy * implement workaround to make azure_endpoint a positional argument * linting * Update README.md --------- Co-authored-by: Alaeddine Abdessalem Co-authored-by: Philip May --- LICENSE | 1 + README.md | 3 ++- mltb2/openai.py | 58 ++++++++++++++++++++++++++++++++++++++++--------- pyproject.toml | 2 +- 4 files changed, 52 insertions(+), 12 deletions(-) diff --git a/LICENSE b/LICENSE index 5d026d0..5229464 100644 --- a/LICENSE +++ b/LICENSE @@ -2,6 +2,7 @@ MIT License Copyright (c) 2023-2024 Philip May Copyright (c) 2023-2024 Philip May, Deutsche Telekom AG +Copyright (c) 2023-2024 Alaeddine Abdessalem, Deutsche Telekom AG Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index d451492..54a0692 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,8 @@ To install those module specific dependencies see ## Licensing Copyright (c) 2023-2024 [Philip May](https://philipmay.org)\ -Copyright (c) 2023-2024 [Philip May](https://philipmay.org), [Deutsche Telekom AG](https://www.telekom.de/) +Copyright (c) 2023-2024 [Philip May](https://philipmay.org), [Deutsche Telekom AG](https://www.telekom.de/)\ +Copyright (c) 2023-2024 Alaeddine Abdessalem, [Deutsche Telekom AG](https://www.telekom.de/) Licensed under the **MIT License** (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License by reviewing the file diff --git a/mltb2/openai.py b/mltb2/openai.py index 7cc2a00..6453daa 100644 --- a/mltb2/openai.py +++ b/mltb2/openai.py @@ -1,5 +1,6 @@ # Copyright (c) 2023-2024 Philip May # Copyright (c) 2024 Philip May, Deutsche Telekom AG +# Copyright (c) 2024 Alaeddine Abdessalem, Deutsche Telekom AG # This software is distributed under the terms of the MIT license # which is available at https://opensource.org/licenses/MIT @@ -171,10 +172,10 @@ class OpenAiChat: model: The OpenAI model name. """ - api_key: str model: str client: Union[OpenAI, AzureOpenAI] = field(init=False, repr=False) async_client: Union[AsyncOpenAI, AsyncAzureOpenAI] = field(init=False, repr=False) + api_key: Optional[str] = None def __post_init__(self) -> None: """Do post init.""" @@ -182,7 +183,7 @@ def __post_init__(self) -> None: self.async_client = AsyncOpenAI(api_key=self.api_key) @classmethod - def from_yaml(cls, yaml_file): + def from_yaml(cls, yaml_file, api_key: Optional[str] = None, **kwargs): """Construct this class from a yaml file. If the ``api_key`` is not set in the yaml file, @@ -190,18 +191,21 @@ def from_yaml(cls, yaml_file): Args: yaml_file: The yaml file. + api_key: The OpenAI API key. + kwargs: extra kwargs to override parameters Returns: The constructed class. """ with open(yaml_file, "r") as file: completion_kwargs = yaml.safe_load(file) - # load api_key from environment variable if it is not set in the yaml file - if "api_key" not in completion_kwargs: - api_key = os.getenv("OPENAI_API_KEY") - if api_key is not None: - completion_kwargs["api_key"] = api_key + # set api_key according to this priority: + # method parameter > yaml > environment variable + api_key = api_key or completion_kwargs.get("api_key") or os.getenv("OPENAI_API_KEY") + completion_kwargs["api_key"] = api_key + if kwargs: + completion_kwargs.update(kwargs) return cls(**completion_kwargs) def create_completions( @@ -323,8 +327,16 @@ async def create_completions_async( return result +# there is a limitation with python dataclasses when it comes to defining a subclass with positional arguments, while +# the parent class already defines keyword arguemnts (positional arguments cannot follow keyword arguments) +# workaroung is defined here: https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses @dataclass -class OpenAiAzureChat(OpenAiChat): +class _OpenAiAzureChatBase: + azure_endpoint: str + + +@dataclass +class OpenAiAzureChat(OpenAiChat, _OpenAiAzureChatBase): """Tool to interact with Azure OpenAI chat models. This can also be constructed with :meth:`~OpenAiChat.from_yaml`. @@ -341,8 +353,9 @@ class OpenAiAzureChat(OpenAiChat): azure_endpoint: The Azure endpoint. """ - api_version: str - azure_endpoint: str + api_version: Optional[str] = None + api_key: Optional[str] = None + azure_ad_token: Optional[str] = None def __post_init__(self) -> None: """Do post init.""" @@ -350,9 +363,34 @@ def __post_init__(self) -> None: api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, + azure_ad_token=self.azure_ad_token, ) self.async_client = AsyncAzureOpenAI( api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, + azure_ad_token=self.azure_ad_token, ) + + @classmethod + def from_yaml(cls, yaml_file, api_key: Optional[str] = None, azure_ad_token: Optional[str] = None, **kwargs): + """Construct this class from a yaml file. + + If the ``api_key`` is not set in the yaml file, + it will be loaded from the environment variable ``OPENAI_API_KEY``. + + Args: + yaml_file: The yaml file. + api_key: The OpenAI API key. + azure_ad_token: Azure AD token + kwargs: extra kwargs to override parameters + Returns: + The constructed class. + """ + with open(yaml_file, "r") as file: + completion_kwargs = yaml.safe_load(file) + + # set azure_ad_token according to this priority: + # method parameter > yaml > environment variable + azure_ad_token = azure_ad_token or completion_kwargs.get("AZURE_AD_TOKEN") or os.getenv("AZURE_AD_TOKEN") + return super().from_yaml(yaml_file, api_key=api_key, azure_ad_token=azure_ad_token, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index d5bddab..0232edb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mltb2" -version = "1.0.1rc2" +version = "1.0.1rc3" description = "Machine Learning Toolbox 2" authors = ["PhilipMay "] readme = "README.md"