Skip to content

Commit

Permalink
Python: support Azure AD auth (microsoft#340)
Browse files Browse the repository at this point in the history
AAD tokens offer greater authentication security and is used by several products.

Add support for Azure Active Directory auth for the `Azure*` backends.
  • Loading branch information
Jordan Henkel authored and dluc committed Apr 13, 2023
1 parent 35cfdcf commit 7d9c40f
Show file tree
Hide file tree
Showing 10 changed files with 482 additions and 12 deletions.
1 change: 1 addition & 0 deletions python/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ OPENAI_API_KEY=""
OPENAI_ORG_ID=""
AZURE_OPENAI_API_KEY=""
AZURE_OPENAI_ENDPOINT=""
AZURE_OPENAI_DEPLOYMENT_NAME=""
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,40 @@
class AzureChatCompletion(OpenAIChatCompletion):
_endpoint: str
_api_version: str
_api_type: str

def __init__(
self,
deployment_name: str,
endpoint: str,
api_key: str,
endpoint: Optional[str] = None,
api_key: Optional[str] = None,
api_version: str = "2023-03-15-preview",
logger: Optional[Logger] = None,
ad_auth=False,
) -> None:
"""
Initialize an AzureChatCompletion backend.
You must provide:
- A deployment_name, endpoint, and api_key (plus, optionally: ad_auth)
:param deployment_name: The name of the Azure deployment. This value
will correspond to the custom name you chose for your deployment
when you deployed a model. This value can be found under
Resource Management > Deployments in the Azure portal or, alternatively,
under Management > Deployments in Azure OpenAI Studio.
:param endpoint: The endpoint of the Azure deployment. This value
can be found in the Keys & Endpoint section when examining
your resource from the Azure portal.
:param api_key: The API key for the Azure deployment. This value can be
found in the Keys & Endpoint section when examining your resource in
the Azure portal. You can use either KEY1 or KEY2.
:param api_version: The API version to use. (Optional)
The default value is "2022-12-01".
:param logger: The logger instance to use. (Optional)
:param ad_auth: Whether to use Azure Active Directory authentication.
(Optional) The default value is False.
"""
if not deployment_name:
raise ValueError("The deployment name cannot be `None` or empty")
if not api_key:
Expand All @@ -32,13 +57,14 @@ def __init__(

self._endpoint = endpoint
self._api_version = api_version
self._api_type = "azure_ad" if ad_auth else "azure"

super().__init__(deployment_name, api_key, org_id=None, log=logger)

def _setup_open_ai(self) -> Any:
import openai

openai.api_type = "azure"
openai.api_type = self._api_type
openai.api_key = self._api_key
openai.api_base = self._endpoint
openai.api_version = self._api_version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,40 @@
class AzureTextCompletion(OpenAITextCompletion):
_endpoint: str
_api_version: str
_api_type: str

def __init__(
self,
deployment_name: str,
endpoint: str,
api_key: str,
endpoint: Optional[str] = None,
api_key: Optional[str] = None,
api_version: str = "2022-12-01",
logger: Optional[Logger] = None,
ad_auth=False,
) -> None:
"""
Initialize an AzureTextCompletion backend.
You must provide:
- A deployment_name, endpoint, and api_key (plus, optionally: ad_auth)
:param deployment_name: The name of the Azure deployment. This value
will correspond to the custom name you chose for your deployment
when you deployed a model. This value can be found under
Resource Management > Deployments in the Azure portal or, alternatively,
under Management > Deployments in Azure OpenAI Studio.
:param endpoint: The endpoint of the Azure deployment. This value
can be found in the Keys & Endpoint section when examining
your resource from the Azure portal.
:param api_key: The API key for the Azure deployment. This value can be
found in the Keys & Endpoint section when examining your resource in
the Azure portal. You can use either KEY1 or KEY2.
:param api_version: The API version to use. (Optional)
The default value is "2022-12-01".
:param logger: The logger instance to use. (Optional)
:param ad_auth: Whether to use Azure Active Directory authentication.
(Optional) The default value is False.
"""
if not deployment_name:
raise ValueError("The deployment name cannot be `None` or empty")
if not api_key:
Expand All @@ -32,13 +57,14 @@ def __init__(

self._endpoint = endpoint
self._api_version = api_version
self._api_type = "azure_ad" if ad_auth else "azure"

super().__init__(deployment_name, api_key, org_id=None, log=logger)

def _setup_open_ai(self) -> Any:
import openai

openai.api_type = "azure"
openai.api_type = self._api_type
openai.api_key = self._api_key
openai.api_base = self._endpoint
openai.api_version = self._api_version
Expand Down
32 changes: 29 additions & 3 deletions python/semantic_kernel/ai/open_ai/services/azure_text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,40 @@
class AzureTextEmbedding(OpenAITextEmbedding):
_endpoint: str
_api_version: str
_api_type: str

def __init__(
self,
deployment_name: str,
endpoint: str,
api_key: str,
endpoint: Optional[str] = None,
api_key: Optional[str] = None,
api_version: str = "2022-12-01",
logger: Optional[Logger] = None,
ad_auth=False,
) -> None:
"""
Initialize an AzureTextEmbedding backend.
You must provide:
- A deployment_name, endpoint, and api_key (plus, optionally: ad_auth)
:param deployment_name: The name of the Azure deployment. This value
will correspond to the custom name you chose for your deployment
when you deployed a model. This value can be found under
Resource Management > Deployments in the Azure portal or, alternatively,
under Management > Deployments in Azure OpenAI Studio.
:param endpoint: The endpoint of the Azure deployment. This value
can be found in the Keys & Endpoint section when examining
your resource from the Azure portal.
:param api_key: The API key for the Azure deployment. This value can be
found in the Keys & Endpoint section when examining your resource in
the Azure portal. You can use either KEY1 or KEY2.
:param api_version: The API version to use. (Optional)
The default value is "2022-12-01".
:param logger: The logger instance to use. (Optional)
:param ad_auth: Whether to use Azure Active Directory authentication.
(Optional) The default value is False.
"""
if not deployment_name:
raise ValueError("The deployment name cannot be `None` or empty")
if not api_key:
Expand All @@ -32,13 +57,14 @@ def __init__(

self._endpoint = endpoint
self._api_version = api_version
self._api_type = "azure_ad" if ad_auth else "azure"

super().__init__(deployment_name, api_key, org_id=None, log=logger)

def _setup_open_ai(self) -> Any:
import openai

openai.api_type = "azure"
openai.api_type = self._api_type
openai.api_key = self._api_key
openai.api_base = self._endpoint
openai.api_version = self._api_version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def complete_chat_async(
)

model_args = {}
if self.open_ai_instance.api_type == "azure":
if self.open_ai_instance.api_type in ["azure", "azure_ad"]:
model_args["engine"] = self._model_id
else:
model_args["model"] = self._model_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def complete_simple_async(
)

model_args = {}
if self.open_ai_instance.api_type == "azure":
if self.open_ai_instance.api_type in ["azure", "azure_ad"]:
model_args["engine"] = self._model_id
else:
model_args["model"] = self._model_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _setup_open_ai(self) -> Any:

async def generate_embeddings_async(self, texts: List[str]) -> ndarray:
model_args = {}
if self.open_ai_instance.api_type == "azure":
if self.open_ai_instance.api_type in ["azure", "azure_ad"]:
model_args["engine"] = self._model_id
else:
model_args["model"] = self._model_id
Expand Down
131 changes: 131 additions & 0 deletions python/tests/unit/ai/open_ai/services/test_azure_chat_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) Microsoft. All rights reserved.

from logging import Logger
from unittest.mock import Mock

from pytest import raises

from semantic_kernel.ai.open_ai.services.azure_chat_completion import (
AzureChatCompletion,
)
from semantic_kernel.ai.open_ai.services.open_ai_chat_completion import (
OpenAIChatCompletion,
)


def test_azure_chat_completion_init() -> None:
deployment_name = "test_deployment"
endpoint = "https://test-endpoint.com"
api_key = "test_api_key"
api_version = "2023-03-15-preview"
logger = Logger("test_logger")

# Test successful initialization
azure_chat_completion = AzureChatCompletion(
deployment_name=deployment_name,
endpoint=endpoint,
api_key=api_key,
api_version=api_version,
logger=logger,
)

assert azure_chat_completion._endpoint == endpoint
assert azure_chat_completion._api_version == api_version
assert azure_chat_completion._api_type == "azure"
assert isinstance(azure_chat_completion, OpenAIChatCompletion)


def test_azure_chat_completion_init_with_empty_deployment_name() -> None:
# deployment_name = "test_deployment"
endpoint = "https://test-endpoint.com"
api_key = "test_api_key"
api_version = "2023-03-15-preview"
logger = Logger("test_logger")

with raises(ValueError, match="The deployment name cannot be `None` or empty"):
AzureChatCompletion(
deployment_name="",
endpoint=endpoint,
api_key=api_key,
api_version=api_version,
logger=logger,
)


def test_azure_chat_completion_init_with_empty_api_key() -> None:
deployment_name = "test_deployment"
endpoint = "https://test-endpoint.com"
# api_key = "test_api_key"
api_version = "2023-03-15-preview"
logger = Logger("test_logger")

with raises(ValueError, match="The Azure API key cannot be `None` or empty`"):
AzureChatCompletion(
deployment_name=deployment_name,
endpoint=endpoint,
api_key="",
api_version=api_version,
logger=logger,
)


def test_azure_chat_completion_init_with_empty_endpoint() -> None:
deployment_name = "test_deployment"
# endpoint = "https://test-endpoint.com"
api_key = "test_api_key"
api_version = "2023-03-15-preview"
logger = Logger("test_logger")

with raises(ValueError, match="The Azure endpoint cannot be `None` or empty"):
AzureChatCompletion(
deployment_name=deployment_name,
endpoint="",
api_key=api_key,
api_version=api_version,
logger=logger,
)


def test_azure_chat_completion_init_with_invalid_endpoint() -> None:
deployment_name = "test_deployment"
endpoint = "http://test-endpoint.com"
api_key = "test_api_key"
api_version = "2023-03-15-preview"
logger = Logger("test_logger")

with raises(ValueError, match="The Azure endpoint must start with https://"):
AzureChatCompletion(
deployment_name=deployment_name,
endpoint=endpoint,
api_key=api_key,
api_version=api_version,
logger=logger,
)


def test_azure_chat_completion_setup_open_ai() -> None:
import sys

deployment_name = "test_deployment"
endpoint = "https://test-endpoint.com"
api_key = "test_api_key"
api_version = "2023-03-15-preview"
logger = Logger("test_logger")

azure_chat_completion = AzureChatCompletion(
deployment_name=deployment_name,
endpoint=endpoint,
api_key=api_key,
api_version=api_version,
logger=logger,
)

mock_openai = Mock()
sys.modules["openai"] = mock_openai

azure_chat_completion._setup_open_ai()

assert mock_openai.api_type == "azure"
assert mock_openai.api_key == api_key
assert mock_openai.api_base == endpoint
assert mock_openai.api_version == api_version
Loading

0 comments on commit 7d9c40f

Please sign in to comment.