Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Client Added for Predibase #40

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/llmperf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
)
from llmperf.ray_clients.sagemaker_client import SageMakerClient
from llmperf.ray_clients.vertexai_client import VertexAIClient
from llmperf.ray_clients.predibase_client import PrediBaseClient
from llmperf.ray_llm_client import LLMClient


SUPPORTED_APIS = ["openai", "anthropic", "litellm"]
SUPPORTED_APIS = ["openai", "anthropic", "litellm", "predibase"]


def construct_clients(llm_api: str, num_clients: int) -> List[LLMClient]:
Expand All @@ -28,6 +29,8 @@ def construct_clients(llm_api: str, num_clients: int) -> List[LLMClient]:
clients = [SageMakerClient.remote() for _ in range(num_clients)]
elif llm_api == "vertexai":
clients = [VertexAIClient.remote() for _ in range(num_clients)]
elif llm_api == "predibase":
clients = [PrediBaseClient.remote() for _ in range(num_clients)]
elif llm_api in SUPPORTED_APIS:
clients = [LiteLLMClient.remote() for _ in range(num_clients)]
else:
Expand Down
127 changes: 127 additions & 0 deletions src/llmperf/ray_clients/predibase_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import io
import json
import os
import time
from typing import Any, Dict

import ray
import requests

from llmperf.ray_llm_client import LLMClient
from llmperf.models import RequestConfig
from llmperf import common_metrics



@ray.remote
class PrediBaseClient(LLMClient):

def llm_request(self, request_config: RequestConfig) -> Dict[str, Any]:
prompt = request_config.prompt
prompt, prompt_len = prompt

if not request_config.sampling_params:
raise ValueError("Set sampling_params to set the parameters in request body.")
else:
request_config.sampling_params['max_new_tokens'] = request_config.sampling_params.pop('max_tokens')

body = {
"inputs": prompt,
"parameters": request_config.sampling_params
}

time_to_next_token = []
tokens_received = 0
ttft = 0
error_response_code = -1
generated_text = ""
error_msg = ""
output_throughput = 0
total_request_time = 0

metrics = {}

metrics[common_metrics.ERROR_CODE] = None
metrics[common_metrics.ERROR_MSG] = ""

start_time = time.monotonic()
most_recent_received_token_time = time.monotonic()

address = os.environ.get("PREDIBASE_API_BASE")
key = os.environ.get("PREDIBASE_API_KEY")


if not address:
raise ValueError("the environment variable PREDIBASE_API_BASE must be set.")

headers = {'Content-Type': 'application/json'}
if not key:
print(f"Warning: PREDIBASE_API_KEY is not set.")
else:
headers["Authorization"] = f"Bearer {key}"

if not address:
raise ValueError("No host provided.")
if not address.endswith("/"):
address = address + "/"
address += "generate_stream"

try:
with requests.post(
address,
json=body,
stream=True,
timeout=180,
headers=headers,
) as response:
if response.status_code != 200:
error_msg = response.text
error_response_code = response.status_code
response.raise_for_status()

for chunk in response.iter_lines(chunk_size=None):
chunk = chunk.strip()

if not chunk:
continue
stem = "data:"
chunk = chunk[len(stem) :]
if chunk == b"[DONE]":
continue
tokens_received += 1
data = json.loads(chunk)
if "error" in data:
error_msg = data["error"]
raise RuntimeError(error_msg)

delta = data["token"]
if delta.get("text", None):
if not ttft:
ttft = time.monotonic() - start_time
time_to_next_token.append(ttft)
else:
time_to_next_token.append(
time.monotonic() - most_recent_received_token_time
)
most_recent_received_token_time = time.monotonic()
generated_text += delta["text"]

total_request_time = time.monotonic() - start_time
output_throughput = tokens_received / total_request_time

except Exception as e:
metrics[common_metrics.ERROR_MSG] = error_msg
metrics[common_metrics.ERROR_CODE] = error_response_code
print(f"Warning Or Error: {e}")
print(error_response_code)

metrics[common_metrics.INTER_TOKEN_LAT] = sum(time_to_next_token) #This should be same as metrics[common_metrics.E2E_LAT]. Leave it here for now
metrics[common_metrics.TTFT] = ttft
metrics[common_metrics.E2E_LAT] = total_request_time
metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = output_throughput
metrics[common_metrics.NUM_TOTAL_TOKENS] = tokens_received + prompt_len
metrics[common_metrics.NUM_OUTPUT_TOKENS] = tokens_received
metrics[common_metrics.NUM_INPUT_TOKENS] = prompt_len

return metrics, generated_text, request_config