diff --git a/docs/observability-prometheus-metrics-in-jetstream-server.md b/docs/observability-prometheus-metrics-in-jetstream-server.md index 04d7be4c..464bc8f3 100644 --- a/docs/observability-prometheus-metrics-in-jetstream-server.md +++ b/docs/observability-prometheus-metrics-in-jetstream-server.md @@ -50,6 +50,22 @@ jetstream_prefill_backlog_size{id="SOME-HOSTNAME-HERE>"} 0.0 jetstream_slots_used_percentage{id="",idx="0"} 0.04166666666666663 ``` +Currently the following metrics are supported: + - `jetstream_prefill_backlog_size`: Size of prefill queue + - `jetstream_transfer_backlog_size`: Size of transfer queue + - `jetstream_generate_backlog_size`: Size of generate queue + - `jetstream_queue_duration`: The total time each request spends enqueued in seconds + - `jetstream_slots_used_percentage`: The percentage of decode slots currently being used + - `jetstream_server_startup_latency`: Total time taken to start the Jetstream server + - `jetstream_request_input_length`: Number of input tokens per request + - `jetstream_request_output_length`: Number of output tokens per request + - `jetstream_request_success_count`: Number of requests successfully completed + - `jetstream_time_to_first_token`: Time to first token per request in seconds + - `jetstream_time_per_output_token`: Average time per output token per request in seconds + - `jetstream_time_per_prefill_token`: Prefill time per token per request in seconds + - `jetstream_time_per_request`: End to end request latency in seconds + - `jetstream_wait_time_per_request`: Time each request is not being prefilled or decoded + ## Observe metrics on GKE clusters The following applies only for Jetstream deployed on a GKE cluster. Currently [Google Cloud Managed Service for Prometheus](https://cloud.google.com/stackdriver/docs/managed-prometheus) is enabled by default on all GKE clusters, it determines scrape targets via the [PodMonitoring](https://github.com/GoogleCloudPlatform/prometheus-engine/blob/v0.10.0/doc/api.md#podmonitoring) custom resource. After you deployed the JetStream GKE workload, you need to apply the PodMonitoring resource to your cluster as follows: diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index dc8a00e9..23694836 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -19,239 +19,207 @@ from prometheus_client import Counter, Gauge, Histogram from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS - -class JetstreamMetricsCollector: - """Wrapper class should be used to assure all metrics have proper tags""" - - _id: str = os.getenv("HOSTNAME", shortuuid.uuid()) - - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(JetstreamMetricsCollector, cls).__new__(cls) - return cls.instance - - # Metric definitions - _prefill_backlog = Gauge( - name="jetstream_prefill_backlog_size", - documentation="Size of prefill queue", - labelnames=["id"], - ) - - _transfer_backlog = Gauge( - name="jetstream_transfer_backlog_size", - documentation="Size of transfer queue", - labelnames=["id", "idx"], - ) - - _generate_backlog = Gauge( - name="jetstream_generate_backlog_size", - documentation="Size of generate queue", - labelnames=["id", "idx"], - ) - - _queue_duration = Histogram( - name="jetstream_queue_duration", - documentation="The total time each request spends enqueued in seconds", - labelnames=["id"], - buckets=[ - 0.01, - 0.02, - 0.05, - 0.1, - 0.2, - 0.5, - 1.0, - 2.0, - 5.0, - 10.0, - 20.0, - 50.0, - 100.0, - ], - ) - - _slots_used_percentage = Gauge( - name="jetstream_slots_used_percentage", - documentation="The percentage of decode slots currently being used", - labelnames=["id", "idx"], - ) - - _server_startup_latency = Gauge( - name="jetstream_server_startup_latency", - documentation="Total time taken to start the Jetstream server", - labelnames=["id"], - ) - _request_input_length = Histogram( - name="jetstream_request_input_length", - documentation="Number of input tokens per request", - labelnames=["id"], - buckets=DEFAULT_PREFILL_BUCKETS, - ) - _request_output_length = Histogram( - name="jetstream_request_output_length", - documentation="Number of output tokens per request", - labelnames=["id"], - buckets=[ - 1, - 2, - 5, - 10, - 20, - 50, - 100, - 200, - 500, - 1000, - 2000, - 5000, - 10000, - 20000, - 50000, - 100000, - 200000, - 500000, - 1000000, - 2000000, - ], - ) - _request_success_count = Counter( - name="jetstream_request_success_count", - documentation="Number of requests successfully completed", - labelnames=["id"], - ) - - _time_to_first_token = Histogram( - name="jetstream_time_to_first_token", - documentation="Time to first token per request in seconds", - labelnames=["id"], - buckets=[ - 0.001, - 0.005, - 0.01, - 0.02, - 0.04, - 0.06, - 0.08, - 0.1, - 0.25, - 0.5, - 0.75, - 1.0, - 2.5, - 5.0, - 7.5, - 10.0, - ], - ) - - _time_per_output_token = Histogram( - name="jetstream_time_per_output_token", - documentation="Average time per output token per request in seconds", - labelnames=["id"], - buckets=[ - 0.01, - 0.025, - 0.05, - 0.075, - 0.1, - 0.15, - 0.2, - 0.3, - 0.4, - 0.5, - 0.75, - 1.0, - 2.5, - ], - ) - - _time_per_prefill_token = Histogram( - name="jetstream_time_per_prefill_token", - documentation="Prefill time per token per request in seconds", - labelnames=["id"], - buckets=[ - 0.00001, - 0.00002, - 0.00005, - 0.0001, - 0.0002, - 0.0005, - 0.001, - 0.002, - 0.005, - 0.01, - 0.02, - 0.05, - 0.1, - ], - ) - - _time_per_request = Histogram( - name="jetstream_time_per_request", - documentation="End to end request latency in seconds", - labelnames=["id"], - buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0], - ) - - _wait_time_per_request = Histogram( - name="jetstream_wait_time_per_request", - documentation="Time each request is not being prefilled or decoded", - labelnames=["id"], - buckets=[ - 0.01, - 0.02, - 0.05, - 0.1, - 0.2, - 0.5, - 1.0, - 2.0, - 5.0, - 10.0, - 20.0, - 50.0, - 100.0, - ], - ) - - def get_prefill_backlog_metric(self): - return self._prefill_backlog.labels(id=self._id) - - def get_transfer_backlog_metric(self, idx: int): - return self._transfer_backlog.labels(id=self._id, idx=idx) - - def get_generate_backlog_metric(self, idx: int): - return self._generate_backlog.labels(id=self._id, idx=idx) - - def get_queue_duration(self): - return self._queue_duration.labels(id=self._id) - - def get_slots_used_percentage_metric(self, idx: int): - return self._slots_used_percentage.labels(id=self._id, idx=idx) - - def get_server_startup_latency_metric(self): - return self._server_startup_latency.labels(id=self._id) - - def get_time_to_first_token(self): - return self._time_to_first_token.labels(id=self._id) - - def get_time_per_output_token(self): - return self._time_per_output_token.labels(id=self._id) - - def get_time_per_prefill_token(self): - return self._time_per_prefill_token.labels(id=self._id) - - def get_time_per_request(self): - return self._time_per_request.labels(id=self._id) - - def get_wait_time_per_request(self): - return self._wait_time_per_request.labels(id=self._id) - - def get_request_input_length(self): - return self._request_input_length.labels(id=self._id) - - def get_request_output_length(self): - return self._request_output_length.labels(id=self._id) - - def get_request_success_count_metric(self): - return self._request_success_count.labels(id=self._id) +# Initialize the unique ID for labeling metrics +_id = os.getenv("HOSTNAME", shortuuid.uuid()) + +# Registry for storing metric objects +_metrics_registry = { + "jetstream_prefill_backlog_size": Gauge( + name="jetstream_prefill_backlog_size", + documentation="Size of prefill queue", + labelnames=["id"], + ), + "jetstream_transfer_backlog_size": Gauge( + name="jetstream_transfer_backlog_size", + documentation="Size of transfer queue", + labelnames=["id", "idx"], + ), + "jetstream_generate_backlog_size": Gauge( + name="jetstream_generate_backlog_size", + documentation="Size of generate queue", + labelnames=["id", "idx"], + ), + "jetstream_queue_duration": Histogram( + name="jetstream_queue_duration", + documentation="The total time each request spends enqueued in seconds", + labelnames=["id"], + buckets=[ + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1.0, + 2.0, + 5.0, + 10.0, + 20.0, + 50.0, + 100.0, + ], + ), + "jetstream_slots_used_percentage": Gauge( + name="jetstream_slots_used_percentage", + documentation="The percentage of decode slots currently being used", + labelnames=["id", "idx"], + ), + "jetstream_server_startup_latency": Gauge( + name="jetstream_server_startup_latency", + documentation="Total time taken to start the Jetstream server", + labelnames=["id"], + ), + "jetstream_request_input_length": Histogram( + name="jetstream_request_input_length", + documentation="Number of input tokens per request", + labelnames=["id"], + buckets=DEFAULT_PREFILL_BUCKETS, + ), + "jetstream_request_output_length": Histogram( + name="jetstream_request_output_length", + documentation="Number of output tokens per request", + labelnames=["id"], + buckets=[ + 1, + 2, + 5, + 10, + 20, + 50, + 100, + 200, + 500, + 1000, + 2000, + 5000, + 10000, + 20000, + 50000, + 100000, + 200000, + 500000, + 1000000, + 2000000, + ], + ), + "jetstream_request_success_count": Counter( + name="jetstream_request_success_count", + documentation="Number of requests successfully completed", + labelnames=["id"], + ), + "jetstream_time_to_first_token": Histogram( + name="jetstream_time_to_first_token", + documentation="Time to first token per request in seconds", + labelnames=["id"], + buckets=[ + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + ], + ), + "jetstream_time_per_output_token": Histogram( + name="jetstream_time_per_output_token", + documentation="Average time per output token per request in seconds", + labelnames=["id"], + buckets=[ + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + ], + ), + "jetstream_time_per_prefill_token": Histogram( + name="jetstream_time_per_prefill_token", + documentation="Prefill time per token per request in seconds", + labelnames=["id"], + buckets=[ + 0.00001, + 0.00002, + 0.00005, + 0.0001, + 0.0002, + 0.0005, + 0.001, + 0.002, + 0.005, + 0.01, + 0.02, + 0.05, + 0.1, + ], + ), + "jetstream_time_per_request": Histogram( + name="jetstream_time_per_request", + documentation="End to end request latency in seconds", + labelnames=["id"], + buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0], + ), + "jetstream_wait_time_per_request": Histogram( + name="jetstream_wait_time_per_request", + documentation="Time each request is not being prefilled or decoded", + labelnames=["id"], + buckets=[ + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1.0, + 2.0, + 5.0, + 10.0, + 20.0, + 50.0, + 100.0, + ], + ), + "jetstream_total_tokens_in_current_batch": Gauge( + name="jetstream_total_tokens_in_current_batch", + documentation="Total number of tokens in the decode batch", + labelnames=["id", "idx"], + ), +} + + +# Function to retrieve a metric with specified labels +def get_metric(metric_name, **labels): + if metric_name not in _metrics_registry: + raise ValueError(f"Metric {metric_name} not found in registry.") + + metric = _metrics_registry[metric_name] + + # Automatically add the 'id' label if it's required by the metric + if "id" in metric._labelnames: # pylint: disable=protected-access + labels["id"] = _id + + # Check for any missing labels + missing_labels = set(metric._labelnames) - labels.keys() # pylint: disable=protected-access + if missing_labels: + raise ValueError( + f"Missing labels for metric {metric_name}: {', '.join(missing_labels)}" + ) + + return metric.labels(**labels) diff --git a/jetstream/core/metrics/utils.py b/jetstream/core/metrics/utils.py new file mode 100644 index 00000000..2aa7f4b4 --- /dev/null +++ b/jetstream/core/metrics/utils.py @@ -0,0 +1,54 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains helper functions for configuring Jetstream server metrics""" + + +def get_time_per_prefill_token(request, true_length: int): + return ( + request.metadata.transfer_enqueue_time + - request.metadata.prefill_dequeue_time + ) / true_length + + +def get_queue_duration(request): + return ( + # Time in prefill queue + request.metadata.prefill_dequeue_time + - request.metadata.prefill_enqueue_time + # Time in transfer queue + + request.metadata.transfer_dequeue_time + - request.metadata.transfer_enqueue_time + # Time in generate queue + + request.metadata.generate_dequeue_time + - request.metadata.generate_enqueue_time + ) + + +def get_tpot(request, result_tokens, slot): + return ( + request.metadata.complete_time - request.metadata.transfer_enqueue_time + ) / result_tokens.get_result_at_slot(slot).lengths + + +def get_wait_time(request): + total_time = request.metadata.complete_time - request.metadata.start_time + prefill_time = ( + request.metadata.transfer_enqueue_time + - request.metadata.prefill_dequeue_time + ) + generate_time = ( + request.metadata.complete_time - request.metadata.generate_dequeue_time + ) + return total_time - prefill_time - generate_time diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index cefabd05..6791fd67 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -89,12 +89,13 @@ import grpc import jax +from jetstream.core.metrics.utils import get_queue_duration, get_time_per_prefill_token, get_tpot, get_wait_time from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc from jetstream.core.utils import async_multifuture from jetstream.core.utils.return_sample import ReturnSample from jetstream.engine import engine_api, tokenizer_api, token_utils -from jetstream.core.metrics.prometheus import JetstreamMetricsCollector +from jetstream.core.metrics.prometheus import get_metric import numpy as np root = logging.getLogger() @@ -222,9 +223,6 @@ class Driver: # todo: remove jax_padding after all then engine migrate to np padding _jax_padding = True - # All metrics we want to monitor should be collected with this - _metrics_collector: JetstreamMetricsCollector | None = None - def __init__( self, prefill_engines: Optional[list[engine_api.Engine]] = None, @@ -233,7 +231,6 @@ def __init__( generate_params: Optional[list[Any]] = None, interleaved_mode: bool = False, jax_padding: bool = True, - metrics_collector: JetstreamMetricsCollector | None = None, is_ray_backend: bool = False, ): if prefill_engines is None: @@ -255,16 +252,14 @@ def __init__( self._prefill_params = prefill_params self._generate_params = generate_params self._interleaved_mode = interleaved_mode - self._metrics_collector = metrics_collector # Stages 1-4 represent the life cycle of a request. # Stage 1 # At first, a request is placed here in order to get prefilled. self._prefill_backlog = queue.Queue() - if self._metrics_collector: - self._metrics_collector.get_prefill_backlog_metric().set_function( - lambda: float(self._prefill_backlog.qsize()) - ) + get_metric("jetstream_prefill_backlog_size").set_function( + lambda: float(self._prefill_backlog.qsize()) + ) # Stage 2 # After prefilling, it is placed here in order to get transferred to @@ -278,11 +273,10 @@ def __init__( queue.Queue(1 if self._interleaved_mode else 4) for i in range(len(self._prefill_engines)) ] - if self._metrics_collector: - for idx, backlog in enumerate(self._transfer_backlogs): - self._metrics_collector.get_transfer_backlog_metric(idx).set_function( - functools.partial(float, backlog.qsize()) - ) + for idx, backlog in enumerate(self._transfer_backlogs): + get_metric("jetstream_transfer_backlog_size", idx=idx).set_function( + functools.partial(float, backlog.qsize()) + ) # Stage 3 # Each generate engine accesses its own generate backlog. # Interleaved Mode: Max size is 1 to increase the HBM utilization @@ -297,11 +291,10 @@ def __init__( ) for idx, engine in enumerate(self._generate_engines) } - if self._metrics_collector: - for idx, backlog in self._generate_backlogs.items(): - self._metrics_collector.get_generate_backlog_metric(idx).set_function( - functools.partial(float, backlog.qsize()) - ) + for idx, backlog in self._generate_backlogs.items(): + get_metric("jetstream_generate_backlog_size", idx=idx).set_function( + functools.partial(float, backlog.qsize()) + ) # Stage 4 # After generation, ActiveRequests are placed on the detokenization backlog # for tokens to be sent into each ActiveRequest's return channel. @@ -545,17 +538,10 @@ def _prefill_thread(self, idx: int): idx, my_transfer_backlog.qsize(), ) - if self._metrics_collector: - self._metrics_collector.get_request_input_length().observe(true_length) - - if self._metrics_collector: - self._metrics_collector.get_time_per_prefill_token().observe( - ( - request.metadata.transfer_enqueue_time - - request.metadata.prefill_dequeue_time - ) - / true_length - ) + get_metric("jetstream_request_input_length").observe(true_length) + get_metric("jetstream_time_per_prefill_token").observe( + get_time_per_prefill_token(request, true_length) + ) del prefill_result del request @@ -650,12 +636,9 @@ def _generate_thread(self, idx: int): max_concurrent_decodes = generate_engine.max_concurrent_decodes - if self._metrics_collector: - self._metrics_collector.get_slots_used_percentage_metric( - idx - ).set_function( - lambda: float(1 - (my_slots.qsize() / max_concurrent_decodes)) - ) + get_metric("jetstream_slots_used_percentage", idx=idx).set_function( + lambda: float(1 - (my_slots.qsize() / max_concurrent_decodes)) + ) # Check if there are any free my_slots. We don't want to block here since # we can still generate if we can't insert. We do this in a while loop to @@ -685,20 +668,9 @@ def _generate_thread(self, idx: int): if new_request is None: break new_request.metadata.generate_dequeue_time = time.perf_counter() - if ( - self._metrics_collector - and new_request.metadata.start_time is not None - ): - self._metrics_collector.get_queue_duration().observe( - # Time in prefill queue - new_request.metadata.prefill_dequeue_time - - new_request.metadata.prefill_enqueue_time - # Time in transfer queue - + new_request.metadata.transfer_dequeue_time - - new_request.metadata.transfer_enqueue_time - # Time in generate queue - + new_request.metadata.generate_dequeue_time - - new_request.metadata.generate_enqueue_time + if new_request.metadata.start_time is not None: + get_metric("jetstream_queue_duration").observe( + get_queue_duration(new_request) ) # Got free slot and new request, use them. except queue.Empty: @@ -798,10 +770,9 @@ def _detokenize_thread(self, idx: int): request.enqueue_samples(results) first_token_return_time = time.perf_counter() - if self._metrics_collector: - self._metrics_collector.get_time_to_first_token().observe( - first_token_return_time - request.metadata.prefill_dequeue_time - ) + get_metric("jetstream_time_to_first_token").observe( + first_token_return_time - request.metadata.prefill_dequeue_time + ) logging.info( "TTFT duration: %fms", (first_token_return_time - request.metadata.prefill_dequeue_time) @@ -831,39 +802,22 @@ def _detokenize_thread(self, idx: int): if request.complete.all(): request.metadata.complete_time = time.perf_counter() request.return_channel.close() - if self._metrics_collector: - self._metrics_collector.get_request_output_length().observe( - result_tokens.get_result_at_slot(slot).lengths - ) - self._metrics_collector.get_request_success_count_metric().inc() - self._metrics_collector.get_time_per_output_token().observe( - ( - request.metadata.complete_time - - request.metadata.transfer_enqueue_time - ) - / result_tokens.get_result_at_slot(slot).lengths + get_metric("jetstream_request_output_length").observe( + result_tokens.get_result_at_slot(slot).lengths + ) + get_metric("jetstream_request_success_count").inc() + get_metric("jetstream_time_per_output_token").observe( + get_tpot(request, result_tokens, slot) + ) + get_metric("jetstream_time_per_request").observe( + request.metadata.complete_time + - request.metadata.transfer_enqueue_time + ) + + if request.metadata.start_time: + get_metric("jetstream_wait_time_per_request").observe( + get_wait_time(request) ) - self._metrics_collector.get_time_per_request().observe( - request.metadata.complete_time - - request.metadata.transfer_enqueue_time - ) - - if request.metadata.start_time: - total_time = ( - request.metadata.complete_time - - request.metadata.start_time - ) - prefill_time = ( - request.metadata.transfer_enqueue_time - - request.metadata.prefill_dequeue_time - ) - generate_time = ( - request.metadata.complete_time - - request.metadata.generate_dequeue_time - ) - self._metrics_collector.get_wait_time_per_request().observe( - total_time - prefill_time - generate_time - ) # Place the slot back on the free queue. my_live_requests[slot] = None my_slots.put(slot, block=False) # This should always have space. diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 22180f09..888b3408 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -32,7 +32,7 @@ import jax from jetstream.core import config_lib from jetstream.core import orchestrator -from jetstream.core.metrics.prometheus import JetstreamMetricsCollector +from jetstream.core.metrics.prometheus import get_metric from jetstream.core.proto import jetstream_pb2_grpc from jetstream.engine import aot_utils, engine_api @@ -97,7 +97,6 @@ def create_driver( config: Type[config_lib.ServerConfig], devices: Any, jax_padding: bool = True, - metrics_collector: JetstreamMetricsCollector | None = None, enable_model_warmup: bool = False, ): """Creates a driver with a specified config. @@ -106,7 +105,6 @@ def create_driver( config: A ServerConfig to config engine, model, device slices, etc. devices: Device objects, will be used to get engine with proper slicing. jax_padding: The flag to enable JAX padding during tokenization. - metrics_collector: The JetStream Promethus metric collector. enable_model_warmup: The flag to enable model server warmup with AOT. Returns: @@ -161,7 +159,6 @@ def create_driver( generate_params=generate_params, interleaved_mode=interleaved_mode, jax_padding=jax_padding, - metrics_collector=metrics_collector, is_ray_backend=config.is_ray_backend, ) @@ -199,21 +196,17 @@ def run( server_start_time = time.time() logging.info("Kicking off gRPC server.") # Setup Prometheus server - metrics_collector: JetstreamMetricsCollector = None - if metrics_server_config and metrics_server_config.port: + if metrics_server_config is not None: logging.info( "Starting Prometheus server on port %d", metrics_server_config.port ) start_http_server(metrics_server_config.port) - metrics_collector = JetstreamMetricsCollector() else: logging.info( "Not starting Prometheus server: --prometheus_port flag not set" ) - driver = create_driver( - config, devices, jax_padding, metrics_collector, enable_model_warmup - ) + driver = create_driver(config, devices, jax_padding, enable_model_warmup) # We default threads to the total number of concurrent allowed decodes, # to make sure we can fully saturate the model. Set default minimum to 64. threads = threads or max(driver.get_total_concurrent_requests(), 64) @@ -222,10 +215,9 @@ def run( jetstream_server.start() - if metrics_collector: - metrics_collector.get_server_startup_latency_metric().set( - time.time() - server_start_time - ) + get_metric("jetstream_server_startup_latency").set( + time.time() - server_start_time + ) # Setup Jax Profiler if enable_jax_profiler: diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py index aaced235..3261ba0f 100644 --- a/jetstream/entrypoints/http/api_server.py +++ b/jetstream/entrypoints/http/api_server.py @@ -20,15 +20,14 @@ from typing import Sequence from absl import app as abslapp from absl import flags +from prometheus_client import start_http_server from fastapi import APIRouter, Response import fastapi from fastapi.responses import StreamingResponse -from prometheus_client import start_http_server import uvicorn from google.protobuf.json_format import Parse -from jetstream.core import config_lib, orchestrator, server_lib -from jetstream.core.metrics.prometheus import JetstreamMetricsCollector +from jetstream.core import orchestrator, server_lib from jetstream.core.proto import jetstream_pb2 from jetstream.entrypoints.config import get_server_config from jetstream.entrypoints.http.protocol import DecodeRequest @@ -100,18 +99,12 @@ def server(argv: Sequence[str]): print(f"server_config: {server_config}") del argv - metrics_server_config: config_lib.MetricsServerConfig | None = None # Setup Prometheus server - metrics_collector: JetstreamMetricsCollector = None if flags.FLAGS.prometheus_port != 0: - metrics_server_config = config_lib.MetricsServerConfig( - port=flags.FLAGS.prometheus_port - ) logging.info( - "Starting Prometheus server on port %d", metrics_server_config.port + "Starting Prometheus server on port %d", flags.FLAGS.prometheus_port ) - start_http_server(metrics_server_config.port) - metrics_collector = JetstreamMetricsCollector() + start_http_server(port=flags.FLAGS.prometheus_port) else: logging.info( "Not starting Prometheus server: --prometheus_port flag not set" @@ -122,7 +115,6 @@ def server(argv: Sequence[str]): driver=server_lib.create_driver( config=server_config, devices=devices, - metrics_collector=metrics_collector, ) ) diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 2fdddce9..7a95cfcd 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -90,10 +90,6 @@ async def test_server( ) ###################### Requester side ###################################### - # if prometheus not configured, assert no metrics collector on Driver - if metrics_enabled is not True: - assert server._driver._metrics_collector is None # pylint: disable=protected-access - async with grpc.aio.secure_channel( f"localhost:{port}", grpc.local_channel_credentials() ) as channel: @@ -122,15 +118,15 @@ async def test_server( assert output_text == expected_text[counter] assert output_token_id == expected_token_ids[counter] counter += 1 - # assert prometheus server is running and responding - if metrics_enabled is True: - assert server._driver._metrics_collector is not None # pylint: disable=protected-access + # assert appropriate responsiveness of the prometheus server + try: + response = requests.get(f"http://localhost:{metrics_port}", timeout=5) assert ( - requests.get( - f"http://localhost:{metrics_port}", timeout=5 - ).status_code - == requests.status_codes.codes["ok"] + response.status_code == requests.status_codes.codes["ok"] + and metrics_enabled ) + except requests.exceptions.ConnectionError: + assert not metrics_enabled server.stop() def test_jax_profiler_server(self):