diff --git a/genai_stack/model/__init__.py b/genai_stack/model/__init__.py index fdd3eb4a..b9b4ad1d 100644 --- a/genai_stack/model/__init__.py +++ b/genai_stack/model/__init__.py @@ -3,3 +3,4 @@ from .run import list_supported_models, get_model_class, AVAILABLE_MODEL_MAPS, run_custom_model from .gpt4all import Gpt4AllModel from .hf import HuggingFaceModel +from .azure import AzureModel diff --git a/genai_stack/model/azure.py b/genai_stack/model/azure.py new file mode 100644 index 00000000..9c955976 --- /dev/null +++ b/genai_stack/model/azure.py @@ -0,0 +1,43 @@ +from typing import Optional, Union, Any, Dict, Tuple +from pydantic import Field +from langchain.chat_models import AzureChatOpenAI + +from genai_stack.model.base import BaseModel, BaseModelConfig, BaseModelConfigModel + +class AzureModelParameters(BaseModelConfigModel): + model_name: str = Field(default="gpt-4", alias="model") + azure_deployment: str + temperature: float = 0.1 + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + api_key: str + openai_api_version: str = Field(default="2024-02-01",alias="api_version") + streaming: bool = False + azure_endpoint: str + + +class AzureModelConfigModel(BaseModelConfigModel): + """ + Data Model for the configs + """ + + parameters: AzureModelParameters + +class AzureModelConfig(BaseModelConfig): + data_model = AzureModelConfigModel + +class AzureModel(BaseModel): + config_class = AzureModelConfig + + def _post_init(self, *args, **kwargs): + self.model = self.load() + + def load(self): + """ + Using dict method here to dynamically access object attributes + """ + model = AzureChatOpenAI(**self.config.parameters.dict()) + return model + + def predict(self, prompt: str): + response = self.model.predict(prompt) + return {"output": response}