Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added azureai client #36

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,28 @@ python llm_correctness.py \

```


### AzureAI Compatible APIs
```bash
export AZUREAI_API_KEY=secret_abcdefg
export AZUREAI_API_BASE="https://api.endpoints.ai.azure.com/v1"

python token_benchmark_ray.py \
--model "Llama-2-70b-chat" \
--mean-input-tokens 550 \
--stddev-input-tokens 150 \
--mean-output-tokens 150 \
--stddev-output-tokens 10 \
--max-num-completed-requests 2 \
--timeout 600 \
--num-concurrent-requests 1 \
--results-dir "result_outputs" \
--llm-api azureai \
--additional-sampling-params '{}'

```


see `python token_benchmark_ray.py --help` for more details on the arguments.

## Correctness Test
Expand Down Expand Up @@ -338,6 +360,7 @@ python llm_correctness.py \

```


## Saving Results

The results of the load test and correctness test are saved in the results directory specified by the `--results-dir` argument. The results are saved in 2 files, one with the summary metrics of the test, and one with metrics from each individual request that is returned.
Expand Down
3 changes: 3 additions & 0 deletions src/llmperf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
)
from llmperf.ray_clients.sagemaker_client import SageMakerClient
from llmperf.ray_clients.vertexai_client import VertexAIClient
from llmperf.ray_clients.azureai_chat_completion import AzureAIChatCompletionsClient
from llmperf.ray_llm_client import LLMClient


Expand All @@ -28,6 +29,8 @@ def construct_clients(llm_api: str, num_clients: int) -> List[LLMClient]:
clients = [SageMakerClient.remote() for _ in range(num_clients)]
elif llm_api == "vertexai":
clients = [VertexAIClient.remote() for _ in range(num_clients)]
elif llm_api == "azureai":
clients = [AzureAIChatCompletionsClient.remote() for _ in range(num_clients)]
elif llm_api in SUPPORTED_APIS:
clients = [LiteLLMClient.remote() for _ in range(num_clients)]
else:
Expand Down
119 changes: 119 additions & 0 deletions src/llmperf/ray_clients/azureai_chat_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import json
import os
import time
from typing import Any, Dict

import ray
import requests

from llmperf.ray_llm_client import LLMClient
from llmperf.models import RequestConfig
from llmperf import common_metrics

@ray.remote
class AzureAIChatCompletionsClient(LLMClient):
"""Client for AzureAI Chat Completions API."""

def llm_request(self, request_config: RequestConfig) -> Dict[str, Any]:
prompt = request_config.prompt
prompt, prompt_len = prompt

message = [
{"role": "system", "content": ""},
{"role": "user", "content": prompt},
]
model = request_config.model
body = {
"model": model,
"messages": message,
"stream": True,
}
sampling_params = request_config.sampling_params
body.update(sampling_params or {})
time_to_next_token = []
tokens_received = 0
ttft = 0
error_response_code = -1
generated_text = ""
error_msg = ""
output_throughput = 0
total_request_time = 0

metrics = {}

metrics[common_metrics.ERROR_CODE] = None
metrics[common_metrics.ERROR_MSG] = ""

start_time = time.monotonic()
most_recent_received_token_time = time.monotonic()
address = os.environ.get("AZUREAI_API_BASE")
if not address:
raise ValueError("the environment variable AZUREAI_API_BASE must be set.")
key = os.environ.get("AZUREAI_API_KEY")
if not key:
raise ValueError("the environment variable AZUREAI_API_KEY must be set.")
headers = {"Authorization": f"Bearer {key}"}
if not address:
raise ValueError("No host provided.")
if not address.endswith("/"):
address = address + "/"
address += "chat/completions"
try:
with requests.post(
address,
json=body,
stream=True,
timeout=180,
headers=headers,
) as response:
if response.status_code != 200:
error_msg = response.text
error_response_code = response.status_code
response.raise_for_status()
for chunk in response.iter_lines(chunk_size=None):
chunk = chunk.strip()

if not chunk:
continue
stem = "data: "
chunk = chunk[len(stem) :]
if chunk == b"[DONE]":
continue
tokens_received += 1
data = json.loads(chunk)

if "error" in data:
error_msg = data["error"]["message"]
error_response_code = data["error"]["code"]
raise RuntimeError(data["error"]["message"])

delta = data["choices"][0]["delta"]
if delta.get("content", None):
if not ttft:
ttft = time.monotonic() - start_time
time_to_next_token.append(ttft)
else:
time_to_next_token.append(
time.monotonic() - most_recent_received_token_time
)
most_recent_received_token_time = time.monotonic()
generated_text += delta["content"]

total_request_time = time.monotonic() - start_time
output_throughput = tokens_received / total_request_time

except Exception as e:
metrics[common_metrics.ERROR_MSG] = error_msg
metrics[common_metrics.ERROR_CODE] = error_response_code
print(f"Warning Or Error: {e}")
print(error_response_code)

metrics[common_metrics.INTER_TOKEN_LAT] = sum(time_to_next_token) #This should be same as metrics[common_metrics.E2E_LAT]. Leave it here for now
metrics[common_metrics.TTFT] = ttft
metrics[common_metrics.E2E_LAT] = total_request_time
metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = output_throughput
metrics[common_metrics.NUM_TOTAL_TOKENS] = tokens_received + prompt_len
metrics[common_metrics.NUM_OUTPUT_TOKENS] = tokens_received
metrics[common_metrics.NUM_INPUT_TOKENS] = prompt_len

return metrics, generated_text, request_config