unified_v2: selective per-request PD-sep via Mooncake (E3+E4)
Adds a sixth routing policy --policy unified_v2 that wraps the
existing unified hybrid picker with a selective PD-sep branch.
When all of the following hold, a request is split prefill-on-src,
decode-on-chosen via Mooncake kv_role=kv_both transfer:
1. new_local = input_length - chosen.cache_hit > 16k
(B2 microbench shows same-worker TTFT idx >= 3x from this size up)
2. chosen has live decodes worth protecting (>= 2 in-flight)
3. some other instance holds materially more cache for this prefix
(>= 8k tokens, and >= 4k more than chosen)
4. cost(src_interference + RDMA xfer) + 0.2s margin < cost(chosen_interference)
The cost model is the audit-blessed shape from E1's post-mortem:
- gate on new_tokens (post-cache), NOT input_length (the old PUSH gate)
- bind to a single transfer mechanism (kv_both peer-to-peer pull)
- realistic RDMA cost as a function of bytes: 0.3s base +
bytes / 2.7 GB/s (calibrated against contention_16s_elastic p50)
- both source and target decode counts considered
E2 mechanism-level patches not yet applied (this commit is policy-only).
Patches 6.2 / 6.3 / 6.5 remain on the table. Patch 6.6 (per-request
xfer timeout, 60s default) is implemented on the proxy side as an
httpx per-chunk read timeout on the dst streaming call, so a stuck
KV transfer fails the request instead of hanging for 600s.
cache_aware_proxy.py:
- Settings: kv_bytes_per_token, prefill_throughput_kv_both,
rdma_base_overhead_s, rdma_effective_gb_per_s, pd_sep_* gating knobs
- estimate_transfer_cost(bytes) replaces the constant rdma_overhead_s
- estimate_same_worker_interference_s(new_tokens, num_decodes) reads off
the B2 penalty curve in 4 bins
- pick_instance_unified_v2: inherits unified, returns extra
(src_inst, src_idx) tuple when PD-sep wins the cost compare
- _handle_combined_pd_sep_v2: prefill on src (do_remote_decode=True,
max_tokens=1), Mooncake xfer, decode-stream on dst with httpx
Timeout(read=pd_sep_xfer_timeout_s)
- --policy unified_v2 added to argparse choices
- lifespan auto-runs init_prefill_bootstrap when policy is unified_v2
b3_isolated_policy.sh:
- ENABLE_KV_BOTH env var, auto-set when POLICY=unified_v2, threads
kv_role=kv_both + VLLM_MOONCAKE_BOOTSTRAP_PORT to vllm and
--bootstrap-ports to the proxy
Tests: 8 new unit tests cover the gating predicates and the cost
estimators; all 32 proxy tests still pass.
Refs: E1 (PUSH post-mortem) + E2 (Mooncake audit) reports.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -343,6 +343,104 @@ def test_pick_instance_sticky_subsequent_never_breaks(proxy):
|
||||
assert idx == 0, "sticky must stay even when pinned instance is saturated"
|
||||
|
||||
|
||||
def test_unified_v2_falls_through_when_new_tokens_small(proxy):
|
||||
"""If post-cache new tokens < threshold, v2 should not PD-sep."""
|
||||
insts = [_make_inst(proxy, f"http://h{i}") for i in range(4)]
|
||||
# Tiny prompt: 2 blocks = 1024 tokens. Below 16k threshold.
|
||||
tokens = [1] * (proxy.BLOCK_SIZE * 2)
|
||||
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
|
||||
insts, tokens, None, len(tokens), {})
|
||||
assert pd_sep is None
|
||||
assert decision["v2_decision"] == "local"
|
||||
assert "below_threshold" in decision["v2_reason"]
|
||||
|
||||
|
||||
def _setup_v2_scene(proxy, *, chosen_decodes: int, src_cache_blocks: int):
|
||||
"""Build a 4-instance scene where inst[0] wins LMetric and inst[2]
|
||||
holds optional cache. All instances have non-zero num_requests so
|
||||
LMetric's bs=0 tie-break doesn't pick an empty instance.
|
||||
|
||||
Returns (insts, prefix_tokens).
|
||||
"""
|
||||
insts = [_make_inst(proxy, f"http://h{i}") for i in range(4)]
|
||||
block_size = proxy.BLOCK_SIZE
|
||||
prefix = []
|
||||
for b in range(128):
|
||||
prefix.extend([2000 + b] * block_size) # 128 × 512 = 65536 tokens
|
||||
|
||||
if src_cache_blocks > 0:
|
||||
insts[2].record_prefix(prefix[: src_cache_blocks * block_size])
|
||||
|
||||
# Make inst[0] have the smallest LMetric score so chosen = inst[0].
|
||||
insts[0].num_requests = max(chosen_decodes, 1)
|
||||
insts[0].pending_prefill_tokens = 0
|
||||
insts[0].ongoing_decode_tokens = chosen_decodes * 5000
|
||||
insts[1].num_requests = 20
|
||||
insts[1].pending_prefill_tokens = 200_000
|
||||
insts[2].num_requests = 30 # src is busy enough not to win LMetric
|
||||
insts[2].pending_prefill_tokens = 200_000
|
||||
insts[3].num_requests = 20
|
||||
insts[3].pending_prefill_tokens = 200_000
|
||||
return insts, prefix
|
||||
|
||||
|
||||
def test_unified_v2_falls_through_when_no_alt_cache(proxy):
|
||||
"""No other instance has meaningful cache → no PD-sep."""
|
||||
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=5, src_cache_blocks=0)
|
||||
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
|
||||
insts, prefix, None, len(prefix), {})
|
||||
assert idx == 0
|
||||
assert pd_sep is None
|
||||
assert "src_cache" in decision["v2_reason"]
|
||||
|
||||
|
||||
def test_unified_v2_triggers_when_src_has_meaningful_cache_and_chosen_has_decodes(proxy):
|
||||
"""Classic v2-win case: big prefill, chosen has decodes, alt has cache."""
|
||||
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=5, src_cache_blocks=128)
|
||||
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
|
||||
insts, prefix, None, len(prefix), {})
|
||||
assert idx == 0
|
||||
assert pd_sep is not None, (
|
||||
f"expected PD-sep, got reason={decision['v2_reason']}"
|
||||
)
|
||||
src, src_idx = pd_sep
|
||||
assert src_idx == 2
|
||||
assert decision["v2_decision"] == "pd_sep"
|
||||
assert decision["v2_src_cache_hit"] >= 60000
|
||||
|
||||
|
||||
def test_unified_v2_falls_through_when_chosen_has_no_decodes(proxy):
|
||||
"""No decodes on chosen → no benefit from PD-sep."""
|
||||
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=0, src_cache_blocks=128)
|
||||
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
|
||||
insts, prefix, None, len(prefix), {})
|
||||
assert pd_sep is None
|
||||
assert "few_decodes" in decision["v2_reason"]
|
||||
|
||||
|
||||
def test_estimate_transfer_cost_is_calibrated_function(proxy):
|
||||
"""RDMA transfer cost grows with bytes, has a non-zero floor."""
|
||||
cost_empty = proxy.estimate_transfer_cost(0)
|
||||
cost_1gb = proxy.estimate_transfer_cost(1024 ** 3)
|
||||
cost_10gb = proxy.estimate_transfer_cost(10 * 1024 ** 3)
|
||||
assert cost_empty >= 0.2, "should have non-zero floor"
|
||||
assert cost_1gb > cost_empty
|
||||
assert cost_10gb > cost_1gb
|
||||
# 10 GB should be roughly 0.3 + 10/2.7 ≈ 4.0 s
|
||||
assert 3.0 < cost_10gb < 5.0
|
||||
|
||||
|
||||
def test_estimate_same_worker_interference_grows_with_size(proxy):
|
||||
"""Interference cost is monotone in new_tokens up to the saturation regime."""
|
||||
c1 = proxy.estimate_same_worker_interference_s(2000, num_decodes=4)
|
||||
c2 = proxy.estimate_same_worker_interference_s(8000, num_decodes=4)
|
||||
c3 = proxy.estimate_same_worker_interference_s(20000, num_decodes=4)
|
||||
c4 = proxy.estimate_same_worker_interference_s(32000, num_decodes=4)
|
||||
assert c1 < c2 < c3 < c4
|
||||
# Zero decodes -> zero cost regardless of size
|
||||
assert proxy.estimate_same_worker_interference_s(32000, num_decodes=0) == 0.0
|
||||
|
||||
|
||||
def test_p_offload_penalty_uses_settings_heavy_threshold(proxy):
|
||||
"""M2: tweaking SETTINGS.heavy_threshold changes the P-offload penalty."""
|
||||
inst = proxy.InstanceState("http://x")
|
||||
|
||||
Reference in New Issue
Block a user