From 6939658c96aa6ebdabf21ee287227b6e5b56e539 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Wed, 18 Jun 2025 13:52:42 -0400 Subject: [PATCH 1/5] Use a configurable template for OpenAI request --- .pre-commit-config.yaml | 1 + pyproject.toml | 1 + src/guidellm/backend/openai.py | 31 ++++++++++--------------------- src/guidellm/config.py | 6 ++++++ 4 files changed, 18 insertions(+), 21 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 61b765a2..7d2d00ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,6 +34,7 @@ repos: setuptools, setuptools-git-versioning, transformers, + jinja2, # dev dependencies pytest, diff --git a/pyproject.toml b/pyproject.toml index a78b1fc5..8310f9ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ dependencies = [ "pyyaml>=6.0.0", "rich", "transformers", + "jinja2>=3.1.6", ] [project.optional-dependencies] diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index e3f23963..08043929 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -6,6 +6,7 @@ from typing import Any, Literal, Optional, Union import httpx +import jinja2 from loguru import logger from PIL import Image @@ -123,6 +124,8 @@ def __init__( self.extra_query = extra_query self.extra_body = extra_body self._async_client: Optional[httpx.AsyncClient] = None + j2_env = jinja2.Environment(loader=jinja2.BaseLoader(), autoescape=True) + self.request_template = j2_env.from_string(settings.openai.request_template) @property def target(self) -> str: @@ -422,29 +425,15 @@ def _completions_payload( max_output_tokens: Optional[int], **kwargs, ) -> dict: - payload = body or {} + payload = json.loads( + self.request_template.render( + model=self.model, + output_tokens=(max_output_tokens or self.max_output_tokens), + ) + ) + payload.update(body or {}) payload.update(orig_kwargs or {}) payload.update(kwargs) - payload["model"] = self.model - payload["stream"] = True - payload["stream_options"] = { - "include_usage": True, - } - - if max_output_tokens or self.max_output_tokens: - logger.debug( - "{} adding payload args for setting output_token_count: {}", - self.__class__.__name__, - max_output_tokens or self.max_output_tokens, - ) - payload["max_tokens"] = max_output_tokens or self.max_output_tokens - payload["max_completion_tokens"] = payload["max_tokens"] - - if max_output_tokens: - # only set stop and ignore_eos if max_output_tokens set at request level - # otherwise the instance value is just the max to enforce we stay below - payload["stop"] = None - payload["ignore_eos"] = True return payload diff --git a/src/guidellm/config.py b/src/guidellm/config.py index ed7e782b..e744281e 100644 --- a/src/guidellm/config.py +++ b/src/guidellm/config.py @@ -85,6 +85,12 @@ class OpenAISettings(BaseModel): project: Optional[str] = None base_url: str = "http://localhost:8000" max_output_tokens: int = 16384 + request_template: str = ( + '{"model": "{{ model }}", {% if output_tokens %} ' + '"max_tokens": {{ output_tokens }}, "stop": null, ' + '"ignore_eos": true, {% endif %} ' + '"stream": true, "stream_options": {"include_usage": true}}' + ) class Settings(BaseSettings): From d927b3c86a33861ef7d2239d816f85a26028c7ee Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Wed, 18 Jun 2025 14:37:19 -0400 Subject: [PATCH 2/5] Create j2 env at runtime to avoid pickle error --- src/guidellm/backend/openai.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index 08043929..f5121800 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -124,8 +124,7 @@ def __init__( self.extra_query = extra_query self.extra_body = extra_body self._async_client: Optional[httpx.AsyncClient] = None - j2_env = jinja2.Environment(loader=jinja2.BaseLoader(), autoescape=True) - self.request_template = j2_env.from_string(settings.openai.request_template) + self.request_template = settings.openai.request_template @property def target(self) -> str: @@ -425,8 +424,10 @@ def _completions_payload( max_output_tokens: Optional[int], **kwargs, ) -> dict: + j2_env = jinja2.Environment(loader=jinja2.BaseLoader(), autoescape=True) + request_template = j2_env.from_string(self.request_template) payload = json.loads( - self.request_template.render( + request_template.render( model=self.model, output_tokens=(max_output_tokens or self.max_output_tokens), ) From 7e27784565f8fb30b272b15c0287da3eb4aa9d83 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Wed, 18 Jun 2025 15:12:41 -0400 Subject: [PATCH 3/5] Cache j2 request templete --- src/guidellm/backend/openai.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index f5121800..d341b43f 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -2,6 +2,7 @@ import json import time from collections.abc import AsyncGenerator +from functools import cached_property from pathlib import Path from typing import Any, Literal, Optional, Union @@ -124,7 +125,12 @@ def __init__( self.extra_query = extra_query self.extra_body = extra_body self._async_client: Optional[httpx.AsyncClient] = None - self.request_template = settings.openai.request_template + self._request_template_str = settings.openai.request_template + + @cached_property + def request_template(self) -> jinja2.Template: + j2_env = jinja2.Environment(loader=jinja2.BaseLoader(), autoescape=True) + return j2_env.from_string(self._request_template_str) @property def target(self) -> str: @@ -424,10 +430,8 @@ def _completions_payload( max_output_tokens: Optional[int], **kwargs, ) -> dict: - j2_env = jinja2.Environment(loader=jinja2.BaseLoader(), autoescape=True) - request_template = j2_env.from_string(self.request_template) payload = json.loads( - request_template.render( + self.request_template.render( model=self.model, output_tokens=(max_output_tokens or self.max_output_tokens), ) From b579c6247909ed8816fe33ddd1eb20658af7befd Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Wed, 18 Jun 2025 15:27:36 -0400 Subject: [PATCH 4/5] Clear cache before pickle --- src/guidellm/backend/openai.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index d341b43f..23e7c1c3 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -127,6 +127,13 @@ def __init__( self._async_client: Optional[httpx.AsyncClient] = None self._request_template_str = settings.openai.request_template + def __getstate__(self) -> object: + state = self.__dict__.copy() + # Templates are not serializable + # so we delete it before pickling + state.pop("request_template", None) + return state + @cached_property def request_template(self) -> jinja2.Template: j2_env = jinja2.Environment(loader=jinja2.BaseLoader(), autoescape=True) @@ -163,6 +170,7 @@ def info(self) -> dict[str, Any]: "project": self.project, "text_completions_path": TEXT_COMPLETIONS_PATH, "chat_completions_path": CHAT_COMPLETIONS_PATH, + "request_template": self._request_template_str, } async def check_setup(self): From a671d18e87540405cc6c3ee96d96530605267b34 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Wed, 18 Jun 2025 16:16:08 -0400 Subject: [PATCH 5/5] Switch to j2 sandbox --- src/guidellm/backend/openai.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index 23e7c1c3..8ac1e64d 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -8,6 +8,7 @@ import httpx import jinja2 +from jinja2.sandbox import ImmutableSandboxedEnvironment from loguru import logger from PIL import Image @@ -136,7 +137,17 @@ def __getstate__(self) -> object: @cached_property def request_template(self) -> jinja2.Template: - j2_env = jinja2.Environment(loader=jinja2.BaseLoader(), autoescape=True) + # Thanks to HuggingFace Tokenizers for this implementation + def tojson(x, ensure_ascii=False): + # We override the built-in tojson filter because Jinja's + # default filter escapes HTML characters + return json.dumps(x, ensure_ascii=ensure_ascii) + + j2_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) + + # Define custom filter functions + j2_env.filters["tojson"] = tojson + return j2_env.from_string(self._request_template_str) @property