MB5 PD-disagg pipeline: working end-to-end

Three independent bugs were blocking PD-disagg smoke; each fix is
isolated so the next PD experiment doesn't re-hit them.

1. mb5_launch.sh
   - stop_all() also kills mb5_pd_proxy.py (our vendored copy),
     not just the upstream filename, and asserts ports 8000-8007 +
     PROXY_PORT are free before launching — stale proxies were
     silently passing the readiness check.
   - Proxy readiness uses a generic "any HTTP response" probe;
     mooncake_connector_proxy only exposes /v1/completions so
     /v1/models 404 is expected.

2. mb5_pd_proxy.py (vendored from third_party so deploy.sh ships it)
   - Force min_tokens=1 on the prefill leg. Clients that set
     min_tokens == max_tokens (our replayer does) collide with
     vLLM's min_tokens<=max_tokens check after the proxy caps
     max_tokens=1.

3. instrument_kv_snapshot.py
   - Adds a second patch target: initialize
     MooncakeConnectorWorker.bootstrap_server = None in __init__.
     vLLM 0.18.1 only sets it under the is_kv_producer branch, so
     kv_consumer hits AttributeError as soon as the first remote
     prefill request lands.
   - apply/revert refactored to iterate over (path, patches) pairs.

plot_kv_pool_timeline.py also handles snapshot files that never
captured a running request (would otherwise IndexError on an empty
stackplot input).

Smoke: 4P+4D × 20 reqs → 20/20 success, mean 3.9s, p99 17s, 8 PIDs
all writing snapshots (601 total), well above the 8C baseline.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-28 00:14:22 +08:00
parent e0d3b5150a
commit a9c7310f4a
4 changed files with 470 additions and 26 deletions

View File

@@ -32,6 +32,10 @@ from pathlib import Path
DEFAULT_VENV = Path("/home/admin/cpfs/wjh/agentic-kv-fresh/.venv")
TARGET_REL = "lib/python3.12/site-packages/vllm/v1/core/sched/scheduler.py"
MOONCAKE_REL = (
"lib/python3.12/site-packages/vllm/distributed/kv_transfer/"
"kv_connector/v1/mooncake/mooncake_connector.py"
)
START_MARK = "# MB5_INSTRUMENT_START"
END_MARK = "# MB5_INSTRUMENT_END"
@@ -165,45 +169,66 @@ SCHED_RET_REPLACE = f""" {START_MARK}
def _agentic_emit_step_log("""
PATCHES = [
SCHED_PATCHES = [
("header", HEADER_ANCHOR, HEADER_ANCHOR + HEADER_INSERT),
("schedule() return", SCHED_RET_TARGET, SCHED_RET_REPLACE),
]
# ---------- Patch 3: vLLM 0.18.1 kv_consumer AttributeError fix --------------
# In MooncakeConnectorWorker.__init__, `self.bootstrap_server` is only assigned
# inside the `is_kv_producer` branch (around line 615). For kv_consumer roles
# the attribute is never set, but later code paths (e.g. line ~1060) check
# `if self.bootstrap_server is not None:` and AttributeError. We initialize it
# unconditionally just before the role-conditional branch.
MOONCAKE_ANCHOR = " self.reqs_need_send: dict[TransferId, SendBlockMeta] = {}\n"
MOONCAKE_INSERT = (
f" {START_MARK}\n"
f" self.bootstrap_server = None # vLLM 0.18.1 kv_consumer fix\n"
f" {END_MARK}\n"
)
def find_target(venv_or_path: Path) -> Path:
candidates = [venv_or_path, DEFAULT_VENV / TARGET_REL]
MOONCAKE_PATCHES = [
("kv_consumer bootstrap_server init", MOONCAKE_ANCHOR,
MOONCAKE_ANCHOR + MOONCAKE_INSERT),
]
PATCH_FILES = [
(TARGET_REL, SCHED_PATCHES),
(MOONCAKE_REL, MOONCAKE_PATCHES),
]
def find_target(venv_or_path: Path, rel_path: str) -> Path:
candidates = [venv_or_path / rel_path, DEFAULT_VENV / rel_path]
for c in candidates:
if c.is_file():
return c
if c.is_dir():
sub = c / TARGET_REL
if sub.is_file():
return sub
raise FileNotFoundError(f"cannot find vllm V1 scheduler at {venv_or_path}")
raise FileNotFoundError(
f"cannot find {rel_path} under {venv_or_path}"
)
def is_patched(text: str) -> bool:
return START_MARK in text
def apply(target: Path) -> None:
def apply_one(target: Path, patches: list) -> None:
text = target.read_text()
if is_patched(text):
print(f"[mb5-instr] already patched: {target}")
return
new = text
for name, src, dst in PATCHES:
for name, src, dst in patches:
if src not in new:
raise RuntimeError(
f"patch {name!r}: anchor not found in {target}."
)
new = new.replace(src, dst, 1)
target.write_text(new)
print(f"[mb5-instr] applied {len(PATCHES)} patches -> {target}")
print(f"[mb5-instr] applied {len(patches)} patches -> {target}")
def revert(target: Path) -> None:
def revert_one(target: Path) -> None:
text = target.read_text()
if not is_patched(text):
print(f"[mb5-instr] not patched (nothing to revert): {target}")
@@ -225,16 +250,18 @@ def main() -> None:
p.add_argument("--check", action="store_true")
p.add_argument("--venv", type=Path, default=DEFAULT_VENV)
args = p.parse_args()
target = find_target(args.venv)
if args.apply:
apply(target)
elif args.revert:
revert(target)
elif args.check:
text = target.read_text()
print(f"[mb5-instr] {'PATCHED' if is_patched(text) else 'CLEAN'}: {target}")
else:
p.error("specify --apply / --revert / --check")
for rel_path, patches in PATCH_FILES:
target = find_target(args.venv, rel_path)
if args.apply:
apply_one(target, patches)
elif args.revert:
revert_one(target)
elif args.check:
text = target.read_text()
state = 'PATCHED' if is_patched(text) else 'CLEAN'
print(f"[mb5-instr] {state}: {target}")
else:
p.error("specify --apply / --revert / --check")
if __name__ == "__main__":

View File

@@ -28,7 +28,7 @@ VENV="${FRESH_ROOT}/.venv"
MODEL="${MODEL:-/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
INSTRUMENT="${SCRIPT_DIR}/instrument_kv_snapshot.py"
PROXY_SRC="${SCRIPT_DIR}/../../third_party/vllm/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py"
PROXY_SRC="${SCRIPT_DIR}/mb5_pd_proxy.py"
CONFIG="${CONFIG:-8C}"
RUN_LABEL="${RUN_LABEL:-default}"
@@ -44,10 +44,21 @@ BASE_BP=8998
BASE_MASTER=29500
stop_all() {
pkill -9 -f "mb5_pd_proxy.py" 2>/dev/null || true
pkill -9 -f "mooncake_connector_proxy.py" 2>/dev/null || true
pkill -9 -f "vllm serve" 2>/dev/null || true
pkill -9 -f "EngineCore" 2>/dev/null || true
sleep 3
# Hard guarantee: required ports must be free before we start. If they
# aren't, an earlier run left a stale process holding the socket and the
# readiness check would (silently) probe the stale proxy.
for port in 8000 8001 8002 8003 8004 8005 8006 8007 "${PROXY_PORT}"; do
if ss -ltn 2>/dev/null | awk '{print $4}' | grep -qE "[:.]${port}\$"; then
echo "[mb5] FATAL port ${port} still in use after stop_all; manual cleanup needed"
ss -ltnp 2>/dev/null | grep -E "[:.]${port}\$" || true
exit 1
fi
done
}
case "${1:-start}" in
@@ -183,9 +194,11 @@ if [ "${ROLES}" = "pd" ]; then
nohup python "${PROXY_SRC}" "${proxy_args[@]}" --port "${PROXY_PORT}" --host 0.0.0.0 \
> "${LOGS_DIR}/proxy.log" 2>&1 &
disown
# wait for proxy
# wait for proxy. Official mooncake_connector_proxy only handles
# /v1/completions, so /health and /v1/models return 404 — accept any
# HTTP response as "alive".
tries=0
while ! curl -sf "http://127.0.0.1:${PROXY_PORT}/v1/models" >/dev/null 2>&1; do
while ! curl -s -o /dev/null -w "%{http_code}" "http://127.0.0.1:${PROXY_PORT}/" 2>/dev/null | grep -qE "^[0-9]"; do
tries=$((tries+1))
if [ ${tries} -gt 60 ]; then
echo "[mb5] FATAL proxy did not come up in 2 min"
@@ -194,7 +207,7 @@ if [ "${ROLES}" = "pd" ]; then
fi
sleep 2
done
echo " proxy port=${PROXY_PORT} ready"
echo " proxy port=${PROXY_PORT} ready (HTTP responding)"
ENDPOINTS="http://127.0.0.1:${PROXY_PORT}"
fi

View File

@@ -0,0 +1,381 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import asyncio
import ipaddress
import itertools
import os
import urllib
import uuid
from contextlib import asynccontextmanager
from typing import Any
import httpx
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
def maybe_wrap_ipv6_address(address: str) -> str:
try:
ipaddress.IPv6Address(address)
return f"[{address}]"
except ValueError:
return address
def make_http_path(host: str, port: int) -> str:
return f"http://{host}:{port}"
def prefiller_cycle(prefill_clients: list[Any]):
while True:
for prefill_client in prefill_clients:
for i in range(prefill_client["dp_size"]):
yield prefill_client, i
async def get_prefiller_info(prefill_clients: list, ready: asyncio.Event):
for prefill_client in prefill_clients:
while True:
try:
# Wait for prefill service to be ready
response = await prefill_client["client"].get("/health")
response.raise_for_status()
except Exception:
await asyncio.sleep(1)
continue
response = await prefill_client["client"].get(
prefill_client["bootstrap_addr"] + "/query"
)
response.raise_for_status()
data = response.json()
break
for dp_rank, dp_entry in data.items():
prefill_client["dp_engine_id"][int(dp_rank)] = dp_entry["engine_id"]
dp_size = len(data)
prefill_client["dp_size"] = dp_size
print(f"Inited prefiller {prefill_client['url']} with dp_size={dp_size}")
ready.set()
print("All prefiller instances are ready.")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize client pools for prefiller and decoder services
app.state.prefill_clients = []
app.state.decode_clients = []
app.state.ready = asyncio.Event()
# Create prefill clients
for i, (url, bootstrap_port) in enumerate(global_args.prefill):
parsed_url = urllib.parse.urlparse(url)
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
app.state.prefill_clients.append(
{
"client": httpx.AsyncClient(
timeout=None,
base_url=url,
limits=httpx.Limits(
max_connections=None,
max_keepalive_connections=None,
),
),
"url": url,
"bootstrap_addr": make_http_path(hostname, bootstrap_port or 8998),
"dp_engine_id": {},
}
)
# Create decode clients
for i, url in enumerate(global_args.decode):
parsed_url = urllib.parse.urlparse(url)
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
app.state.decode_clients.append(
{
"client": httpx.AsyncClient(
timeout=None,
base_url=url,
limits=httpx.Limits(
max_connections=None,
max_keepalive_connections=None,
),
),
}
)
asyncio.create_task(get_prefiller_info(app.state.prefill_clients, app.state.ready))
# Initialize round-robin iterators
app.state.prefill_iterator = prefiller_cycle(app.state.prefill_clients)
app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients)))
print(
f"Got {len(app.state.prefill_clients)} prefill clients "
f"and {len(app.state.decode_clients)} decode clients."
)
yield
# Shutdown: Close all clients
for client_info in app.state.prefill_clients:
await client_info["client"].aclose()
for client_info in app.state.decode_clients:
await client_info["client"].aclose()
# Update FastAPI app initialization to use lifespan
app = FastAPI(lifespan=lifespan)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
# Always use 127.0.0.1 as localhost binds to IPv6 which is blocked on CI
parser.add_argument("--host", type=str, default="127.0.0.1")
# For prefiller instances
parser.add_argument(
"--prefill",
nargs="+",
action="append",
dest="prefill_raw",
metavar=("URL", "bootstrap_port"),
help=(
"Prefill server URL and optional bootstrap port. "
"Can be specified multiple times. "
"Format: --prefill URL [BOOTSTRAP_PORT]. "
"BOOTSTRAP_PORT can be a port number, "
"'none', or omitted (defaults to none)."
),
)
# For decoder instances
parser.add_argument(
"--decode",
nargs=1,
action="append",
dest="decode_raw",
metavar=("URL",),
help="Decode server URL. Can be specified multiple times.",
)
args = parser.parse_args()
args.prefill = _parse_prefill_urls(args.prefill_raw)
args.decode = _parse_decode_urls(args.decode_raw)
return args
# From sglang router_args.py
def _parse_prefill_urls(prefill_list):
"""Parse prefill URLs from --prefill arguments.
Format: --prefill URL [BOOTSTRAP_PORT]
Example:
--prefill http://prefill1:8080 9000 # With bootstrap port
--prefill http://prefill2:8080 none # Explicitly no bootstrap port
--prefill http://prefill3:8080 # Defaults to no bootstrap port
"""
if not prefill_list:
return []
prefill_urls = []
for prefill_args in prefill_list:
url = prefill_args[0]
# Handle optional bootstrap port
if len(prefill_args) >= 2:
bootstrap_port_str = prefill_args[1]
# Handle 'none' as None
if bootstrap_port_str.lower() == "none":
bootstrap_port = None
else:
try:
bootstrap_port = int(bootstrap_port_str)
except ValueError as e:
raise ValueError(
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'" # noqa: E501
) from e
else:
# No bootstrap port specified, default to None
bootstrap_port = None
prefill_urls.append((url, bootstrap_port))
return prefill_urls
def _parse_decode_urls(decode_list):
"""Parse decode URLs from --decode arguments.
Format: --decode URL
Example: --decode http://decode1:8081 --decode http://decode2:8081
"""
if not decode_list:
return []
# decode_list is a list of single-element lists due to nargs=1
return [url[0] for url in decode_list]
def get_next_client(app, service_type: str):
"""
Get the next client in round-robin fashion.
Args:
app: The FastAPI app instance
service_type: Either 'prefill' or 'decode'
Returns:
The next client to use
"""
if service_type == "prefill":
return next(app.state.prefill_iterator)
elif service_type == "decode":
client_idx = next(app.state.decode_iterator)
return app.state.decode_clients[client_idx]
else:
raise ValueError(f"Unknown service type: {service_type}")
async def send_request_to_service(
client_info: dict, dp_rank: int, endpoint: str, req_data: dict, request_id: str
):
"""
Send a request to a service using a client from the pool.
"""
req_data = req_data.copy()
req_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
req_data["stream"] = False
req_data["max_tokens"] = 1
# MB5 fix: clients (our replayer) may set min_tokens to enforce a fixed
# output length. After the proxy caps max_tokens=1 on the prefill leg,
# any min_tokens > 1 violates vLLM's `min_tokens <= max_tokens` check.
if "min_tokens" in req_data:
req_data["min_tokens"] = 1
if "max_completion_tokens" in req_data:
req_data["max_completion_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
"X-data-parallel-rank": str(dp_rank),
}
response = await client_info["client"].post(
endpoint, json=req_data, headers=headers
)
response.raise_for_status()
# CRITICAL: Release connection back to pool
await response.aclose()
async def stream_service_response(
prefill_client_info: dict,
prefill_dp_rank: int,
decode_client_info: dict,
endpoint: str,
req_data: dict,
request_id: str,
):
"""
Asynchronously stream response from a service using a client from the pool.
"""
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
req_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_bootstrap_addr": prefill_client_info["bootstrap_addr"],
"remote_engine_id": prefill_client_info["dp_engine_id"][prefill_dp_rank],
"transfer_id": f"xfer-{request_id}",
}
async with decode_client_info["client"].stream(
"POST", endpoint, json=req_data, headers=headers
) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
async def _handle_completions(api: str, request: Request):
if not app.state.ready.is_set():
raise HTTPException(status_code=503, detail="Service Unavailable")
try:
req_data = await request.json()
request_id = str(uuid.uuid4())
# Get the next prefill client in round-robin fashion
prefill_client_info, prefill_dp_rank = get_next_client(request.app, "prefill")
# Send request to prefill service
asyncio.create_task(
send_request_to_service(
prefill_client_info, prefill_dp_rank, api, req_data, request_id
)
)
decode_client_info = get_next_client(request.app, "decode")
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(
prefill_client_info,
prefill_dp_rank,
decode_client_info,
api,
req_data,
request_id=request_id,
):
yield chunk
return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle_completions("/v1/completions", request)
@app.post("/v1/chat/completions")
async def handle_chat_completions(request: Request):
return await _handle_completions("/v1/chat/completions", request)
if __name__ == "__main__":
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@@ -70,6 +70,29 @@ def plot_one_instance(snaps: list[dict], out: Path, title: str) -> None:
# Sort by first-seen time so the band order follows arrival
all_req_ids.sort(key=lambda r: req_first_seen[r])
if not all_req_ids:
# No requests ever ran on this instance; plot a flat used_blocks line
# instead of the stackplot (which can't handle empty input).
fig, ax1 = plt.subplots(figsize=(13, 4))
used = [s["used_blocks"] for s in snaps]
ax1.plot(times, used, color="#888", lw=1.5, label="used_blocks (no running reqs sampled)")
ax1.axhline(total_blocks, color="#444", lw=1.5, ls="-",
label=f"pool total = {total_blocks} blocks")
ax1.axhline(total_blocks * 0.9, color="#c44e52", lw=1.2, ls="--", alpha=0.7,
label="90% capacity")
ax1.set_ylabel("KV blocks")
ax1.set_ylim(0, total_blocks * 1.05)
ax1.set_xlabel("wall-clock since first snapshot (s)")
ax1.set_title(title + " [no per-request data; instance idle?]")
ax1.legend(loc="upper right", fontsize=9)
ax1.grid(True, alpha=0.3)
out.parent.mkdir(parents=True, exist_ok=True)
fig.tight_layout()
fig.savefig(out, dpi=120)
plt.close(fig)
print(f"wrote {out} (n_snapshots={len(snaps)}, 0 running reqs ever)")
return
matrix = np.zeros((len(all_req_ids), len(times)), dtype=np.int64)
req_to_row = {r: i for i, r in enumerate(all_req_ids)}
for j, s in enumerate(snaps):