Fix A+C: real cache sync + cached-prefill-on-C architecture

A: Add /estimate_hit endpoint to bootstrap server for real-time cache
   probing. Proxy queries this before committing to PUSH, eliminating
   24% zero-match PUSH requests (shadow cache divergence).

C: Add _handle_cached_prefill_offload: C (cache source) does fast
   cached prefill → KV to Mooncake → D pulls and decodes.
   Replaces broken direct_read PUSH where D waited for RDMA transfer
   while occupying KV blocks without doing compute.

Also: update §3.9 baseline to plain vLLM with full mean/p50/p90/p99.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-24 11:22:38 +08:00
parent 2b9eae0d54
commit cdf83493ab
3 changed files with 252 additions and 59 deletions

View File

@@ -373,20 +373,28 @@ Only 2 measured parameters: `prefill_throughput=7000 tok/s`, `rdma_overhead=0.1s
**Results (eval_unified_v3, 850/850, 0 errors):**
| Metric | Baseline | **Unified v3** | Delta |
|--------|----------|---------------|-------|
| TTFT mean | 4.35s | **3.24s** | **-25.5%** |
| TTFT p50 | 0.95s | **0.78s** | **-17.9%** |
| TTFT p90 | 12.47s | **7.79s** | **-37.5%** |
| TPOT p90 | 0.177 | 0.204 | +14.9% |
| E2E mean | 19.10s | **17.69s** | **-7.4%** |
| E2E p50 | 6.44s | **5.48s** | **-14.9%** |
Baseline = `eval_baseline_linear` (plain vLLM, no Mooncake, linear policy, 850 req, same trace).
| Metric | Baseline (plain) | **Unified v3 (kv_both)** | Delta |
|--------|-----------------|-------------------------|-------|
| TTFT mean | 4.348s | **3.277s** | **-24.6%** |
| TTFT p50 | 0.945s | **0.793s** | **-16.1%** |
| TTFT p90 | 12.468s | **8.472s** | **-32.1%** |
| TTFT p99 | 48.149s | **41.587s** | **-13.6%** |
| TPOT mean | 0.116s | 0.112s | -3.1% |
| TPOT p50 | 0.071s | 0.077s | +8.9% |
| TPOT p90 | 0.177s | 0.198s | +11.7% |
| TPOT p99 | 1.018s | **0.816s** | **-19.9%** |
| E2E mean | 19.10s | 19.81s | +3.7% |
| E2E p50 | 6.443s | **5.599s** | **-13.1%** |
| E2E p90 | 42.27s | 47.48s | +12.3% |
| E2E p99 | 192.2s | 238.0s | +23.8% |
**Routing**: 723 LOCAL + 116 PUSH_MIGRATE (13.8%). All 116 pushes had cache (avg 25k tokens) no cold offloads. The unified cost model naturally avoids cold migration because `cold + RDMA > cold` (RDMA adds overhead without reducing prefill).
**Tradeoff**: TPOT p90 +15% from kv_both background threads + PUSH operations. In exchange: TTFT -38%, E2E -15% at p50.
**Tradeoff**: TTFT uniformly improves (p50 -16%, p90 -32%). TPOT mixed: p50/p90 worse (+9%/+12%), but p99 improves (-20%) PUSH migration relieves the heaviest tail prefills. **E2E tail degrades significantly** (p90 +12%, p99 +24%): kv_both always-on overhead + PUSH transfer latency on migrated requests inflates E2E for long requests, offsetting the TTFT gain. The p50 benefit (-13%) comes from the majority of LOCAL requests getting faster prefill due to reduced queue contention.
**Output**: `outputs/eval_unified_v3/` on dash0.
**Output**: `outputs/eval_unified_v3/` on dash0, baseline from `outputs/eval_baseline_linear/`.
## 4. System-Level Analysis

View File

@@ -159,6 +159,44 @@ def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] |
return instances[best_idx], best_idx
_bootstrap_client: httpx.AsyncClient | None = None
BOOTSTRAP_TIMEOUT_S = 1.0 # timeout for /estimate_hit calls
async def _get_bootstrap_client() -> httpx.AsyncClient:
global _bootstrap_client
if _bootstrap_client is None:
_bootstrap_client = httpx.AsyncClient(
timeout=httpx.Timeout(BOOTSTRAP_TIMEOUT_S),
limits=httpx.Limits(max_connections=32, max_keepalive_connections=16),
)
return _bootstrap_client
async def _query_bootstrap_hit(
inst: InstanceState, token_ids: list[int],
) -> int | None:
"""Query bootstrap's /estimate_hit for real cache hit count.
Returns hit_tokens on success, None on failure (caller should fallback).
"""
if inst.bootstrap_port is None:
return None
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/estimate_hit"
try:
client = await _get_bootstrap_client()
resp = await client.post(url, json={
"token_ids": token_ids,
"block_size": BLOCK_SIZE,
})
resp.raise_for_status()
return resp.json()["hit_tokens"]
except Exception:
return None
global_args = None
combined_instances: list[InstanceState] = []
prefill_instances: list[InstanceState] = []
@@ -286,6 +324,8 @@ async def lifespan(app: FastAPI):
await reconcile_task
except asyncio.CancelledError:
pass
if _bootstrap_client is not None:
await _bootstrap_client.aclose()
for inst in combined_instances + prefill_instances + decode_instances:
await inst.client.aclose()
@@ -397,66 +437,171 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
if best_needs_push:
c_inst = combined_instances[best_cache_idx]
d_inst = chosen
push_cache_hit = best_cache_hit
push_new = max(0, input_length - push_cache_hit)
d_inst.ongoing_tokens += input_length
d_inst.pending_prefill_tokens += push_new
d_inst.num_requests += 1
c_inst.active_p_offloads += 1
# Query real cache hit from bootstrap (shadow cache is inaccurate)
real_hit = await _query_bootstrap_hit(c_inst, token_ids)
breakdown["shadow_cache_hit"] = best_cache_hit
breakdown["real_cache_hit"] = real_hit
breakdown["route_class"] = "PUSH_MIGRATE"
breakdown["c_inst"] = c_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["push_cache_hit"] = push_cache_hit
if real_hit is not None:
push_cache_hit = real_hit
else:
push_cache_hit = best_cache_hit # fallback to shadow estimate
return await _handle_direct_read_offload(
api, req_data, headers, token_ids, input_length,
c_inst, d_inst, push_cache_hit, push_new, breakdown)
else:
breakdown["route_class"] = "LOCAL"
breakdown["routed_to"] = chosen.url
# If real hit > 0, proceed with cached prefill on C → decode on D
if push_cache_hit > 0:
push_new = max(0, input_length - push_cache_hit)
chosen.ongoing_tokens += input_length
chosen.pending_prefill_tokens += estimated_new
chosen.num_requests += 1
c_inst.ongoing_tokens += input_length
c_inst.pending_prefill_tokens += push_new
c_inst.num_requests += 1
c_inst.active_p_offloads += 1
async def generate():
prefill_done = False
try:
for attempt in range(MAX_STREAM_RETRIES):
try:
async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
chosen.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
prefill_done = True
yield chunk
chosen.record_prefix(token_ids)
break
except (httpx.ConnectError, httpx.RemoteProtocolError):
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
raise
await asyncio.sleep(RETRY_DELAY_S)
finally:
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
else:
chosen.ongoing_decode_tokens -= input_length
chosen.ongoing_tokens -= input_length
chosen.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
breakdown["route_class"] = "CACHED_PREFILL_OFFLOAD"
breakdown["c_inst"] = c_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["push_cache_hit"] = push_cache_hit
return StreamingResponse(generate(), media_type="text/event-stream")
return await _handle_cached_prefill_offload(
api, req_data, headers, token_ids, input_length,
c_inst, d_inst, push_cache_hit, push_new, breakdown)
# Real hit is 0 — downgrade to LOCAL
breakdown["push_downgraded"] = True
# LOCAL path (also handles downgraded PUSH)
breakdown["route_class"] = "LOCAL"
breakdown["routed_to"] = chosen.url
chosen.ongoing_tokens += input_length
chosen.pending_prefill_tokens += estimated_new
chosen.num_requests += 1
async def generate():
prefill_done = False
try:
for attempt in range(MAX_STREAM_RETRIES):
try:
async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
chosen.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
prefill_done = True
yield chunk
chosen.record_prefix(token_ids)
break
except (httpx.ConnectError, httpx.RemoteProtocolError):
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
raise
await asyncio.sleep(RETRY_DELAY_S)
finally:
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
else:
chosen.ongoing_decode_tokens -= input_length
chosen.ongoing_tokens -= input_length
chosen.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
async def _handle_cached_prefill_offload(api, req_data, headers, token_ids,
input_length, c_inst, d_inst,
cache_hit, estimated_new, breakdown):
"""C does fast cached prefill → KV to Mooncake → D pulls KV and decodes.
Unlike direct_read (D pulls blocks from C), here C's scheduler IS
involved: C prefills (fast, because prefix is cached), pushes KV to
Mooncake store, then D pulls and decodes. This avoids the broken
PUSH path where D waits for RDMA transfer while occupying KV blocks.
"""
request_id = headers.get("X-Request-Id", "")
# Step 1: send blocking prefill to C
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
breakdown["t_prefill_sent"] = _time.monotonic()
try:
resp = await c_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
c_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
c_inst.ongoing_tokens -= input_length
c_inst.pending_prefill_tokens -= estimated_new
c_inst.num_requests -= 1
raise HTTPException(status_code=502, detail=f"Prefill on C failed: {e}")
c_inst.ongoing_tokens -= input_length
c_inst.pending_prefill_tokens -= estimated_new
c_inst.num_requests -= 1
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
# Step 2: send decode to D (pull KV from C via Mooncake)
d_inst.ongoing_tokens += input_length
d_inst.num_requests += 1
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{c_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": c_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
d_inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
d_inst.record_prefix(token_ids)
finally:
if not first_token:
d_inst.ongoing_decode_tokens -= input_length
d_inst.ongoing_tokens -= input_length
d_inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_direct_read_offload(api, req_data, headers, token_ids,
input_length, c_inst, d_inst,
cache_hit, estimated_new, breakdown):

View File

@@ -61,6 +61,15 @@ class PushBlocksResponse(BaseModel):
pushed: bool
class EstimateHitRequest(BaseModel):
token_ids: list[int]
block_size: int = 512
class EstimateHitResponse(BaseModel):
hit_tokens: int
class UnpinBlocksRequest(BaseModel):
pin_token: str
@@ -98,6 +107,7 @@ class MooncakeBootstrapServer:
self.app.post("/query_blocks")(self.query_blocks)
self.app.post("/unpin_blocks")(self.unpin_blocks)
self.app.post("/push_blocks")(self.push_blocks)
self.app.post("/estimate_hit")(self.estimate_hit)
def start(self):
if self.server_thread:
@@ -286,6 +296,36 @@ class MooncakeBootstrapServer:
self._pinned.pop(req.pin_token, None)
return {"status": "ok"}
async def estimate_hit(self, req: EstimateHitRequest):
"""Read-only probe: how many prefix-contiguous tokens are cached?"""
if self._kv_info is None:
raise HTTPException(503, "Worker KV info not registered yet")
block_size = req.block_size or self._kv_info.get("block_size", 512)
n_tokens = len(req.token_ids)
num_blocks = n_tokens // block_size
if num_blocks == 0 or not self._hash_table:
return EstimateHitResponse(hit_tokens=0)
import vllm.v1.core.kv_cache_utils as kv_utils
from vllm.utils.hashing import sha256
prev_hash = kv_utils.NONE_HASH
hit_blocks = 0
for i in range(num_blocks):
block_tokens = tuple(
req.token_ids[i * block_size:(i + 1) * block_size])
block_hash = kv_utils.hash_block_tokens(
sha256, prev_hash, block_tokens, None)
prev_hash = block_hash
if self._hash_table.get(block_hash.hex()) is not None:
hit_blocks += 1
else:
break
return EstimateHitResponse(hit_tokens=hit_blocks * block_size)
async def push_blocks(self, req: PushBlocksRequest):
"""Query matching blocks by token_ids, then PUSH them to D via RDMA write."""
if self._kv_info is None or self._transfer_engine is None: