Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: function calling #241

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
936 changes: 174 additions & 762 deletions aphrodite/endpoints/openai/api_server.py

Large diffs are not rendered by default.

150 changes: 139 additions & 11 deletions aphrodite/endpoints/openai/protocol.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from pydantic import BaseModel, Field

from aphrodite.common.utils import random_uuid
from aphrodite.common.sampling_params import SamplingParams


class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: Optional[str] = None
code: int


class ModelPermission(BaseModel):
Expand Down Expand Up @@ -51,10 +52,59 @@ class UsageInfo(BaseModel):
total_tokens: int = 0
completion_tokens: Optional[int] = 0

class Function(BaseModel):
name: str
arguments: str


class ChatCompletionMessageToolCall(BaseModel):
id: str
type: str
function: Function


class FunctionDefinition(BaseModel):
name: str
description: str
parameters: Optional[Any] = None
# See : https://json-schema.org/understanding-json-schema/reference/object


class ChatCompletionToolParam(BaseModel):
type: str = "function"
function: FunctionDefinition = None


class ChatCompletionSystemMessage(BaseModel):
role: Literal["system"]
content: str
name: Optional[str] = None


class ChatCompletionUserMessage(BaseModel):
role: Literal["user"]
content: Union[str, List[str]]
name: Optional[str] = None


class ChatCompletionAssistantMessage(BaseModel):
role: Literal["assistant"]
content: Optional[str] = None
name: Optional[str] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None


class ChatCompletionToolMessage(BaseModel):
role: Literal["tool"]
content: str
tool_call_id: str

class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[Dict[str, str]]]
messages: List[Union[ChatCompletionToolMessage,
ChatCompletionAssistantMessage,
ChatCompletionUserMessage,
ChatCompletionSystemMessage]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
tfs: Optional[float] = 1.0
Expand Down Expand Up @@ -91,6 +141,40 @@ class ChatCompletionRequest(BaseModel):
spaces_between_special_tokens: Optional[bool] = True
add_generation_prompt: Optional[bool] = True
echo: Optional[bool] = False
tools: Optional[List[ChatCompletionToolParam]] = None
tool_choice: Optional[str] = None

def to_sampling_params(self) -> SamplingParams:
return SamplingParams(
n=self.n,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p,
tfs=self.tfs,
eta_cutoff=self.eta_cutoff,
epsilon_cutoff=self.epsilon_cutoff,
typical_p=self.typical_p,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
best_of=self.best_of,
top_k=self.top_k,
top_a=self.top_a,
min_p=self.min_p,
mirostat_mode=self.mirostat_mode,
mirostat_tau=self.mirostat_tau,
mirostat_eta=self.mirostat_eta,
dynatemp_range=self.dynatemp_range,
dynatemp_exponent=self.dynatemp_exponent,
smoothing_factor=self.smoothing_factor,
ignore_eos=self.ignore_eos,
use_beam_search=self.use_beam_search,
stop_token_ids=self.stop_token_ids,
custom_token_bans=self.custom_token_bans,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output
)


class CompletionRequest(BaseModel):
Expand Down Expand Up @@ -136,13 +220,47 @@ class CompletionRequest(BaseModel):
spaces_between_special_tokens: Optional[bool] = True
grammar: Optional[str] = None

def to_sampling_params(self) -> SamplingParams:
return SamplingParams(
n=self.n,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p,
tfs=self.tfs,
eta_cutoff=self.eta_cutoff,
epsilon_cutoff=self.epsilon_cutoff,
typical_p=self.typical_p,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
best_of=self.best_of,
top_k=self.top_k,
top_a=self.top_a,
min_p=self.min_p,
mirostat_mode=self.mirostat_mode,
mirostat_tau=self.mirostat_tau,
mirostat_eta=self.mirostat_eta,
dynatemp_range=self.dynatemp_range,
dynatemp_exponent=self.dynatemp_exponent,
smoothing_factor=self.smoothing_factor,
ignore_eos=self.ignore_eos,
use_beam_search=self.use_beam_search,
stop_token_ids=self.stop_token_ids,
custom_token_bans=self.custom_token_bans,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
logprobs=self.logprobs,
prompt_logprobs=self.logprobs if self.echo else None,
logits_processors=self.grammar,
)


class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None


class CompletionResponseChoice(BaseModel):
Expand Down Expand Up @@ -174,18 +292,19 @@ class CompletionStreamResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo]
usage: Optional[UsageInfo] = Field(default=None)


class ChatMessage(BaseModel):
role: str
content: str
content: Optional[str] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None


class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[Literal["stop", "length"]] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None


class ChatCompletionResponse(BaseModel):
Expand All @@ -196,16 +315,23 @@ class ChatCompletionResponse(BaseModel):
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo

class ChoiceDeltaToolCall(BaseModel):
index: int
id: str
type: str
function: Function


class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: Optional[List[ChoiceDeltaToolCall]] = None


class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None


class ChatCompletionStreamResponse(BaseModel):
Expand All @@ -214,5 +340,7 @@ class ChatCompletionStreamResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(
default=None, description="data about request and response")
usage: Optional[UsageInfo] = Field(default=None)

class Prompt(BaseModel):
prompt: str
Loading
Loading