Files
agentic-pd-hybrid/src/agentic_pd_hybrid/pd_router.py
kzlin c9d350b372 docs: KVC v1-v4 debug journey + raise session soft_cap to 16
Document the iterative debugging from v1 (broken KVC) through v4
(routing fixed + session cap raised), with code-level analysis of
the two main bugs encountered:

1. v2 root cause (mis-diagnosed previously as `allow_local_prefill`):
   `--policy default` for KVC mechanism caused replay's round-robin
   policy and the PD router's round-robin to diverge, sending requests
   with `session_params` to a D worker that did not have the session
   open. Resulted in 56-61% truncation with finish_reason
   "session id X does not exist".
   Fix: use `--policy kv-aware` (sweep_tp1_v3_kvaware.sh) so replay
   emits `x-smg-target-worker` and PD router uses consistent_hashing.

2. v3 new bottleneck: `pd-router-fallback-large-append-session-cap`
   dominated 52-65% of requests. Root cause was hardcoded
   `min(4, ...)` in `_decode_session_soft_cap`. With 7 D workers x 4
   sessions = 28 slots for 52 trace sessions, ~24 sessions starved
   permanently (bimodal direct-to-D rate of 0% or 99%).
   Fix: raise the cap to 16 (replay.py).

Also includes the v3 finding that direct-to-d-session path P50=0.495s
and TTFT P50=0.043s already beats the 8-way DP baseline (0.65s/0.093s)
- the KVC core mechanism works when fallback paths are avoided.

Files:
- docs/KVC_DEBUG_JOURNEY_V1_TO_V4.md: full journey + code location index
- docs/SWEBENCH_EXPERIMENT_{PROGRESS,RESULTS}.md: prior session notes
- scripts/sweep_tp1_v{2,3,4}*.sh: experiment driver scripts
- src/agentic_pd_hybrid/replay.py: cap 4 -> 16, audit fields
- src/agentic_pd_hybrid/pd_router.py: strip session_params from prefill
- src/agentic_pd_hybrid/metrics.py: truncated_request_count

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 21:10:41 +08:00

467 lines
15 KiB
Python

from __future__ import annotations
import argparse
import asyncio
import random
import urllib.parse
from dataclasses import dataclass
from http import HTTPStatus
from itertools import chain
from typing import AsyncIterator
import aiohttp
import orjson
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
_STREAM_CHUNK_SIZE = 1024 * 64
@dataclass
class RouterConfig:
host: str
port: int
prefill_urls: list[tuple[str, int]]
decode_urls: list[str]
prefill_policy: str = "round_robin"
decode_policy: str = "manual"
request_timeout_s: float = 1800.0
class RouterState:
def __init__(self, config: RouterConfig):
if not config.prefill_urls:
raise ValueError("At least one prefill worker is required")
if not config.decode_urls:
raise ValueError("At least one decode worker is required")
self.config = config
self.prefill_cursor = 0
self.decode_cursor = 0
self.sticky_decode_map: dict[str, int] = {}
def select_pair(self, headers: dict[str, str]) -> tuple[str, int, str]:
prefill_url, bootstrap_port = self.config.prefill_urls[
self.prefill_cursor % len(self.config.prefill_urls)
]
self.prefill_cursor += 1
decode_index = self._select_decode_index(headers)
return prefill_url, bootstrap_port, self.config.decode_urls[decode_index]
def _select_decode_index(self, headers: dict[str, str]) -> int:
target_worker = headers.get("x-smg-target-worker")
routing_key = headers.get("x-smg-routing-key")
if (
self.config.decode_policy == "consistent_hashing"
and target_worker is not None
):
idx = int(target_worker)
if 0 <= idx < len(self.config.decode_urls):
return idx
if self.config.decode_policy == "manual" and routing_key:
cached = self.sticky_decode_map.get(routing_key)
if cached is not None:
return cached
idx = self.decode_cursor % len(self.config.decode_urls)
self.decode_cursor += 1
self.sticky_decode_map[routing_key] = idx
return idx
idx = self.decode_cursor % len(self.config.decode_urls)
self.decode_cursor += 1
return idx
@dataclass
class DpRouterConfig:
host: str
port: int
backend_urls: list[str]
backend_policy: str = "round_robin"
request_timeout_s: float = 1800.0
class DpRouterState:
"""DP (data-parallel) router: forward each request to exactly one backend."""
def __init__(self, config: DpRouterConfig):
if not config.backend_urls:
raise ValueError("At least one backend worker is required")
self.config = config
self.cursor = 0
self.sticky_map: dict[str, int] = {}
def select_backend(self, headers: dict[str, str]) -> str:
idx = self._select_index(headers)
return self.config.backend_urls[idx]
def _select_index(self, headers: dict[str, str]) -> int:
target_worker = headers.get("x-smg-target-worker")
routing_key = headers.get("x-smg-routing-key")
if (
self.config.backend_policy == "consistent_hashing"
and target_worker is not None
):
idx = int(target_worker)
if 0 <= idx < len(self.config.backend_urls):
return idx
if self.config.backend_policy == "manual" and routing_key:
cached = self.sticky_map.get(routing_key)
if cached is not None:
return cached
idx = self.cursor % len(self.config.backend_urls)
self.cursor += 1
self.sticky_map[routing_key] = idx
return idx
idx = self.cursor % len(self.config.backend_urls)
self.cursor += 1
return idx
app = FastAPI()
router_state: RouterState | None = None
dp_state: DpRouterState | None = None
@app.get("/health")
async def health() -> Response:
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate() -> Response:
if dp_state is not None:
async with aiohttp.ClientSession() as session:
tasks = [
session.get(f"{url}/health_generate")
for url in dp_state.config.backend_urls
]
for response in asyncio.as_completed(tasks):
async with await response:
pass
return Response(status_code=200)
state = _require_state()
async with aiohttp.ClientSession() as session:
tasks = []
for server in chain(
(url for url, _ in state.config.prefill_urls),
state.config.decode_urls,
):
tasks.append(session.get(f"{server}/health_generate"))
for response in asyncio.as_completed(tasks):
async with await response:
pass
return Response(status_code=200)
@app.get("/v1/models")
async def models() -> ORJSONResponse:
if dp_state is not None:
async with aiohttp.ClientSession() as session:
async with session.get(f"{dp_state.config.backend_urls[0]}/v1/models") as resp:
payload = await resp.json()
return ORJSONResponse(payload, status_code=resp.status)
state = _require_state()
async with aiohttp.ClientSession() as session:
async with session.get(f"{state.config.prefill_urls[0][0]}/v1/models") as response:
payload = await response.json()
return ORJSONResponse(payload, status_code=response.status)
@app.post("/v1/chat/completions")
async def chat_completions(request: Request) -> Response:
request_data = await request.json()
headers = {key.lower(): value for key, value in request.headers.items()}
return await _forward_to_backend(
request_data=request_data,
headers=headers,
endpoint_name="v1/chat/completions",
)
@app.post("/v1/completions")
async def completions(request: Request) -> Response:
request_data = await request.json()
headers = {key.lower(): value for key, value in request.headers.items()}
return await _forward_to_backend(
request_data=request_data,
headers=headers,
endpoint_name="v1/completions",
)
@app.post("/generate")
async def generate(request: Request) -> Response:
request_data = await request.json()
headers = {key.lower(): value for key, value in request.headers.items()}
return await _forward_to_backend(
request_data=request_data,
headers=headers,
endpoint_name="generate",
)
async def _forward_to_backend(
*,
request_data: dict,
headers: dict[str, str],
endpoint_name: str,
) -> Response:
# DP mode: forward to a single backend
if dp_state is not None:
return await _forward_to_dp_backend(
request_data=request_data,
headers=headers,
endpoint_name=endpoint_name,
)
# PD mode: coordinate prefill + decode
state = _require_state()
prefill_server, bootstrap_port, decode_server = state.select_pair(headers)
prefill_request, decode_request = _build_backend_requests(
request_data=request_data,
prefill_server=prefill_server,
bootstrap_port=bootstrap_port,
)
if request_data.get("stream", False):
return StreamingResponse(
_stream_generate(
prefill_request=prefill_request,
decode_request=decode_request,
prefill_server=prefill_server,
decode_server=decode_server,
endpoint_name=endpoint_name,
timeout_s=state.config.request_timeout_s,
),
media_type="text/event-stream",
)
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=state.config.request_timeout_s)
) as session:
prefill_response, decode_response = await asyncio.gather(
session.post(f"{prefill_server}/{endpoint_name}", json=prefill_request),
session.post(f"{decode_server}/{endpoint_name}", json=decode_request),
)
async with prefill_response:
await prefill_response.read()
async with decode_response:
body = await decode_response.read()
return Response(
content=body,
status_code=decode_response.status,
media_type=decode_response.content_type,
)
async def _forward_to_dp_backend(
*,
request_data: dict,
headers: dict[str, str],
endpoint_name: str,
) -> Response:
assert dp_state is not None
backend_server = dp_state.select_backend(headers)
cleaned = _strip_internal_fields(request_data)
timeout_s = dp_state.config.request_timeout_s
if request_data.get("stream", False):
return StreamingResponse(
_stream_dp_generate(
request_data=cleaned,
backend_server=backend_server,
endpoint_name=endpoint_name,
timeout_s=timeout_s,
),
media_type="text/event-stream",
)
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=timeout_s)
) as session:
async with session.post(
f"{backend_server}/{endpoint_name}", json=cleaned
) as response:
body = await response.read()
return Response(
content=body,
status_code=response.status,
media_type=response.content_type,
)
async def _stream_dp_generate(
*,
request_data: dict,
backend_server: str,
endpoint_name: str,
timeout_s: float,
) -> AsyncIterator[bytes]:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=timeout_s)
) as session:
async with session.post(
f"{backend_server}/{endpoint_name}", json=request_data
) as response:
if response.status != HTTPStatus.OK:
payload = await response.read()
yield payload
return
async for chunk in response.content.iter_chunked(_STREAM_CHUNK_SIZE):
yield chunk
async def _stream_generate(
*,
prefill_request: dict,
decode_request: dict,
prefill_server: str,
decode_server: str,
endpoint_name: str,
timeout_s: float,
) -> AsyncIterator[bytes]:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=timeout_s)
) as session:
prefill_response, decode_response = await asyncio.gather(
session.post(f"{prefill_server}/{endpoint_name}", json=prefill_request),
session.post(f"{decode_server}/{endpoint_name}", json=decode_request),
)
async with prefill_response, decode_response:
if decode_response.status != HTTPStatus.OK:
payload = await decode_response.read()
yield payload
return
async for chunk in decode_response.content.iter_chunked(_STREAM_CHUNK_SIZE):
yield chunk
def _build_bootstrap_payload(prefill_server: str, bootstrap_port: int) -> dict[str, object]:
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
if hostname is None:
raise HTTPException(
status_code=500,
detail=f"Unable to parse prefill hostname from {prefill_server}",
)
return {
"bootstrap_host": hostname,
"bootstrap_port": bootstrap_port,
"bootstrap_room": random.randint(0, 2**63 - 1),
}
def _build_backend_requests(
*,
request_data: dict,
prefill_server: str,
bootstrap_port: int,
) -> tuple[dict, dict]:
prefill_priority = request_data.get("smg_prefill_priority")
decode_priority = request_data.get("smg_decode_priority")
prefill_request = _strip_internal_fields(request_data)
decode_request = _strip_internal_fields(request_data)
bootstrap_payload = _build_bootstrap_payload(prefill_server, bootstrap_port)
prefill_request.update(bootstrap_payload)
decode_request.update(bootstrap_payload)
# session_params is only meaningful for the decode worker (streaming session
# KV reuse). Sending it to the prefill worker causes the D side to
# short-circuit with local-prefill on already-open sessions, returning
# truncated responses while P's KV transfer gets aborted.
prefill_request.pop("session_params", None)
if prefill_priority is not None:
prefill_request["priority"] = int(prefill_priority)
if decode_priority is not None:
decode_request["priority"] = int(decode_priority)
return prefill_request, decode_request
def _strip_internal_fields(request_data: dict) -> dict:
cleaned = request_data.copy()
cleaned.pop("smg_prefill_priority", None)
cleaned.pop("smg_decode_priority", None)
return cleaned
def _require_state() -> RouterState:
if router_state is None:
raise HTTPException(status_code=500, detail="router not initialized")
return router_state
def main() -> None:
parser = argparse.ArgumentParser(description="Minimal local PD / DP router")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--prefill",
nargs=2,
metavar=("URL", "BOOTSTRAP_PORT"),
action="append",
default=None,
)
parser.add_argument(
"--decode",
action="append",
default=None,
)
parser.add_argument("--prefill-policy", default="round_robin")
parser.add_argument("--decode-policy", default="manual")
parser.add_argument(
"--backend",
action="append",
default=None,
help="Backend URL for DP (data-parallel) mode. Repeat for each worker.",
)
parser.add_argument(
"--backend-policy",
default="round_robin",
help="Routing policy for DP mode: round_robin, manual, consistent_hashing.",
)
parser.add_argument("--request-timeout-s", type=float, default=1800.0)
args = parser.parse_args()
global router_state, dp_state
if args.backend:
# DP mode: simple forward to one of N backends
dp_state = DpRouterState(
DpRouterConfig(
host=args.host,
port=args.port,
backend_urls=list(args.backend),
backend_policy=args.backend_policy,
request_timeout_s=args.request_timeout_s,
)
)
elif args.prefill and args.decode:
# PD mode: prefill/decode coordination
router_state = RouterState(
RouterConfig(
host=args.host,
port=args.port,
prefill_urls=[(url, int(port)) for url, port in args.prefill],
decode_urls=list(args.decode),
prefill_policy=args.prefill_policy,
decode_policy=args.decode_policy,
request_timeout_s=args.request_timeout_s,
)
)
else:
parser.error("Either --backend (DP mode) or both --prefill and --decode (PD mode) are required")
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()