Fix A+C: real cache sync + cached-prefill-on-C architecture
A: Add /estimate_hit endpoint to bootstrap server for real-time cache probing. Proxy queries this before committing to PUSH, eliminating 24% zero-match PUSH requests (shadow cache divergence). C: Add _handle_cached_prefill_offload: C (cache source) does fast cached prefill → KV to Mooncake → D pulls and decodes. Replaces broken direct_read PUSH where D waited for RDMA transfer while occupying KV blocks without doing compute. Also: update §3.9 baseline to plain vLLM with full mean/p50/p90/p99. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
28
REPORT.md
28
REPORT.md
@@ -373,20 +373,28 @@ Only 2 measured parameters: `prefill_throughput=7000 tok/s`, `rdma_overhead=0.1s
|
||||
|
||||
**Results (eval_unified_v3, 850/850, 0 errors):**
|
||||
|
||||
| Metric | Baseline | **Unified v3** | Delta |
|
||||
|--------|----------|---------------|-------|
|
||||
| TTFT mean | 4.35s | **3.24s** | **-25.5%** |
|
||||
| TTFT p50 | 0.95s | **0.78s** | **-17.9%** |
|
||||
| TTFT p90 | 12.47s | **7.79s** | **-37.5%** |
|
||||
| TPOT p90 | 0.177 | 0.204 | +14.9% |
|
||||
| E2E mean | 19.10s | **17.69s** | **-7.4%** |
|
||||
| E2E p50 | 6.44s | **5.48s** | **-14.9%** |
|
||||
Baseline = `eval_baseline_linear` (plain vLLM, no Mooncake, linear policy, 850 req, same trace).
|
||||
|
||||
| Metric | Baseline (plain) | **Unified v3 (kv_both)** | Delta |
|
||||
|--------|-----------------|-------------------------|-------|
|
||||
| TTFT mean | 4.348s | **3.277s** | **-24.6%** |
|
||||
| TTFT p50 | 0.945s | **0.793s** | **-16.1%** |
|
||||
| TTFT p90 | 12.468s | **8.472s** | **-32.1%** |
|
||||
| TTFT p99 | 48.149s | **41.587s** | **-13.6%** |
|
||||
| TPOT mean | 0.116s | 0.112s | -3.1% |
|
||||
| TPOT p50 | 0.071s | 0.077s | +8.9% |
|
||||
| TPOT p90 | 0.177s | 0.198s | +11.7% |
|
||||
| TPOT p99 | 1.018s | **0.816s** | **-19.9%** |
|
||||
| E2E mean | 19.10s | 19.81s | +3.7% |
|
||||
| E2E p50 | 6.443s | **5.599s** | **-13.1%** |
|
||||
| E2E p90 | 42.27s | 47.48s | +12.3% |
|
||||
| E2E p99 | 192.2s | 238.0s | +23.8% |
|
||||
|
||||
**Routing**: 723 LOCAL + 116 PUSH_MIGRATE (13.8%). All 116 pushes had cache (avg 25k tokens) — no cold offloads. The unified cost model naturally avoids cold migration because `cold + RDMA > cold` (RDMA adds overhead without reducing prefill).
|
||||
|
||||
**Tradeoff**: TPOT p90 +15% from kv_both background threads + PUSH operations. In exchange: TTFT -38%, E2E -15% at p50.
|
||||
**Tradeoff**: TTFT uniformly improves (p50 -16%, p90 -32%). TPOT mixed: p50/p90 worse (+9%/+12%), but p99 improves (-20%) — PUSH migration relieves the heaviest tail prefills. **E2E tail degrades significantly** (p90 +12%, p99 +24%): kv_both always-on overhead + PUSH transfer latency on migrated requests inflates E2E for long requests, offsetting the TTFT gain. The p50 benefit (-13%) comes from the majority of LOCAL requests getting faster prefill due to reduced queue contention.
|
||||
|
||||
**Output**: `outputs/eval_unified_v3/` on dash0.
|
||||
**Output**: `outputs/eval_unified_v3/` on dash0, baseline from `outputs/eval_baseline_linear/`.
|
||||
|
||||
## 4. System-Level Analysis
|
||||
|
||||
|
||||
@@ -159,6 +159,44 @@ def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] |
|
||||
return instances[best_idx], best_idx
|
||||
|
||||
|
||||
_bootstrap_client: httpx.AsyncClient | None = None
|
||||
|
||||
BOOTSTRAP_TIMEOUT_S = 1.0 # timeout for /estimate_hit calls
|
||||
|
||||
|
||||
async def _get_bootstrap_client() -> httpx.AsyncClient:
|
||||
global _bootstrap_client
|
||||
if _bootstrap_client is None:
|
||||
_bootstrap_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(BOOTSTRAP_TIMEOUT_S),
|
||||
limits=httpx.Limits(max_connections=32, max_keepalive_connections=16),
|
||||
)
|
||||
return _bootstrap_client
|
||||
|
||||
|
||||
async def _query_bootstrap_hit(
|
||||
inst: InstanceState, token_ids: list[int],
|
||||
) -> int | None:
|
||||
"""Query bootstrap's /estimate_hit for real cache hit count.
|
||||
|
||||
Returns hit_tokens on success, None on failure (caller should fallback).
|
||||
"""
|
||||
if inst.bootstrap_port is None:
|
||||
return None
|
||||
parsed = urllib.parse.urlparse(str(inst.client.base_url))
|
||||
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/estimate_hit"
|
||||
try:
|
||||
client = await _get_bootstrap_client()
|
||||
resp = await client.post(url, json={
|
||||
"token_ids": token_ids,
|
||||
"block_size": BLOCK_SIZE,
|
||||
})
|
||||
resp.raise_for_status()
|
||||
return resp.json()["hit_tokens"]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
global_args = None
|
||||
combined_instances: list[InstanceState] = []
|
||||
prefill_instances: list[InstanceState] = []
|
||||
@@ -286,6 +324,8 @@ async def lifespan(app: FastAPI):
|
||||
await reconcile_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if _bootstrap_client is not None:
|
||||
await _bootstrap_client.aclose()
|
||||
for inst in combined_instances + prefill_instances + decode_instances:
|
||||
await inst.client.aclose()
|
||||
|
||||
@@ -397,66 +437,171 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
if best_needs_push:
|
||||
c_inst = combined_instances[best_cache_idx]
|
||||
d_inst = chosen
|
||||
push_cache_hit = best_cache_hit
|
||||
push_new = max(0, input_length - push_cache_hit)
|
||||
|
||||
d_inst.ongoing_tokens += input_length
|
||||
d_inst.pending_prefill_tokens += push_new
|
||||
d_inst.num_requests += 1
|
||||
c_inst.active_p_offloads += 1
|
||||
# Query real cache hit from bootstrap (shadow cache is inaccurate)
|
||||
real_hit = await _query_bootstrap_hit(c_inst, token_ids)
|
||||
breakdown["shadow_cache_hit"] = best_cache_hit
|
||||
breakdown["real_cache_hit"] = real_hit
|
||||
|
||||
breakdown["route_class"] = "PUSH_MIGRATE"
|
||||
breakdown["c_inst"] = c_inst.url
|
||||
breakdown["d_inst"] = d_inst.url
|
||||
breakdown["push_cache_hit"] = push_cache_hit
|
||||
if real_hit is not None:
|
||||
push_cache_hit = real_hit
|
||||
else:
|
||||
push_cache_hit = best_cache_hit # fallback to shadow estimate
|
||||
|
||||
return await _handle_direct_read_offload(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
c_inst, d_inst, push_cache_hit, push_new, breakdown)
|
||||
else:
|
||||
breakdown["route_class"] = "LOCAL"
|
||||
breakdown["routed_to"] = chosen.url
|
||||
# If real hit > 0, proceed with cached prefill on C → decode on D
|
||||
if push_cache_hit > 0:
|
||||
push_new = max(0, input_length - push_cache_hit)
|
||||
|
||||
chosen.ongoing_tokens += input_length
|
||||
chosen.pending_prefill_tokens += estimated_new
|
||||
chosen.num_requests += 1
|
||||
c_inst.ongoing_tokens += input_length
|
||||
c_inst.pending_prefill_tokens += push_new
|
||||
c_inst.num_requests += 1
|
||||
c_inst.active_p_offloads += 1
|
||||
|
||||
async def generate():
|
||||
prefill_done = False
|
||||
try:
|
||||
for attempt in range(MAX_STREAM_RETRIES):
|
||||
try:
|
||||
async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if not prefill_done:
|
||||
chosen.pending_prefill_tokens -= estimated_new
|
||||
chosen.ongoing_decode_tokens += input_length
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
prefill_done = True
|
||||
yield chunk
|
||||
chosen.record_prefix(token_ids)
|
||||
break
|
||||
except (httpx.ConnectError, httpx.RemoteProtocolError):
|
||||
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
|
||||
raise
|
||||
await asyncio.sleep(RETRY_DELAY_S)
|
||||
finally:
|
||||
if not prefill_done:
|
||||
chosen.pending_prefill_tokens -= estimated_new
|
||||
else:
|
||||
chosen.ongoing_decode_tokens -= input_length
|
||||
chosen.ongoing_tokens -= input_length
|
||||
chosen.num_requests -= 1
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
_breakdown_log.append(breakdown)
|
||||
breakdown["route_class"] = "CACHED_PREFILL_OFFLOAD"
|
||||
breakdown["c_inst"] = c_inst.url
|
||||
breakdown["d_inst"] = d_inst.url
|
||||
breakdown["push_cache_hit"] = push_cache_hit
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
return await _handle_cached_prefill_offload(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
c_inst, d_inst, push_cache_hit, push_new, breakdown)
|
||||
|
||||
# Real hit is 0 — downgrade to LOCAL
|
||||
breakdown["push_downgraded"] = True
|
||||
|
||||
# LOCAL path (also handles downgraded PUSH)
|
||||
breakdown["route_class"] = "LOCAL"
|
||||
breakdown["routed_to"] = chosen.url
|
||||
|
||||
chosen.ongoing_tokens += input_length
|
||||
chosen.pending_prefill_tokens += estimated_new
|
||||
chosen.num_requests += 1
|
||||
|
||||
async def generate():
|
||||
prefill_done = False
|
||||
try:
|
||||
for attempt in range(MAX_STREAM_RETRIES):
|
||||
try:
|
||||
async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if not prefill_done:
|
||||
chosen.pending_prefill_tokens -= estimated_new
|
||||
chosen.ongoing_decode_tokens += input_length
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
prefill_done = True
|
||||
yield chunk
|
||||
chosen.record_prefix(token_ids)
|
||||
break
|
||||
except (httpx.ConnectError, httpx.RemoteProtocolError):
|
||||
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
|
||||
raise
|
||||
await asyncio.sleep(RETRY_DELAY_S)
|
||||
finally:
|
||||
if not prefill_done:
|
||||
chosen.pending_prefill_tokens -= estimated_new
|
||||
else:
|
||||
chosen.ongoing_decode_tokens -= input_length
|
||||
chosen.ongoing_tokens -= input_length
|
||||
chosen.num_requests -= 1
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
|
||||
|
||||
|
||||
async def _handle_cached_prefill_offload(api, req_data, headers, token_ids,
|
||||
input_length, c_inst, d_inst,
|
||||
cache_hit, estimated_new, breakdown):
|
||||
"""C does fast cached prefill → KV to Mooncake → D pulls KV and decodes.
|
||||
|
||||
Unlike direct_read (D pulls blocks from C), here C's scheduler IS
|
||||
involved: C prefills (fast, because prefix is cached), pushes KV to
|
||||
Mooncake store, then D pulls and decodes. This avoids the broken
|
||||
PUSH path where D waits for RDMA transfer while occupying KV blocks.
|
||||
"""
|
||||
request_id = headers.get("X-Request-Id", "")
|
||||
|
||||
# Step 1: send blocking prefill to C
|
||||
prefill_data = req_data.copy()
|
||||
prefill_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"transfer_id": f"xfer-{request_id}",
|
||||
}
|
||||
prefill_data["stream"] = False
|
||||
prefill_data["max_tokens"] = 1
|
||||
prefill_data.pop("max_completion_tokens", None)
|
||||
prefill_data.pop("stream_options", None)
|
||||
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||||
|
||||
breakdown["t_prefill_sent"] = _time.monotonic()
|
||||
|
||||
try:
|
||||
resp = await c_inst.client.post(api, json=prefill_data, headers=p_headers)
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
resp.raise_for_status()
|
||||
await resp.aclose()
|
||||
c_inst.record_prefix(token_ids)
|
||||
except Exception as e:
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
breakdown["prefill_error"] = True
|
||||
_breakdown_log.append(breakdown)
|
||||
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
|
||||
c_inst.ongoing_tokens -= input_length
|
||||
c_inst.pending_prefill_tokens -= estimated_new
|
||||
c_inst.num_requests -= 1
|
||||
raise HTTPException(status_code=502, detail=f"Prefill on C failed: {e}")
|
||||
|
||||
c_inst.ongoing_tokens -= input_length
|
||||
c_inst.pending_prefill_tokens -= estimated_new
|
||||
c_inst.num_requests -= 1
|
||||
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
|
||||
|
||||
# Step 2: send decode to D (pull KV from C via Mooncake)
|
||||
d_inst.ongoing_tokens += input_length
|
||||
d_inst.num_requests += 1
|
||||
|
||||
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
|
||||
bootstrap_addr = f"http://{parsed.hostname}:{c_inst.bootstrap_port}"
|
||||
|
||||
decode_data = req_data.copy()
|
||||
decode_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": False,
|
||||
"do_remote_prefill": True,
|
||||
"remote_bootstrap_addr": bootstrap_addr,
|
||||
"remote_engine_id": c_inst.engine_id.get(0, ""),
|
||||
"transfer_id": f"xfer-{request_id}",
|
||||
}
|
||||
|
||||
breakdown["t_decode_sent"] = _time.monotonic()
|
||||
|
||||
async def generate():
|
||||
first_token = True
|
||||
try:
|
||||
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if first_token:
|
||||
d_inst.ongoing_decode_tokens += input_length
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
first_token = False
|
||||
yield chunk
|
||||
d_inst.record_prefix(token_ids)
|
||||
finally:
|
||||
if not first_token:
|
||||
d_inst.ongoing_decode_tokens -= input_length
|
||||
d_inst.ongoing_tokens -= input_length
|
||||
d_inst.num_requests -= 1
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
async def _handle_direct_read_offload(api, req_data, headers, token_ids,
|
||||
input_length, c_inst, d_inst,
|
||||
cache_hit, estimated_new, breakdown):
|
||||
|
||||
@@ -61,6 +61,15 @@ class PushBlocksResponse(BaseModel):
|
||||
pushed: bool
|
||||
|
||||
|
||||
class EstimateHitRequest(BaseModel):
|
||||
token_ids: list[int]
|
||||
block_size: int = 512
|
||||
|
||||
|
||||
class EstimateHitResponse(BaseModel):
|
||||
hit_tokens: int
|
||||
|
||||
|
||||
class UnpinBlocksRequest(BaseModel):
|
||||
pin_token: str
|
||||
|
||||
@@ -98,6 +107,7 @@ class MooncakeBootstrapServer:
|
||||
self.app.post("/query_blocks")(self.query_blocks)
|
||||
self.app.post("/unpin_blocks")(self.unpin_blocks)
|
||||
self.app.post("/push_blocks")(self.push_blocks)
|
||||
self.app.post("/estimate_hit")(self.estimate_hit)
|
||||
|
||||
def start(self):
|
||||
if self.server_thread:
|
||||
@@ -286,6 +296,36 @@ class MooncakeBootstrapServer:
|
||||
self._pinned.pop(req.pin_token, None)
|
||||
return {"status": "ok"}
|
||||
|
||||
async def estimate_hit(self, req: EstimateHitRequest):
|
||||
"""Read-only probe: how many prefix-contiguous tokens are cached?"""
|
||||
if self._kv_info is None:
|
||||
raise HTTPException(503, "Worker KV info not registered yet")
|
||||
|
||||
block_size = req.block_size or self._kv_info.get("block_size", 512)
|
||||
n_tokens = len(req.token_ids)
|
||||
num_blocks = n_tokens // block_size
|
||||
if num_blocks == 0 or not self._hash_table:
|
||||
return EstimateHitResponse(hit_tokens=0)
|
||||
|
||||
import vllm.v1.core.kv_cache_utils as kv_utils
|
||||
from vllm.utils.hashing import sha256
|
||||
|
||||
prev_hash = kv_utils.NONE_HASH
|
||||
hit_blocks = 0
|
||||
for i in range(num_blocks):
|
||||
block_tokens = tuple(
|
||||
req.token_ids[i * block_size:(i + 1) * block_size])
|
||||
block_hash = kv_utils.hash_block_tokens(
|
||||
sha256, prev_hash, block_tokens, None)
|
||||
prev_hash = block_hash
|
||||
|
||||
if self._hash_table.get(block_hash.hex()) is not None:
|
||||
hit_blocks += 1
|
||||
else:
|
||||
break
|
||||
|
||||
return EstimateHitResponse(hit_tokens=hit_blocks * block_size)
|
||||
|
||||
async def push_blocks(self, req: PushBlocksRequest):
|
||||
"""Query matching blocks by token_ids, then PUSH them to D via RDMA write."""
|
||||
if self._kv_info is None or self._transfer_engine is None:
|
||||
|
||||
Reference in New Issue
Block a user