From d92cbd5cf840f23e30d55d4a2bd3503ab5c22e40 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Wed, 7 Aug 2024 22:21:21 +0000 Subject: [PATCH 01/22] first commit --- jetstream/core/metrics/prometheus.py | 100 ++++++++++-------- jetstream/core/orchestrator.py | 123 +++++++++++------------ jetstream/entrypoints/http/api_server.py | 13 +-- jetstream/tests/core/test_server.py | 6 +- 4 files changed, 124 insertions(+), 118 deletions(-) diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index dc8a00e9..712eb20b 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -14,9 +14,14 @@ """Contains common functions for configuring Jetstream server metrics""" +import logging import os +from typing import Optional import shortuuid -from prometheus_client import Counter, Gauge, Histogram +from prometheus_client import Counter, Gauge, Histogram, start_http_server + +from jetstream.core import config_lib + from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS @@ -24,12 +29,24 @@ class JetstreamMetricsCollector: """Wrapper class should be used to assure all metrics have proper tags""" _id: str = os.getenv("HOSTNAME", shortuuid.uuid()) + _metrics_server_config: Optional[config_lib.MetricsServerConfig] = None def __new__(cls): if not hasattr(cls, "instance"): cls.instance = super(JetstreamMetricsCollector, cls).__new__(cls) return cls.instance + def start_http_server( + self, metrics_server_config: config_lib.MetricsServerConfig + ): + self._metrics_server_config = metrics_server_config + logging.info( + "Starting Prometheus server on port %d", + self._metrics_server_config.port, + ) + start_http_server(self._metrics_server_config.port) + self._serving_metrics = True + # Metric definitions _prefill_backlog = Gauge( name="jetstream_prefill_backlog_size", @@ -37,18 +54,27 @@ def __new__(cls): labelnames=["id"], ) + def get_prefill_backlog_metric(self): + return self._prefill_backlog.labels(id=self._id) + _transfer_backlog = Gauge( name="jetstream_transfer_backlog_size", documentation="Size of transfer queue", labelnames=["id", "idx"], ) + def get_transfer_backlog_metric(self, idx: int): + return self._transfer_backlog.labels(id=self._id, idx=idx) + _generate_backlog = Gauge( name="jetstream_generate_backlog_size", documentation="Size of generate queue", labelnames=["id", "idx"], ) + def get_generate_backlog_metric(self, idx: int): + return self._generate_backlog.labels(id=self._id, idx=idx) + _queue_duration = Histogram( name="jetstream_queue_duration", documentation="The total time each request spends enqueued in seconds", @@ -70,23 +96,37 @@ def __new__(cls): ], ) + def get_queue_duration(self): + return self._queue_duration.labels(id=self._id) + _slots_used_percentage = Gauge( name="jetstream_slots_used_percentage", documentation="The percentage of decode slots currently being used", labelnames=["id", "idx"], ) + def get_slots_used_percentage_metric(self, idx: int): + return self._slots_used_percentage.labels(id=self._id, idx=idx) + _server_startup_latency = Gauge( name="jetstream_server_startup_latency", documentation="Total time taken to start the Jetstream server", labelnames=["id"], ) + + def get_server_startup_latency_metric(self): + return self._server_startup_latency.labels(id=self._id) + _request_input_length = Histogram( name="jetstream_request_input_length", documentation="Number of input tokens per request", labelnames=["id"], buckets=DEFAULT_PREFILL_BUCKETS, ) + + def get_request_input_length(self): + return self._request_input_length.labels(id=self._id) + _request_output_length = Histogram( name="jetstream_request_output_length", documentation="Number of output tokens per request", @@ -114,12 +154,19 @@ def __new__(cls): 2000000, ], ) + + def get_request_output_length(self): + return self._request_output_length.labels(id=self._id) + _request_success_count = Counter( name="jetstream_request_success_count", documentation="Number of requests successfully completed", labelnames=["id"], ) + def get_request_success_count_metric(self): + return self._request_success_count.labels(id=self._id) + _time_to_first_token = Histogram( name="jetstream_time_to_first_token", documentation="Time to first token per request in seconds", @@ -144,6 +191,9 @@ def __new__(cls): ], ) + def get_time_to_first_token(self): + return self._time_to_first_token.labels(id=self._id) + _time_per_output_token = Histogram( name="jetstream_time_per_output_token", documentation="Average time per output token per request in seconds", @@ -165,6 +215,9 @@ def __new__(cls): ], ) + def get_time_per_output_token(self): + return self._time_per_output_token.labels(id=self._id) + _time_per_prefill_token = Histogram( name="jetstream_time_per_prefill_token", documentation="Prefill time per token per request in seconds", @@ -186,6 +239,9 @@ def __new__(cls): ], ) + def get_time_per_prefill_token(self): + return self._time_per_prefill_token.labels(id=self._id) + _time_per_request = Histogram( name="jetstream_time_per_request", documentation="End to end request latency in seconds", @@ -193,6 +249,9 @@ def __new__(cls): buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0], ) + def get_time_per_request(self): + return self._time_per_request.labels(id=self._id) + _wait_time_per_request = Histogram( name="jetstream_wait_time_per_request", documentation="Time each request is not being prefilled or decoded", @@ -214,44 +273,5 @@ def __new__(cls): ], ) - def get_prefill_backlog_metric(self): - return self._prefill_backlog.labels(id=self._id) - - def get_transfer_backlog_metric(self, idx: int): - return self._transfer_backlog.labels(id=self._id, idx=idx) - - def get_generate_backlog_metric(self, idx: int): - return self._generate_backlog.labels(id=self._id, idx=idx) - - def get_queue_duration(self): - return self._queue_duration.labels(id=self._id) - - def get_slots_used_percentage_metric(self, idx: int): - return self._slots_used_percentage.labels(id=self._id, idx=idx) - - def get_server_startup_latency_metric(self): - return self._server_startup_latency.labels(id=self._id) - - def get_time_to_first_token(self): - return self._time_to_first_token.labels(id=self._id) - - def get_time_per_output_token(self): - return self._time_per_output_token.labels(id=self._id) - - def get_time_per_prefill_token(self): - return self._time_per_prefill_token.labels(id=self._id) - - def get_time_per_request(self): - return self._time_per_request.labels(id=self._id) - def get_wait_time_per_request(self): return self._wait_time_per_request.labels(id=self._id) - - def get_request_input_length(self): - return self._request_input_length.labels(id=self._id) - - def get_request_output_length(self): - return self._request_output_length.labels(id=self._id) - - def get_request_success_count_metric(self): - return self._request_success_count.labels(id=self._id) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index cefabd05..c5dc47ed 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -223,7 +223,7 @@ class Driver: _jax_padding = True # All metrics we want to monitor should be collected with this - _metrics_collector: JetstreamMetricsCollector | None = None + _metrics_collector: JetstreamMetricsCollector = JetstreamMetricsCollector() def __init__( self, @@ -255,16 +255,16 @@ def __init__( self._prefill_params = prefill_params self._generate_params = generate_params self._interleaved_mode = interleaved_mode - self._metrics_collector = metrics_collector + if metrics_collector is not None: + self._metrics_collector = metrics_collector # Stages 1-4 represent the life cycle of a request. # Stage 1 # At first, a request is placed here in order to get prefilled. self._prefill_backlog = queue.Queue() - if self._metrics_collector: - self._metrics_collector.get_prefill_backlog_metric().set_function( - lambda: float(self._prefill_backlog.qsize()) - ) + self._metrics_collector.get_prefill_backlog_metric().set_function( + lambda: float(self._prefill_backlog.qsize()) + ) # Stage 2 # After prefilling, it is placed here in order to get transferred to @@ -278,11 +278,10 @@ def __init__( queue.Queue(1 if self._interleaved_mode else 4) for i in range(len(self._prefill_engines)) ] - if self._metrics_collector: - for idx, backlog in enumerate(self._transfer_backlogs): - self._metrics_collector.get_transfer_backlog_metric(idx).set_function( - functools.partial(float, backlog.qsize()) - ) + for idx, backlog in enumerate(self._transfer_backlogs): + self._metrics_collector.get_transfer_backlog_metric(idx).set_function( + functools.partial(float, backlog.qsize()) + ) # Stage 3 # Each generate engine accesses its own generate backlog. # Interleaved Mode: Max size is 1 to increase the HBM utilization @@ -297,11 +296,10 @@ def __init__( ) for idx, engine in enumerate(self._generate_engines) } - if self._metrics_collector: - for idx, backlog in self._generate_backlogs.items(): - self._metrics_collector.get_generate_backlog_metric(idx).set_function( - functools.partial(float, backlog.qsize()) - ) + for idx, backlog in self._generate_backlogs.items(): + self._metrics_collector.get_generate_backlog_metric(idx).set_function( + functools.partial(float, backlog.qsize()) + ) # Stage 4 # After generation, ActiveRequests are placed on the detokenization backlog # for tokens to be sent into each ActiveRequest's return channel. @@ -545,17 +543,14 @@ def _prefill_thread(self, idx: int): idx, my_transfer_backlog.qsize(), ) - if self._metrics_collector: - self._metrics_collector.get_request_input_length().observe(true_length) - - if self._metrics_collector: - self._metrics_collector.get_time_per_prefill_token().observe( - ( - request.metadata.transfer_enqueue_time - - request.metadata.prefill_dequeue_time - ) - / true_length - ) + self._metrics_collector.get_request_input_length().observe(true_length) + self._metrics_collector.get_time_per_prefill_token().observe( + ( + request.metadata.transfer_enqueue_time + - request.metadata.prefill_dequeue_time + ) + / true_length + ) del prefill_result del request @@ -650,12 +645,11 @@ def _generate_thread(self, idx: int): max_concurrent_decodes = generate_engine.max_concurrent_decodes - if self._metrics_collector: - self._metrics_collector.get_slots_used_percentage_metric( - idx - ).set_function( - lambda: float(1 - (my_slots.qsize() / max_concurrent_decodes)) - ) + self._metrics_collector.get_slots_used_percentage_metric( + idx + ).set_function( + lambda: float(1 - (my_slots.qsize() / max_concurrent_decodes)) + ) # Check if there are any free my_slots. We don't want to block here since # we can still generate if we can't insert. We do this in a while loop to @@ -798,10 +792,9 @@ def _detokenize_thread(self, idx: int): request.enqueue_samples(results) first_token_return_time = time.perf_counter() - if self._metrics_collector: - self._metrics_collector.get_time_to_first_token().observe( - first_token_return_time - request.metadata.prefill_dequeue_time - ) + self._metrics_collector.get_time_to_first_token().observe( + first_token_return_time - request.metadata.prefill_dequeue_time + ) logging.info( "TTFT duration: %fms", (first_token_return_time - request.metadata.prefill_dequeue_time) @@ -831,39 +824,37 @@ def _detokenize_thread(self, idx: int): if request.complete.all(): request.metadata.complete_time = time.perf_counter() request.return_channel.close() - if self._metrics_collector: - self._metrics_collector.get_request_output_length().observe( - result_tokens.get_result_at_slot(slot).lengths + self._metrics_collector.get_request_output_length().observe( + result_tokens.get_result_at_slot(slot).lengths + ) + self._metrics_collector.get_request_success_count_metric().inc() + self._metrics_collector.get_time_per_output_token().observe( + ( + request.metadata.complete_time + - request.metadata.transfer_enqueue_time + ) + / result_tokens.get_result_at_slot(slot).lengths + ) + self._metrics_collector.get_time_per_request().observe( + request.metadata.complete_time + - request.metadata.transfer_enqueue_time + ) + + if request.metadata.start_time: + total_time = ( + request.metadata.complete_time - request.metadata.start_time ) - self._metrics_collector.get_request_success_count_metric().inc() - self._metrics_collector.get_time_per_output_token().observe( - ( - request.metadata.complete_time - - request.metadata.transfer_enqueue_time - ) - / result_tokens.get_result_at_slot(slot).lengths + prefill_time = ( + request.metadata.transfer_enqueue_time + - request.metadata.prefill_dequeue_time ) - self._metrics_collector.get_time_per_request().observe( + generate_time = ( request.metadata.complete_time - - request.metadata.transfer_enqueue_time + - request.metadata.generate_dequeue_time + ) + self._metrics_collector.get_wait_time_per_request().observe( + total_time - prefill_time - generate_time ) - - if request.metadata.start_time: - total_time = ( - request.metadata.complete_time - - request.metadata.start_time - ) - prefill_time = ( - request.metadata.transfer_enqueue_time - - request.metadata.prefill_dequeue_time - ) - generate_time = ( - request.metadata.complete_time - - request.metadata.generate_dequeue_time - ) - self._metrics_collector.get_wait_time_per_request().observe( - total_time - prefill_time - generate_time - ) # Place the slot back on the free queue. my_live_requests[slot] = None my_slots.put(slot, block=False) # This should always have space. diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py index aaced235..664f9ed6 100644 --- a/jetstream/entrypoints/http/api_server.py +++ b/jetstream/entrypoints/http/api_server.py @@ -23,7 +23,6 @@ from fastapi import APIRouter, Response import fastapi from fastapi.responses import StreamingResponse -from prometheus_client import start_http_server import uvicorn from google.protobuf.json_format import Parse @@ -100,18 +99,12 @@ def server(argv: Sequence[str]): print(f"server_config: {server_config}") del argv - metrics_server_config: config_lib.MetricsServerConfig | None = None # Setup Prometheus server - metrics_collector: JetstreamMetricsCollector = None + metrics_collector: JetstreamMetricsCollector = JetstreamMetricsCollector() if flags.FLAGS.prometheus_port != 0: - metrics_server_config = config_lib.MetricsServerConfig( - port=flags.FLAGS.prometheus_port + metrics_collector.start_http_server( + config_lib.MetricsServerConfig(port=flags.FLAGS.prometheus_port) ) - logging.info( - "Starting Prometheus server on port %d", metrics_server_config.port - ) - start_http_server(metrics_server_config.port) - metrics_collector = JetstreamMetricsCollector() else: logging.info( "Not starting Prometheus server: --prometheus_port flag not set" diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 2fdddce9..008e05ad 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -92,7 +92,7 @@ async def test_server( # if prometheus not configured, assert no metrics collector on Driver if metrics_enabled is not True: - assert server._driver._metrics_collector is None # pylint: disable=protected-access + assert server._driver._metrics_collector._metrics_server_config is None # pylint: disable=protected-access async with grpc.aio.secure_channel( f"localhost:{port}", grpc.local_channel_credentials() @@ -124,7 +124,9 @@ async def test_server( counter += 1 # assert prometheus server is running and responding if metrics_enabled is True: - assert server._driver._metrics_collector is not None # pylint: disable=protected-access + assert ( + server._driver._metrics_collector._metrics_server_config is not None # pylint: disable=protected-access + ) assert ( requests.get( f"http://localhost:{metrics_port}", timeout=5 From 4923ae7b4a1633785b043b3fd7050599a633f682 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 16:37:12 +0000 Subject: [PATCH 02/22] cleanup --- jetstream/core/metrics/prometheus.py | 1 - jetstream/core/orchestrator.py | 2 +- jetstream/core/server_lib.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index 712eb20b..b300d035 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -45,7 +45,6 @@ def start_http_server( self._metrics_server_config.port, ) start_http_server(self._metrics_server_config.port) - self._serving_metrics = True # Metric definitions _prefill_backlog = Gauge( diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index c5dc47ed..ebdab3d4 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -233,7 +233,7 @@ def __init__( generate_params: Optional[list[Any]] = None, interleaved_mode: bool = False, jax_padding: bool = True, - metrics_collector: JetstreamMetricsCollector | None = None, + metrics_collector: Optional[JetstreamMetricsCollector] = None, is_ray_backend: bool = False, ): if prefill_engines is None: diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 22180f09..eeb02d85 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -200,7 +200,7 @@ def run( logging.info("Kicking off gRPC server.") # Setup Prometheus server metrics_collector: JetstreamMetricsCollector = None - if metrics_server_config and metrics_server_config.port: + if metrics_server_config is not None: logging.info( "Starting Prometheus server on port %d", metrics_server_config.port ) From 0b6233b4ee39c974a582233c0a9d4eb2822c7e20 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 16:59:14 +0000 Subject: [PATCH 03/22] change test conditions --- jetstream/core/server_lib.py | 3 --- jetstream/tests/core/test_server.py | 12 ++++++------ 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index eeb02d85..9a8ba629 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -201,9 +201,6 @@ def run( # Setup Prometheus server metrics_collector: JetstreamMetricsCollector = None if metrics_server_config is not None: - logging.info( - "Starting Prometheus server on port %d", metrics_server_config.port - ) start_http_server(metrics_server_config.port) metrics_collector = JetstreamMetricsCollector() else: diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 008e05ad..3899a2ad 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -90,10 +90,6 @@ async def test_server( ) ###################### Requester side ###################################### - # if prometheus not configured, assert no metrics collector on Driver - if metrics_enabled is not True: - assert server._driver._metrics_collector._metrics_server_config is None # pylint: disable=protected-access - async with grpc.aio.secure_channel( f"localhost:{port}", grpc.local_channel_credentials() ) as channel: @@ -125,13 +121,17 @@ async def test_server( # assert prometheus server is running and responding if metrics_enabled is True: assert ( - server._driver._metrics_collector._metrics_server_config is not None # pylint: disable=protected-access + requests.get( + f"http://localhost:{metrics_port}", timeout=5 + ).status_code + == requests.status_codes.codes["ok"] ) + else: assert ( requests.get( f"http://localhost:{metrics_port}", timeout=5 ).status_code - == requests.status_codes.codes["ok"] + == requests.status_codes.codes["not_found"] ) server.stop() From b5308629524b34f79b90274f2905c394e521fa05 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 17:01:17 +0000 Subject: [PATCH 04/22] change test conditions --- jetstream/tests/core/test_server.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 3899a2ad..35a0517e 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -118,21 +118,15 @@ async def test_server( assert output_text == expected_text[counter] assert output_token_id == expected_token_ids[counter] counter += 1 - # assert prometheus server is running and responding - if metrics_enabled is True: - assert ( - requests.get( - f"http://localhost:{metrics_port}", timeout=5 - ).status_code - == requests.status_codes.codes["ok"] - ) - else: - assert ( - requests.get( - f"http://localhost:{metrics_port}", timeout=5 - ).status_code - == requests.status_codes.codes["not_found"] - ) + # assert appropriate responsiveness of the prometheus server + assert ( + requests.get( + f"http://localhost:{metrics_port}", timeout=5 + ).status_code + == requests.status_codes.codes[ + "ok" if metrics_enabled else "not_found" + ] + ) server.stop() def test_jax_profiler_server(self): From 26e80902f3058fccb30ec8807d015208bd33b14d Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 17:11:31 +0000 Subject: [PATCH 05/22] change test conditions --- jetstream/tests/core/test_server.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 35a0517e..20d82997 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -119,14 +119,16 @@ async def test_server( assert output_token_id == expected_token_ids[counter] counter += 1 # assert appropriate responsiveness of the prometheus server - assert ( - requests.get( - f"http://localhost:{metrics_port}", timeout=5 - ).status_code - == requests.status_codes.codes[ - "ok" if metrics_enabled else "not_found" - ] - ) + try: + response = requests.get( + f"http://localhost:{metrics_port}", timeout=5 + ).response + assert ( + response.status_code == requests.status_codes.codes["ok"] + and metrics_enabled + ) + except requests.exceptions.MaxRetryError: + assert not metrics_enabled server.stop() def test_jax_profiler_server(self): From ae968b6e1e6dbf653fd9617b52bef505ca02cc34 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 17:15:18 +0000 Subject: [PATCH 06/22] wrong error type --- jetstream/tests/core/test_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 20d82997..cd50ace4 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -127,7 +127,7 @@ async def test_server( response.status_code == requests.status_codes.codes["ok"] and metrics_enabled ) - except requests.exceptions.MaxRetryError: + except requests.exceptions.ConnectionError: assert not metrics_enabled server.stop() From ecb58f44f5f048644f6e409fa639c55a327779bd Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 17:19:40 +0000 Subject: [PATCH 07/22] no attribute response --- jetstream/tests/core/test_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index cd50ace4..30040063 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -122,7 +122,7 @@ async def test_server( try: response = requests.get( f"http://localhost:{metrics_port}", timeout=5 - ).response + ) assert ( response.status_code == requests.status_codes.codes["ok"] and metrics_enabled From 4bb965cd31fd0bacdc7912a916c210288f420499 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 17:31:16 +0000 Subject: [PATCH 08/22] Add documentation --- ...ity-prometheus-metrics-in-jetstream-server.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/observability-prometheus-metrics-in-jetstream-server.md b/docs/observability-prometheus-metrics-in-jetstream-server.md index 04d7be4c..464bc8f3 100644 --- a/docs/observability-prometheus-metrics-in-jetstream-server.md +++ b/docs/observability-prometheus-metrics-in-jetstream-server.md @@ -50,6 +50,22 @@ jetstream_prefill_backlog_size{id="SOME-HOSTNAME-HERE>"} 0.0 jetstream_slots_used_percentage{id="",idx="0"} 0.04166666666666663 ``` +Currently the following metrics are supported: + - `jetstream_prefill_backlog_size`: Size of prefill queue + - `jetstream_transfer_backlog_size`: Size of transfer queue + - `jetstream_generate_backlog_size`: Size of generate queue + - `jetstream_queue_duration`: The total time each request spends enqueued in seconds + - `jetstream_slots_used_percentage`: The percentage of decode slots currently being used + - `jetstream_server_startup_latency`: Total time taken to start the Jetstream server + - `jetstream_request_input_length`: Number of input tokens per request + - `jetstream_request_output_length`: Number of output tokens per request + - `jetstream_request_success_count`: Number of requests successfully completed + - `jetstream_time_to_first_token`: Time to first token per request in seconds + - `jetstream_time_per_output_token`: Average time per output token per request in seconds + - `jetstream_time_per_prefill_token`: Prefill time per token per request in seconds + - `jetstream_time_per_request`: End to end request latency in seconds + - `jetstream_wait_time_per_request`: Time each request is not being prefilled or decoded + ## Observe metrics on GKE clusters The following applies only for Jetstream deployed on a GKE cluster. Currently [Google Cloud Managed Service for Prometheus](https://cloud.google.com/stackdriver/docs/managed-prometheus) is enabled by default on all GKE clusters, it determines scrape targets via the [PodMonitoring](https://github.com/GoogleCloudPlatform/prometheus-engine/blob/v0.10.0/doc/api.md#podmonitoring) custom resource. After you deployed the JetStream GKE workload, you need to apply the PodMonitoring resource to your cluster as follows: From 228e3489e9a4eb320c0d31448cdecea754fc79b4 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 17:39:12 +0000 Subject: [PATCH 09/22] fmt --- jetstream/tests/core/test_server.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 30040063..7a95cfcd 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -120,9 +120,7 @@ async def test_server( counter += 1 # assert appropriate responsiveness of the prometheus server try: - response = requests.get( - f"http://localhost:{metrics_port}", timeout=5 - ) + response = requests.get(f"http://localhost:{metrics_port}", timeout=5) assert ( response.status_code == requests.status_codes.codes["ok"] and metrics_enabled From e49e15fb93199d2d91bc5bcb333edf66583fb4b9 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 21:39:41 +0000 Subject: [PATCH 10/22] remove JetstreamMetricsCollector class --- jetstream/core/config_lib.py | 5 - jetstream/core/metrics/prometheus.py | 454 ++++++++++------------- jetstream/core/orchestrator.py | 41 +- jetstream/core/server_lib.py | 18 +- jetstream/entrypoints/http/api_server.py | 8 +- 5 files changed, 227 insertions(+), 299 deletions(-) diff --git a/jetstream/core/config_lib.py b/jetstream/core/config_lib.py index f3022d01..53d0e37e 100644 --- a/jetstream/core/config_lib.py +++ b/jetstream/core/config_lib.py @@ -48,11 +48,6 @@ class InstantiatedEngines: interleaved_engines: List[engine_api.Engine] -@dataclasses.dataclass -class MetricsServerConfig: - port: uint16 - - # ▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼# diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index b300d035..10b157fb 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -14,263 +14,215 @@ """Contains common functions for configuring Jetstream server metrics""" -import logging import os -from typing import Optional import shortuuid -from prometheus_client import Counter, Gauge, Histogram, start_http_server +from prometheus_client import Counter, Gauge, Histogram from jetstream.core import config_lib from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS - -class JetstreamMetricsCollector: - """Wrapper class should be used to assure all metrics have proper tags""" - - _id: str = os.getenv("HOSTNAME", shortuuid.uuid()) - _metrics_server_config: Optional[config_lib.MetricsServerConfig] = None - - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(JetstreamMetricsCollector, cls).__new__(cls) - return cls.instance - - def start_http_server( - self, metrics_server_config: config_lib.MetricsServerConfig - ): - self._metrics_server_config = metrics_server_config - logging.info( - "Starting Prometheus server on port %d", - self._metrics_server_config.port, +# Initialize the unique ID for labeling metrics +_id = os.getenv("HOSTNAME", shortuuid.uuid()) + +# Registry for storing metric objects +_metrics_registry = { + "jetstream_prefill_backlog_size": Gauge( + name="jetstream_prefill_backlog_size", + documentation="Size of prefill queue", + labelnames=["id"], + ), + "jetstream_transfer_backlog_size": Gauge( + name="jetstream_transfer_backlog_size", + documentation="Size of transfer queue", + labelnames=["id", "idx"], + ), + "jetstream_generate_backlog_size": Gauge( + name="jetstream_generate_backlog_size", + documentation="Size of generate queue", + labelnames=["id", "idx"], + ), + "jetstream_queue_duration": Histogram( + name="jetstream_queue_duration", + documentation="The total time each request spends enqueued in seconds", + labelnames=["id"], + buckets=[ + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1.0, + 2.0, + 5.0, + 10.0, + 20.0, + 50.0, + 100.0, + ], + ), + "jetstream_slots_used_percentage": Gauge( + name="jetstream_slots_used_percentage", + documentation="The percentage of decode slots currently being used", + labelnames=["id", "idx"], + ), + "jetstream_server_startup_latency": Gauge( + name="jetstream_server_startup_latency", + documentation="Total time taken to start the Jetstream server", + labelnames=["id"], + ), + "jetstream_request_input_length": Histogram( + name="jetstream_request_input_length", + documentation="Number of input tokens per request", + labelnames=["id"], + buckets=DEFAULT_PREFILL_BUCKETS, + ), + "jetstream_request_output_length": Histogram( + name="jetstream_request_output_length", + documentation="Number of output tokens per request", + labelnames=["id"], + buckets=[ + 1, + 2, + 5, + 10, + 20, + 50, + 100, + 200, + 500, + 1000, + 2000, + 5000, + 10000, + 20000, + 50000, + 100000, + 200000, + 500000, + 1000000, + 2000000, + ], + ), + "jetstream_request_success_count": Counter( + name="jetstream_request_success_count", + documentation="Number of requests successfully completed", + labelnames=["id"], + ), + "jetstream_time_to_first_token": Histogram( + name="jetstream_time_to_first_token", + documentation="Time to first token per request in seconds", + labelnames=["id"], + buckets=[ + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + ], + ), + "jetstream_time_per_output_token": Histogram( + name="jetstream_time_per_output_token", + documentation="Average time per output token per request in seconds", + labelnames=["id"], + buckets=[ + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + ], + ), + "jetstream_time_per_prefill_token": Histogram( + name="jetstream_time_per_prefill_token", + documentation="Prefill time per token per request in seconds", + labelnames=["id"], + buckets=[ + 0.00001, + 0.00002, + 0.00005, + 0.0001, + 0.0002, + 0.0005, + 0.001, + 0.002, + 0.005, + 0.01, + 0.02, + 0.05, + 0.1, + ], + ), + "jetstream_time_per_request": Histogram( + name="jetstream_time_per_request", + documentation="End to end request latency in seconds", + labelnames=["id"], + buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0], + ), + "jetstream_wait_time_per_request": Histogram( + name="jetstream_wait_time_per_request", + documentation="Time each request is not being prefilled or decoded", + labelnames=["id"], + buckets=[ + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1.0, + 2.0, + 5.0, + 10.0, + 20.0, + 50.0, + 100.0, + ], + ), + "jetstream_total_tokens_in_current_batch": Gauge( + name="jetstream_total_tokens_in_current_batch", + documentation="Total number of tokens in the decode batch", + labelnames=["id", "idx"], + ), +} + + +# Function to retrieve a metric with specified labels +def get_metric(metric_name, **labels): + if metric_name not in _metrics_registry: + raise ValueError(f"Metric {metric_name} not found in registry.") + + metric = _metrics_registry[metric_name] + + # Automatically add the 'id' label if it's required by the metric + if "id" in metric._labelnames: + labels["id"] = _id + + # Check for any missing labels + missing_labels = set(metric._labelnames) - labels.keys() + if missing_labels: + raise ValueError( + f"Missing labels for metric {metric_name}: {', '.join(missing_labels)}" ) - start_http_server(self._metrics_server_config.port) - - # Metric definitions - _prefill_backlog = Gauge( - name="jetstream_prefill_backlog_size", - documentation="Size of prefill queue", - labelnames=["id"], - ) - - def get_prefill_backlog_metric(self): - return self._prefill_backlog.labels(id=self._id) - - _transfer_backlog = Gauge( - name="jetstream_transfer_backlog_size", - documentation="Size of transfer queue", - labelnames=["id", "idx"], - ) - - def get_transfer_backlog_metric(self, idx: int): - return self._transfer_backlog.labels(id=self._id, idx=idx) - - _generate_backlog = Gauge( - name="jetstream_generate_backlog_size", - documentation="Size of generate queue", - labelnames=["id", "idx"], - ) - - def get_generate_backlog_metric(self, idx: int): - return self._generate_backlog.labels(id=self._id, idx=idx) - - _queue_duration = Histogram( - name="jetstream_queue_duration", - documentation="The total time each request spends enqueued in seconds", - labelnames=["id"], - buckets=[ - 0.01, - 0.02, - 0.05, - 0.1, - 0.2, - 0.5, - 1.0, - 2.0, - 5.0, - 10.0, - 20.0, - 50.0, - 100.0, - ], - ) - - def get_queue_duration(self): - return self._queue_duration.labels(id=self._id) - - _slots_used_percentage = Gauge( - name="jetstream_slots_used_percentage", - documentation="The percentage of decode slots currently being used", - labelnames=["id", "idx"], - ) - - def get_slots_used_percentage_metric(self, idx: int): - return self._slots_used_percentage.labels(id=self._id, idx=idx) - - _server_startup_latency = Gauge( - name="jetstream_server_startup_latency", - documentation="Total time taken to start the Jetstream server", - labelnames=["id"], - ) - - def get_server_startup_latency_metric(self): - return self._server_startup_latency.labels(id=self._id) - - _request_input_length = Histogram( - name="jetstream_request_input_length", - documentation="Number of input tokens per request", - labelnames=["id"], - buckets=DEFAULT_PREFILL_BUCKETS, - ) - - def get_request_input_length(self): - return self._request_input_length.labels(id=self._id) - - _request_output_length = Histogram( - name="jetstream_request_output_length", - documentation="Number of output tokens per request", - labelnames=["id"], - buckets=[ - 1, - 2, - 5, - 10, - 20, - 50, - 100, - 200, - 500, - 1000, - 2000, - 5000, - 10000, - 20000, - 50000, - 100000, - 200000, - 500000, - 1000000, - 2000000, - ], - ) - - def get_request_output_length(self): - return self._request_output_length.labels(id=self._id) - - _request_success_count = Counter( - name="jetstream_request_success_count", - documentation="Number of requests successfully completed", - labelnames=["id"], - ) - - def get_request_success_count_metric(self): - return self._request_success_count.labels(id=self._id) - - _time_to_first_token = Histogram( - name="jetstream_time_to_first_token", - documentation="Time to first token per request in seconds", - labelnames=["id"], - buckets=[ - 0.001, - 0.005, - 0.01, - 0.02, - 0.04, - 0.06, - 0.08, - 0.1, - 0.25, - 0.5, - 0.75, - 1.0, - 2.5, - 5.0, - 7.5, - 10.0, - ], - ) - - def get_time_to_first_token(self): - return self._time_to_first_token.labels(id=self._id) - - _time_per_output_token = Histogram( - name="jetstream_time_per_output_token", - documentation="Average time per output token per request in seconds", - labelnames=["id"], - buckets=[ - 0.01, - 0.025, - 0.05, - 0.075, - 0.1, - 0.15, - 0.2, - 0.3, - 0.4, - 0.5, - 0.75, - 1.0, - 2.5, - ], - ) - - def get_time_per_output_token(self): - return self._time_per_output_token.labels(id=self._id) - - _time_per_prefill_token = Histogram( - name="jetstream_time_per_prefill_token", - documentation="Prefill time per token per request in seconds", - labelnames=["id"], - buckets=[ - 0.00001, - 0.00002, - 0.00005, - 0.0001, - 0.0002, - 0.0005, - 0.001, - 0.002, - 0.005, - 0.01, - 0.02, - 0.05, - 0.1, - ], - ) - - def get_time_per_prefill_token(self): - return self._time_per_prefill_token.labels(id=self._id) - - _time_per_request = Histogram( - name="jetstream_time_per_request", - documentation="End to end request latency in seconds", - labelnames=["id"], - buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0], - ) - - def get_time_per_request(self): - return self._time_per_request.labels(id=self._id) - - _wait_time_per_request = Histogram( - name="jetstream_wait_time_per_request", - documentation="Time each request is not being prefilled or decoded", - labelnames=["id"], - buckets=[ - 0.01, - 0.02, - 0.05, - 0.1, - 0.2, - 0.5, - 1.0, - 2.0, - 5.0, - 10.0, - 20.0, - 50.0, - 100.0, - ], - ) - def get_wait_time_per_request(self): - return self._wait_time_per_request.labels(id=self._id) + return metric.labels(**labels) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index ebdab3d4..3036bfae 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -94,7 +94,7 @@ from jetstream.core.utils import async_multifuture from jetstream.core.utils.return_sample import ReturnSample from jetstream.engine import engine_api, tokenizer_api, token_utils -from jetstream.core.metrics.prometheus import JetstreamMetricsCollector +from jetstream.core.metrics.prometheus import JetstreamMetricsCollector, get_metric import numpy as np root = logging.getLogger() @@ -222,9 +222,6 @@ class Driver: # todo: remove jax_padding after all then engine migrate to np padding _jax_padding = True - # All metrics we want to monitor should be collected with this - _metrics_collector: JetstreamMetricsCollector = JetstreamMetricsCollector() - def __init__( self, prefill_engines: Optional[list[engine_api.Engine]] = None, @@ -233,7 +230,6 @@ def __init__( generate_params: Optional[list[Any]] = None, interleaved_mode: bool = False, jax_padding: bool = True, - metrics_collector: Optional[JetstreamMetricsCollector] = None, is_ray_backend: bool = False, ): if prefill_engines is None: @@ -255,14 +251,12 @@ def __init__( self._prefill_params = prefill_params self._generate_params = generate_params self._interleaved_mode = interleaved_mode - if metrics_collector is not None: - self._metrics_collector = metrics_collector # Stages 1-4 represent the life cycle of a request. # Stage 1 # At first, a request is placed here in order to get prefilled. self._prefill_backlog = queue.Queue() - self._metrics_collector.get_prefill_backlog_metric().set_function( + get_metric("jetstream_prefill_backlog").set_function( lambda: float(self._prefill_backlog.qsize()) ) @@ -279,7 +273,7 @@ def __init__( for i in range(len(self._prefill_engines)) ] for idx, backlog in enumerate(self._transfer_backlogs): - self._metrics_collector.get_transfer_backlog_metric(idx).set_function( + get_metric("jetstream_transfer_backlog", idx=idx).set_function( functools.partial(float, backlog.qsize()) ) # Stage 3 @@ -297,7 +291,7 @@ def __init__( for idx, engine in enumerate(self._generate_engines) } for idx, backlog in self._generate_backlogs.items(): - self._metrics_collector.get_generate_backlog_metric(idx).set_function( + get_metric("jetstream_generate_backlog", idx=idx).set_function( functools.partial(float, backlog.qsize()) ) # Stage 4 @@ -543,8 +537,8 @@ def _prefill_thread(self, idx: int): idx, my_transfer_backlog.qsize(), ) - self._metrics_collector.get_request_input_length().observe(true_length) - self._metrics_collector.get_time_per_prefill_token().observe( + get_metric("jetstream_request_input_length").observe(true_length) + get_metric("jetstream_time_per_prefill_token").observe( ( request.metadata.transfer_enqueue_time - request.metadata.prefill_dequeue_time @@ -645,9 +639,7 @@ def _generate_thread(self, idx: int): max_concurrent_decodes = generate_engine.max_concurrent_decodes - self._metrics_collector.get_slots_used_percentage_metric( - idx - ).set_function( + get_metric("jetstream_slots_used_percentage", idx=idx).set_function( lambda: float(1 - (my_slots.qsize() / max_concurrent_decodes)) ) @@ -679,11 +671,8 @@ def _generate_thread(self, idx: int): if new_request is None: break new_request.metadata.generate_dequeue_time = time.perf_counter() - if ( - self._metrics_collector - and new_request.metadata.start_time is not None - ): - self._metrics_collector.get_queue_duration().observe( + if new_request.metadata.start_time is not None: + get_metric("jetstream_queue_duration").observe( # Time in prefill queue new_request.metadata.prefill_dequeue_time - new_request.metadata.prefill_enqueue_time @@ -792,7 +781,7 @@ def _detokenize_thread(self, idx: int): request.enqueue_samples(results) first_token_return_time = time.perf_counter() - self._metrics_collector.get_time_to_first_token().observe( + get_metric("jetstream_time_to_first_token").observe( first_token_return_time - request.metadata.prefill_dequeue_time ) logging.info( @@ -824,18 +813,18 @@ def _detokenize_thread(self, idx: int): if request.complete.all(): request.metadata.complete_time = time.perf_counter() request.return_channel.close() - self._metrics_collector.get_request_output_length().observe( + get_metric("jetstream_request_output_length").observe( result_tokens.get_result_at_slot(slot).lengths ) - self._metrics_collector.get_request_success_count_metric().inc() - self._metrics_collector.get_time_per_output_token().observe( + get_metric("jetstream_request_success_count").inc() + get_metric("jetstream_time_per_output_token").observe( ( request.metadata.complete_time - request.metadata.transfer_enqueue_time ) / result_tokens.get_result_at_slot(slot).lengths ) - self._metrics_collector.get_time_per_request().observe( + get_metric("jetstream_time_per_request").observe( request.metadata.complete_time - request.metadata.transfer_enqueue_time ) @@ -852,7 +841,7 @@ def _detokenize_thread(self, idx: int): request.metadata.complete_time - request.metadata.generate_dequeue_time ) - self._metrics_collector.get_wait_time_per_request().observe( + get_metric("jetstream_wait_time_per_request").observe( total_time - prefill_time - generate_time ) # Place the slot back on the free queue. diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 9a8ba629..7cd9bcf8 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -32,7 +32,7 @@ import jax from jetstream.core import config_lib from jetstream.core import orchestrator -from jetstream.core.metrics.prometheus import JetstreamMetricsCollector +from jetstream.core.metrics.prometheus import JetstreamMetricsCollector, get_metric from jetstream.core.proto import jetstream_pb2_grpc from jetstream.engine import aot_utils, engine_api @@ -97,7 +97,6 @@ def create_driver( config: Type[config_lib.ServerConfig], devices: Any, jax_padding: bool = True, - metrics_collector: JetstreamMetricsCollector | None = None, enable_model_warmup: bool = False, ): """Creates a driver with a specified config. @@ -106,7 +105,6 @@ def create_driver( config: A ServerConfig to config engine, model, device slices, etc. devices: Device objects, will be used to get engine with proper slicing. jax_padding: The flag to enable JAX padding during tokenization. - metrics_collector: The JetStream Promethus metric collector. enable_model_warmup: The flag to enable model server warmup with AOT. Returns: @@ -161,7 +159,6 @@ def create_driver( generate_params=generate_params, interleaved_mode=interleaved_mode, jax_padding=jax_padding, - metrics_collector=metrics_collector, is_ray_backend=config.is_ray_backend, ) @@ -199,18 +196,14 @@ def run( server_start_time = time.time() logging.info("Kicking off gRPC server.") # Setup Prometheus server - metrics_collector: JetstreamMetricsCollector = None if metrics_server_config is not None: start_http_server(metrics_server_config.port) - metrics_collector = JetstreamMetricsCollector() else: logging.info( "Not starting Prometheus server: --prometheus_port flag not set" ) - driver = create_driver( - config, devices, jax_padding, metrics_collector, enable_model_warmup - ) + driver = create_driver(config, devices, jax_padding, enable_model_warmup) # We default threads to the total number of concurrent allowed decodes, # to make sure we can fully saturate the model. Set default minimum to 64. threads = threads or max(driver.get_total_concurrent_requests(), 64) @@ -219,10 +212,9 @@ def run( jetstream_server.start() - if metrics_collector: - metrics_collector.get_server_startup_latency_metric().set( - time.time() - server_start_time - ) + get_metric("get_server_startup_latency_metric").set( + time.time() - server_start_time + ) # Setup Jax Profiler if enable_jax_profiler: diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py index 664f9ed6..5f8e8b8f 100644 --- a/jetstream/entrypoints/http/api_server.py +++ b/jetstream/entrypoints/http/api_server.py @@ -20,6 +20,7 @@ from typing import Sequence from absl import app as abslapp from absl import flags +from prometheus_client import start_http_server from fastapi import APIRouter, Response import fastapi from fastapi.responses import StreamingResponse @@ -100,11 +101,11 @@ def server(argv: Sequence[str]): del argv # Setup Prometheus server - metrics_collector: JetstreamMetricsCollector = JetstreamMetricsCollector() if flags.FLAGS.prometheus_port != 0: - metrics_collector.start_http_server( - config_lib.MetricsServerConfig(port=flags.FLAGS.prometheus_port) + logging.info( + "Starting Prometheus server on port %d", metrics_server_config.port ) + start_http_server(port=flags.FLAGS.prometheus_port) else: logging.info( "Not starting Prometheus server: --prometheus_port flag not set" @@ -115,7 +116,6 @@ def server(argv: Sequence[str]): driver=server_lib.create_driver( config=server_config, devices=devices, - metrics_collector=metrics_collector, ) ) From f94b084fb66a4cce07a865bfbcb9c214b388d478 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 21:57:38 +0000 Subject: [PATCH 11/22] removed unused imports --- jetstream/core/config_lib.py | 1 - jetstream/core/metrics/prometheus.py | 6 ++---- jetstream/core/orchestrator.py | 2 +- jetstream/core/server_lib.py | 5 ++++- jetstream/entrypoints/http/api_server.py | 5 ++--- 5 files changed, 9 insertions(+), 10 deletions(-) diff --git a/jetstream/core/config_lib.py b/jetstream/core/config_lib.py index 53d0e37e..e9daffac 100644 --- a/jetstream/core/config_lib.py +++ b/jetstream/core/config_lib.py @@ -17,7 +17,6 @@ import dataclasses import functools from typing import Any, Callable, List, Tuple, Type -from numpy import uint16 from jetstream.engine import engine_api from jetstream.engine import mock_engine diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index 10b157fb..3dab9251 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -18,8 +18,6 @@ import shortuuid from prometheus_client import Counter, Gauge, Histogram -from jetstream.core import config_lib - from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS # Initialize the unique ID for labeling metrics @@ -215,11 +213,11 @@ def get_metric(metric_name, **labels): metric = _metrics_registry[metric_name] # Automatically add the 'id' label if it's required by the metric - if "id" in metric._labelnames: + if "id" in metric._labelnames: # pylint: disable=protected-access labels["id"] = _id # Check for any missing labels - missing_labels = set(metric._labelnames) - labels.keys() + missing_labels = set(metric._labelnames) - labels.keys() # pylint: disable=protected-access if missing_labels: raise ValueError( f"Missing labels for metric {metric_name}: {', '.join(missing_labels)}" diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 3036bfae..9ea80c84 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -94,7 +94,7 @@ from jetstream.core.utils import async_multifuture from jetstream.core.utils.return_sample import ReturnSample from jetstream.engine import engine_api, tokenizer_api, token_utils -from jetstream.core.metrics.prometheus import JetstreamMetricsCollector, get_metric +from jetstream.core.metrics.prometheus import get_metric import numpy as np root = logging.getLogger() diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 7cd9bcf8..a9a42026 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -32,7 +32,7 @@ import jax from jetstream.core import config_lib from jetstream.core import orchestrator -from jetstream.core.metrics.prometheus import JetstreamMetricsCollector, get_metric +from jetstream.core.metrics.prometheus import get_metric from jetstream.core.proto import jetstream_pb2_grpc from jetstream.engine import aot_utils, engine_api @@ -197,6 +197,9 @@ def run( logging.info("Kicking off gRPC server.") # Setup Prometheus server if metrics_server_config is not None: + logging.info( + "Starting Prometheus server on port %d", metrics_server_config.port + ) start_http_server(metrics_server_config.port) else: logging.info( diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py index 5f8e8b8f..3261ba0f 100644 --- a/jetstream/entrypoints/http/api_server.py +++ b/jetstream/entrypoints/http/api_server.py @@ -27,8 +27,7 @@ import uvicorn from google.protobuf.json_format import Parse -from jetstream.core import config_lib, orchestrator, server_lib -from jetstream.core.metrics.prometheus import JetstreamMetricsCollector +from jetstream.core import orchestrator, server_lib from jetstream.core.proto import jetstream_pb2 from jetstream.entrypoints.config import get_server_config from jetstream.entrypoints.http.protocol import DecodeRequest @@ -103,7 +102,7 @@ def server(argv: Sequence[str]): # Setup Prometheus server if flags.FLAGS.prometheus_port != 0: logging.info( - "Starting Prometheus server on port %d", metrics_server_config.port + "Starting Prometheus server on port %d", flags.FLAGS.prometheus_port ) start_http_server(port=flags.FLAGS.prometheus_port) else: From 3c7a205118f8eac4b66d967717294a2bb1e7bfcc Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 22:02:13 +0000 Subject: [PATCH 12/22] undo meticsserverconfig class removal --- jetstream/core/config_lib.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/jetstream/core/config_lib.py b/jetstream/core/config_lib.py index e9daffac..f34d3dd1 100644 --- a/jetstream/core/config_lib.py +++ b/jetstream/core/config_lib.py @@ -18,6 +18,8 @@ import functools from typing import Any, Callable, List, Tuple, Type +from numpy import uint16 + from jetstream.engine import engine_api from jetstream.engine import mock_engine @@ -75,6 +77,11 @@ class InterleavedCPUTestServer(ServerConfig): ) +@dataclasses.dataclass +class MetricsServerConfig: + port: uint16 + + # ▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼# From d63ca41a6f20db3d1b591992b953a881871c35bf Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 22:07:31 +0000 Subject: [PATCH 13/22] misnamed metrics --- jetstream/core/orchestrator.py | 6 +++--- jetstream/core/server_lib.py | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 9ea80c84..62b78828 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -256,7 +256,7 @@ def __init__( # Stage 1 # At first, a request is placed here in order to get prefilled. self._prefill_backlog = queue.Queue() - get_metric("jetstream_prefill_backlog").set_function( + get_metric("jetstream_prefill_backlog_size").set_function( lambda: float(self._prefill_backlog.qsize()) ) @@ -273,7 +273,7 @@ def __init__( for i in range(len(self._prefill_engines)) ] for idx, backlog in enumerate(self._transfer_backlogs): - get_metric("jetstream_transfer_backlog", idx=idx).set_function( + get_metric("jetstream_transfer_backlog_size", idx=idx).set_function( functools.partial(float, backlog.qsize()) ) # Stage 3 @@ -291,7 +291,7 @@ def __init__( for idx, engine in enumerate(self._generate_engines) } for idx, backlog in self._generate_backlogs.items(): - get_metric("jetstream_generate_backlog", idx=idx).set_function( + get_metric("jetstream_generate_backlog_size", idx=idx).set_function( functools.partial(float, backlog.qsize()) ) # Stage 4 diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index a9a42026..a12c4e52 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -215,9 +215,7 @@ def run( jetstream_server.start() - get_metric("get_server_startup_latency_metric").set( - time.time() - server_start_time - ) + get_metric("get_server_startup_latency").set(time.time() - server_start_time) # Setup Jax Profiler if enable_jax_profiler: From 8f2504ebd3070f4db2b6359213e843e89fe0dd10 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 22:11:56 +0000 Subject: [PATCH 14/22] misnamed metric --- jetstream/core/server_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index a12c4e52..8a2ac1da 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -215,7 +215,7 @@ def run( jetstream_server.start() - get_metric("get_server_startup_latency").set(time.time() - server_start_time) + get_metric("jetstream_server_startup_latency").set(time.time() - server_start_time) # Setup Jax Profiler if enable_jax_profiler: From 2a67bc8fd1b821a0a533e4b4db78b133b7288a46 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 22:18:07 +0000 Subject: [PATCH 15/22] fmt --- jetstream/core/server_lib.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 8a2ac1da..888b3408 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -215,7 +215,9 @@ def run( jetstream_server.start() - get_metric("jetstream_server_startup_latency").set(time.time() - server_start_time) + get_metric("jetstream_server_startup_latency").set( + time.time() - server_start_time + ) # Setup Jax Profiler if enable_jax_profiler: From c272e0902f7a8d80e817c89a6b244d42f67da100 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 22:42:02 +0000 Subject: [PATCH 16/22] revert a few changes --- jetstream/core/config_lib.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/jetstream/core/config_lib.py b/jetstream/core/config_lib.py index f34d3dd1..f3022d01 100644 --- a/jetstream/core/config_lib.py +++ b/jetstream/core/config_lib.py @@ -17,7 +17,6 @@ import dataclasses import functools from typing import Any, Callable, List, Tuple, Type - from numpy import uint16 from jetstream.engine import engine_api @@ -49,6 +48,11 @@ class InstantiatedEngines: interleaved_engines: List[engine_api.Engine] +@dataclasses.dataclass +class MetricsServerConfig: + port: uint16 + + # ▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼# @@ -77,11 +81,6 @@ class InterleavedCPUTestServer(ServerConfig): ) -@dataclasses.dataclass -class MetricsServerConfig: - port: uint16 - - # ▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼▼# From 7b95aaf123747af76757d5e04f45b15bc0870786 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Thu, 8 Aug 2024 22:42:58 +0000 Subject: [PATCH 17/22] remove newline --- jetstream/core/metrics/prometheus.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index 3dab9251..23694836 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -17,7 +17,6 @@ import os import shortuuid from prometheus_client import Counter, Gauge, Histogram - from jetstream.engine.token_utils import DEFAULT_PREFILL_BUCKETS # Initialize the unique ID for labeling metrics From 25212aa711ae0da9271ec330ddf10e8ffbee22eb Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Fri, 9 Aug 2024 22:13:23 +0000 Subject: [PATCH 18/22] helper functions --- Makefile | 2 +- jetstream/core/metrics/utils.py | 40 +++++++++++++++++++++++++++++++++ jetstream/core/orchestrator.py | 36 +++++------------------------ 3 files changed, 46 insertions(+), 32 deletions(-) create mode 100644 jetstream/core/metrics/utils.py diff --git a/Makefile b/Makefile index a7699a53..2445a9d2 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,7 @@ generate-and-prepend-preambles: done format: - $(PIP) install pyink + pyink --pyink-indentation 2 --line-length 80 --verbose . # Code checking related targets diff --git a/jetstream/core/metrics/utils.py b/jetstream/core/metrics/utils.py new file mode 100644 index 00000000..5761808c --- /dev/null +++ b/jetstream/core/metrics/utils.py @@ -0,0 +1,40 @@ +from jetstream.core.orchestrator import ActiveRequest + + +def get_time_per_prefill_token(request: ActiveRequest, true_length: int): + return ( + request.metadata.transfer_enqueue_time + - request.metadata.prefill_dequeue_time + ) / true_length + + +def get_queue_duration(request: ActiveRequest): + return ( + # Time in prefill queue + request.metadata.prefill_dequeue_time + - request.metadata.prefill_enqueue_time + # Time in transfer queue + + request.metadata.transfer_dequeue_time + - request.metadata.transfer_enqueue_time + # Time in generate queue + + request.metadata.generate_dequeue_time + - request.metadata.generate_enqueue_time + ) + + +def get_tpot(request: ActiveRequest, result_tokens): + return ( + request.metadata.complete_time - request.metadata.transfer_enqueue_time + ) / result_tokens.get_result_at_slot(slot).lengths + + +def get_wait_time(request: ActiveRequest): + total_time = request.metadata.complete_time - request.metadata.start_time + prefill_time = ( + request.metadata.transfer_enqueue_time + - request.metadata.prefill_dequeue_time + ) + generate_time = ( + request.metadata.complete_time - request.metadata.generate_dequeue_time + ) + return total_time - prefill_time - generate_time diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 62b78828..12665a7e 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -89,6 +89,7 @@ import grpc import jax +from jetstream.core.metrics.utils import get_queue_duration, get_time_per_prefill_token, get_tpot, get_wait_time from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc from jetstream.core.utils import async_multifuture @@ -539,11 +540,7 @@ def _prefill_thread(self, idx: int): ) get_metric("jetstream_request_input_length").observe(true_length) get_metric("jetstream_time_per_prefill_token").observe( - ( - request.metadata.transfer_enqueue_time - - request.metadata.prefill_dequeue_time - ) - / true_length + get_time_per_prefill_token(request, true_length) ) del prefill_result @@ -673,15 +670,7 @@ def _generate_thread(self, idx: int): new_request.metadata.generate_dequeue_time = time.perf_counter() if new_request.metadata.start_time is not None: get_metric("jetstream_queue_duration").observe( - # Time in prefill queue - new_request.metadata.prefill_dequeue_time - - new_request.metadata.prefill_enqueue_time - # Time in transfer queue - + new_request.metadata.transfer_dequeue_time - - new_request.metadata.transfer_enqueue_time - # Time in generate queue - + new_request.metadata.generate_dequeue_time - - new_request.metadata.generate_enqueue_time + get_queue_duration(new_request) ) # Got free slot and new request, use them. except queue.Empty: @@ -818,11 +807,7 @@ def _detokenize_thread(self, idx: int): ) get_metric("jetstream_request_success_count").inc() get_metric("jetstream_time_per_output_token").observe( - ( - request.metadata.complete_time - - request.metadata.transfer_enqueue_time - ) - / result_tokens.get_result_at_slot(slot).lengths + get_tpot(request, result_tokens) ) get_metric("jetstream_time_per_request").observe( request.metadata.complete_time @@ -830,19 +815,8 @@ def _detokenize_thread(self, idx: int): ) if request.metadata.start_time: - total_time = ( - request.metadata.complete_time - request.metadata.start_time - ) - prefill_time = ( - request.metadata.transfer_enqueue_time - - request.metadata.prefill_dequeue_time - ) - generate_time = ( - request.metadata.complete_time - - request.metadata.generate_dequeue_time - ) get_metric("jetstream_wait_time_per_request").observe( - total_time - prefill_time - generate_time + get_wait_time(request) ) # Place the slot back on the free queue. my_live_requests[slot] = None From f0c8ef8dc8490462a5c5ea055e51b4ec823b2b1f Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Fri, 9 Aug 2024 22:14:09 +0000 Subject: [PATCH 19/22] revert --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 2445a9d2..a7699a53 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,7 @@ generate-and-prepend-preambles: done format: - + $(PIP) install pyink pyink --pyink-indentation 2 --line-length 80 --verbose . # Code checking related targets From 62d1b16ad418a509b3e161895a1d9583f133135f Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Fri, 9 Aug 2024 22:19:24 +0000 Subject: [PATCH 20/22] remove circular import --- jetstream/core/metrics/utils.py | 11 ++++------- jetstream/core/orchestrator.py | 2 +- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/jetstream/core/metrics/utils.py b/jetstream/core/metrics/utils.py index 5761808c..307bce61 100644 --- a/jetstream/core/metrics/utils.py +++ b/jetstream/core/metrics/utils.py @@ -1,14 +1,11 @@ -from jetstream.core.orchestrator import ActiveRequest - - -def get_time_per_prefill_token(request: ActiveRequest, true_length: int): +def get_time_per_prefill_token(request, true_length: int): return ( request.metadata.transfer_enqueue_time - request.metadata.prefill_dequeue_time ) / true_length -def get_queue_duration(request: ActiveRequest): +def get_queue_duration(request): return ( # Time in prefill queue request.metadata.prefill_dequeue_time @@ -22,13 +19,13 @@ def get_queue_duration(request: ActiveRequest): ) -def get_tpot(request: ActiveRequest, result_tokens): +def get_tpot(request, result_tokens, slot): return ( request.metadata.complete_time - request.metadata.transfer_enqueue_time ) / result_tokens.get_result_at_slot(slot).lengths -def get_wait_time(request: ActiveRequest): +def get_wait_time(request): total_time = request.metadata.complete_time - request.metadata.start_time prefill_time = ( request.metadata.transfer_enqueue_time diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 12665a7e..6791fd67 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -807,7 +807,7 @@ def _detokenize_thread(self, idx: int): ) get_metric("jetstream_request_success_count").inc() get_metric("jetstream_time_per_output_token").observe( - get_tpot(request, result_tokens) + get_tpot(request, result_tokens, slot) ) get_metric("jetstream_time_per_request").observe( request.metadata.complete_time From e8bf58c7c60498024e6a54a94097787d1c2c4eab Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Fri, 9 Aug 2024 22:24:35 +0000 Subject: [PATCH 21/22] reamble + docstring --- jetstream/core/metrics/utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/jetstream/core/metrics/utils.py b/jetstream/core/metrics/utils.py index 307bce61..7775e2b5 100644 --- a/jetstream/core/metrics/utils.py +++ b/jetstream/core/metrics/utils.py @@ -1,3 +1,19 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains helper functions for configuring Jetstream server metrics""" + def get_time_per_prefill_token(request, true_length: int): return ( request.metadata.transfer_enqueue_time From 419ecb8b20d7df8638eb0409e2fa69dbbba8d590 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Fri, 9 Aug 2024 22:28:55 +0000 Subject: [PATCH 22/22] fmt --- jetstream/core/metrics/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jetstream/core/metrics/utils.py b/jetstream/core/metrics/utils.py index 7775e2b5..2aa7f4b4 100644 --- a/jetstream/core/metrics/utils.py +++ b/jetstream/core/metrics/utils.py @@ -14,6 +14,7 @@ """Contains helper functions for configuring Jetstream server metrics""" + def get_time_per_prefill_token(request, true_length: int): return ( request.metadata.transfer_enqueue_time