diff --git a/aphrodite/endpoints/openai/api_server.py b/aphrodite/endpoints/openai/api_server.py index 5e88bf690..10d1635df 100644 --- a/aphrodite/endpoints/openai/api_server.py +++ b/aphrodite/endpoints/openai/api_server.py @@ -1,59 +1,62 @@ -# Adapted from -# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py - import argparse import asyncio -import codecs import json -import time -from http import HTTPStatus -from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union +import sys +from contextlib import asynccontextmanager +import os +import importlib +import inspect from aioprometheus import MetricsMiddleware from aioprometheus.asgi.starlette import metrics import fastapi import uvicorn -from fastapi import Request, Response, Header, HTTPException, Depends +from http import HTTPStatus +from fastapi import Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, StreamingResponse -from pydantic import BaseModel +from fastapi.responses import JSONResponse, StreamingResponse, Response from aphrodite.engine.args_tools import AsyncEngineArgs from aphrodite.engine.async_aphrodite import AsyncAphrodite from aphrodite.engine.metrics import add_global_metrics_labels from aphrodite.endpoints.openai.protocol import ( - CompletionRequest, CompletionResponse, CompletionResponseChoice, - CompletionResponseStreamChoice, CompletionStreamResponse, - ChatCompletionRequest, ChatCompletionResponse, - ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, - LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo) + CompletionRequest, ChatCompletionRequest, ErrorResponse, Prompt) from aphrodite.common.logger import init_logger -from aphrodite.common.outputs import RequestOutput -from aphrodite.common.sampling_params import SamplingParams -from aphrodite.transformers_utils.tokenizer import get_tokenizer -from aphrodite.common.utils import random_uuid -from aphrodite.common.logits_processor import BiasLogitsProcessor -from aphrodite.common.grammar import (GrammarLogitsProcessor, - RayRemoteGrammarLogitsProcessor) +from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat +from aphrodite.endpoints.openai.serving_completions import OpenAIServingCompletion +from aphrodite.endpoints.openai.tools import OpenAIToolsPrompter TIMEOUT_KEEP_ALIVE = 5 # seconds +aphrodite_engine = None +aphrodite_engine_args = None +openai_serving_chat: OpenAIServingChat = None +openai_serving_completion: OpenAIServingCompletion = None logger = init_logger(__name__) -served_model = None -app = fastapi.FastAPI() -engine = None -response_role = None + + +@asynccontextmanager +async def lifespan(app: fastapi.FastAPI): + + async def _force_log(): + while True: + await asyncio.sleep(10) + await aphrodite_engine.do_log_stats() + + if not aphrodite_engine_args.disable_log_stats: + asyncio.create_task(_force_log()) + + yield + + +app = fastapi.FastAPI(lifespan=lifespan) def parse_args(): parser = argparse.ArgumentParser( description="Aphrodite OpenAI-Compatible RESTful API server.") - parser.add_argument("--host", - type=str, - default="localhost", - help="host name") + parser.add_argument("--host", type=str, default=None, help="host name") parser.add_argument("--port", type=int, default=2242, help="port number") parser.add_argument("--allow-credentials", action="store_true", @@ -70,775 +73,162 @@ def parse_args(): type=json.loads, default=["*"], help="allowed headers") + parser.add_argument( + "--api-keys", + type=str, + default=None, + help= + "If provided, the server will require this key to be presented in the header." + ) parser.add_argument("--served-model-name", type=str, default=None, help="The model name used in the API. If not " "specified, the model name will be the same as " "the huggingface name.") - parser.add_argument("--api-keys", - nargs="*", - help="Authorization API Keys for the server.") parser.add_argument("--chat-template", type=str, default=None, help="The file path to the chat template, " "or the template in single-line form " - "for the specified model.") + "for the specified model") + parser.add_argument("--tools-template", + type=str, + default=None, + help="The file path to alternative tools template") + parser.add_argument("--enable-api-tools", + action="store_true", + help="Enable OpenAI-like tools API " + "(only function calls are currently supported)") parser.add_argument("--response-role", type=str, default="assistant", help="The role name to return if " - "`request.add_generation_prompt=True.") + "`request.add_generation_prompt=true`.") parser.add_argument("--ssl-keyfile", type=str, default=None, - help="SSL key file path.") + help="The file path to the SSL key file") parser.add_argument("--ssl-certfile", type=str, default=None, - help="SSL cert file path.") + help="The file path to the SSL cert file") + parser.add_argument( + "--dev-mode", + action="store_true", + help= + "Enable API internals and templates reloading but do not deallocate the engine. This should only be used for development purpose." + ) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument( + "--middleware", + type=str, + action="append", + default=[], + help="Additional ASGI middleware to apply to the app. " + "We accept multiple --middleware arguments. " + "The value should be an import path. " + "If a function is provided, Aphrodite will add it to the server using @app.middleware('http'). " + "If a class is provided, Aphrodite will add it to the server using app.add_middleware(). " + ) parser = AsyncEngineArgs.add_cli_args(parser) return parser.parse_args() - -app.add_middleware(MetricsMiddleware) # trace HTTP server metrics -app.add_route("/metrics", metrics) - - -def _verify_api_key(x_api_key: str = Header(None), - authorization: str = Header(None)): - if not EXPECTED_API_KEYS: # If no keys are provided - return "NoKey" # Return a default value - if x_api_key and x_api_key in EXPECTED_API_KEYS: - return x_api_key - elif authorization: - scheme, _, token = authorization.partition(" ") - if scheme.lower() == "bearer" and token in EXPECTED_API_KEYS: - return token - raise HTTPException( - status_code=401, - detail="Invalid API Key", - ) - - -def create_error_response(status_code: HTTPStatus, - message: str) -> JSONResponse: - return JSONResponse(ErrorResponse(message=message, - type="invalid_request_error").dict(), - status_code=status_code.value) - - -def load_chat_template(args, tokenizer): # pylint: disable=redefined-outer-name - if args.chat_template is not None: - try: - with open(args.chat_template, "r") as f: - chat_template = f.read() - except OSError: - # If opening a file fails, set chat template to be args to - # ensure we decode so our escape are interpreted correctly - chat_template = codecs.decode(args.chat_template, "unicode_escape") - - tokenizer.chat_template = chat_template - logger.info( - f"Using supplied chat template:\n{tokenizer.chat_template}") - elif tokenizer.chat_template is not None: - logger.info(f"Using default chat template:\n{tokenizer.chat_template}") - else: - logger.warning("No chat template provided. Chat API will not work.") +def _loadServingServices(): + """ Load or reload the OpenAI service. + This function should only be called once on initialization, but may be called to reload the API internals. + Reloading must be used for development purpose only. """ + global openai_serving_chat + global openai_serving_completion + if openai_serving_chat is not None: + del openai_serving_chat + if openai_serving_completion is not None: + del openai_serving_completion + + openai_tools_prompter = OpenAIToolsPrompter( + template_path=args.tools_template) if args.enable_api_tools else None + openai_serving_chat = OpenAIServingChat( + engine=aphrodite_engine, + served_model=served_model, + response_role=args.response_role, + chat_template=args.chat_template, + openai_tools_prompter=openai_tools_prompter, + dev_mode=args.dev_mode) + openai_serving_completion = OpenAIServingCompletion( + aphrodite_engine, served_model) + +app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics +app.add_route("/metrics", metrics) # Exposes HTTP metrics @app.exception_handler(RequestValidationError) -async def validation_exception_handler(request, exc): # pylint: disable=unused-argument - return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) - - -async def check_model(request) -> Optional[JSONResponse]: - if request.model == served_model: - return - ret = create_error_response( - HTTPStatus.NOT_FOUND, - f"The model `{request.model}` does not exist.", - ) - return ret - - -async def check_length( - request: Union[ChatCompletionRequest, CompletionRequest], - prompt: Optional[str] = None, - prompt_ids: Optional[List[int]] = None -) -> Tuple[List[int], Optional[JSONResponse]]: - assert (not (prompt is None and prompt_ids is None) - and not (prompt is not None and prompt_ids is not None) - ), "Either prompt or prompt_ids should be provided." - input_ids = prompt_ids if prompt_ids is not None else tokenizer( - prompt).input_ids - token_num = len(input_ids) - - if request.max_tokens is None: - request.max_tokens = max_model_len - token_num - if token_num + request.max_tokens > max_model_len: - return input_ids, create_error_response( - HTTPStatus.BAD_REQUEST, - f"This model's maximum context length is {max_model_len} tokens. " - f"However, you requested {request.max_tokens + token_num} tokens " - f"({token_num} in the messages, " - f"{request.max_tokens} in the completion). " - f"Please reduce the length of the messages or completion.", - ) - else: - return input_ids, None +async def validation_exception_handler(_, exc): + err = openai_serving_chat.create_error_response(message=str(exc)) + return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) @app.get("/health") async def health() -> Response: - """Health check route for K8s""" + """Health check.""" return Response(status_code=200) +if "--dev-mode" in sys.argv: -class Prompt(BaseModel): - prompt: str - + @app.get("/privileged") + async def privileged() -> Response: + """Reload the API internals. Dangerous!""" + logger.warning("privileged called.") + _loadServingServices() + return Response(status_code=200) @app.post("/v1/tokenize") -async def tokenize_text( - prompt: Prompt, - # pylint: disable=unused-argument - api_key: str = Depends(_verify_api_key)): - """Tokenize prompt using the tokenizer. - Returns: - value: The number of tokens in the prompt. - ids: The token IDs of the prompt. - """ - try: - tokenized_prompt = tokenizer.tokenize(prompt.prompt) - token_ids = tokenizer.convert_tokens_to_ids(tokenized_prompt) - return {"value": len(tokenized_prompt), "ids": token_ids} - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) from e - +async def tokenize(prompt: Prompt): + tokenized = await openai_serving_chat.tokenize_text(prompt) + return JSONResponse(content=tokenized.model_dump()) @app.get("/v1/models") -async def show_available_models( - # pylint: disable=unused-argument - api_key: str = Depends(_verify_api_key)): - """Show available models. Right now we only have one model.""" - model_cards = [ - ModelCard(id=served_model, - root=served_model, - permission=[ModelPermission()]) - ] - return ModelList(data=model_cards) - - -def create_logprobs( - token_ids: List[int], - top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None, - num_output_top_logprobs: Optional[int] = None, - initial_text_offset: int = 0, -) -> LogProbs: - """Create OpenAI-style logprobs.""" - logprobs = LogProbs() - last_token_len = 0 - if num_output_top_logprobs: - logprobs.top_logprobs = [] - for i, token_id in enumerate(token_ids): - step_top_logprobs = top_logprobs[i] - if step_top_logprobs is not None: - token_logprob = step_top_logprobs[token_id] - else: - token_logprob = None - token = tokenizer.convert_ids_to_tokens(token_id) - logprobs.tokens.append(token) - logprobs.token_logprobs.append(token_logprob) - if len(logprobs.text_offset) == 0: - logprobs.text_offset.append(initial_text_offset) - else: - logprobs.text_offset.append(logprobs.text_offset[-1] + - last_token_len) - last_token_len = len(token) - - if num_output_top_logprobs: - logprobs.top_logprobs.append({ - tokenizer.convert_ids_to_tokens(i): p - for i, p in step_top_logprobs.items() - } if step_top_logprobs else None) - - logprobs.top_logprobs = [{ - k: v if v > -1000 else -1000 - for k, v in top_logprob.items() - } for top_logprob in logprobs.top_logprobs if top_logprob is not None] - - return logprobs +async def show_available_models(): + models = await openai_serving_chat.show_available_models() + return JSONResponse(content=models.model_dump()) @app.post("/v1/chat/completions") -async def create_chat_completion( - request: ChatCompletionRequest, - raw_request: Request, - # pylint: disable=unused-argument - api_key: str = Depends(_verify_api_key)): - """Completion API similar to OpenAI's API. - - See https://platform.openai.com/docs/api-reference/chat/create - for the API specification. This API mimics the OpenAI ChatCompletion API. - - NOTE: Currently we do not support the following features: - - function_call (Users should implement this by themselves) - """ - - error_check_ret = await check_model(request) - if error_check_ret is not None: - return error_check_ret - - try: - prompt = tokenizer.apply_chat_template( - conversation=request.messages, - tokenize=False, - add_generation_prompt=request.add_generation_prompt) - except Exception as e: # pylint: disable=broad-except - logger.error(f"Error in applying chat template from request: {str(e)}") - return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - - token_ids, error_check_ret = await check_length(request, prompt=prompt) - if error_check_ret is not None: - return error_check_ret - - if not request.logit_bias: - logit_processors = [] - else: - biases = dict( - map(lambda bias: (int(bias[0]), bias[1]), - request.logit_bias.items())) - logit_processors = [BiasLogitsProcessor(biases)] - - model_name = request.model - request_id = f"cmpl-{random_uuid()}" - created_time = int(time.monotonic()) - chunk_object_type = "chat.completion.chunk" - - # We disable top_k at -1, add this conversion for - # compatibility - if request.top_k == 0: - request.top_k = -1 - try: - sampling_params = SamplingParams( - n=request.n, - presence_penalty=request.presence_penalty, - frequency_penalty=request.frequency_penalty, - repetition_penalty=request.repetition_penalty, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - top_a=request.top_a, - min_p=request.min_p, - tfs=request.tfs, - eta_cutoff=request.eta_cutoff, - epsilon_cutoff=request.epsilon_cutoff, - typical_p=request.typical_p, - mirostat_mode=request.mirostat_mode, - mirostat_tau=request.mirostat_tau, - mirostat_eta=request.mirostat_eta, - dynatemp_range=request.dynatemp_range, - dynatemp_exponent=request.dynatemp_exponent, - smoothing_factor=request.smoothing_factor, - stop=request.stop, - stop_token_ids=request.stop_token_ids, - include_stop_str_in_output=request.include_stop_str_in_output, - max_tokens=request.max_tokens, - best_of=request.best_of, - ignore_eos=request.ignore_eos, - use_beam_search=request.use_beam_search, - skip_special_tokens=request.skip_special_tokens, - spaces_between_special_tokens=request. - spaces_between_special_tokens, # pylint: disable=line-too-long - custom_token_bans=request.custom_token_bans, - logprobs=request.logprobs, - prompt_logprobs=request.prompt_logprobs, - logits_processors=logit_processors, - ) - except ValueError as e: - return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - - result_generator = engine.generate(prompt, sampling_params, request_id, - token_ids) - - def get_role() -> str: - if request.add_generation_prompt: - return response_role - else: - return request.messages[-1]["role"] - - async def completion_stream_generator() -> AsyncGenerator[str, None]: - # Send first response for each request.n (index) with the role - role = get_role() - for i in range(request.n): - choice_data = ChatCompletionResponseStreamChoice( - index=i, delta=DeltaMessage(role=role), finish_reason=None) - chunk = ChatCompletionStreamResponse(id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - data = chunk.json(exclude_unset=True, ensure_ascii=False) - yield f"data: {data}\n\n" - - # Send response to echo the input portion of the last message - if request.echo: - last_msg_content = "" - if request.messages and isinstance( - request.messages, list) and request.messages[-1].get( - "content") and request.messages[-1].get( - "role") == role: - last_msg_content = request.messages[-1]["content"] - if last_msg_content: - for i in range(request.n): - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(content=last_msg_content), - finish_reason=None) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - data = chunk.json(exclude_unset=True, ensure_ascii=False) - yield f"data: {data}\n\n" - - # Send response for each token for each request.n (index) - previous_texts = [""] * request.n - previous_num_tokens = [0] * request.n - finish_reason_sent = [False] * request.n - async for res in result_generator: - res: RequestOutput - for output in res.outputs: - i = output.index - - if finish_reason_sent[i]: - continue - - if output.finish_reason is None: - # Send token-by-token response for each request.n - delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) - choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(content=delta_text), - finish_reason=None) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - data = chunk.json(exclude_unset=True, ensure_ascii=False) - yield f"data: {data}\n\n" - else: - # Send the finish response for each request.n only once - prompt_tokens = len(res.prompt_token_ids) - final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=previous_num_tokens[i], - total_tokens=prompt_tokens + previous_num_tokens[i], - ) - choice_data = ChatCompletionResponseStreamChoice( - index=i, delta=[], finish_reason=output.finish_reason) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - if final_usage is not None: - chunk.usage = final_usage - data = chunk.json(exclude_unset=True, - exclude_none=True, - ensure_ascii=False) - yield f"data: {data}\n\n" - finish_reason_sent[i] = True - # Send the final done message after all response.n are finished - yield "data: [DONE]\n\n" - - async def completion_full_generator(): - final_res: RequestOutput = None - async for res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await engine.abort(request_id) - return create_error_response(HTTPStatus.BAD_REQUEST, - "Client disconnected") - final_res = res - assert final_res is not None - - choices = [] - role = get_role() - for output in final_res.outputs: - choice_data = ChatCompletionResponseChoice( - index=output.index, - message=ChatMessage(role=role, content=output.text), - finish_reason=output.finish_reason, - ) - choices.append(choice_data) - - if request.echo: - last_msg_content = "" - if request.messages and isinstance( - request.messages, list) and request.messages[-1].get( - "content") and request.messages[-1].get( - "role") == role: - last_msg_content = request.messages[-1]["content"] - - for choice in choices: - full_message = last_msg_content + choice.message.content - choice.message.content = full_message - - num_prompt_tokens = len(final_res.prompt_token_ids) - num_generated_tokens = sum( - len(output.token_ids) for output in final_res.outputs) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) - response = ChatCompletionResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices, - usage=usage, - ) - - return response - - # Streaming response +async def create_chat_completion(request: ChatCompletionRequest, + raw_request: Request): + generator = await openai_serving_chat.create_chat_completion( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) if request.stream: - return StreamingResponse(completion_stream_generator(), + return StreamingResponse(content=generator, media_type="text/event-stream") else: - return await completion_full_generator() + return JSONResponse(content=generator.model_dump()) @app.post("/v1/completions") -async def create_completion( - request: CompletionRequest, - raw_request: Request, - # pylint: disable=unused-argument - api_key: str = Depends(_verify_api_key)): - """Completion API similar to OpenAI's API. - - See https://platform.openai.com/docs/api-reference/completions/create - for the API specification. This API mimics the OpenAI Completion API. - - NOTE: Currently we do not support the following features: - - echo (since the Aphrodite engine does not currently support - getting the logprobs of prompt tokens) - - suffix (the language models we currently support do not support - suffix) - """ - - error_check_ret = await check_model(request) - if error_check_ret is not None: - return error_check_ret - - if not request.logit_bias: - logit_processors = [] - else: - biases = dict( - map(lambda bias: (int(bias[0]), bias[1]), - request.logit_bias.items())) - logit_processors = [BiasLogitsProcessor(biases)] - - if request.grammar: - if engine.worker_use_ray: - grammar_logits_processor = RayRemoteGrammarLogitsProcessor( - tokenizer=tokenizer, grammar=request.grammar) - else: - grammar_logits_processor = GrammarLogitsProcessor( - tokenizer=tokenizer, grammar=request.grammar) - logit_processors = [grammar_logits_processor] - else: - logit_processors = [] - - # OpenAI API supports echoing the prompt when max_tokens is 0. - echo_without_generation = request.echo and request.max_tokens == 0 - - if request.suffix is not None: - # The language models we currently support do not support suffix. - return create_error_response(HTTPStatus.BAD_REQUEST, - "suffix is not currently supported") - - model_name = request.model - request_id = f"cmpl-{random_uuid()}" - - use_token_ids = False - if isinstance(request.prompt, list): - if len(request.prompt) == 0: - return create_error_response(HTTPStatus.BAD_REQUEST, - "please provide at least one prompt") - first_element = request.prompt[0] - if isinstance(first_element, int): - use_token_ids = True - prompt = request.prompt - elif isinstance(first_element, (str, list)): - # TODO: handles multiple prompt case in list[list[int]] - if len(request.prompt) > 1: - return create_error_response( - HTTPStatus.BAD_REQUEST, - "multiple prompts in a batch is not currently supported") - use_token_ids = not isinstance(first_element, str) - prompt = request.prompt[0] - else: - prompt = request.prompt - - if use_token_ids: - _, error_check_ret = await check_length(request, prompt_ids=prompt) - else: - token_ids, error_check_ret = await check_length(request, prompt=prompt) - if error_check_ret is not None: - return error_check_ret - - created_time = int(time.monotonic()) - - # We disable top_k at -1, add this conversion for - # compatibility - if request.top_k == 0: - request.top_k = -1 - - try: - sampling_params = SamplingParams( - n=request.n, - presence_penalty=request.presence_penalty, - frequency_penalty=request.frequency_penalty, - repetition_penalty=request.repetition_penalty, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - top_a=request.top_a, - min_p=request.min_p, - tfs=request.tfs, - eta_cutoff=request.eta_cutoff, - epsilon_cutoff=request.epsilon_cutoff, - typical_p=request.typical_p, - mirostat_mode=request.mirostat_mode, - mirostat_tau=request.mirostat_tau, - mirostat_eta=request.mirostat_eta, - dynatemp_range=request.dynatemp_range, - dynatemp_exponent=request.dynatemp_exponent, - smoothing_factor=request.smoothing_factor, - stop=request.stop, - stop_token_ids=request.stop_token_ids, - include_stop_str_in_output=request.include_stop_str_in_output, - max_tokens=request.max_tokens - if not echo_without_generation else 1, - best_of=request.best_of, - ignore_eos=request.ignore_eos, - use_beam_search=request.use_beam_search, - skip_special_tokens=request.skip_special_tokens, - spaces_between_special_tokens=request. - spaces_between_special_tokens, # pylint: disable=line-too-long - custom_token_bans=request.custom_token_bans, - logprobs=request.logprobs, - prompt_logprobs=request.prompt_logprobs if request.echo else None, - logits_processors=logit_processors, - ) - except ValueError as e: - return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - - if use_token_ids: - result_generator = engine.generate(None, - sampling_params, - request_id, - prompt_token_ids=prompt) - else: - result_generator = engine.generate(prompt, sampling_params, request_id, - token_ids) - - # Similar to the OpenAI API, when n != best_of, we do not stream the - # results. In addition, we do not stream the results when use beam search. - stream = (request.stream - and (request.best_of is None or request.n == request.best_of) - and not request.use_beam_search) - - def create_stream_response_json( - index: int, - text: str, - logprobs: Optional[LogProbs] = None, - finish_reason: Optional[str] = None, - usage: Optional[UsageInfo] = None, - ) -> str: - choice_data = CompletionResponseStreamChoice( - index=index, - text=text, - logprobs=logprobs, - finish_reason=finish_reason, - ) - response = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[choice_data], - ) - if usage is not None: - response.usage = usage - response_json = response.json(exclude_unset=True, ensure_ascii=False) - - return response_json - - async def completion_stream_generator() -> AsyncGenerator[str, None]: - previous_texts = [""] * request.n - previous_num_tokens = [0] * request.n - has_echoed = [False] * request.n - async for res in result_generator: - res: RequestOutput - for output in res.outputs: - i = output.index - delta_text = output.text[len(previous_texts[i]):] - token_ids = output.token_ids[previous_num_tokens[i]:] - if request.logprobs is not None: - top_logprobs = output.logprobs[previous_num_tokens[i]:] - else: - top_logprobs = None - offsets = len(previous_texts[i]) - if request.echo and not has_echoed[i]: - if not echo_without_generation: - delta_text = res.prompt + delta_text - token_ids = res.prompt_token_ids + token_ids - if top_logprobs: - top_logprobs = res.prompt_logprobs + top_logprobs - else: # only just return the prompt - delta_text = res.prompt - token_ids = res.prompt_token_ids - if top_logprobs: - top_logprobs = res.prompt_logprobs - has_echoed[i] = True - if request.logprobs is not None: - logprobs = create_logprobs( - token_ids=token_ids, - top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - initial_text_offset=offsets, - ) - else: - logprobs = None - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) - finish_reason = output.finish_reason - response_json = create_stream_response_json( - index=i, - text=delta_text, - logprobs=logprobs, - finish_reason=finish_reason, - ) - yield f"data: {response_json}\n\n" - if output.finish_reason is not None: - logprobs = (LogProbs() - if request.logprobs is not None else None) - prompt_tokens = len(res.prompt_token_ids) - completion_tokens = len(output.token_ids) - final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - response_json = create_stream_response_json( - index=i, - text="", - logprobs=logprobs, - finish_reason=output.finish_reason, - usage=final_usage, - ) - yield f"data: {response_json}\n\n" - yield "data: [DONE]\n\n" - - # Streaming response - if stream: - return StreamingResponse(completion_stream_generator(), - media_type="text/event-stream") - - # Non-streaming response - final_res: RequestOutput = None - async for res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await engine.abort(request_id) - return create_error_response(HTTPStatus.BAD_REQUEST, - "Client disconnected") - final_res = res - assert final_res is not None - choices = [] - prompt_token_ids = final_res.prompt_token_ids - prompt_logprobs = final_res.prompt_logprobs - prompt_text = final_res.prompt - for output in final_res.outputs: - if request.logprobs is not None: - if not echo_without_generation: - token_ids = output.token_ids - top_logprobs = output.logprobs - if request.echo: - token_ids = prompt_token_ids + token_ids - top_logprobs = prompt_logprobs + top_logprobs - else: - token_ids = prompt_token_ids - top_logprobs = prompt_logprobs - logprobs = create_logprobs( - token_ids=token_ids, - top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - ) - else: - logprobs = None - if not echo_without_generation: - output_text = output.text - if request.echo: - output_text = prompt_text + output_text - else: - output_text = prompt_text - choice_data = CompletionResponseChoice( - index=output.index, - text=output_text, - logprobs=logprobs, - finish_reason=output.finish_reason, - ) - choices.append(choice_data) - - num_prompt_tokens = len(final_res.prompt_token_ids) - num_generated_tokens = sum( - len(output.token_ids) for output in final_res.outputs) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) - response = CompletionResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices, - usage=usage, - ) - +async def create_completion(request: CompletionRequest, raw_request: Request): + generator = await openai_serving_completion.create_completion( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) if request.stream: - # When user requests streaming but we don't stream, we still need to - # return a streaming response with a single event. - response_json = response.json(ensure_ascii=False) - - async def fake_stream_generator() -> AsyncGenerator[str, None]: - yield f"data: {response_json}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse(fake_stream_generator(), + return StreamingResponse(content=generator, media_type="text/event-stream") - - return response + else: + return JSONResponse(content=generator.model_dump()) if __name__ == "__main__": args = parse_args() - global EXPECTED_API_KEYS # pylint: disable=global-at-module-level - EXPECTED_API_KEYS = args.api_keys + app.add_middleware( CORSMiddleware, allow_origins=args.allowed_origins, @@ -847,30 +237,52 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: allow_headers=args.allowed_headers, ) - logger.debug(f"args: {args}") + if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys: + + @app.middleware("http") + async def authentication(request: Request, call_next): + if not request.url.path.startswith("/v1"): + return await call_next(request) + if request.headers.get("Authorization") != "Bearer " + token: + return JSONResponse(content={"error": "Unauthorized"}, + status_code=401) + return await call_next(request) + + for middleware in args.middleware: + module_path, object_name = middleware.rsplit(".", 1) + imported = getattr(importlib.import_module(module_path), object_name) + if inspect.isclass(imported): + app.add_middleware(imported) + elif inspect.iscoroutinefunction(imported): + app.middleware("http")(imported) + else: + raise ValueError( + f"Invalid middleware {middleware}. Must be a function or a class." + ) + + logger.info(f"args: {args}") + if args.dev_mode: + logger.warning( + "\n" + "######################################################################\n" + "dev-mode enabled. This should only be used for development purpose.\n" + "If It's not the case, you should disable this!\n" + "######################################################################\n" + ) if args.served_model_name is not None: served_model = args.served_model_name else: served_model = args.model - response_role = args.response_role - - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncAphrodite.from_engine_args(engine_args) - engine_model_config = asyncio.run(engine.get_model_config()) - max_model_len = engine_model_config.max_model_len - - # A separate tokenizer to map token IDs to strings. - tokenizer = get_tokenizer( - engine_model_config.tokenizer, - tokenizer_mode=engine_model_config.tokenizer_mode, - trust_remote_code=engine_model_config.trust_remote_code) - - load_chat_template(args, tokenizer) + aphrodite_engine_args = AsyncEngineArgs.from_cli_args(args) + aphrodite_engine = AsyncAphrodite.from_engine_args(aphrodite_engine_args) + _loadServingServices() - add_global_metrics_labels(model_name=engine_args.model) + # Register labels for metrics + add_global_metrics_labels(model_name=aphrodite_engine_args.model) + app.root_path = args.root_path uvicorn.run(app, host=args.host, port=args.port, diff --git a/aphrodite/endpoints/openai/protocol.py b/aphrodite/endpoints/openai/protocol.py index e275d0fb3..a2739992a 100644 --- a/aphrodite/endpoints/openai/protocol.py +++ b/aphrodite/endpoints/openai/protocol.py @@ -1,11 +1,12 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time -from typing import Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field from aphrodite.common.utils import random_uuid +from aphrodite.common.sampling_params import SamplingParams class ErrorResponse(BaseModel): @@ -13,7 +14,7 @@ class ErrorResponse(BaseModel): message: str type: str param: Optional[str] = None - code: Optional[str] = None + code: int class ModelPermission(BaseModel): @@ -51,10 +52,59 @@ class UsageInfo(BaseModel): total_tokens: int = 0 completion_tokens: Optional[int] = 0 +class Function(BaseModel): + name: str + arguments: str + + +class ChatCompletionMessageToolCall(BaseModel): + id: str + type: str + function: Function + + +class FunctionDefinition(BaseModel): + name: str + description: str + parameters: Optional[Any] = None + # See : https://json-schema.org/understanding-json-schema/reference/object + + +class ChatCompletionToolParam(BaseModel): + type: str = "function" + function: FunctionDefinition = None + + +class ChatCompletionSystemMessage(BaseModel): + role: Literal["system"] + content: str + name: Optional[str] = None + + +class ChatCompletionUserMessage(BaseModel): + role: Literal["user"] + content: Union[str, List[str]] + name: Optional[str] = None + + +class ChatCompletionAssistantMessage(BaseModel): + role: Literal["assistant"] + content: Optional[str] = None + name: Optional[str] = None + tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None + + +class ChatCompletionToolMessage(BaseModel): + role: Literal["tool"] + content: str + tool_call_id: str class ChatCompletionRequest(BaseModel): model: str - messages: Union[str, List[Dict[str, str]]] + messages: List[Union[ChatCompletionToolMessage, + ChatCompletionAssistantMessage, + ChatCompletionUserMessage, + ChatCompletionSystemMessage]] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 tfs: Optional[float] = 1.0 @@ -91,6 +141,40 @@ class ChatCompletionRequest(BaseModel): spaces_between_special_tokens: Optional[bool] = True add_generation_prompt: Optional[bool] = True echo: Optional[bool] = False + tools: Optional[List[ChatCompletionToolParam]] = None + tool_choice: Optional[str] = None + + def to_sampling_params(self) -> SamplingParams: + return SamplingParams( + n=self.n, + max_tokens=self.max_tokens, + temperature=self.temperature, + top_p=self.top_p, + tfs=self.tfs, + eta_cutoff=self.eta_cutoff, + epsilon_cutoff=self.epsilon_cutoff, + typical_p=self.typical_p, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=self.repetition_penalty, + best_of=self.best_of, + top_k=self.top_k, + top_a=self.top_a, + min_p=self.min_p, + mirostat_mode=self.mirostat_mode, + mirostat_tau=self.mirostat_tau, + mirostat_eta=self.mirostat_eta, + dynatemp_range=self.dynatemp_range, + dynatemp_exponent=self.dynatemp_exponent, + smoothing_factor=self.smoothing_factor, + ignore_eos=self.ignore_eos, + use_beam_search=self.use_beam_search, + stop_token_ids=self.stop_token_ids, + custom_token_bans=self.custom_token_bans, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output + ) class CompletionRequest(BaseModel): @@ -136,13 +220,47 @@ class CompletionRequest(BaseModel): spaces_between_special_tokens: Optional[bool] = True grammar: Optional[str] = None + def to_sampling_params(self) -> SamplingParams: + return SamplingParams( + n=self.n, + max_tokens=self.max_tokens, + temperature=self.temperature, + top_p=self.top_p, + tfs=self.tfs, + eta_cutoff=self.eta_cutoff, + epsilon_cutoff=self.epsilon_cutoff, + typical_p=self.typical_p, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=self.repetition_penalty, + best_of=self.best_of, + top_k=self.top_k, + top_a=self.top_a, + min_p=self.min_p, + mirostat_mode=self.mirostat_mode, + mirostat_tau=self.mirostat_tau, + mirostat_eta=self.mirostat_eta, + dynatemp_range=self.dynatemp_range, + dynatemp_exponent=self.dynatemp_exponent, + smoothing_factor=self.smoothing_factor, + ignore_eos=self.ignore_eos, + use_beam_search=self.use_beam_search, + stop_token_ids=self.stop_token_ids, + custom_token_bans=self.custom_token_bans, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output, + logprobs=self.logprobs, + prompt_logprobs=self.logprobs if self.echo else None, + logits_processors=self.grammar, + ) + class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) - top_logprobs: List[Optional[Dict[str, - float]]] = Field(default_factory=list) + top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None class CompletionResponseChoice(BaseModel): @@ -174,18 +292,19 @@ class CompletionStreamResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[CompletionResponseStreamChoice] - usage: Optional[UsageInfo] + usage: Optional[UsageInfo] = Field(default=None) class ChatMessage(BaseModel): role: str - content: str + content: Optional[str] = None + tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage - finish_reason: Optional[Literal["stop", "length"]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None class ChatCompletionResponse(BaseModel): @@ -196,16 +315,23 @@ class ChatCompletionResponse(BaseModel): choices: List[ChatCompletionResponseChoice] usage: UsageInfo +class ChoiceDeltaToolCall(BaseModel): + index: int + id: str + type: str + function: Function + class DeltaMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None + tool_calls: Optional[List[ChoiceDeltaToolCall]] = None class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage - finish_reason: Optional[Literal["stop", "length"]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None class ChatCompletionStreamResponse(BaseModel): @@ -214,5 +340,7 @@ class ChatCompletionStreamResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseStreamChoice] - usage: Optional[UsageInfo] = Field( - default=None, description="data about request and response") + usage: Optional[UsageInfo] = Field(default=None) + +class Prompt(BaseModel): + prompt: str \ No newline at end of file diff --git a/aphrodite/endpoints/openai/serving_chat.py b/aphrodite/endpoints/openai/serving_chat.py new file mode 100644 index 000000000..bac879c03 --- /dev/null +++ b/aphrodite/endpoints/openai/serving_chat.py @@ -0,0 +1,453 @@ +import time +import codecs +import asyncio +from fastapi import Request +from typing import AsyncGenerator, AsyncIterator, Union + +from aphrodite.common.logger import init_logger +from aphrodite.common.utils import random_uuid +from aphrodite.engine.async_aphrodite import AsyncAphrodite +from aphrodite.endpoints.openai.protocol import ( + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionAssistantMessage, ChatCompletionToolMessage, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, + UsageInfo) +from aphrodite.common.outputs import RequestOutput +from aphrodite.endpoints.openai.serving_engine import OpenAIServing +from aphrodite.endpoints.openai.tools import OpenAIToolsPrompter, ChatPromptCapture + +logger = init_logger(__name__) + + +class OpenAIServingChat(OpenAIServing): + + def __init__(self, + engine: AsyncAphrodite, + served_model: str, + response_role: str, + chat_template=None, + openai_tools_prompter: OpenAIToolsPrompter = None, + dev_mode: bool = False): + super().__init__(engine=engine, served_model=served_model) + self.dev_mode = dev_mode + self.response_role = response_role + self.openai_tools_prompter = openai_tools_prompter + + try: + event_loop = asyncio.get_running_loop() + except RuntimeError: + event_loop = None + + if event_loop is not None and event_loop.is_running( + ): # If the current is instanced by Ray Serve, there is already a running event loop + event_loop.create_task(self._load_chat_template(chat_template)) + else: # When using Aphrodite without parallelism or engine_use_ray + asyncio.run(self._load_chat_template(chat_template)) + + async def create_chat_completion( + self, request: ChatCompletionRequest, raw_request: Request + ) -> Union[ErrorResponse, AsyncGenerator[str, None], + ChatCompletionResponse]: + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI ChatCompletion API. + + NOTE: Currently we do not support the following features:) + - logit_bias (to be supported by Aphrodite engine) + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + if request.logit_bias is not None and len(request.logit_bias) > 0: + # TODO: support logit_bias in Aphrodite engine. + return self.create_error_response( + "logit_bias is not currently supported") + + if self.openai_tools_prompter is not None: + self.openai_tools_prompter.inject_prompt(request) + + # FIXME: The tokenizer only accepts "role" and "content" attributes. + # So we manually copy other attributes into "content" when needed. + for m in request.messages: + if isinstance(m, ChatCompletionAssistantMessage + ) and m.tool_calls is not None: + m.content = self.openai_tools_prompter.content_from_assistant( + m) + elif isinstance(m, ChatCompletionToolMessage + ) and m.tool_call_id is not None: + m.content = self.openai_tools_prompter.content_from_tool(m) + + try: + prompt = self.tokenizer.apply_chat_template( + conversation=request.messages, + tokenize=False, + add_generation_prompt=request.add_generation_prompt) + except Exception as e: + logger.error( + f"Error in applying chat template from request: {str(e)}") + return self.create_error_response(str(e)) + + if self.dev_mode: # ease the templates development + logger.info("\n######## Development info (dev-mode) ########") + logger.info("- Request:\n%s" % str(request.model_dump())) + logger.info("") + logger.info("- Prompt:\n%s" % str(prompt)) + logger.info("##############################################") + + request_id = f"cmpl-{random_uuid()}" + try: + token_ids = self._validate_prompt_and_tokenize(request, + prompt=prompt) + sampling_params = request.to_sampling_params() + except ValueError as e: + return self.create_error_response(str(e)) + + result_generator = self.engine.generate(prompt, sampling_params, + request_id, token_ids) + # Streaming response + if request.stream: + return self.chat_completion_stream_generator( + request, result_generator, request_id) + else: + return await self.chat_completion_full_generator( + request, raw_request, result_generator, request_id) + + def get_chat_request_role(self, request: ChatCompletionRequest) -> str: + if request.add_generation_prompt: + return self.response_role + else: + return request.messages[-1].role + + async def chat_completion_stream_generator( + self, request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], request_id: str + ) -> Union[ErrorResponse, AsyncGenerator[str, None]]: + + model_name = request.model + created_time = int(time.monotonic()) + chunk_object_type = "chat.completion.chunk" + + # Send first response for each request.n (index) with the role + role = self.get_chat_request_role(request) + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, delta=DeltaMessage(role=role), finish_reason=None) + chunk = ChatCompletionStreamResponse(id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the last message + if request.echo: + last_msg_content = "" + if request.messages and isinstance( + request.messages, list) and request.messages[ + -1].content and request.messages[-1].role == role: + last_msg_content = request.messages[-1].content + if last_msg_content: + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=last_msg_content), + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + if self.openai_tools_prompter is not None and request.tools is not None: + tools_capture_texts = [ChatPromptCapture()] * request.n + else: + tools_capture_texts = None + + # Send response for each token for each request.n (index) + previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n + finish_reason_sent = [False] * request.n + async for res in result_generator: + res: RequestOutput + for output in res.outputs: + i = output.index + + if finish_reason_sent[i]: + continue + + current_capture = tools_capture_texts[ + i] if tools_capture_texts is not None else None + + if current_capture is not None and current_capture.after_new_function_call: + current_capture.after_new_function_call = False + # If the last token is a new line char right after a function call, we ignore it. + # Otherwise, each function call creates a line break in the content part of the response. + if output.text[len(previous_texts[i]):] == "\n": + previous_texts[i] = output.text + continue + + # Manage tools calling + if self.openai_tools_prompter is not None and \ + request.tools is not None and \ + output.finish_reason is None: + if len(current_capture.content) == 0: + current_token: str = output.text[len(previous_texts[i] + ):] + if self.openai_tools_prompter.func_call_token_pre( + ) in current_token: + start_pos: int = current_token.index( + self.openai_tools_prompter.func_call_token_pre( + )) + current_capture.content = current_token[ + start_pos:] # With some models the completion may start by a space. + current_capture.prefix_size = len( + output.text) - len(current_capture.content) + current_capture.maybe_function_call = True + else: # Maybe a function call... + current_token: str = output.text[ + len(current_capture.content) + + current_capture.prefix_size:] + current_capture.content += current_token + if len( + current_capture.content + ) < self.openai_tools_prompter.func_call_token_size(): + pass + elif not current_capture.is_function_call: + if current_capture.content.startswith( + self.openai_tools_prompter.func_call_token( + )): # Function call ! + current_capture.is_function_call = True + else: # This is not a function call... + current_capture.reset(False) + else: # Currently extracting the function call + if current_capture.content.rfind("}", -6) != -1: + c1 = current_capture.content.count("{") + c2 = current_capture.content.count("}") + if c1 == c2: # We have the complete call block + previous_texts[i] = output.text + current_capture.make_calls_list( + self.openai_tools_prompter) + current_capture.reset(False) + current_capture.after_new_function_call = True + else: + pass + if current_capture is None or ( + not current_capture.maybe_function_call): + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + + if output.finish_reason is None: + if len(delta_text) > 0: + # Send token-by-token response for each request.n + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + else: + if output.finish_reason == "stop" and ( + current_capture is not None and + (current_capture.num_calls() > 0)): + tools_calls_list = current_capture.to_ChoiceDeltaToolCallList( + ) + + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + content=None, tool_calls=tools_calls_list), + finish_reason="tool_calls") + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + chunk.usage = UsageInfo( + prompt_tokens=len(res.prompt_token_ids), + completion_tokens=len(output.token_ids), + total_tokens=len(res.prompt_token_ids) + + len(output.token_ids), + ) + data = chunk.model_dump_json(exclude_unset=True, + exclude_none=True) + yield f"data: {data}\n\n" + else: + # Send the finish response for each request.n only once + prompt_tokens = len(res.prompt_token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=previous_num_tokens[i], + total_tokens=prompt_tokens + + previous_num_tokens[i], + ) + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=output.finish_reason) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if final_usage is not None: + chunk.usage = final_usage + data = chunk.model_dump_json(exclude_unset=True, + exclude_none=True) + yield f"data: {data}\n\n" + finish_reason_sent[i] = True + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def chat_completion_full_generator( + self, request: ChatCompletionRequest, raw_request: Request, + result_generator: AsyncIterator[RequestOutput], + request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: + + model_name = request.model + created_time = int(time.monotonic()) + final_res: RequestOutput = None + + async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + return self.create_error_response("Client disconnected") + final_res = res + assert final_res is not None + + choices = [] + role = self.get_chat_request_role(request) + for output in final_res.outputs: + tools_calls_validation = False + + # Manage tools calling + if self.openai_tools_prompter is not None and \ + request.tools is not None: + current_capture = ChatPromptCapture() + + start_pos = 0 + while True: + pos = output.text.find( + self.openai_tools_prompter.func_call_token(), + start_pos, -1) + if pos < 0: + break + start_bloc = output.text.find("{", pos, -1) + if start_bloc < 0: + break + if (start_bloc - + (pos + + self.openai_tools_prompter.func_call_token_size()) + ) > 1: + break + count = 1 + bloc_end = start_bloc + 1 + for it_ch in range(start_bloc + 1, len(output.text), 1): + ch = output.text[it_ch] + bloc_end += 1 + if ch == "{": + count += 1 + elif ch == "}": + count -= 1 + if count == 0: # We have the complete call block + current_capture.content = output.text[ + start_bloc:bloc_end] + current_capture.make_calls_list( + self.openai_tools_prompter) + current_capture.reset(False) + break + start_pos = bloc_end + 1 + + if current_capture.num_calls() > 0: + tools_calls_validation = True + tools_calls_list = current_capture.to_ChatCompletionMessageToolCallList( + ) + message = ChatMessage(role=role, + content=None, + tool_calls=tools_calls_list) + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + finish_reason="tool_calls") + choices.append(choice_data) + if not tools_calls_validation: + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=ChatMessage(role=role, content=output.text), + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + if request.echo: + last_msg_content = "" + if request.messages and isinstance( + request.messages, list) and request.messages[ + -1].content and request.messages[-1].role == role: + last_msg_content = request.messages[-1].content + + for choice in choices: + full_message = last_msg_content + choice.message.content + choice.message.content = full_message + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + return response + + async def _load_chat_template(self, chat_template): + while True: + if self.tokenizer is not None: + if chat_template is not None: + try: + with open(chat_template, "r") as f: + self.tokenizer.chat_template = f.read() + except OSError: + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + self.tokenizer.chat_template = codecs.decode( + chat_template, "unicode_escape") + + logger.info( + f"Using supplied chat template:\n{self.tokenizer.chat_template}" + ) + elif self.tokenizer.chat_template is not None: + logger.info( + f"Using default chat template:\n{self.tokenizer.chat_template}" + ) + else: + logger.warning( + "No chat template provided. Chat API will not work.") + break + else: + logger.info("Waiting for the tokenizer initialization...") + await asyncio.sleep(0.100) + diff --git a/aphrodite/endpoints/openai/serving_completions.py b/aphrodite/endpoints/openai/serving_completions.py new file mode 100644 index 000000000..dbf5be906 --- /dev/null +++ b/aphrodite/endpoints/openai/serving_completions.py @@ -0,0 +1,365 @@ +import asyncio +import time +from fastapi import Request +from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional, Dict, Tuple + +from aphrodite.common.logger import init_logger +from aphrodite.common.utils import random_uuid +from aphrodite.common.grammar import GrammarLogitsProcessor, RayRemoteGrammarLogitsProcessor +from aphrodite.engine.async_aphrodite import AsyncAphrodite +from .protocol import ( + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + LogProbs, + UsageInfo, +) +from aphrodite.common.outputs import RequestOutput +from aphrodite.endpoints.openai.serving_engine import OpenAIServing + +logger = init_logger(__name__) + +TypeTokenIDs = List[int] +TypeTopLogProbs = List[Optional[Dict[int, float]]] +TypeCreateLogProbsFn = Callable[ + [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] + + +async def completion_stream_generator( + request: CompletionRequest, + raw_request: Request, + on_abort, + result_generator: AsyncIterator[Tuple[int, RequestOutput]], + create_logprobs_fn: TypeCreateLogProbsFn, + request_id: str, + created_time: int, + model_name: str, + num_prompts: int, +) -> AsyncGenerator[str, None]: + previous_texts = [""] * request.n * num_prompts + previous_num_tokens = [0] * request.n * num_prompts + has_echoed = [False] * request.n * num_prompts + + async for prompt_idx, res in result_generator: + + # Abort the request if the client disconnects. + if await raw_request.is_disconnected(): + await on_abort(f"{request_id}-{prompt_idx}") + raise StopAsyncIteration() + + for output in res.outputs: + i = output.index + prompt_idx * request.n + # TODO: optimize the performance by avoiding full text O(n^2) sending. + + if request.echo and request.max_tokens == 0: + # only return the prompt + delta_text = res.prompt + delta_token_ids = res.prompt_token_ids + top_logprobs = res.prompt_logprobs + has_echoed[i] = True + elif request.echo and request.max_tokens > 0 and not has_echoed[i]: + # echo the prompt and first token + delta_text = res.prompt + output.text + delta_token_ids = res.prompt_token_ids + output.token_ids + top_logprobs = res.prompt_logprobs + (output.logprobs or []) + has_echoed[i] = True + else: + # return just the delta + delta_text = output.text[len(previous_texts[i]):] + delta_token_ids = output.token_ids[previous_num_tokens[i]:] + top_logprobs = output.logprobs[ + previous_num_tokens[i]:] if output.logprobs else None + + if request.logprobs is not None: + assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested" + logprobs = create_logprobs_fn( + token_ids=delta_token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=len(previous_texts[i]), + ) + else: + logprobs = None + + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + finish_reason = output.finish_reason + response_json = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=logprobs, + finish_reason=finish_reason, + ) + ]).model_dump_json(exclude_unset=True) + yield f"data: {response_json}\n\n" + + if output.finish_reason is not None: # return final usage + logprobs = LogProbs() if request.logprobs is not None else None + prompt_tokens = len(res.prompt_token_ids) + completion_tokens = len(output.token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + response_json = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice( + index=i, + text="", + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + ], + usage=final_usage, + ).model_dump_json(exclude_unset=True) + yield f"data: {response_json}\n\n" + + yield "data: [DONE]\n\n" + + +def parse_prompt_format(prompt) -> Tuple[bool, list]: + # get the prompt, openai supports the following + # "a string, array of strings, array of tokens, or array of token arrays." + prompt_is_tokens = False + prompts = [prompt] # case 1: a string + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + elif isinstance(prompt[0], str): + prompt_is_tokens = False + prompts = prompt # case 2: array of strings + elif isinstance(prompt[0], int): + prompt_is_tokens = True + prompts = [prompt] # case 3: array of tokens + elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int): + prompt_is_tokens = True + prompts = prompt # case 4: array of token arrays + else: + raise ValueError( + "prompt must be a string, array of strings, array of tokens, or array of token arrays" + ) + return prompt_is_tokens, prompts + + +def request_output_to_completion_response( + final_res_batch: List[RequestOutput], + request: CompletionRequest, + create_logprobs_fn: TypeCreateLogProbsFn, + request_id: str, + created_time: int, + model_name: str, +) -> CompletionResponse: + choices = [] + num_prompt_tokens = 0 + num_generated_tokens = 0 + for final_res in final_res_batch: + assert final_res is not None + prompt_token_ids = final_res.prompt_token_ids + prompt_logprobs = final_res.prompt_logprobs + prompt_text = final_res.prompt + + for output in final_res.outputs: + if request.echo and request.max_tokens == 0: + token_ids = prompt_token_ids + top_logprobs = prompt_logprobs + output_text = prompt_text + elif request.echo and request.max_tokens > 0: + token_ids = prompt_token_ids + output.token_ids + top_logprobs = prompt_logprobs + output.logprobs + output_text = prompt_text + output.text + else: + token_ids = output.token_ids + top_logprobs = output.logprobs + output_text = output.text + + if request.logprobs is not None: + logprobs = create_logprobs_fn( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + ) + else: + logprobs = None + + choice_data = CompletionResponseChoice( + index=len(choices), + text=output_text, + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + num_prompt_tokens += len(prompt_token_ids) + num_generated_tokens += sum( + len(output.token_ids) for output in final_res.outputs) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + + return CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + +def merge_async_iterators(*iterators): + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + """ + queue = asyncio.Queue() + + finished = [False] * len(iterators) + + async def producer(i, iterator): + async for item in iterator: + await queue.put((i, item)) + finished[i] = True + + _tasks = [ + asyncio.create_task(producer(i, iterator)) + for i, iterator in enumerate(iterators) + ] + + async def consumer(): + while not all(finished) or not queue.empty(): + item = await queue.get() + yield item + await asyncio.gather(*_tasks) + + return consumer() + + +class OpenAIServingCompletion(OpenAIServing): + + def __init__(self, engine: AsyncAphrodite, served_model: str): + super().__init__(engine=engine, served_model=served_model) + + async def create_completion(self, request: CompletionRequest, + raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/completions/create + for the API specification. This API mimics the OpenAI Completion API. + + NOTE: Currently we do not support the following features: + - suffix (the language models we currently support do not support + suffix) + - logit_bias (to be supported by Aphrodite engine) + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + engine = self.engine + # Return error for unsupported features. + if request.suffix is not None: + return self.create_error_response( + "suffix is not currently supported") + if request.logit_bias is not None and len(request.logit_bias) > 0: + return self.create_error_response( + "logit_bias is not currently supported") + if request.grammar: + return self.create_error_response( + "grammar is not currently supported") + # if engine.worker_use_ray: + # grammar_logits_processor = RayRemoteGrammarLogitsProcessor( + # tokenizer=self.tokenizer, grammar=request.grammar) + # else: + # grammar_logits_processor = GrammarLogitsProcessor( + # tokenizer=self.tokenizer, grammar=request.grammar) + # logit_processors = [grammar_logits_processor] + # else: + # logit_processors = [] + + + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + created_time = int(time.monotonic()) + + # Schedule the request and get the result generator. + generators = [] + try: + sampling_params = request.to_sampling_params() + prompt_is_tokens, prompts = parse_prompt_format(request.prompt) + + for i, prompt in enumerate(prompts): + if prompt_is_tokens: + input_ids = self._validate_prompt_and_tokenize( + request, prompt_ids=prompt) + else: + input_ids = self._validate_prompt_and_tokenize( + request, prompt=prompt) + + generators.append( + self.engine.generate(None, + sampling_params, + f"{request_id}-{i}", + prompt_token_ids=input_ids)) + except ValueError as e: + return self.create_error_response(str(e)) + + result_generator: AsyncIterator[Tuple[ + int, RequestOutput]] = merge_async_iterators(*generators) + + # Similar to the OpenAI API, when n != best_of, we do not stream the + # results. In addition, we do not stream the results when use beam search. + stream = (request.stream + and (request.best_of is None or request.n == request.best_of) + and not request.use_beam_search) + + # Streaming response + if stream: + return completion_stream_generator(request, + raw_request, + self.engine.abort, + result_generator, + self._create_logprobs, + request_id, + created_time, + model_name, + num_prompts=len(prompts)) + + # Non-streaming response + final_res_batch: RequestOutput = [None] * len(prompts) + async for i, res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(f"{request_id}-{i}") + return self.create_error_response("Client disconnected") + final_res_batch[i] = res + response = request_output_to_completion_response( + final_res_batch, request, self._create_logprobs, request_id, + created_time, model_name) + + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + if request.stream: + response_json = response.model_dump_json() + + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + return fake_stream_generator() + + return response diff --git a/aphrodite/endpoints/openai/serving_engine.py b/aphrodite/endpoints/openai/serving_engine.py new file mode 100644 index 000000000..035b1a12a --- /dev/null +++ b/aphrodite/endpoints/openai/serving_engine.py @@ -0,0 +1,158 @@ +import asyncio +from http import HTTPStatus +from typing import Dict, List, Optional, Union +from fastapi import HTTPException + +from aphrodite.common.logger import init_logger +from aphrodite.transformers_utils.tokenizer import get_tokenizer +from aphrodite.engine.async_aphrodite import AsyncAphrodite +from aphrodite.endpoints.openai.protocol import (CompletionRequest, + ChatCompletionRequest, + ErrorResponse, LogProbs, + ModelCard, ModelList, + ModelPermission, Prompt) + +logger = init_logger(__name__) + + +class OpenAIServing: + + def __init__(self, engine: AsyncAphrodite, served_model: str): + self.engine = engine + self.served_model = served_model + + self.max_model_len = 0 + self.tokenizer = None + + try: + event_loop = asyncio.get_running_loop() + except RuntimeError: + event_loop = None + + if event_loop is not None and event_loop.is_running( + ): # If the current is instanced by Ray Serve, there is already a running event loop + event_loop.create_task(self._post_init()) + else: # When using single aphrodite without engine_use_ray + asyncio.run(self._post_init()) + + async def _post_init(self): + engine_model_config = await self.engine.get_model_config() + self.max_model_len = engine_model_config.max_model_len + + # A separate tokenizer to map token IDs to strings. + self.tokenizer = get_tokenizer( + engine_model_config.tokenizer, + tokenizer_mode=engine_model_config.tokenizer_mode, + trust_remote_code=engine_model_config.trust_remote_code) + + async def show_available_models(self) -> ModelList: + """Show available models. Right now we only have one model.""" + model_cards = [ + ModelCard(id=self.served_model, + root=self.served_model, + permission=[ModelPermission()]) + ] + return ModelList(data=model_cards) + + async def tokenize_text(self, prompt: Prompt + ) -> Dict[str, Union[int, List[int]]]: + """Tokenize a prompt using the model's tokenizer. + Returns: + value: The number of tokens in the prompt. + ids: The token IDs of the prompt. + """ + if self.tokenizer is None: + raise HTTPException(status_code=500, detail="Tokenizer is not set.") + + try: + tokenized_prompt = self.tokenizer.tokenize(prompt.prompt) + token_ids = self.tokenizer.convert_tokens_to_ids(tokenized_prompt) + return {"value": len(tokenized_prompt), "ids": token_ids} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + def _create_logprobs( + self, + token_ids: List[int], + top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None, + num_output_top_logprobs: Optional[int] = None, + initial_text_offset: int = 0, + ) -> LogProbs: + """Create OpenAI-style logprobs.""" + logprobs = LogProbs() + last_token_len = 0 + if num_output_top_logprobs: + logprobs.top_logprobs = [] + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is not None: + token_logprob = step_top_logprobs[token_id] + else: + token_logprob = None + token = self.tokenizer.convert_ids_to_tokens(token_id) + logprobs.tokens.append(token) + logprobs.token_logprobs.append(token_logprob) + if len(logprobs.text_offset) == 0: + logprobs.text_offset.append(initial_text_offset) + else: + logprobs.text_offset.append(logprobs.text_offset[-1] + + last_token_len) + last_token_len = len(token) + + if num_output_top_logprobs: + logprobs.top_logprobs.append({ + self.tokenizer.convert_ids_to_tokens(i): p + for i, p in step_top_logprobs.items() + } if step_top_logprobs else None) + + logprobs.top_logprobs = [{ + k: v if v > -1000 else -1000 + for k, v in top_logprob.items() + } for top_logprob in logprobs.top_logprobs if top_logprob is not None] + + return logprobs + + def create_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + return ErrorResponse(message=message, + type=err_type, + code=status_code.value) + + async def _check_model(self, request) -> Optional[ErrorResponse]: + if request.model == self.served_model: + return + return self.create_error_response( + message=f"The model `{request.model}` does not exist.", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND) + + def _validate_prompt_and_tokenize( + self, + request: Union[ChatCompletionRequest, CompletionRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None) -> List[int]: + if not (prompt or prompt_ids): + raise ValueError("Either prompt or prompt_ids should be provided.") + if (prompt and prompt_ids): + raise ValueError( + "Only one of prompt or prompt_ids should be provided.") + + input_ids = prompt_ids if prompt_ids is not None else self.tokenizer( + prompt).input_ids + token_num = len(input_ids) + + if request.max_tokens is None: + request.max_tokens = self.max_model_len - token_num + + if token_num + request.max_tokens > self.max_model_len: + raise ValueError( + f"This model's maximum context length is {self.max_model_len} tokens. " + f"However, you requested {request.max_tokens + token_num} tokens " + f"({token_num} in the messages, " + f"{request.max_tokens} in the completion). " + f"Please reduce the length of the messages or completion.", ) + else: + return input_ids diff --git a/aphrodite/endpoints/openai/templates/tools_function.jinja b/aphrodite/endpoints/openai/templates/tools_function.jinja new file mode 100644 index 000000000..b82be0fed --- /dev/null +++ b/aphrodite/endpoints/openai/templates/tools_function.jinja @@ -0,0 +1,60 @@ +{%- set func_call_token = "!function_call:" -%} {#- The special prefix to functions calls, be aware of extra space or new lines ! -#} + +{%- if CONTEXT == CALL_TOKEN -%} {#- return only the func_call_token value. Needed by the implementation. No data included -#} + {{- func_call_token -}} +{%- endif -%} {#- CONTEXT == CALL_TOKEN -#} + +{%- if CONTEXT == CALLS_NOTIF -%} {#- Format the notification of the function call. Data: tool_calls = ChatCompletionMessageToolCall -#} + {%- for call in tool_calls -%} + {%- if call.function.arguments == None or call.function.arguments|count == 0 -%} + {{- call.id }} was called with no argument + {%- else -%} + {{- call.id }} was called with arguments : {{- call.function.arguments -}} + {%- endif -%} + {%- raw %} +{% endraw -%} + {%- endfor -%} +{%- endif -%} {#- CONTEXT == CALLS_NOTIF -#} + +{%- if CONTEXT == TOOL_RESPONSE -%} {#- Format of the response of the function call. Data: message = ChatCompletionToolMessage -#} + {{- message.content -}} +{%- endif -%} {#- CONTEXT == TOOL_RESPONSE -#} + +{%- if CONTEXT == FORCE_CALL -%} {#- One tool call defined request. Data: tool = ChatCompletionToolParam -#} +You must call the following function at least one time to answer the question. You may call it multiple times if needed: + {%- if tool.function.parameters == None or tool.function.parameters|count == 0 -%} {#- without parameter #} + {'name': "{{tool.function.name}}", 'description': "{{tool.function.description}}", 'arguments': null}, + {%- else -%} {#- with parameters #} + {'name': "{{tool.function.name}}", 'description': "{{tool.function.description}}", 'arguments': { {{tool.function.parameters}} {{ '}}' }}, + {%- endif %} {#- tool.function.parameters #} +{%- endif -%} {#- CONTEXT == FORCE_CALL -#} + +{%- if CONTEXT == FUNCTIONS_LIST -%} {#- Functions list generation Data: tools_list = List[ChatCompletionToolParam] -#} + {%- raw -%}The following is a list of external functions that may be called to complete certain tasks: +[ + {%- endraw -%} + {%- for tool in tools_list -%} + {%- if tool.function.parameters == None or tool.function.parameters|count == 0 -%} {#- without parameter #} + {'name': "{{tool.function.name}}", 'description': "{{tool.function.description}}", 'arguments': null}, + {%- else -%} {#- with parameters #} + {'name': "{{tool.function.name}}", 'description': "{{tool.function.description}}", 'arguments': { {{tool.function.parameters}} {{ '}}' }}, + {% endif -%} {#- tool.function.parameters #} + {%- endfor -%} + {%- raw %} +] +End of list + +* Whenever the user asks you something, you can either respond directly or invoke a function if it is present in the previous list. +* The decision to invoke a function is yours, only invoke a function if it is necessary to answer the user's question +* If you need to call at least one function, your message should contain only a list of function calls and nothing else; the function calls are the response. + {%- endraw %} +{%- endif -%} {#- CONTEXT == FUNCTIONS_LIST -#} + +{%- if CONTEXT == FORCE_CALL or CONTEXT == FUNCTIONS_LIST -%} +To call a function, the message must start by "{{func_call_token}}" followed by a json like this: +* With arguments: +{{func_call_token}}{"call": "function_name", "arguments": {"arg1": "value1"}} +* Without arguments: +{{func_call_token}}{"call": "function_name", "arguments": null} +End of functions instructions +{%- endif -%} {#- CONTEXT == FORCE_CALL or CONTEXT == FUNCTIONS_LIST -#} \ No newline at end of file diff --git a/aphrodite/endpoints/openai/tools.py b/aphrodite/endpoints/openai/tools.py new file mode 100644 index 000000000..8b70cba36 --- /dev/null +++ b/aphrodite/endpoints/openai/tools.py @@ -0,0 +1,202 @@ +import os +import json +import jinja2 +from enum import Enum +from typing import List, Union +from aphrodite.common.logger import init_logger +from .protocol import (ChatCompletionRequest, ChatCompletionToolParam, + ChoiceDeltaToolCall, ChatCompletionMessageToolCall, + Function, ChatCompletionAssistantMessage, + ChatCompletionToolMessage) + +logger = init_logger(__name__) + + +class ToolsCallsTemplateContext(Enum): + """ This is used within the template to generate depending on the context. """ + CALL_TOKEN = 1 + FUNCTIONS_LIST = 2 + FORCE_CALL = 3 + CALLS_NOTIF = 4 + TOOL_RESPONSE = 5 + + +class ToolsCallsTemplate: + + def __init__(self, template_path=None): + self.trim_blocks = True + self.lstrip_blocks = True + if template_path is None: + template_path = os.path.dirname( + __file__) + "/templates/tools_functions.jinja" + self.environment = jinja2.Environment( + loader=jinja2.FileSystemLoader(os.path.dirname(template_path))) + self.template = self.environment.get_template( + os.path.basename(template_path)) + self.template.globals[ + "FUNCTIONS_LIST"] = ToolsCallsTemplateContext.FUNCTIONS_LIST + self.template.globals[ + "FORCE_CALL"] = ToolsCallsTemplateContext.FORCE_CALL + self.template.globals[ + "CALL_TOKEN"] = ToolsCallsTemplateContext.CALL_TOKEN + self.template.globals[ + "CALLS_NOTIF"] = ToolsCallsTemplateContext.CALLS_NOTIF + self.template.globals[ + "TOOL_RESPONSE"] = ToolsCallsTemplateContext.TOOL_RESPONSE + + def get_func_call_token(self) -> str: + """ Return the special token used to find functions calls. """ + return self.template.render( + CONTEXT=ToolsCallsTemplateContext.CALL_TOKEN) + + def render_toolcalls(self, tool_calls: [ChatCompletionMessageToolCall]): + return self.template.render( + CONTEXT=ToolsCallsTemplateContext.CALLS_NOTIF, + tool_calls=tool_calls) + + def render_toolmessage(self, message: ChatCompletionToolMessage): + return self.template.render( + CONTEXT=ToolsCallsTemplateContext.TOOL_RESPONSE, message=message) + + def render_toolslist(self, tool_choice: Union[str, None], + tools_list: [ChatCompletionToolParam]) -> str: + if isinstance(tool_choice, str) and tool_choice == "auto": + tool_choice = None + if tool_choice is not None: + for tool in tools_list: + # Search if the tool_choice is in the tools_list + if tool.type == "function" and tool.function.name == tool_choice: + return self.template.render( + CONTEXT=ToolsCallsTemplateContext.FORCE_CALL, + tool=tool) + return None + else: + return self.template.render( + CONTEXT=ToolsCallsTemplateContext.FUNCTIONS_LIST, + tools_list=tools_list) + + +class OpenAIToolsPrompter: + """ + https://platform.openai.com/docs/assistants/tools + """ + + def __init__(self, template_path=None): + self.template = ToolsCallsTemplate(template_path) + self.call_token_str = self.template.get_func_call_token() + if self.call_token_str is None: + logger.error("There is something wrong with the tools template.") + else: + self.call_token_pre = self.call_token_str[0] + + def func_call_token_pre(self) -> str: + return self.call_token_pre + + def func_call_token_size(self) -> int: + return len(self.call_token_str) + + def func_call_token(self) -> str: + return self.call_token_str + + def content_from_assistant(self, + message: ChatCompletionAssistantMessage) -> str: + text = self.template.render_toolcalls(message.tool_calls) + if message.content is None: + return text + else: + return message.content + "\n" + text + + def content_from_tool(self, message: ChatCompletionToolMessage) -> str: + return self.template.render_toolmessage(message) + + def inject_prompt(self, request: ChatCompletionRequest): + """ Generate and inject the prompt for tools calls. """ + if request.tools is not None and self.call_token_str is not None and len( + request.tools): + select_tool_choice = request.tool_choice if ( + request.tool_choice is not None + and request.tool_choice != "auto") else None + text_inject = self.template.render_toolslist( + tool_choice=select_tool_choice, tools_list=request.tools) + if isinstance(request.messages, str): + request.messages = text_inject + request.messages + elif isinstance(request.messages, + List) and len(request.messages) >= 1: + request.messages[ + 0].content = text_inject + request.messages[0].content + +class ChatPromptCapture: + + def __init__(self): + self.content: str = "" + self.maybe_function_call = False + self.is_function_call = False + self.prefix_size = 0 + self.calls_list: List[{}] = [] + self.after_new_function_call = False + + def reset(self, reset_calls_list=False): + self.content = "" + self.maybe_function_call = False + self.is_function_call = False + self.prefix_size = 0 + if reset_calls_list: + self.calls_list = [] + + def num_calls(self): + return len(self.calls_list) + + def make_calls_list(self, prompter: OpenAIToolsPrompter): + calls_list = self.content.split(prompter.func_call_token()) + for v_call in calls_list: + if len(v_call): + try: + call_dict = json.loads(v_call) + if "name" in call_dict: + self.calls_list.append(call_dict) + except json.decoder.JSONDecodeError: + # Simply ignore invalid functions calls... + pass + + def to_ChatCompletionMessageToolCall( + self, call_id: int) -> Union[ChatCompletionMessageToolCall, None]: + if len(self.calls_list) and call_id < len(self.calls_list): + call = self.calls_list[call_id] + arguments = call.get("arguments") or call.get("parameters") + if arguments is not None: + arguments = json.dumps(arguments) + function_call = Function(name=call["name"], + arguments=json.dumps(arguments) + if arguments is not None else "") + return ChatCompletionMessageToolCall(id="call_" + call["name"] + + "_" + str(call_id), + type="function", + function=function_call) + return None + + def to_ChatCompletionMessageToolCallList( + self) -> [ChatCompletionMessageToolCall]: + calls_count = self.num_calls() + tools_calls_list = [] + for call_id in range(calls_count): + tools_calls_list.append( + self.to_ChatCompletionMessageToolCall(call_id=call_id)) + return tools_calls_list + + def to_ChoiceDeltaToolCall( + self, call_id: int) -> Union[ChoiceDeltaToolCall, None]: + mesg = self.to_ChatCompletionMessageToolCall(call_id) + if mesg is not None: + return ChoiceDeltaToolCall(index=call_id, + id=mesg.id, + type=mesg.type, + function=mesg.function) + return None + + def to_ChoiceDeltaToolCallList(self): + calls_count = self.num_calls() + tools_calls_list = [] + for call_id in range(calls_count): + tools_calls_list.append( + self.to_ChoiceDeltaToolCall(call_id=call_id)) + return tools_calls_list diff --git a/examples/function_call.py b/examples/function_call.py new file mode 100644 index 000000000..0bdf64856 --- /dev/null +++ b/examples/function_call.py @@ -0,0 +1,155 @@ +""" +Inspired by the OpenAI example found here: + https://platform.openai.com/docs/guides/function-calling/parallel-function-calling +""" + +import datetime +from openai import OpenAI +import json + +client = OpenAI(api_key="EMPTY", base_url="http://localhost:2242/v1") +models = client.models.list() +model = models.data[0].id +stream = True + + +def get_current_date_utc(): + print("Calling get_current_date_utc client side.") + return datetime.datetime.now(datetime.timezone.utc).strftime( + "The current UTC datetime is (day: %A, date (day/month/year): %d/%m/%Y, time: %H:%M)." + ) + + +# Example dummy function hard coded to return the same weather +# In production, this could be your backend API or an external API +def get_current_weather(location, unit="fahrenheit"): + """Get the current weather in a given location""" + print("Calling get_current_weather client side.") + if "tokyo" in location.lower(): + return json.dumps({ + "location": "Tokyo", + "temperature": "10", + "unit": unit + }) + elif "san francisco" in location.lower(): + return json.dumps({ + "location": "San Francisco", + "temperature": "72", + "unit": unit + }) + elif "paris" in location.lower(): + return json.dumps({ + "location": "Paris", + "temperature": "22", + "unit": unit + }) + else: + return json.dumps({"location": location, "temperature": "unknown"}) + + +def run_conversation(): + # Step 1: send the conversation and available functions to the model + # messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] + messages = [{ + "role": + "user", + "content": + "What's the weather like in San Francisco, Tokyo, and Paris ? We also need to know the current date." + }] + tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": + "string", + "description": + "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + }, + }, + "required": ["location"], + }, + }, + }, { + "type": "function", + "function": { + "name": "get_current_date_utc", + "description": "Get the current UTC time", + }, + }] + response = client.chat.completions.create( + model=model, + messages=messages, + tools=tools, + stream=stream, + tool_choice="auto", # auto is default, but we'll be explicit + ) + response_message = "" + tool_calls = None + if stream: + text_message = "" + for chunk in response: + if chunk.choices[0].finish_reason is not None: + if chunk.choices[0].finish_reason == "tool_calls": + tool_calls = chunk.choices[0].delta.tool_calls + break + if chunk.choices[0].delta.content is not None: + text_message += chunk.choices[0].delta.content + response_message = {"role": "assistant", "content": text_message} + else: + if not len(response.choices): + return None + response_message = response.choices[0].message + # print(str(response_message)) + tool_calls = response_message.tool_calls + + # Step 2: check if the model wanted to call a function + if tool_calls: + # Step 3: call the function + # Note: the JSON response may not always be valid; be sure to handle errors + available_functions = { + "get_current_weather": get_current_weather, + "get_current_date_utc": get_current_date_utc, + } + messages.append( + response_message) # extend conversation with assistant's reply + # Step 4: send the info for each function call and function response to the model + for tool_call in tool_calls: + function_name = tool_call.function.name + function_to_call = available_functions[function_name] + if function_name == "get_current_weather": + function_args = json.loads(tool_call.function.arguments) + function_response = function_to_call( + location=function_args.get("location"), + unit=function_args.get("unit"), + ) + else: + function_response = function_to_call() + + messages.append({ + "tool_call_id": tool_call.id, + "role": "tool", + "name": function_name, + "content": function_response, + }) # extend conversation with function response + second_response = client.chat.completions.create( + model=model, + messages=messages, + ) # get a new response from the model where it can see the function response + + for it_msg, msg in enumerate(messages): + print("Message %i:\n %s\n" % (it_msg, str(msg))) + + return second_response + + +result = run_conversation() +print("Final response:\n%s" % result) diff --git a/requirements-rocm.txt b/requirements-rocm.txt index c62b79f18..616bc2a95 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -12,4 +12,4 @@ einops # Required for phi-1_5 transformers >= 4.34.0 # Required for Mistral. fastapi uvicorn[standard] -pydantic == 1.10.13 # Required for OpenAI server. \ No newline at end of file +pydantic >= 2.0 # Required for OpenAI server. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b7262f323..47065f90e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,8 +9,7 @@ uvicorn openai # for fastapi's openai proxy emulation xformers >= 0.0.24 einops # Required for phi-1_5 -fschat >= 0.2.23 -pydantic == 1.10.13 +pydantic >= 2.0 fastapi colorlog einops # for phi