From a9c7310f4a2efe6c7dab93a875340d9b69449579 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Thu, 28 May 2026 00:14:22 +0800 Subject: [PATCH] MB5 PD-disagg pipeline: working end-to-end MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three independent bugs were blocking PD-disagg smoke; each fix is isolated so the next PD experiment doesn't re-hit them. 1. mb5_launch.sh - stop_all() also kills mb5_pd_proxy.py (our vendored copy), not just the upstream filename, and asserts ports 8000-8007 + PROXY_PORT are free before launching — stale proxies were silently passing the readiness check. - Proxy readiness uses a generic "any HTTP response" probe; mooncake_connector_proxy only exposes /v1/completions so /v1/models 404 is expected. 2. mb5_pd_proxy.py (vendored from third_party so deploy.sh ships it) - Force min_tokens=1 on the prefill leg. Clients that set min_tokens == max_tokens (our replayer does) collide with vLLM's min_tokens<=max_tokens check after the proxy caps max_tokens=1. 3. instrument_kv_snapshot.py - Adds a second patch target: initialize MooncakeConnectorWorker.bootstrap_server = None in __init__. vLLM 0.18.1 only sets it under the is_kv_producer branch, so kv_consumer hits AttributeError as soon as the first remote prefill request lands. - apply/revert refactored to iterate over (path, patches) pairs. plot_kv_pool_timeline.py also handles snapshot files that never captured a running request (would otherwise IndexError on an empty stackplot input). Smoke: 4P+4D × 20 reqs → 20/20 success, mean 3.9s, p99 17s, 8 PIDs all writing snapshots (601 total), well above the 8C baseline. Co-Authored-By: Claude Opus 4.7 --- .../fresh_setup/instrument_kv_snapshot.py | 71 +++- microbench/fresh_setup/mb5_launch.sh | 21 +- microbench/fresh_setup/mb5_pd_proxy.py | 381 ++++++++++++++++++ .../fresh_setup/plot_kv_pool_timeline.py | 23 ++ 4 files changed, 470 insertions(+), 26 deletions(-) create mode 100644 microbench/fresh_setup/mb5_pd_proxy.py diff --git a/microbench/fresh_setup/instrument_kv_snapshot.py b/microbench/fresh_setup/instrument_kv_snapshot.py index fe0221f..d90d257 100644 --- a/microbench/fresh_setup/instrument_kv_snapshot.py +++ b/microbench/fresh_setup/instrument_kv_snapshot.py @@ -32,6 +32,10 @@ from pathlib import Path DEFAULT_VENV = Path("/home/admin/cpfs/wjh/agentic-kv-fresh/.venv") TARGET_REL = "lib/python3.12/site-packages/vllm/v1/core/sched/scheduler.py" +MOONCAKE_REL = ( + "lib/python3.12/site-packages/vllm/distributed/kv_transfer/" + "kv_connector/v1/mooncake/mooncake_connector.py" +) START_MARK = "# MB5_INSTRUMENT_START" END_MARK = "# MB5_INSTRUMENT_END" @@ -165,45 +169,66 @@ SCHED_RET_REPLACE = f""" {START_MARK} def _agentic_emit_step_log(""" -PATCHES = [ +SCHED_PATCHES = [ ("header", HEADER_ANCHOR, HEADER_ANCHOR + HEADER_INSERT), ("schedule() return", SCHED_RET_TARGET, SCHED_RET_REPLACE), ] +# ---------- Patch 3: vLLM 0.18.1 kv_consumer AttributeError fix -------------- +# In MooncakeConnectorWorker.__init__, `self.bootstrap_server` is only assigned +# inside the `is_kv_producer` branch (around line 615). For kv_consumer roles +# the attribute is never set, but later code paths (e.g. line ~1060) check +# `if self.bootstrap_server is not None:` and AttributeError. We initialize it +# unconditionally just before the role-conditional branch. +MOONCAKE_ANCHOR = " self.reqs_need_send: dict[TransferId, SendBlockMeta] = {}\n" +MOONCAKE_INSERT = ( + f" {START_MARK}\n" + f" self.bootstrap_server = None # vLLM 0.18.1 kv_consumer fix\n" + f" {END_MARK}\n" +) -def find_target(venv_or_path: Path) -> Path: - candidates = [venv_or_path, DEFAULT_VENV / TARGET_REL] +MOONCAKE_PATCHES = [ + ("kv_consumer bootstrap_server init", MOONCAKE_ANCHOR, + MOONCAKE_ANCHOR + MOONCAKE_INSERT), +] + +PATCH_FILES = [ + (TARGET_REL, SCHED_PATCHES), + (MOONCAKE_REL, MOONCAKE_PATCHES), +] + + +def find_target(venv_or_path: Path, rel_path: str) -> Path: + candidates = [venv_or_path / rel_path, DEFAULT_VENV / rel_path] for c in candidates: if c.is_file(): return c - if c.is_dir(): - sub = c / TARGET_REL - if sub.is_file(): - return sub - raise FileNotFoundError(f"cannot find vllm V1 scheduler at {venv_or_path}") + raise FileNotFoundError( + f"cannot find {rel_path} under {venv_or_path}" + ) def is_patched(text: str) -> bool: return START_MARK in text -def apply(target: Path) -> None: +def apply_one(target: Path, patches: list) -> None: text = target.read_text() if is_patched(text): print(f"[mb5-instr] already patched: {target}") return new = text - for name, src, dst in PATCHES: + for name, src, dst in patches: if src not in new: raise RuntimeError( f"patch {name!r}: anchor not found in {target}." ) new = new.replace(src, dst, 1) target.write_text(new) - print(f"[mb5-instr] applied {len(PATCHES)} patches -> {target}") + print(f"[mb5-instr] applied {len(patches)} patches -> {target}") -def revert(target: Path) -> None: +def revert_one(target: Path) -> None: text = target.read_text() if not is_patched(text): print(f"[mb5-instr] not patched (nothing to revert): {target}") @@ -225,16 +250,18 @@ def main() -> None: p.add_argument("--check", action="store_true") p.add_argument("--venv", type=Path, default=DEFAULT_VENV) args = p.parse_args() - target = find_target(args.venv) - if args.apply: - apply(target) - elif args.revert: - revert(target) - elif args.check: - text = target.read_text() - print(f"[mb5-instr] {'PATCHED' if is_patched(text) else 'CLEAN'}: {target}") - else: - p.error("specify --apply / --revert / --check") + for rel_path, patches in PATCH_FILES: + target = find_target(args.venv, rel_path) + if args.apply: + apply_one(target, patches) + elif args.revert: + revert_one(target) + elif args.check: + text = target.read_text() + state = 'PATCHED' if is_patched(text) else 'CLEAN' + print(f"[mb5-instr] {state}: {target}") + else: + p.error("specify --apply / --revert / --check") if __name__ == "__main__": diff --git a/microbench/fresh_setup/mb5_launch.sh b/microbench/fresh_setup/mb5_launch.sh index 45bb988..d530d2c 100755 --- a/microbench/fresh_setup/mb5_launch.sh +++ b/microbench/fresh_setup/mb5_launch.sh @@ -28,7 +28,7 @@ VENV="${FRESH_ROOT}/.venv" MODEL="${MODEL:-/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}" SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" INSTRUMENT="${SCRIPT_DIR}/instrument_kv_snapshot.py" -PROXY_SRC="${SCRIPT_DIR}/../../third_party/vllm/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py" +PROXY_SRC="${SCRIPT_DIR}/mb5_pd_proxy.py" CONFIG="${CONFIG:-8C}" RUN_LABEL="${RUN_LABEL:-default}" @@ -44,10 +44,21 @@ BASE_BP=8998 BASE_MASTER=29500 stop_all() { + pkill -9 -f "mb5_pd_proxy.py" 2>/dev/null || true pkill -9 -f "mooncake_connector_proxy.py" 2>/dev/null || true pkill -9 -f "vllm serve" 2>/dev/null || true pkill -9 -f "EngineCore" 2>/dev/null || true sleep 3 + # Hard guarantee: required ports must be free before we start. If they + # aren't, an earlier run left a stale process holding the socket and the + # readiness check would (silently) probe the stale proxy. + for port in 8000 8001 8002 8003 8004 8005 8006 8007 "${PROXY_PORT}"; do + if ss -ltn 2>/dev/null | awk '{print $4}' | grep -qE "[:.]${port}\$"; then + echo "[mb5] FATAL port ${port} still in use after stop_all; manual cleanup needed" + ss -ltnp 2>/dev/null | grep -E "[:.]${port}\$" || true + exit 1 + fi + done } case "${1:-start}" in @@ -183,9 +194,11 @@ if [ "${ROLES}" = "pd" ]; then nohup python "${PROXY_SRC}" "${proxy_args[@]}" --port "${PROXY_PORT}" --host 0.0.0.0 \ > "${LOGS_DIR}/proxy.log" 2>&1 & disown - # wait for proxy + # wait for proxy. Official mooncake_connector_proxy only handles + # /v1/completions, so /health and /v1/models return 404 — accept any + # HTTP response as "alive". tries=0 - while ! curl -sf "http://127.0.0.1:${PROXY_PORT}/v1/models" >/dev/null 2>&1; do + while ! curl -s -o /dev/null -w "%{http_code}" "http://127.0.0.1:${PROXY_PORT}/" 2>/dev/null | grep -qE "^[0-9]"; do tries=$((tries+1)) if [ ${tries} -gt 60 ]; then echo "[mb5] FATAL proxy did not come up in 2 min" @@ -194,7 +207,7 @@ if [ "${ROLES}" = "pd" ]; then fi sleep 2 done - echo " proxy port=${PROXY_PORT} ready" + echo " proxy port=${PROXY_PORT} ready (HTTP responding)" ENDPOINTS="http://127.0.0.1:${PROXY_PORT}" fi diff --git a/microbench/fresh_setup/mb5_pd_proxy.py b/microbench/fresh_setup/mb5_pd_proxy.py new file mode 100644 index 0000000..d28e3fa --- /dev/null +++ b/microbench/fresh_setup/mb5_pd_proxy.py @@ -0,0 +1,381 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import asyncio +import ipaddress +import itertools +import os +import urllib +import uuid +from contextlib import asynccontextmanager +from typing import Any + +import httpx +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import StreamingResponse + + +def maybe_wrap_ipv6_address(address: str) -> str: + try: + ipaddress.IPv6Address(address) + return f"[{address}]" + except ValueError: + return address + + +def make_http_path(host: str, port: int) -> str: + return f"http://{host}:{port}" + + +def prefiller_cycle(prefill_clients: list[Any]): + while True: + for prefill_client in prefill_clients: + for i in range(prefill_client["dp_size"]): + yield prefill_client, i + + +async def get_prefiller_info(prefill_clients: list, ready: asyncio.Event): + for prefill_client in prefill_clients: + while True: + try: + # Wait for prefill service to be ready + response = await prefill_client["client"].get("/health") + response.raise_for_status() + except Exception: + await asyncio.sleep(1) + continue + + response = await prefill_client["client"].get( + prefill_client["bootstrap_addr"] + "/query" + ) + response.raise_for_status() + data = response.json() + break + + for dp_rank, dp_entry in data.items(): + prefill_client["dp_engine_id"][int(dp_rank)] = dp_entry["engine_id"] + dp_size = len(data) + prefill_client["dp_size"] = dp_size + print(f"Inited prefiller {prefill_client['url']} with dp_size={dp_size}") + + ready.set() + print("All prefiller instances are ready.") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize client pools for prefiller and decoder services + app.state.prefill_clients = [] + app.state.decode_clients = [] + app.state.ready = asyncio.Event() + + # Create prefill clients + for i, (url, bootstrap_port) in enumerate(global_args.prefill): + parsed_url = urllib.parse.urlparse(url) + hostname = maybe_wrap_ipv6_address(parsed_url.hostname) + app.state.prefill_clients.append( + { + "client": httpx.AsyncClient( + timeout=None, + base_url=url, + limits=httpx.Limits( + max_connections=None, + max_keepalive_connections=None, + ), + ), + "url": url, + "bootstrap_addr": make_http_path(hostname, bootstrap_port or 8998), + "dp_engine_id": {}, + } + ) + + # Create decode clients + for i, url in enumerate(global_args.decode): + parsed_url = urllib.parse.urlparse(url) + hostname = maybe_wrap_ipv6_address(parsed_url.hostname) + app.state.decode_clients.append( + { + "client": httpx.AsyncClient( + timeout=None, + base_url=url, + limits=httpx.Limits( + max_connections=None, + max_keepalive_connections=None, + ), + ), + } + ) + + asyncio.create_task(get_prefiller_info(app.state.prefill_clients, app.state.ready)) + + # Initialize round-robin iterators + app.state.prefill_iterator = prefiller_cycle(app.state.prefill_clients) + app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients))) + + print( + f"Got {len(app.state.prefill_clients)} prefill clients " + f"and {len(app.state.decode_clients)} decode clients." + ) + + yield + + # Shutdown: Close all clients + for client_info in app.state.prefill_clients: + await client_info["client"].aclose() + + for client_info in app.state.decode_clients: + await client_info["client"].aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + # Always use 127.0.0.1 as localhost binds to IPv6 which is blocked on CI + parser.add_argument("--host", type=str, default="127.0.0.1") + + # For prefiller instances + parser.add_argument( + "--prefill", + nargs="+", + action="append", + dest="prefill_raw", + metavar=("URL", "bootstrap_port"), + help=( + "Prefill server URL and optional bootstrap port. " + "Can be specified multiple times. " + "Format: --prefill URL [BOOTSTRAP_PORT]. " + "BOOTSTRAP_PORT can be a port number, " + "'none', or omitted (defaults to none)." + ), + ) + + # For decoder instances + parser.add_argument( + "--decode", + nargs=1, + action="append", + dest="decode_raw", + metavar=("URL",), + help="Decode server URL. Can be specified multiple times.", + ) + + args = parser.parse_args() + args.prefill = _parse_prefill_urls(args.prefill_raw) + args.decode = _parse_decode_urls(args.decode_raw) + + return args + + +# From sglang router_args.py +def _parse_prefill_urls(prefill_list): + """Parse prefill URLs from --prefill arguments. + + Format: --prefill URL [BOOTSTRAP_PORT] + Example: + --prefill http://prefill1:8080 9000 # With bootstrap port + --prefill http://prefill2:8080 none # Explicitly no bootstrap port + --prefill http://prefill3:8080 # Defaults to no bootstrap port + """ + if not prefill_list: + return [] + + prefill_urls = [] + for prefill_args in prefill_list: + url = prefill_args[0] + + # Handle optional bootstrap port + if len(prefill_args) >= 2: + bootstrap_port_str = prefill_args[1] + # Handle 'none' as None + if bootstrap_port_str.lower() == "none": + bootstrap_port = None + else: + try: + bootstrap_port = int(bootstrap_port_str) + except ValueError as e: + raise ValueError( + f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'" # noqa: E501 + ) from e + else: + # No bootstrap port specified, default to None + bootstrap_port = None + + prefill_urls.append((url, bootstrap_port)) + + return prefill_urls + + +def _parse_decode_urls(decode_list): + """Parse decode URLs from --decode arguments. + + Format: --decode URL + Example: --decode http://decode1:8081 --decode http://decode2:8081 + """ + if not decode_list: + return [] + + # decode_list is a list of single-element lists due to nargs=1 + return [url[0] for url in decode_list] + + +def get_next_client(app, service_type: str): + """ + Get the next client in round-robin fashion. + + Args: + app: The FastAPI app instance + service_type: Either 'prefill' or 'decode' + + Returns: + The next client to use + """ + if service_type == "prefill": + return next(app.state.prefill_iterator) + elif service_type == "decode": + client_idx = next(app.state.decode_iterator) + return app.state.decode_clients[client_idx] + else: + raise ValueError(f"Unknown service type: {service_type}") + + +async def send_request_to_service( + client_info: dict, dp_rank: int, endpoint: str, req_data: dict, request_id: str +): + """ + Send a request to a service using a client from the pool. + """ + req_data = req_data.copy() + req_data["kv_transfer_params"] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "transfer_id": f"xfer-{request_id}", + } + req_data["stream"] = False + req_data["max_tokens"] = 1 + # MB5 fix: clients (our replayer) may set min_tokens to enforce a fixed + # output length. After the proxy caps max_tokens=1 on the prefill leg, + # any min_tokens > 1 violates vLLM's `min_tokens <= max_tokens` check. + if "min_tokens" in req_data: + req_data["min_tokens"] = 1 + if "max_completion_tokens" in req_data: + req_data["max_completion_tokens"] = 1 + if "stream_options" in req_data: + del req_data["stream_options"] + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + "X-data-parallel-rank": str(dp_rank), + } + + response = await client_info["client"].post( + endpoint, json=req_data, headers=headers + ) + response.raise_for_status() + + # CRITICAL: Release connection back to pool + await response.aclose() + + +async def stream_service_response( + prefill_client_info: dict, + prefill_dp_rank: int, + decode_client_info: dict, + endpoint: str, + req_data: dict, + request_id: str, +): + """ + Asynchronously stream response from a service using a client from the pool. + """ + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + + req_data["kv_transfer_params"] = { + "do_remote_decode": False, + "do_remote_prefill": True, + "remote_bootstrap_addr": prefill_client_info["bootstrap_addr"], + "remote_engine_id": prefill_client_info["dp_engine_id"][prefill_dp_rank], + "transfer_id": f"xfer-{request_id}", + } + + async with decode_client_info["client"].stream( + "POST", endpoint, json=req_data, headers=headers + ) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +async def _handle_completions(api: str, request: Request): + if not app.state.ready.is_set(): + raise HTTPException(status_code=503, detail="Service Unavailable") + + try: + req_data = await request.json() + request_id = str(uuid.uuid4()) + + # Get the next prefill client in round-robin fashion + prefill_client_info, prefill_dp_rank = get_next_client(request.app, "prefill") + + # Send request to prefill service + asyncio.create_task( + send_request_to_service( + prefill_client_info, prefill_dp_rank, api, req_data, request_id + ) + ) + + decode_client_info = get_next_client(request.app, "decode") + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response( + prefill_client_info, + prefill_dp_rank, + decode_client_info, + api, + req_data, + request_id=request_id, + ): + yield chunk + + return StreamingResponse(generate_stream(), media_type="application/json") + + except Exception as e: + import sys + import traceback + + exc_info = sys.exc_info() + print(f"Error occurred in disagg prefill proxy server - {api} endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + return await _handle_completions("/v1/completions", request) + + +@app.post("/v1/chat/completions") +async def handle_chat_completions(request: Request): + return await _handle_completions("/v1/chat/completions", request) + + +if __name__ == "__main__": + global global_args + global_args = parse_args() + + import uvicorn + + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/microbench/fresh_setup/plot_kv_pool_timeline.py b/microbench/fresh_setup/plot_kv_pool_timeline.py index 3dc2304..36a38c2 100644 --- a/microbench/fresh_setup/plot_kv_pool_timeline.py +++ b/microbench/fresh_setup/plot_kv_pool_timeline.py @@ -70,6 +70,29 @@ def plot_one_instance(snaps: list[dict], out: Path, title: str) -> None: # Sort by first-seen time so the band order follows arrival all_req_ids.sort(key=lambda r: req_first_seen[r]) + if not all_req_ids: + # No requests ever ran on this instance; plot a flat used_blocks line + # instead of the stackplot (which can't handle empty input). + fig, ax1 = plt.subplots(figsize=(13, 4)) + used = [s["used_blocks"] for s in snaps] + ax1.plot(times, used, color="#888", lw=1.5, label="used_blocks (no running reqs sampled)") + ax1.axhline(total_blocks, color="#444", lw=1.5, ls="-", + label=f"pool total = {total_blocks} blocks") + ax1.axhline(total_blocks * 0.9, color="#c44e52", lw=1.2, ls="--", alpha=0.7, + label="90% capacity") + ax1.set_ylabel("KV blocks") + ax1.set_ylim(0, total_blocks * 1.05) + ax1.set_xlabel("wall-clock since first snapshot (s)") + ax1.set_title(title + " [no per-request data; instance idle?]") + ax1.legend(loc="upper right", fontsize=9) + ax1.grid(True, alpha=0.3) + out.parent.mkdir(parents=True, exist_ok=True) + fig.tight_layout() + fig.savefig(out, dpi=120) + plt.close(fig) + print(f"wrote {out} (n_snapshots={len(snaps)}, 0 running reqs ever)") + return + matrix = np.zeros((len(all_req_ids), len(times)), dtype=np.int64) req_to_row = {r: i for i, r in enumerate(all_req_ids)} for j, s in enumerate(snaps):