Skip to content

Implement LiteLLM Instruction Param #10323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions litellm/llms/base_llm/speech/transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional, Union

import httpx

from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.types.llms.openai import OpenAISpeechOptionalParams
from litellm.types.utils import FileTypes, ModelResponse

if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj

LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any


class BaseSpeechConfig(BaseConfig, ABC):
@abstractmethod
def get_supported_openai_params(
self, model: str
) -> List[OpenAISpeechOptionalParams]:
pass

@abstractmethod
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map the OpenAI params to the Whisper params
"""
supported_params = self.get_supported_openai_params(model)
for k, v in non_default_params.items():
if k in supported_params:
optional_params[k] = v
return optional_params
2 changes: 2 additions & 0 deletions litellm/llms/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,7 @@ def audio_speech(
model: str,
input: str,
voice: str,
instructions: Optional[str],
optional_params: dict,
api_key: Optional[str],
api_base: Optional[str],
Expand Down Expand Up @@ -1393,6 +1394,7 @@ def audio_speech(
response = cast(OpenAI, openai_client).audio.speech.create(
model=model,
voice=voice, # type: ignore
instructions=instructions,
input=input,
**optional_params,
)
Expand Down
35 changes: 35 additions & 0 deletions litellm/llms/openai/speech/gpt_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import List

from litellm.llms.base_llm.speech.transformation import (
BaseSpeechConfig,
)
from litellm.types.llms.openai import OpenAISpeechOptionalParams

class OpenAIGPTSpeechConfig(BaseSpeechConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAISpeechOptionalParams]:
"""
Get the supported OpenAI params for the gpt models
"""
return [
"instructions",
"response_format",
"speed",
]

def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map the OpenAI params to the Speech params
"""
supported_params = self.get_supported_openai_params(model)
for k, v in non_default_params.items():
if k in supported_params:
optional_params[k] = v
return optional_params
33 changes: 33 additions & 0 deletions litellm/llms/openai/speech/tts_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import List
from litellm.llms.base_llm.speech.transformation import (
BaseSpeechConfig,
)
from litellm.types.llms.openai import OpenAISpeechOptionalParams

class OpenAITTSSpeechConfig(BaseSpeechConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAISpeechOptionalParams]:
"""
Get the supported OpenAI params for the tts models
"""
return [
"response_format",
"speed",
]

def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map the OpenAI params to the Speech params
"""
supported_params = self.get_supported_openai_params(model)
for k, v in non_default_params.items():
if k in supported_params:
optional_params[k] = v
return optional_params
20 changes: 15 additions & 5 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
get_optional_params_embeddings,
get_optional_params_image_gen,
get_optional_params_transcription,
get_optional_params_speech,
get_secret,
mock_completion_streaming_obj,
read_config_args,
Expand Down Expand Up @@ -5302,6 +5303,7 @@ def speech( # noqa: PLR0915
model: str,
input: str,
voice: Optional[Union[str, dict]] = None,
instructions: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
Expand All @@ -5328,11 +5330,18 @@ def speech( # noqa: PLR0915
) # type: ignore
kwargs.pop("tags", [])

optional_params = {}
if response_format is not None:
optional_params["response_format"] = response_format
if speed is not None:
optional_params["speed"] = speed # type: ignore
# optional_params = {}
# if response_format is not None:
# optional_params["response_format"] = response_format
# if speed is not None:
# optional_params["speed"] = speed # type: ignore

optional_params = get_optional_params_speech(
model=model,
response_format=response_format,
speed=speed,
instructions=instructions,
)

if timeout is None:
timeout = litellm.request_timeout
Expand Down Expand Up @@ -5401,6 +5410,7 @@ def speech( # noqa: PLR0915
model=model,
input=input,
voice=voice,
instructions=instructions,
optional_params=optional_params,
api_key=api_key,
api_base=api_base,
Expand Down
7 changes: 6 additions & 1 deletion litellm/types/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,9 +889,14 @@ class Config:
"include",
]


OpenAIImageVariationOptionalParams = Literal["n", "size", "response_format", "user"]

OpenAISpeechOptionalParams = Literal[
"instructions",
"response_format",
"speed",
]


class ComputerToolParam(TypedDict, total=False):
display_height: Required[float]
Expand Down
70 changes: 70 additions & 0 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,6 +2457,76 @@ def _check_valid_arg(supported_params):
return optional_params


def get_optional_params_speech(
model: str,
response_format: Optional[str] = None,
speed: Optional[float] = None,
instructions: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
**kwargs,
):
# retrieve all parameters passed to the function
passed_params = locals()
custom_llm_provider = passed_params.pop("custom_llm_provider")
drop_params = passed_params.pop("drop_params")
special_params = passed_params.pop("kwargs")
for k, v in special_params.items():
passed_params[k] = v

default_params = {
"response_format": None,
"speed": None,
"instructions": None,
}

non_default_params = {
k: v
for k, v in passed_params.items()
if (k in default_params and v != default_params[k])
}
optional_params = {}

## raise exception if non-default value passed for non-openai/azure embedding calls
def _check_valid_arg(supported_params):
if len(non_default_params.keys()) > 0:
keys = list(non_default_params.keys())
for k in keys:
if (
drop_params is True or litellm.drop_params is True
) and k not in supported_params: # drop the unsupported non-default values
non_default_params.pop(k, None)
elif k not in supported_params:
raise UnsupportedParamsError(
status_code=500,
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
)
return non_default_params

provider_config: Optional[BaseSpeechConfig] = None
if custom_llm_provider is not None:
provider_config = ProviderConfigManager.get_provider_speech_config(
model=model,
provider=LlmProviders(custom_llm_provider),
)

if provider_config is not None: # handles fireworks ai, and any future providers
supported_params = provider_config.get_supported_openai_params(model=model)
_check_valid_arg(supported_params=supported_params)
optional_params = provider_config.map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params if drop_params is not None else False,
)
for k in passed_params.keys(): # pass additional kwargs without modification
if k not in default_params.keys():
optional_params[k] = passed_params[k]
return optional_params




def get_optional_params_image_gen(
model: Optional[str] = None,
n: Optional[int] = None,
Expand Down
93 changes: 93 additions & 0 deletions tests/local_testing/test_audio_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,96 @@ def test_audio_speech_cost_calc():
]
print(f"standard_logging_payload: {standard_logging_payload}")
assert standard_logging_payload["response_cost"] > 0

@pytest.mark.parametrize(
"sync_mode",
[True],
)
@pytest.mark.parametrize(
"model, api_key, api_base",
[
("openai/gpt-4o-mini-tts", os.getenv("OPENAI_API_KEY"), None),
],
)
@pytest.mark.asyncio
async def test_audio_speech_litellm(sync_mode, model, api_base, api_key):
speech_file_path = Path(__file__).parent / "speech.mp3"
litellm._turn_on_debug()
if sync_mode:
response = litellm.speech(
model=model,
voice="alloy",
instructions="speak the text as though you are like a crazy person, almost goofy and laughing at the end",
input="say hello to the world",
api_base=api_base,
api_key=api_key,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=None,
optional_params={},
)


from litellm.types.llms.openai import HttpxBinaryResponseContent


print("response", response)


assert isinstance(response, HttpxBinaryResponseContent)
with open(speech_file_path, "wb") as f:
f.write(response.content)


@pytest.mark.parametrize(
"sync_mode",
[True],
)
@pytest.mark.parametrize(
"model, api_key, api_base",
[
("openai/gpt-4o-mini-tts", os.getenv("OPENAI_API_KEY"), None),
],
)
@pytest.mark.asyncio
async def test_audio_speech_passes_instructions_to_openai(sync_mode, model, api_base, api_key):
speech_file_path = Path(__file__).parent / "speech.mp3"
litellm._turn_on_debug()
if sync_mode:

from openai import OpenAI
from litellm.utils import supports_system_messages

litellm.set_verbose = True
client = OpenAI(api_key="fake-api-key")

test_instructions = "speak the text as though you are like a crazy person, almost goofy and laughing at the end"
with patch.object(
client.audio.speech, "create"
) as mock_client:
try:
litellm.speech(
model=model,
voice="alloy",
instructions=test_instructions,
input="say hello to the world",
api_base=api_base,
api_key=api_key,
organization=None,
project=None,
max_retries=1,
timeout=600,
client=client,
optional_params={},
)
except Exception as e:
print(f"Error: {e}")

mock_client.assert_called_once()
request_body = mock_client.call_args.kwargs

print("request_body: ", request_body)

assert request_body["instructions"] == test_instructions
Loading