Fix A+C: real cache sync + cached-prefill-on-C architecture

A: Add /estimate_hit endpoint to bootstrap server for real-time cache
   probing. Proxy queries this before committing to PUSH, eliminating
   24% zero-match PUSH requests (shadow cache divergence).

C: Add _handle_cached_prefill_offload: C (cache source) does fast
   cached prefill → KV to Mooncake → D pulls and decodes.
   Replaces broken direct_read PUSH where D waited for RDMA transfer
   while occupying KV blocks without doing compute.

Also: update §3.9 baseline to plain vLLM with full mean/p50/p90/p99.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-24 11:22:38 +08:00
parent 2b9eae0d54
commit cdf83493ab
3 changed files with 252 additions and 59 deletions

View File

@@ -159,6 +159,44 @@ def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] |
return instances[best_idx], best_idx
_bootstrap_client: httpx.AsyncClient | None = None
BOOTSTRAP_TIMEOUT_S = 1.0 # timeout for /estimate_hit calls
async def _get_bootstrap_client() -> httpx.AsyncClient:
global _bootstrap_client
if _bootstrap_client is None:
_bootstrap_client = httpx.AsyncClient(
timeout=httpx.Timeout(BOOTSTRAP_TIMEOUT_S),
limits=httpx.Limits(max_connections=32, max_keepalive_connections=16),
)
return _bootstrap_client
async def _query_bootstrap_hit(
inst: InstanceState, token_ids: list[int],
) -> int | None:
"""Query bootstrap's /estimate_hit for real cache hit count.
Returns hit_tokens on success, None on failure (caller should fallback).
"""
if inst.bootstrap_port is None:
return None
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/estimate_hit"
try:
client = await _get_bootstrap_client()
resp = await client.post(url, json={
"token_ids": token_ids,
"block_size": BLOCK_SIZE,
})
resp.raise_for_status()
return resp.json()["hit_tokens"]
except Exception:
return None
global_args = None
combined_instances: list[InstanceState] = []
prefill_instances: list[InstanceState] = []
@@ -286,6 +324,8 @@ async def lifespan(app: FastAPI):
await reconcile_task
except asyncio.CancelledError:
pass
if _bootstrap_client is not None:
await _bootstrap_client.aclose()
for inst in combined_instances + prefill_instances + decode_instances:
await inst.client.aclose()
@@ -397,66 +437,171 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
if best_needs_push:
c_inst = combined_instances[best_cache_idx]
d_inst = chosen
push_cache_hit = best_cache_hit
push_new = max(0, input_length - push_cache_hit)
d_inst.ongoing_tokens += input_length
d_inst.pending_prefill_tokens += push_new
d_inst.num_requests += 1
c_inst.active_p_offloads += 1
# Query real cache hit from bootstrap (shadow cache is inaccurate)
real_hit = await _query_bootstrap_hit(c_inst, token_ids)
breakdown["shadow_cache_hit"] = best_cache_hit
breakdown["real_cache_hit"] = real_hit
breakdown["route_class"] = "PUSH_MIGRATE"
breakdown["c_inst"] = c_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["push_cache_hit"] = push_cache_hit
if real_hit is not None:
push_cache_hit = real_hit
else:
push_cache_hit = best_cache_hit # fallback to shadow estimate
return await _handle_direct_read_offload(
api, req_data, headers, token_ids, input_length,
c_inst, d_inst, push_cache_hit, push_new, breakdown)
else:
breakdown["route_class"] = "LOCAL"
breakdown["routed_to"] = chosen.url
# If real hit > 0, proceed with cached prefill on C → decode on D
if push_cache_hit > 0:
push_new = max(0, input_length - push_cache_hit)
chosen.ongoing_tokens += input_length
chosen.pending_prefill_tokens += estimated_new
chosen.num_requests += 1
c_inst.ongoing_tokens += input_length
c_inst.pending_prefill_tokens += push_new
c_inst.num_requests += 1
c_inst.active_p_offloads += 1
async def generate():
prefill_done = False
try:
for attempt in range(MAX_STREAM_RETRIES):
try:
async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
chosen.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
prefill_done = True
yield chunk
chosen.record_prefix(token_ids)
break
except (httpx.ConnectError, httpx.RemoteProtocolError):
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
raise
await asyncio.sleep(RETRY_DELAY_S)
finally:
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
else:
chosen.ongoing_decode_tokens -= input_length
chosen.ongoing_tokens -= input_length
chosen.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
breakdown["route_class"] = "CACHED_PREFILL_OFFLOAD"
breakdown["c_inst"] = c_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["push_cache_hit"] = push_cache_hit
return StreamingResponse(generate(), media_type="text/event-stream")
return await _handle_cached_prefill_offload(
api, req_data, headers, token_ids, input_length,
c_inst, d_inst, push_cache_hit, push_new, breakdown)
# Real hit is 0 — downgrade to LOCAL
breakdown["push_downgraded"] = True
# LOCAL path (also handles downgraded PUSH)
breakdown["route_class"] = "LOCAL"
breakdown["routed_to"] = chosen.url
chosen.ongoing_tokens += input_length
chosen.pending_prefill_tokens += estimated_new
chosen.num_requests += 1
async def generate():
prefill_done = False
try:
for attempt in range(MAX_STREAM_RETRIES):
try:
async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
chosen.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
prefill_done = True
yield chunk
chosen.record_prefix(token_ids)
break
except (httpx.ConnectError, httpx.RemoteProtocolError):
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
raise
await asyncio.sleep(RETRY_DELAY_S)
finally:
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
else:
chosen.ongoing_decode_tokens -= input_length
chosen.ongoing_tokens -= input_length
chosen.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
async def _handle_cached_prefill_offload(api, req_data, headers, token_ids,
input_length, c_inst, d_inst,
cache_hit, estimated_new, breakdown):
"""C does fast cached prefill → KV to Mooncake → D pulls KV and decodes.
Unlike direct_read (D pulls blocks from C), here C's scheduler IS
involved: C prefills (fast, because prefix is cached), pushes KV to
Mooncake store, then D pulls and decodes. This avoids the broken
PUSH path where D waits for RDMA transfer while occupying KV blocks.
"""
request_id = headers.get("X-Request-Id", "")
# Step 1: send blocking prefill to C
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
breakdown["t_prefill_sent"] = _time.monotonic()
try:
resp = await c_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
c_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
c_inst.ongoing_tokens -= input_length
c_inst.pending_prefill_tokens -= estimated_new
c_inst.num_requests -= 1
raise HTTPException(status_code=502, detail=f"Prefill on C failed: {e}")
c_inst.ongoing_tokens -= input_length
c_inst.pending_prefill_tokens -= estimated_new
c_inst.num_requests -= 1
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
# Step 2: send decode to D (pull KV from C via Mooncake)
d_inst.ongoing_tokens += input_length
d_inst.num_requests += 1
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{c_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": c_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
d_inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
d_inst.record_prefix(token_ids)
finally:
if not first_token:
d_inst.ongoing_decode_tokens -= input_length
d_inst.ongoing_tokens -= input_length
d_inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_direct_read_offload(api, req_data, headers, token_ids,
input_length, c_inst, d_inst,
cache_hit, estimated_new, breakdown):