Skip to content

[WIP][V1][P/D]Enhance Performance for P2pNcclConnector #20074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions examples/online_serving/disagg_xpyd/disagg_prefill_proxy_xpyd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,35 @@
import os
import socket
import threading
import time
import uuid
from typing import Any

import aiohttp
import msgpack
import zmq
from quart import Quart, make_response, request

count = 0
prefill_instances: dict[str, str] = {} # http_address: zmq_address
decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp)
decode_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp)

prefill_cv = threading.Condition()
decode_cv = threading.Condition()

DEFAULT_PING_SECONDS = 5


def _remove_oldest_instances(instances: dict[str, Any]) -> None:
oldest_key = next(iter(instances), None)
while oldest_key is not None:
value = instances[oldest_key]
if value[1] > time.time():
break
print(f"🔴Remove [HTTP:{oldest_key}, ZMQ:{value[0]}, stamp:{value[1]}]")
instances.pop(oldest_key, None)
oldest_key = next(iter(instances), None)


def _listen_for_register(poller, router_socket):
while True:
Expand All @@ -30,19 +45,33 @@ def _listen_for_register(poller, router_socket):
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_instances[data["http_address"]] = data["zmq_address"]
node = prefill_instances.pop(data["http_address"], None)
prefill_instances[data["http_address"]] = (
data["zmq_address"],
time.time() + DEFAULT_PING_SECONDS,
)
_remove_oldest_instances(prefill_instances)

elif data["type"] == "D":
global decode_instances
global decode_cv
with decode_cv:
decode_instances[data["http_address"]] = data["zmq_address"]
node = decode_instances.pop(data["http_address"], None)
decode_instances[data["http_address"]] = (
data["zmq_address"],
time.time() + DEFAULT_PING_SECONDS,
)
_remove_oldest_instances(decode_instances)
else:
print(
"Unexpected, Received message from %s, data: %s",
remote_address,
data,
)

if node is None:
print(f"🔵Add [HTTP:{data['http_address']}, ZMQ:{data['zmq_address']}")


def start_service_discovery(hostname, port):
if not hostname:
Expand Down Expand Up @@ -104,12 +133,14 @@ async def handle_request():
with prefill_cv:
prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
prefill_zmq_addr = prefill_zmq_addr[0]

global decode_instances
global decode_cv
with decode_cv:
decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
decode_zmq_addr = decode_zmq_addr[0]

print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def inject_kv_into_layer(

# Load the KV for each request each layer
for request in metadata.requests:
is_success = True
for layer_name in forward_context.no_compile_layers:
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache_layer = attn_layer.kv_cache[ \
Expand All @@ -202,10 +203,14 @@ def inject_kv_into_layer(
if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s",
request.request_id)
continue
is_success = False
break

inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping, request.request_id)
if is_success:
logger.info(
"🔵KV Cache is injected into layer, %s", request.request_id)

def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
Expand Down Expand Up @@ -265,9 +270,8 @@ def extract_kv_from_layer(
kv_cache, remote_address)

def wait_for_save(self):
if self.is_producer:
assert self.p2p_nccl_engine is not None
self.p2p_nccl_engine.wait_for_sent()
"""P2pNcclConnector does not save explicitly."""
return

def get_finished(
self, finished_req_ids: set[str],
Expand Down Expand Up @@ -317,18 +321,18 @@ def get_num_new_matched_tokens(
num_external_tokens = (len(request.prompt_token_ids) - 1 -
num_computed_tokens)

if num_external_tokens < 0:
num_external_tokens = 0
if num_external_tokens <= 0:
return 0, False

return num_external_tokens, False
return num_external_tokens, True

def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
"""
if not self.is_producer and num_external_tokens > 0:
if not self.is_producer and num_external_tokens == 0:
self._requests_need_load[request.request_id] = (
request, blocks.get_block_ids()[0])

Expand Down Expand Up @@ -414,14 +418,6 @@ def build_connector_meta(
block_ids=block_ids,
block_size=self._block_size)

# Requests loaded asynchronously are not in the scheduler_output.
# for request_id in self._requests_need_load:
# request, block_ids = self._requests_need_load[request_id]
# meta.add_request(request_id=request.request_id,
# token_ids=request.prompt_token_ids,
# block_ids=block_ids,
# block_size=self._block_size)

self._requests_need_load.clear()
return meta

Expand All @@ -443,7 +439,7 @@ def request_finished(

self.chunked_prefill.pop(request.request_id, None)

return False, None
return self.is_producer, None

# ==============================
# Static methods
Expand Down
Loading
Loading