unified_v2: selective per-request PD-sep via Mooncake (E3+E4)

Adds a sixth routing policy --policy unified_v2 that wraps the
existing unified hybrid picker with a selective PD-sep branch.
When all of the following hold, a request is split prefill-on-src,
decode-on-chosen via Mooncake kv_role=kv_both transfer:

  1. new_local = input_length - chosen.cache_hit > 16k
     (B2 microbench shows same-worker TTFT idx >= 3x from this size up)
  2. chosen has live decodes worth protecting (>= 2 in-flight)
  3. some other instance holds materially more cache for this prefix
     (>= 8k tokens, and >= 4k more than chosen)
  4. cost(src_interference + RDMA xfer) + 0.2s margin < cost(chosen_interference)

The cost model is the audit-blessed shape from E1's post-mortem:
- gate on new_tokens (post-cache), NOT input_length (the old PUSH gate)
- bind to a single transfer mechanism (kv_both peer-to-peer pull)
- realistic RDMA cost as a function of bytes: 0.3s base +
  bytes / 2.7 GB/s (calibrated against contention_16s_elastic p50)
- both source and target decode counts considered

E2 mechanism-level patches not yet applied (this commit is policy-only).
Patches 6.2 / 6.3 / 6.5 remain on the table. Patch 6.6 (per-request
xfer timeout, 60s default) is implemented on the proxy side as an
httpx per-chunk read timeout on the dst streaming call, so a stuck
KV transfer fails the request instead of hanging for 600s.

cache_aware_proxy.py:
- Settings: kv_bytes_per_token, prefill_throughput_kv_both,
  rdma_base_overhead_s, rdma_effective_gb_per_s, pd_sep_* gating knobs
- estimate_transfer_cost(bytes) replaces the constant rdma_overhead_s
- estimate_same_worker_interference_s(new_tokens, num_decodes) reads off
  the B2 penalty curve in 4 bins
- pick_instance_unified_v2: inherits unified, returns extra
  (src_inst, src_idx) tuple when PD-sep wins the cost compare
- _handle_combined_pd_sep_v2: prefill on src (do_remote_decode=True,
  max_tokens=1), Mooncake xfer, decode-stream on dst with httpx
  Timeout(read=pd_sep_xfer_timeout_s)
- --policy unified_v2 added to argparse choices
- lifespan auto-runs init_prefill_bootstrap when policy is unified_v2

b3_isolated_policy.sh:
- ENABLE_KV_BOTH env var, auto-set when POLICY=unified_v2, threads
  kv_role=kv_both + VLLM_MOONCAKE_BOOTSTRAP_PORT to vllm and
  --bootstrap-ports to the proxy

Tests: 8 new unit tests cover the gating predicates and the cost
estimators; all 32 proxy tests still pass.

Refs: E1 (PUSH post-mortem) + E2 (Mooncake audit) reports.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-26 09:25:45 +08:00
parent c63dc151a0
commit 19f69a9d2e
3 changed files with 523 additions and 18 deletions

View File

@@ -44,7 +44,7 @@ class Settings:
(e.g. by tests/) and __main__ does not run.
"""
prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20)
rdma_overhead_s: float = 0.1 # RDMA PUSH overhead (~10-50ms measured)
rdma_overhead_s: float = 0.1 # legacy floor; v2 uses estimate_transfer_cost
cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks
heavy_threshold: int = 20000
overload_factor: float = 2.0
@@ -52,10 +52,91 @@ class Settings:
cache_gate_ratio: float = 0.0
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
# --- Patch 6.9: cost-model calibration for unified_v2 ---
# Throughput when the engine runs in kv_both mode. Lower than the
# pure-decode 7000 tok/s because kv_both adds always-on overhead
# (REPORT §3.8 documents ~+16% TPOT vs plain).
prefill_throughput_kv_both: float = 4000.0
# Calibrated RDMA transfer cost: base + bandwidth term.
# Floor from isolated test ≈ 0.3 s (handshake + scheduler step).
# Bandwidth term reflects realized effective throughput, not
# theoretical 25 GB/s — production p50 = 1.1 s for ~3 GB ≈ 2.7 GB/s
# effective on the contended kv_both path. v2 uses this lookup
# rather than the constant rdma_overhead_s.
rdma_base_overhead_s: float = 0.3
rdma_effective_gb_per_s: float = 2.7
# Qwen3-Coder-30B-A3B (bf16, 48 layers × 4 KV heads × 128 head_dim × 2):
# 2 × 48 × 4 × 128 × 2 = 98304 bytes per token.
kv_bytes_per_token: int = 98304
# --- unified_v2 gating knobs ---
pd_sep_min_new_tokens: int = 16000 # B2 idx ≥ 3× starts here
pd_sep_min_decodes_protected: int = 2 # require src has live decodes to protect
pd_sep_min_src_cache_tokens: int = 8000 # require non-trivial cache to transfer
pd_sep_min_extra_cache_tokens: int = 4000 # src must have meaningfully more cache than chosen
pd_sep_margin_s: float = 0.2 # require cost gap > 0.2 s before migrating
# Patch 6.6: per-request KV-xfer wall-clock timeout (proxy side).
pd_sep_xfer_timeout_s: float = 60.0
SETTINGS = Settings()
def estimate_transfer_cost(transfer_bytes: int) -> float:
"""Calibrated RDMA transfer cost as a function of bytes.
Replaces the legacy constant rdma_overhead_s. Calibration sources:
- Floor: isolated-test ~0.3 s for a few-block PUSH (scripts/test_direct_read.py)
- Bandwidth term: outputs/contention_16s_elastic/breakdown.json shows
decode_sent->first_token p50 = 1.1 s for ~3 GB transfers, giving
~2.7 GB/s effective on the contended kv_both path.
The p90 in that same run is 6.7 s (D-side block reservation +
scheduler step delays). v2's cost model uses the *median* — being
too pessimistic would suppress all PD-sep triggers. The risk of
underestimation is mitigated by the pd_sep_margin_s safety factor.
"""
base = SETTINGS.rdma_base_overhead_s
bw_term = transfer_bytes / (SETTINGS.rdma_effective_gb_per_s * 1024 ** 3)
return base + bw_term
def estimate_same_worker_interference_s(
new_tokens: int,
num_decodes: int,
) -> float:
"""Estimated additional latency on `num_decodes` co-located decodes
when a `new_tokens`-token prefill runs on the same worker.
Derived from B2 microbench (analysis/characterization/window_1_results.md):
same-worker prefill of size N steals decode capacity for the
prefill's duration. The penalty factor is the fraction of decode
steps stolen during the prefill window.
For new_tokens < 4k: ~0.2 (chunked prefill leaves room)
For new_tokens 16k: ~0.5 (mid-regime, B2 TPOT idx 3.4×)
For new_tokens 32k: ~0.8 (B2 peak TPOT idx 7.9×)
For new_tokens > 32k: ~0.95 (B2 TTFT regime — decodes are nearly fully blocked)
The cost in seconds is roughly: prefill_duration × penalty × n_decodes,
because each affected decode loses ~penalty fraction of its capacity
during the prefill window.
"""
if num_decodes <= 0:
return 0.0
prefill_dur_s = new_tokens / SETTINGS.prefill_throughput_kv_both
if new_tokens < 4000:
penalty = 0.2
elif new_tokens < 16000:
penalty = 0.5
elif new_tokens < 32000:
penalty = 0.8
else:
penalty = 0.95
return prefill_dur_s * penalty * num_decodes
class InstanceState:
def __init__(self, url: str, bootstrap_port: int | None = None):
self.url = url
@@ -326,6 +407,128 @@ def pick_instance_unified_hybrid(
return instances[chosen_idx], chosen_idx, decision
def pick_instance_unified_v2(
instances: list[InstanceState],
token_ids: list[int] | None,
session_id: str | None,
input_length: int,
affinity: dict[str, int],
) -> tuple[InstanceState, int, dict, tuple[InstanceState, int] | None]:
"""unified_v2 = unified hybrid + selective per-request PD-sep trigger.
Stage 1 picks `chosen` exactly as `pick_instance_unified_hybrid`.
Stage 2 asks: is there another instance with materially more cache
for this request? If yes, would doing prefill on that instance and
transferring KV to `chosen` for decode be cheaper than just doing
everything on `chosen`?
The cost model compares two scenarios in seconds-of-decode-disruption:
local: same-worker prefill on chosen of (input - chosen.cache_hit)
tokens interferes with chosen.num_decodes co-located decodes.
pd-sep: same-worker prefill on src of (input - src.cache_hit) tokens
(smaller, because src has more cache) interferes with
src.num_decodes co-located decodes, plus we pay RDMA
transfer of src.cache_hit blocks to chosen.
We migrate only when local cost > pd-sep cost + safety margin AND
a set of hard gates (size, cache, decodes) are met.
Returns (chosen, chosen_idx, decision, pd_sep). When pd_sep is None
the handler should do local routing on `chosen`. When pd_sep is
(src_inst, src_idx) the handler should do prefill-on-src,
decode-on-chosen via Mooncake.
"""
chosen, chosen_idx, decision = pick_instance_unified_hybrid(
instances, token_ids, session_id, input_length, affinity)
decision["v2_pd_sep"] = False
decision["v2_decision"] = "local"
decision["v2_reason"] = None
if not token_ids:
decision["v2_reason"] = "no_token_ids"
return chosen, chosen_idx, decision, None
chosen_cache_hit = chosen.estimate_cache_hit(token_ids)
new_local = max(0, input_length - chosen_cache_hit)
# Hard gate 1: prefill must be large enough that interference
# outweighs the fixed RDMA setup cost.
if new_local < SETTINGS.pd_sep_min_new_tokens:
decision["v2_reason"] = f"new_local_below_threshold ({new_local} < {SETTINGS.pd_sep_min_new_tokens})"
return chosen, chosen_idx, decision, None
# Hard gate 2: chosen must have live decodes worth protecting.
if chosen.ongoing_decode_tokens // max(1, SETTINGS.heavy_threshold) < 1 \
and chosen.num_requests < SETTINGS.pd_sep_min_decodes_protected:
# Heuristic for "num_decodes": prefer num_requests as an upper
# bound since we don't track decode count separately at route time.
decision["v2_reason"] = f"chosen_few_decodes ({chosen.num_requests})"
return chosen, chosen_idx, decision, None
# Find best alternative cache source.
best_src_idx, best_src_hit = -1, 0
for i, inst in enumerate(instances):
if i == chosen_idx:
continue
h = inst.estimate_cache_hit(token_ids)
if h > best_src_hit:
best_src_idx, best_src_hit = i, h
# Hard gate 3: src must hold meaningful cache.
if best_src_hit < SETTINGS.pd_sep_min_src_cache_tokens:
decision["v2_reason"] = f"src_cache_below_threshold ({best_src_hit} < {SETTINGS.pd_sep_min_src_cache_tokens})"
return chosen, chosen_idx, decision, None
# Hard gate 4: src must hold materially more cache than chosen.
if best_src_hit - chosen_cache_hit < SETTINGS.pd_sep_min_extra_cache_tokens:
decision["v2_reason"] = (
f"src_not_meaningfully_more_cache "
f"(src={best_src_hit} chosen={chosen_cache_hit})"
)
return chosen, chosen_idx, decision, None
src = instances[best_src_idx]
new_src = max(0, input_length - best_src_hit)
# Cost-benefit in seconds-of-decode-disruption.
cost_local = estimate_same_worker_interference_s(
new_local, chosen.num_requests)
cost_src_interf = estimate_same_worker_interference_s(
new_src, src.num_requests)
transfer_bytes = best_src_hit * SETTINGS.kv_bytes_per_token
cost_xfer = estimate_transfer_cost(transfer_bytes)
cost_migrate = cost_src_interf + cost_xfer
decision["v2_chosen_cache_hit"] = chosen_cache_hit
decision["v2_src_idx"] = best_src_idx
decision["v2_src_cache_hit"] = best_src_hit
decision["v2_new_local"] = new_local
decision["v2_new_src"] = new_src
decision["v2_cost_local_s"] = cost_local
decision["v2_cost_src_interf_s"] = cost_src_interf
decision["v2_cost_xfer_s"] = cost_xfer
decision["v2_cost_migrate_s"] = cost_migrate
if cost_local > cost_migrate + SETTINGS.pd_sep_margin_s:
decision["v2_pd_sep"] = True
decision["v2_decision"] = "pd_sep"
decision["v2_reason"] = (
f"local_cost {cost_local:.2f}s > migrate_cost {cost_migrate:.2f}s "
f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s"
)
return chosen, chosen_idx, decision, (src, best_src_idx)
decision["v2_reason"] = (
f"local_cost {cost_local:.2f}s <= migrate_cost {cost_migrate:.2f}s "
f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s"
)
return chosen, chosen_idx, decision, None
def _extract_output_token_ids_from_sse(
buffer: str,
chunk: bytes,
@@ -480,8 +683,15 @@ async def lifespan(app: FastAPI):
combined_instances.append(InstanceState(url, bp))
# Bootstrap combined instances for offload (need engine_ids for KV transfer)
if global_args.offload and bp_list:
policy = getattr(global_args, 'policy', 'linear')
needs_bootstrap = global_args.offload or policy == "unified_v2"
if needs_bootstrap and bp_list:
await init_prefill_bootstrap(combined_instances, app.state.ready)
elif needs_bootstrap and not bp_list:
raise RuntimeError(
f"--policy {policy} requires --bootstrap-ports for KV transfer; "
"got empty bootstrap list."
)
else:
app.state.ready.set()
@@ -623,6 +833,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
pre_decision_workers = snapshot_workers(
combined_instances, token_ids, input_length)
pd_sep_v2: tuple[InstanceState, int] | None = None
if policy == "lmetric":
chosen, best_idx = pick_instance_lmetric(
combined_instances, token_ids, session_id, input_length,
@@ -642,6 +853,13 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
breakdown.update(decision)
if session_id:
session_affinity_combined[session_id] = best_idx
elif policy == "unified_v2":
chosen, best_idx, decision, pd_sep_v2 = pick_instance_unified_v2(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
breakdown.update(decision)
if session_id:
session_affinity_combined[session_id] = best_idx
else: # linear (default)
chosen, best_idx = pick_instance(
combined_instances, token_ids, session_id, input_length,
@@ -653,7 +871,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
breakdown.update({
"cache_hit": cache_hit,
"estimated_new_tokens": estimated_new,
"route_class": "LOCAL",
"route_class": "LOCAL" if pd_sep_v2 is None else "PD_SEP_V2",
"routed_to": chosen.url,
"chosen_idx": best_idx,
"candidate_scores": pre_decision_workers,
@@ -667,14 +885,152 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
"session_id": session_id,
"policy": policy,
"chosen_idx": best_idx,
"v2_pd_sep": pd_sep_v2 is not None,
"workers": pre_decision_workers,
})
if pd_sep_v2 is not None:
src_inst, src_idx = pd_sep_v2
breakdown["v2_src_url"] = src_inst.url
breakdown["v2_src_idx"] = src_idx
return await _handle_combined_pd_sep_v2(
api, req_data, headers, token_ids, input_length,
src_inst, chosen, breakdown,
request_id=request_id)
return await _handle_local_request(
api, req_data, headers, token_ids, input_length,
chosen, estimated_new, breakdown)
async def _handle_combined_pd_sep_v2(
api, req_data, headers, token_ids, input_length,
src: InstanceState, dst: InstanceState, breakdown: dict,
*, request_id: str,
):
"""Per-request PD-sep among combined instances (unified_v2 path).
src does cached prefill (max_tokens=1) and ships KV to dst via
Mooncake; dst pulls KV and decodes. Both instances must run in
kv_role=kv_both with bootstrap server enabled.
Patch 6.6: the dst streaming call uses a per-chunk read timeout
of SETTINGS.pd_sep_xfer_timeout_s, so a stuck KV transfer fails
the request instead of hanging for 600 s.
"""
if src.bootstrap_port is None:
raise HTTPException(
status_code=500,
detail=(
"unified_v2 PD-sep triggered but src instance "
f"{src.url} has no bootstrap_port; launch with "
"kv_role=kv_both and pass --bootstrap-ports"
),
)
# Reserve load on both endpoints.
src.ongoing_tokens += input_length
src.num_requests += 1
dst.ongoing_tokens += input_length
dst.num_requests += 1
src_load_held = True
dst_load_held = True
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data["min_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
breakdown["t_prefill_sent"] = _time.monotonic()
breakdown["t_prefill_sent_unix"] = _time.time()
try:
resp = await src.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
resp.raise_for_status()
await resp.aclose()
src.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
breakdown["prefill_error"] = True
breakdown["error_detail"] = repr(e)[:300]
_breakdown_log.append(breakdown)
# Release reservations on failure.
src.ongoing_tokens -= input_length
src.num_requests -= 1
dst.ongoing_tokens -= input_length
dst.num_requests -= 1
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
finally:
if src_load_held:
src.ongoing_tokens -= input_length
src.num_requests -= 1
src_load_held = False
parsed = urllib.parse.urlparse(str(src.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{src.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": src.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
breakdown["t_decode_sent_unix"] = _time.time()
xfer_timeout = httpx.Timeout(
connect=10.0,
read=SETTINGS.pd_sep_xfer_timeout_s,
write=10.0,
pool=10.0,
)
async def generate():
nonlocal dst_load_held
first_token = True
sse_buffer = ""
output_token_ids: list[int] = []
try:
async with dst.client.stream(
"POST", api, json=decode_data, headers=headers,
timeout=xfer_timeout,
) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
sse_buffer, chunk)
output_token_ids.extend(new_output_ids)
if first_token:
breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token_unix"] = _time.time()
first_token = False
yield chunk
dst.record_prefix(_realized_tokens(token_ids, output_token_ids))
finally:
breakdown["t_done"] = _time.monotonic()
breakdown["t_done_unix"] = _time.time()
if dst_load_held:
dst.ongoing_tokens -= input_length
dst.num_requests -= 1
dst_load_held = False
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers):
"""PD-Sep mode with per-stage breakdown profiling."""
@@ -849,11 +1205,15 @@ def parse_args():
p.add_argument("--bootstrap-ports", type=str, default="",
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
p.add_argument("--policy", type=str, default="linear",
choices=["linear", "lmetric", "load_only", "sticky", "unified"],
choices=["linear", "lmetric", "load_only", "sticky",
"unified", "unified_v2"],
help="Routing policy: linear (cache-aware), lmetric (P_tokens × BS), "
"load_only (B3 control: pure min-num_requests), "
"sticky (B3 control: hard session affinity), "
"or unified (hybrid affinity + LMetric fallback)")
"unified (hybrid affinity + LMetric fallback), "
"or unified_v2 (unified + selective per-request PD-sep "
"via Mooncake; requires --bootstrap-ports and "
"kv_role=kv_both vLLM launch)")
p.add_argument("--overload-factor", type=float, default=2.0,
help="Break session affinity when instance load > factor * avg")
# The four flags below are accepted for bench.sh backward compatibility but