MB5 PD-disagg pipeline: working end-to-end
Three independent bugs were blocking PD-disagg smoke; each fix is
isolated so the next PD experiment doesn't re-hit them.
1. mb5_launch.sh
- stop_all() also kills mb5_pd_proxy.py (our vendored copy),
not just the upstream filename, and asserts ports 8000-8007 +
PROXY_PORT are free before launching — stale proxies were
silently passing the readiness check.
- Proxy readiness uses a generic "any HTTP response" probe;
mooncake_connector_proxy only exposes /v1/completions so
/v1/models 404 is expected.
2. mb5_pd_proxy.py (vendored from third_party so deploy.sh ships it)
- Force min_tokens=1 on the prefill leg. Clients that set
min_tokens == max_tokens (our replayer does) collide with
vLLM's min_tokens<=max_tokens check after the proxy caps
max_tokens=1.
3. instrument_kv_snapshot.py
- Adds a second patch target: initialize
MooncakeConnectorWorker.bootstrap_server = None in __init__.
vLLM 0.18.1 only sets it under the is_kv_producer branch, so
kv_consumer hits AttributeError as soon as the first remote
prefill request lands.
- apply/revert refactored to iterate over (path, patches) pairs.
plot_kv_pool_timeline.py also handles snapshot files that never
captured a running request (would otherwise IndexError on an empty
stackplot input).
Smoke: 4P+4D × 20 reqs → 20/20 success, mean 3.9s, p99 17s, 8 PIDs
all writing snapshots (601 total), well above the 8C baseline.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -32,6 +32,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
DEFAULT_VENV = Path("/home/admin/cpfs/wjh/agentic-kv-fresh/.venv")
|
DEFAULT_VENV = Path("/home/admin/cpfs/wjh/agentic-kv-fresh/.venv")
|
||||||
TARGET_REL = "lib/python3.12/site-packages/vllm/v1/core/sched/scheduler.py"
|
TARGET_REL = "lib/python3.12/site-packages/vllm/v1/core/sched/scheduler.py"
|
||||||
|
MOONCAKE_REL = (
|
||||||
|
"lib/python3.12/site-packages/vllm/distributed/kv_transfer/"
|
||||||
|
"kv_connector/v1/mooncake/mooncake_connector.py"
|
||||||
|
)
|
||||||
|
|
||||||
START_MARK = "# MB5_INSTRUMENT_START"
|
START_MARK = "# MB5_INSTRUMENT_START"
|
||||||
END_MARK = "# MB5_INSTRUMENT_END"
|
END_MARK = "# MB5_INSTRUMENT_END"
|
||||||
@@ -165,45 +169,66 @@ SCHED_RET_REPLACE = f""" {START_MARK}
|
|||||||
def _agentic_emit_step_log("""
|
def _agentic_emit_step_log("""
|
||||||
|
|
||||||
|
|
||||||
PATCHES = [
|
SCHED_PATCHES = [
|
||||||
("header", HEADER_ANCHOR, HEADER_ANCHOR + HEADER_INSERT),
|
("header", HEADER_ANCHOR, HEADER_ANCHOR + HEADER_INSERT),
|
||||||
("schedule() return", SCHED_RET_TARGET, SCHED_RET_REPLACE),
|
("schedule() return", SCHED_RET_TARGET, SCHED_RET_REPLACE),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# ---------- Patch 3: vLLM 0.18.1 kv_consumer AttributeError fix --------------
|
||||||
|
# In MooncakeConnectorWorker.__init__, `self.bootstrap_server` is only assigned
|
||||||
|
# inside the `is_kv_producer` branch (around line 615). For kv_consumer roles
|
||||||
|
# the attribute is never set, but later code paths (e.g. line ~1060) check
|
||||||
|
# `if self.bootstrap_server is not None:` and AttributeError. We initialize it
|
||||||
|
# unconditionally just before the role-conditional branch.
|
||||||
|
MOONCAKE_ANCHOR = " self.reqs_need_send: dict[TransferId, SendBlockMeta] = {}\n"
|
||||||
|
MOONCAKE_INSERT = (
|
||||||
|
f" {START_MARK}\n"
|
||||||
|
f" self.bootstrap_server = None # vLLM 0.18.1 kv_consumer fix\n"
|
||||||
|
f" {END_MARK}\n"
|
||||||
|
)
|
||||||
|
|
||||||
def find_target(venv_or_path: Path) -> Path:
|
MOONCAKE_PATCHES = [
|
||||||
candidates = [venv_or_path, DEFAULT_VENV / TARGET_REL]
|
("kv_consumer bootstrap_server init", MOONCAKE_ANCHOR,
|
||||||
|
MOONCAKE_ANCHOR + MOONCAKE_INSERT),
|
||||||
|
]
|
||||||
|
|
||||||
|
PATCH_FILES = [
|
||||||
|
(TARGET_REL, SCHED_PATCHES),
|
||||||
|
(MOONCAKE_REL, MOONCAKE_PATCHES),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def find_target(venv_or_path: Path, rel_path: str) -> Path:
|
||||||
|
candidates = [venv_or_path / rel_path, DEFAULT_VENV / rel_path]
|
||||||
for c in candidates:
|
for c in candidates:
|
||||||
if c.is_file():
|
if c.is_file():
|
||||||
return c
|
return c
|
||||||
if c.is_dir():
|
raise FileNotFoundError(
|
||||||
sub = c / TARGET_REL
|
f"cannot find {rel_path} under {venv_or_path}"
|
||||||
if sub.is_file():
|
)
|
||||||
return sub
|
|
||||||
raise FileNotFoundError(f"cannot find vllm V1 scheduler at {venv_or_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def is_patched(text: str) -> bool:
|
def is_patched(text: str) -> bool:
|
||||||
return START_MARK in text
|
return START_MARK in text
|
||||||
|
|
||||||
|
|
||||||
def apply(target: Path) -> None:
|
def apply_one(target: Path, patches: list) -> None:
|
||||||
text = target.read_text()
|
text = target.read_text()
|
||||||
if is_patched(text):
|
if is_patched(text):
|
||||||
print(f"[mb5-instr] already patched: {target}")
|
print(f"[mb5-instr] already patched: {target}")
|
||||||
return
|
return
|
||||||
new = text
|
new = text
|
||||||
for name, src, dst in PATCHES:
|
for name, src, dst in patches:
|
||||||
if src not in new:
|
if src not in new:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"patch {name!r}: anchor not found in {target}."
|
f"patch {name!r}: anchor not found in {target}."
|
||||||
)
|
)
|
||||||
new = new.replace(src, dst, 1)
|
new = new.replace(src, dst, 1)
|
||||||
target.write_text(new)
|
target.write_text(new)
|
||||||
print(f"[mb5-instr] applied {len(PATCHES)} patches -> {target}")
|
print(f"[mb5-instr] applied {len(patches)} patches -> {target}")
|
||||||
|
|
||||||
|
|
||||||
def revert(target: Path) -> None:
|
def revert_one(target: Path) -> None:
|
||||||
text = target.read_text()
|
text = target.read_text()
|
||||||
if not is_patched(text):
|
if not is_patched(text):
|
||||||
print(f"[mb5-instr] not patched (nothing to revert): {target}")
|
print(f"[mb5-instr] not patched (nothing to revert): {target}")
|
||||||
@@ -225,16 +250,18 @@ def main() -> None:
|
|||||||
p.add_argument("--check", action="store_true")
|
p.add_argument("--check", action="store_true")
|
||||||
p.add_argument("--venv", type=Path, default=DEFAULT_VENV)
|
p.add_argument("--venv", type=Path, default=DEFAULT_VENV)
|
||||||
args = p.parse_args()
|
args = p.parse_args()
|
||||||
target = find_target(args.venv)
|
for rel_path, patches in PATCH_FILES:
|
||||||
if args.apply:
|
target = find_target(args.venv, rel_path)
|
||||||
apply(target)
|
if args.apply:
|
||||||
elif args.revert:
|
apply_one(target, patches)
|
||||||
revert(target)
|
elif args.revert:
|
||||||
elif args.check:
|
revert_one(target)
|
||||||
text = target.read_text()
|
elif args.check:
|
||||||
print(f"[mb5-instr] {'PATCHED' if is_patched(text) else 'CLEAN'}: {target}")
|
text = target.read_text()
|
||||||
else:
|
state = 'PATCHED' if is_patched(text) else 'CLEAN'
|
||||||
p.error("specify --apply / --revert / --check")
|
print(f"[mb5-instr] {state}: {target}")
|
||||||
|
else:
|
||||||
|
p.error("specify --apply / --revert / --check")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ VENV="${FRESH_ROOT}/.venv"
|
|||||||
MODEL="${MODEL:-/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
MODEL="${MODEL:-/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
INSTRUMENT="${SCRIPT_DIR}/instrument_kv_snapshot.py"
|
INSTRUMENT="${SCRIPT_DIR}/instrument_kv_snapshot.py"
|
||||||
PROXY_SRC="${SCRIPT_DIR}/../../third_party/vllm/examples/online_serving/disaggregated_serving/mooncake_connector/mooncake_connector_proxy.py"
|
PROXY_SRC="${SCRIPT_DIR}/mb5_pd_proxy.py"
|
||||||
|
|
||||||
CONFIG="${CONFIG:-8C}"
|
CONFIG="${CONFIG:-8C}"
|
||||||
RUN_LABEL="${RUN_LABEL:-default}"
|
RUN_LABEL="${RUN_LABEL:-default}"
|
||||||
@@ -44,10 +44,21 @@ BASE_BP=8998
|
|||||||
BASE_MASTER=29500
|
BASE_MASTER=29500
|
||||||
|
|
||||||
stop_all() {
|
stop_all() {
|
||||||
|
pkill -9 -f "mb5_pd_proxy.py" 2>/dev/null || true
|
||||||
pkill -9 -f "mooncake_connector_proxy.py" 2>/dev/null || true
|
pkill -9 -f "mooncake_connector_proxy.py" 2>/dev/null || true
|
||||||
pkill -9 -f "vllm serve" 2>/dev/null || true
|
pkill -9 -f "vllm serve" 2>/dev/null || true
|
||||||
pkill -9 -f "EngineCore" 2>/dev/null || true
|
pkill -9 -f "EngineCore" 2>/dev/null || true
|
||||||
sleep 3
|
sleep 3
|
||||||
|
# Hard guarantee: required ports must be free before we start. If they
|
||||||
|
# aren't, an earlier run left a stale process holding the socket and the
|
||||||
|
# readiness check would (silently) probe the stale proxy.
|
||||||
|
for port in 8000 8001 8002 8003 8004 8005 8006 8007 "${PROXY_PORT}"; do
|
||||||
|
if ss -ltn 2>/dev/null | awk '{print $4}' | grep -qE "[:.]${port}\$"; then
|
||||||
|
echo "[mb5] FATAL port ${port} still in use after stop_all; manual cleanup needed"
|
||||||
|
ss -ltnp 2>/dev/null | grep -E "[:.]${port}\$" || true
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
}
|
}
|
||||||
|
|
||||||
case "${1:-start}" in
|
case "${1:-start}" in
|
||||||
@@ -183,9 +194,11 @@ if [ "${ROLES}" = "pd" ]; then
|
|||||||
nohup python "${PROXY_SRC}" "${proxy_args[@]}" --port "${PROXY_PORT}" --host 0.0.0.0 \
|
nohup python "${PROXY_SRC}" "${proxy_args[@]}" --port "${PROXY_PORT}" --host 0.0.0.0 \
|
||||||
> "${LOGS_DIR}/proxy.log" 2>&1 &
|
> "${LOGS_DIR}/proxy.log" 2>&1 &
|
||||||
disown
|
disown
|
||||||
# wait for proxy
|
# wait for proxy. Official mooncake_connector_proxy only handles
|
||||||
|
# /v1/completions, so /health and /v1/models return 404 — accept any
|
||||||
|
# HTTP response as "alive".
|
||||||
tries=0
|
tries=0
|
||||||
while ! curl -sf "http://127.0.0.1:${PROXY_PORT}/v1/models" >/dev/null 2>&1; do
|
while ! curl -s -o /dev/null -w "%{http_code}" "http://127.0.0.1:${PROXY_PORT}/" 2>/dev/null | grep -qE "^[0-9]"; do
|
||||||
tries=$((tries+1))
|
tries=$((tries+1))
|
||||||
if [ ${tries} -gt 60 ]; then
|
if [ ${tries} -gt 60 ]; then
|
||||||
echo "[mb5] FATAL proxy did not come up in 2 min"
|
echo "[mb5] FATAL proxy did not come up in 2 min"
|
||||||
@@ -194,7 +207,7 @@ if [ "${ROLES}" = "pd" ]; then
|
|||||||
fi
|
fi
|
||||||
sleep 2
|
sleep 2
|
||||||
done
|
done
|
||||||
echo " proxy port=${PROXY_PORT} ready"
|
echo " proxy port=${PROXY_PORT} ready (HTTP responding)"
|
||||||
ENDPOINTS="http://127.0.0.1:${PROXY_PORT}"
|
ENDPOINTS="http://127.0.0.1:${PROXY_PORT}"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
381
microbench/fresh_setup/mb5_pd_proxy.py
Normal file
381
microbench/fresh_setup/mb5_pd_proxy.py
Normal file
@@ -0,0 +1,381 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import ipaddress
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
import urllib
|
||||||
|
import uuid
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_wrap_ipv6_address(address: str) -> str:
|
||||||
|
try:
|
||||||
|
ipaddress.IPv6Address(address)
|
||||||
|
return f"[{address}]"
|
||||||
|
except ValueError:
|
||||||
|
return address
|
||||||
|
|
||||||
|
|
||||||
|
def make_http_path(host: str, port: int) -> str:
|
||||||
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
def prefiller_cycle(prefill_clients: list[Any]):
|
||||||
|
while True:
|
||||||
|
for prefill_client in prefill_clients:
|
||||||
|
for i in range(prefill_client["dp_size"]):
|
||||||
|
yield prefill_client, i
|
||||||
|
|
||||||
|
|
||||||
|
async def get_prefiller_info(prefill_clients: list, ready: asyncio.Event):
|
||||||
|
for prefill_client in prefill_clients:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Wait for prefill service to be ready
|
||||||
|
response = await prefill_client["client"].get("/health")
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response = await prefill_client["client"].get(
|
||||||
|
prefill_client["bootstrap_addr"] + "/query"
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
break
|
||||||
|
|
||||||
|
for dp_rank, dp_entry in data.items():
|
||||||
|
prefill_client["dp_engine_id"][int(dp_rank)] = dp_entry["engine_id"]
|
||||||
|
dp_size = len(data)
|
||||||
|
prefill_client["dp_size"] = dp_size
|
||||||
|
print(f"Inited prefiller {prefill_client['url']} with dp_size={dp_size}")
|
||||||
|
|
||||||
|
ready.set()
|
||||||
|
print("All prefiller instances are ready.")
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""
|
||||||
|
Lifespan context manager to handle startup and shutdown events.
|
||||||
|
"""
|
||||||
|
# Startup: Initialize client pools for prefiller and decoder services
|
||||||
|
app.state.prefill_clients = []
|
||||||
|
app.state.decode_clients = []
|
||||||
|
app.state.ready = asyncio.Event()
|
||||||
|
|
||||||
|
# Create prefill clients
|
||||||
|
for i, (url, bootstrap_port) in enumerate(global_args.prefill):
|
||||||
|
parsed_url = urllib.parse.urlparse(url)
|
||||||
|
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||||
|
app.state.prefill_clients.append(
|
||||||
|
{
|
||||||
|
"client": httpx.AsyncClient(
|
||||||
|
timeout=None,
|
||||||
|
base_url=url,
|
||||||
|
limits=httpx.Limits(
|
||||||
|
max_connections=None,
|
||||||
|
max_keepalive_connections=None,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"url": url,
|
||||||
|
"bootstrap_addr": make_http_path(hostname, bootstrap_port or 8998),
|
||||||
|
"dp_engine_id": {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create decode clients
|
||||||
|
for i, url in enumerate(global_args.decode):
|
||||||
|
parsed_url = urllib.parse.urlparse(url)
|
||||||
|
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
||||||
|
app.state.decode_clients.append(
|
||||||
|
{
|
||||||
|
"client": httpx.AsyncClient(
|
||||||
|
timeout=None,
|
||||||
|
base_url=url,
|
||||||
|
limits=httpx.Limits(
|
||||||
|
max_connections=None,
|
||||||
|
max_keepalive_connections=None,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.create_task(get_prefiller_info(app.state.prefill_clients, app.state.ready))
|
||||||
|
|
||||||
|
# Initialize round-robin iterators
|
||||||
|
app.state.prefill_iterator = prefiller_cycle(app.state.prefill_clients)
|
||||||
|
app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients)))
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Got {len(app.state.prefill_clients)} prefill clients "
|
||||||
|
f"and {len(app.state.decode_clients)} decode clients."
|
||||||
|
)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown: Close all clients
|
||||||
|
for client_info in app.state.prefill_clients:
|
||||||
|
await client_info["client"].aclose()
|
||||||
|
|
||||||
|
for client_info in app.state.decode_clients:
|
||||||
|
await client_info["client"].aclose()
|
||||||
|
|
||||||
|
|
||||||
|
# Update FastAPI app initialization to use lifespan
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
# Always use 127.0.0.1 as localhost binds to IPv6 which is blocked on CI
|
||||||
|
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||||
|
|
||||||
|
# For prefiller instances
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefill",
|
||||||
|
nargs="+",
|
||||||
|
action="append",
|
||||||
|
dest="prefill_raw",
|
||||||
|
metavar=("URL", "bootstrap_port"),
|
||||||
|
help=(
|
||||||
|
"Prefill server URL and optional bootstrap port. "
|
||||||
|
"Can be specified multiple times. "
|
||||||
|
"Format: --prefill URL [BOOTSTRAP_PORT]. "
|
||||||
|
"BOOTSTRAP_PORT can be a port number, "
|
||||||
|
"'none', or omitted (defaults to none)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# For decoder instances
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode",
|
||||||
|
nargs=1,
|
||||||
|
action="append",
|
||||||
|
dest="decode_raw",
|
||||||
|
metavar=("URL",),
|
||||||
|
help="Decode server URL. Can be specified multiple times.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.prefill = _parse_prefill_urls(args.prefill_raw)
|
||||||
|
args.decode = _parse_decode_urls(args.decode_raw)
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
# From sglang router_args.py
|
||||||
|
def _parse_prefill_urls(prefill_list):
|
||||||
|
"""Parse prefill URLs from --prefill arguments.
|
||||||
|
|
||||||
|
Format: --prefill URL [BOOTSTRAP_PORT]
|
||||||
|
Example:
|
||||||
|
--prefill http://prefill1:8080 9000 # With bootstrap port
|
||||||
|
--prefill http://prefill2:8080 none # Explicitly no bootstrap port
|
||||||
|
--prefill http://prefill3:8080 # Defaults to no bootstrap port
|
||||||
|
"""
|
||||||
|
if not prefill_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
prefill_urls = []
|
||||||
|
for prefill_args in prefill_list:
|
||||||
|
url = prefill_args[0]
|
||||||
|
|
||||||
|
# Handle optional bootstrap port
|
||||||
|
if len(prefill_args) >= 2:
|
||||||
|
bootstrap_port_str = prefill_args[1]
|
||||||
|
# Handle 'none' as None
|
||||||
|
if bootstrap_port_str.lower() == "none":
|
||||||
|
bootstrap_port = None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
bootstrap_port = int(bootstrap_port_str)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'" # noqa: E501
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
# No bootstrap port specified, default to None
|
||||||
|
bootstrap_port = None
|
||||||
|
|
||||||
|
prefill_urls.append((url, bootstrap_port))
|
||||||
|
|
||||||
|
return prefill_urls
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_decode_urls(decode_list):
|
||||||
|
"""Parse decode URLs from --decode arguments.
|
||||||
|
|
||||||
|
Format: --decode URL
|
||||||
|
Example: --decode http://decode1:8081 --decode http://decode2:8081
|
||||||
|
"""
|
||||||
|
if not decode_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# decode_list is a list of single-element lists due to nargs=1
|
||||||
|
return [url[0] for url in decode_list]
|
||||||
|
|
||||||
|
|
||||||
|
def get_next_client(app, service_type: str):
|
||||||
|
"""
|
||||||
|
Get the next client in round-robin fashion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: The FastAPI app instance
|
||||||
|
service_type: Either 'prefill' or 'decode'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The next client to use
|
||||||
|
"""
|
||||||
|
if service_type == "prefill":
|
||||||
|
return next(app.state.prefill_iterator)
|
||||||
|
elif service_type == "decode":
|
||||||
|
client_idx = next(app.state.decode_iterator)
|
||||||
|
return app.state.decode_clients[client_idx]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown service type: {service_type}")
|
||||||
|
|
||||||
|
|
||||||
|
async def send_request_to_service(
|
||||||
|
client_info: dict, dp_rank: int, endpoint: str, req_data: dict, request_id: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Send a request to a service using a client from the pool.
|
||||||
|
"""
|
||||||
|
req_data = req_data.copy()
|
||||||
|
req_data["kv_transfer_params"] = {
|
||||||
|
"do_remote_decode": True,
|
||||||
|
"do_remote_prefill": False,
|
||||||
|
"transfer_id": f"xfer-{request_id}",
|
||||||
|
}
|
||||||
|
req_data["stream"] = False
|
||||||
|
req_data["max_tokens"] = 1
|
||||||
|
# MB5 fix: clients (our replayer) may set min_tokens to enforce a fixed
|
||||||
|
# output length. After the proxy caps max_tokens=1 on the prefill leg,
|
||||||
|
# any min_tokens > 1 violates vLLM's `min_tokens <= max_tokens` check.
|
||||||
|
if "min_tokens" in req_data:
|
||||||
|
req_data["min_tokens"] = 1
|
||||||
|
if "max_completion_tokens" in req_data:
|
||||||
|
req_data["max_completion_tokens"] = 1
|
||||||
|
if "stream_options" in req_data:
|
||||||
|
del req_data["stream_options"]
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||||
|
"X-Request-Id": request_id,
|
||||||
|
"X-data-parallel-rank": str(dp_rank),
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await client_info["client"].post(
|
||||||
|
endpoint, json=req_data, headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# CRITICAL: Release connection back to pool
|
||||||
|
await response.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_service_response(
|
||||||
|
prefill_client_info: dict,
|
||||||
|
prefill_dp_rank: int,
|
||||||
|
decode_client_info: dict,
|
||||||
|
endpoint: str,
|
||||||
|
req_data: dict,
|
||||||
|
request_id: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Asynchronously stream response from a service using a client from the pool.
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||||
|
"X-Request-Id": request_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
req_data["kv_transfer_params"] = {
|
||||||
|
"do_remote_decode": False,
|
||||||
|
"do_remote_prefill": True,
|
||||||
|
"remote_bootstrap_addr": prefill_client_info["bootstrap_addr"],
|
||||||
|
"remote_engine_id": prefill_client_info["dp_engine_id"][prefill_dp_rank],
|
||||||
|
"transfer_id": f"xfer-{request_id}",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with decode_client_info["client"].stream(
|
||||||
|
"POST", endpoint, json=req_data, headers=headers
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
async for chunk in response.aiter_bytes():
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_completions(api: str, request: Request):
|
||||||
|
if not app.state.ready.is_set():
|
||||||
|
raise HTTPException(status_code=503, detail="Service Unavailable")
|
||||||
|
|
||||||
|
try:
|
||||||
|
req_data = await request.json()
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Get the next prefill client in round-robin fashion
|
||||||
|
prefill_client_info, prefill_dp_rank = get_next_client(request.app, "prefill")
|
||||||
|
|
||||||
|
# Send request to prefill service
|
||||||
|
asyncio.create_task(
|
||||||
|
send_request_to_service(
|
||||||
|
prefill_client_info, prefill_dp_rank, api, req_data, request_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
decode_client_info = get_next_client(request.app, "decode")
|
||||||
|
|
||||||
|
# Stream response from decode service
|
||||||
|
async def generate_stream():
|
||||||
|
async for chunk in stream_service_response(
|
||||||
|
prefill_client_info,
|
||||||
|
prefill_dp_rank,
|
||||||
|
decode_client_info,
|
||||||
|
api,
|
||||||
|
req_data,
|
||||||
|
request_id=request_id,
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return StreamingResponse(generate_stream(), media_type="application/json")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
|
||||||
|
print(e)
|
||||||
|
print("".join(traceback.format_exception(*exc_info)))
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/completions")
|
||||||
|
async def handle_completions(request: Request):
|
||||||
|
return await _handle_completions("/v1/completions", request)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def handle_chat_completions(request: Request):
|
||||||
|
return await _handle_completions("/v1/chat/completions", request)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
global global_args
|
||||||
|
global_args = parse_args()
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||||
@@ -70,6 +70,29 @@ def plot_one_instance(snaps: list[dict], out: Path, title: str) -> None:
|
|||||||
# Sort by first-seen time so the band order follows arrival
|
# Sort by first-seen time so the band order follows arrival
|
||||||
all_req_ids.sort(key=lambda r: req_first_seen[r])
|
all_req_ids.sort(key=lambda r: req_first_seen[r])
|
||||||
|
|
||||||
|
if not all_req_ids:
|
||||||
|
# No requests ever ran on this instance; plot a flat used_blocks line
|
||||||
|
# instead of the stackplot (which can't handle empty input).
|
||||||
|
fig, ax1 = plt.subplots(figsize=(13, 4))
|
||||||
|
used = [s["used_blocks"] for s in snaps]
|
||||||
|
ax1.plot(times, used, color="#888", lw=1.5, label="used_blocks (no running reqs sampled)")
|
||||||
|
ax1.axhline(total_blocks, color="#444", lw=1.5, ls="-",
|
||||||
|
label=f"pool total = {total_blocks} blocks")
|
||||||
|
ax1.axhline(total_blocks * 0.9, color="#c44e52", lw=1.2, ls="--", alpha=0.7,
|
||||||
|
label="90% capacity")
|
||||||
|
ax1.set_ylabel("KV blocks")
|
||||||
|
ax1.set_ylim(0, total_blocks * 1.05)
|
||||||
|
ax1.set_xlabel("wall-clock since first snapshot (s)")
|
||||||
|
ax1.set_title(title + " [no per-request data; instance idle?]")
|
||||||
|
ax1.legend(loc="upper right", fontsize=9)
|
||||||
|
ax1.grid(True, alpha=0.3)
|
||||||
|
out.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(out, dpi=120)
|
||||||
|
plt.close(fig)
|
||||||
|
print(f"wrote {out} (n_snapshots={len(snaps)}, 0 running reqs ever)")
|
||||||
|
return
|
||||||
|
|
||||||
matrix = np.zeros((len(all_req_ids), len(times)), dtype=np.int64)
|
matrix = np.zeros((len(all_req_ids), len(times)), dtype=np.int64)
|
||||||
req_to_row = {r: i for i, r in enumerate(all_req_ids)}
|
req_to_row = {r: i for i, r in enumerate(all_req_ids)}
|
||||||
for j, s in enumerate(snaps):
|
for j, s in enumerate(snaps):
|
||||||
|
|||||||
Reference in New Issue
Block a user