Fix A+C: real cache sync + cached-prefill-on-C architecture
A: Add /estimate_hit endpoint to bootstrap server for real-time cache probing. Proxy queries this before committing to PUSH, eliminating 24% zero-match PUSH requests (shadow cache divergence). C: Add _handle_cached_prefill_offload: C (cache source) does fast cached prefill → KV to Mooncake → D pulls and decodes. Replaces broken direct_read PUSH where D waited for RDMA transfer while occupying KV blocks without doing compute. Also: update §3.9 baseline to plain vLLM with full mean/p50/p90/p99. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -159,6 +159,44 @@ def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] |
|
||||
return instances[best_idx], best_idx
|
||||
|
||||
|
||||
_bootstrap_client: httpx.AsyncClient | None = None
|
||||
|
||||
BOOTSTRAP_TIMEOUT_S = 1.0 # timeout for /estimate_hit calls
|
||||
|
||||
|
||||
async def _get_bootstrap_client() -> httpx.AsyncClient:
|
||||
global _bootstrap_client
|
||||
if _bootstrap_client is None:
|
||||
_bootstrap_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(BOOTSTRAP_TIMEOUT_S),
|
||||
limits=httpx.Limits(max_connections=32, max_keepalive_connections=16),
|
||||
)
|
||||
return _bootstrap_client
|
||||
|
||||
|
||||
async def _query_bootstrap_hit(
|
||||
inst: InstanceState, token_ids: list[int],
|
||||
) -> int | None:
|
||||
"""Query bootstrap's /estimate_hit for real cache hit count.
|
||||
|
||||
Returns hit_tokens on success, None on failure (caller should fallback).
|
||||
"""
|
||||
if inst.bootstrap_port is None:
|
||||
return None
|
||||
parsed = urllib.parse.urlparse(str(inst.client.base_url))
|
||||
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/estimate_hit"
|
||||
try:
|
||||
client = await _get_bootstrap_client()
|
||||
resp = await client.post(url, json={
|
||||
"token_ids": token_ids,
|
||||
"block_size": BLOCK_SIZE,
|
||||
})
|
||||
resp.raise_for_status()
|
||||
return resp.json()["hit_tokens"]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
global_args = None
|
||||
combined_instances: list[InstanceState] = []
|
||||
prefill_instances: list[InstanceState] = []
|
||||
@@ -286,6 +324,8 @@ async def lifespan(app: FastAPI):
|
||||
await reconcile_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if _bootstrap_client is not None:
|
||||
await _bootstrap_client.aclose()
|
||||
for inst in combined_instances + prefill_instances + decode_instances:
|
||||
await inst.client.aclose()
|
||||
|
||||
@@ -397,66 +437,171 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
if best_needs_push:
|
||||
c_inst = combined_instances[best_cache_idx]
|
||||
d_inst = chosen
|
||||
push_cache_hit = best_cache_hit
|
||||
push_new = max(0, input_length - push_cache_hit)
|
||||
|
||||
d_inst.ongoing_tokens += input_length
|
||||
d_inst.pending_prefill_tokens += push_new
|
||||
d_inst.num_requests += 1
|
||||
c_inst.active_p_offloads += 1
|
||||
# Query real cache hit from bootstrap (shadow cache is inaccurate)
|
||||
real_hit = await _query_bootstrap_hit(c_inst, token_ids)
|
||||
breakdown["shadow_cache_hit"] = best_cache_hit
|
||||
breakdown["real_cache_hit"] = real_hit
|
||||
|
||||
breakdown["route_class"] = "PUSH_MIGRATE"
|
||||
breakdown["c_inst"] = c_inst.url
|
||||
breakdown["d_inst"] = d_inst.url
|
||||
breakdown["push_cache_hit"] = push_cache_hit
|
||||
if real_hit is not None:
|
||||
push_cache_hit = real_hit
|
||||
else:
|
||||
push_cache_hit = best_cache_hit # fallback to shadow estimate
|
||||
|
||||
return await _handle_direct_read_offload(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
c_inst, d_inst, push_cache_hit, push_new, breakdown)
|
||||
else:
|
||||
breakdown["route_class"] = "LOCAL"
|
||||
breakdown["routed_to"] = chosen.url
|
||||
# If real hit > 0, proceed with cached prefill on C → decode on D
|
||||
if push_cache_hit > 0:
|
||||
push_new = max(0, input_length - push_cache_hit)
|
||||
|
||||
chosen.ongoing_tokens += input_length
|
||||
chosen.pending_prefill_tokens += estimated_new
|
||||
chosen.num_requests += 1
|
||||
c_inst.ongoing_tokens += input_length
|
||||
c_inst.pending_prefill_tokens += push_new
|
||||
c_inst.num_requests += 1
|
||||
c_inst.active_p_offloads += 1
|
||||
|
||||
async def generate():
|
||||
prefill_done = False
|
||||
try:
|
||||
for attempt in range(MAX_STREAM_RETRIES):
|
||||
try:
|
||||
async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if not prefill_done:
|
||||
chosen.pending_prefill_tokens -= estimated_new
|
||||
chosen.ongoing_decode_tokens += input_length
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
prefill_done = True
|
||||
yield chunk
|
||||
chosen.record_prefix(token_ids)
|
||||
break
|
||||
except (httpx.ConnectError, httpx.RemoteProtocolError):
|
||||
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
|
||||
raise
|
||||
await asyncio.sleep(RETRY_DELAY_S)
|
||||
finally:
|
||||
if not prefill_done:
|
||||
chosen.pending_prefill_tokens -= estimated_new
|
||||
else:
|
||||
chosen.ongoing_decode_tokens -= input_length
|
||||
chosen.ongoing_tokens -= input_length
|
||||
chosen.num_requests -= 1
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
_breakdown_log.append(breakdown)
|
||||
breakdown["route_class"] = "CACHED_PREFILL_OFFLOAD"
|
||||
breakdown["c_inst"] = c_inst.url
|
||||
breakdown["d_inst"] = d_inst.url
|
||||
breakdown["push_cache_hit"] = push_cache_hit
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
return await _handle_cached_prefill_offload(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
c_inst, d_inst, push_cache_hit, push_new, breakdown)
|
||||
|
||||
# Real hit is 0 — downgrade to LOCAL
|
||||
breakdown["push_downgraded"] = True
|
||||
|
||||
# LOCAL path (also handles downgraded PUSH)
|
||||
breakdown["route_class"] = "LOCAL"
|
||||
breakdown["routed_to"] = chosen.url
|
||||
|
||||
chosen.ongoing_tokens += input_length
|
||||
chosen.pending_prefill_tokens += estimated_new
|
||||
chosen.num_requests += 1
|
||||
|
||||
async def generate():
|
||||
prefill_done = False
|
||||
try:
|
||||
for attempt in range(MAX_STREAM_RETRIES):
|
||||
try:
|
||||
async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if not prefill_done:
|
||||
chosen.pending_prefill_tokens -= estimated_new
|
||||
chosen.ongoing_decode_tokens += input_length
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
prefill_done = True
|
||||
yield chunk
|
||||
chosen.record_prefix(token_ids)
|
||||
break
|
||||
except (httpx.ConnectError, httpx.RemoteProtocolError):
|
||||
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
|
||||
raise
|
||||
await asyncio.sleep(RETRY_DELAY_S)
|
||||
finally:
|
||||
if not prefill_done:
|
||||
chosen.pending_prefill_tokens -= estimated_new
|
||||
else:
|
||||
chosen.ongoing_decode_tokens -= input_length
|
||||
chosen.ongoing_tokens -= input_length
|
||||
chosen.num_requests -= 1
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
|
||||
|
||||
|
||||
async def _handle_cached_prefill_offload(api, req_data, headers, token_ids,
|
||||
input_length, c_inst, d_inst,
|
||||
cache_hit, estimated_new, breakdown):
|
||||
"""C does fast cached prefill → KV to Mooncake → D pulls KV and decodes.
|
||||
|
||||
Unlike direct_read (D pulls blocks from C), here C's scheduler IS
|
||||
involved: C prefills (fast, because prefix is cached), pushes KV to
|
||||
Mooncake store, then D pulls and decodes. This avoids the broken
|
||||
PUSH path where D waits for RDMA transfer while occupying KV blocks.
|
||||
"""
|
||||
request_id = headers.get("X-Request-Id", "")
|
||||
|
||||
# Step 1: send blocking prefill to C
|
||||
prefill_data = req_data.copy()
|
||||
prefill_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"transfer_id": f"xfer-{request_id}",
|
||||
}
|
||||
prefill_data["stream"] = False
|
||||
prefill_data["max_tokens"] = 1
|
||||
prefill_data.pop("max_completion_tokens", None)
|
||||
prefill_data.pop("stream_options", None)
|
||||
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||||
|
||||
breakdown["t_prefill_sent"] = _time.monotonic()
|
||||
|
||||
try:
|
||||
resp = await c_inst.client.post(api, json=prefill_data, headers=p_headers)
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
resp.raise_for_status()
|
||||
await resp.aclose()
|
||||
c_inst.record_prefix(token_ids)
|
||||
except Exception as e:
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
breakdown["prefill_error"] = True
|
||||
_breakdown_log.append(breakdown)
|
||||
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
|
||||
c_inst.ongoing_tokens -= input_length
|
||||
c_inst.pending_prefill_tokens -= estimated_new
|
||||
c_inst.num_requests -= 1
|
||||
raise HTTPException(status_code=502, detail=f"Prefill on C failed: {e}")
|
||||
|
||||
c_inst.ongoing_tokens -= input_length
|
||||
c_inst.pending_prefill_tokens -= estimated_new
|
||||
c_inst.num_requests -= 1
|
||||
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
|
||||
|
||||
# Step 2: send decode to D (pull KV from C via Mooncake)
|
||||
d_inst.ongoing_tokens += input_length
|
||||
d_inst.num_requests += 1
|
||||
|
||||
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
|
||||
bootstrap_addr = f"http://{parsed.hostname}:{c_inst.bootstrap_port}"
|
||||
|
||||
decode_data = req_data.copy()
|
||||
decode_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": False,
|
||||
"do_remote_prefill": True,
|
||||
"remote_bootstrap_addr": bootstrap_addr,
|
||||
"remote_engine_id": c_inst.engine_id.get(0, ""),
|
||||
"transfer_id": f"xfer-{request_id}",
|
||||
}
|
||||
|
||||
breakdown["t_decode_sent"] = _time.monotonic()
|
||||
|
||||
async def generate():
|
||||
first_token = True
|
||||
try:
|
||||
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if first_token:
|
||||
d_inst.ongoing_decode_tokens += input_length
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
first_token = False
|
||||
yield chunk
|
||||
d_inst.record_prefix(token_ids)
|
||||
finally:
|
||||
if not first_token:
|
||||
d_inst.ongoing_decode_tokens -= input_length
|
||||
d_inst.ongoing_tokens -= input_length
|
||||
d_inst.num_requests -= 1
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
async def _handle_direct_read_offload(api, req_data, headers, token_ids,
|
||||
input_length, c_inst, d_inst,
|
||||
cache_hit, estimated_new, breakdown):
|
||||
|
||||
Reference in New Issue
Block a user