Files
agentic-kvc/tests/test_proxy_pick.py
Gahow Wang 19f69a9d2e unified_v2: selective per-request PD-sep via Mooncake (E3+E4)
Adds a sixth routing policy --policy unified_v2 that wraps the
existing unified hybrid picker with a selective PD-sep branch.
When all of the following hold, a request is split prefill-on-src,
decode-on-chosen via Mooncake kv_role=kv_both transfer:

  1. new_local = input_length - chosen.cache_hit > 16k
     (B2 microbench shows same-worker TTFT idx >= 3x from this size up)
  2. chosen has live decodes worth protecting (>= 2 in-flight)
  3. some other instance holds materially more cache for this prefix
     (>= 8k tokens, and >= 4k more than chosen)
  4. cost(src_interference + RDMA xfer) + 0.2s margin < cost(chosen_interference)

The cost model is the audit-blessed shape from E1's post-mortem:
- gate on new_tokens (post-cache), NOT input_length (the old PUSH gate)
- bind to a single transfer mechanism (kv_both peer-to-peer pull)
- realistic RDMA cost as a function of bytes: 0.3s base +
  bytes / 2.7 GB/s (calibrated against contention_16s_elastic p50)
- both source and target decode counts considered

E2 mechanism-level patches not yet applied (this commit is policy-only).
Patches 6.2 / 6.3 / 6.5 remain on the table. Patch 6.6 (per-request
xfer timeout, 60s default) is implemented on the proxy side as an
httpx per-chunk read timeout on the dst streaming call, so a stuck
KV transfer fails the request instead of hanging for 600s.

cache_aware_proxy.py:
- Settings: kv_bytes_per_token, prefill_throughput_kv_both,
  rdma_base_overhead_s, rdma_effective_gb_per_s, pd_sep_* gating knobs
- estimate_transfer_cost(bytes) replaces the constant rdma_overhead_s
- estimate_same_worker_interference_s(new_tokens, num_decodes) reads off
  the B2 penalty curve in 4 bins
- pick_instance_unified_v2: inherits unified, returns extra
  (src_inst, src_idx) tuple when PD-sep wins the cost compare
- _handle_combined_pd_sep_v2: prefill on src (do_remote_decode=True,
  max_tokens=1), Mooncake xfer, decode-stream on dst with httpx
  Timeout(read=pd_sep_xfer_timeout_s)
- --policy unified_v2 added to argparse choices
- lifespan auto-runs init_prefill_bootstrap when policy is unified_v2

b3_isolated_policy.sh:
- ENABLE_KV_BOTH env var, auto-set when POLICY=unified_v2, threads
  kv_role=kv_both + VLLM_MOONCAKE_BOOTSTRAP_PORT to vllm and
  --bootstrap-ports to the proxy

Tests: 8 new unit tests cover the gating predicates and the cost
estimators; all 32 proxy tests still pass.

Refs: E1 (PUSH post-mortem) + E2 (Mooncake audit) reports.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-26 09:25:45 +08:00

456 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Minimal coverage for scripts/cache_aware_proxy pick_instance + cache LRU (S1)."""
from __future__ import annotations
import importlib.util
import sys
import types
from pathlib import Path
import pytest
PROXY_PATH = Path(__file__).resolve().parent.parent / "scripts" / "cache_aware_proxy.py"
def _install_stub_modules() -> None:
"""Provide minimal stand-ins for fastapi/uvicorn/httpx so the proxy
module imports cleanly without the full server deps."""
if "uvicorn" not in sys.modules:
sys.modules["uvicorn"] = types.ModuleType("uvicorn")
if "fastapi" not in sys.modules:
fastapi_mod = types.ModuleType("fastapi")
class _FastAPI:
def __init__(self, *a, **kw):
self.state = types.SimpleNamespace()
def post(self, *a, **kw):
def deco(fn): return fn
return deco
def get(self, *a, **kw):
def deco(fn): return fn
return deco
class _HTTPException(Exception):
def __init__(self, status_code=500, detail=""):
self.status_code = status_code
self.detail = detail
class _Request: # not actually instantiated by the routing tests
pass
fastapi_mod.FastAPI = _FastAPI
fastapi_mod.HTTPException = _HTTPException
fastapi_mod.Request = _Request
sys.modules["fastapi"] = fastapi_mod
responses_mod = types.ModuleType("fastapi.responses")
class _StreamingResponse:
def __init__(self, *a, **kw): pass
responses_mod.StreamingResponse = _StreamingResponse
sys.modules["fastapi.responses"] = responses_mod
if "httpx" not in sys.modules:
httpx_mod = types.ModuleType("httpx")
class _AsyncClient:
def __init__(self, *a, **kw): pass
async def aclose(self): pass
class _Limits:
def __init__(self, *a, **kw): pass
httpx_mod.AsyncClient = _AsyncClient
httpx_mod.Limits = _Limits
sys.modules["httpx"] = httpx_mod
@pytest.fixture(scope="module")
def proxy():
_install_stub_modules()
spec = importlib.util.spec_from_file_location("cache_aware_proxy", PROXY_PATH)
if spec is None or spec.loader is None:
pytest.skip(f"cannot load proxy module at {PROXY_PATH}")
mod = importlib.util.module_from_spec(spec)
sys.modules["cache_aware_proxy"] = mod
try:
spec.loader.exec_module(mod)
except ModuleNotFoundError as exc:
pytest.skip(f"proxy dependency missing: {exc}")
return mod
def _make_inst(proxy, url: str, ongoing_tokens: int = 0,
active_p_offloads: int = 0):
inst = proxy.InstanceState(url)
inst.ongoing_tokens = ongoing_tokens
inst.active_p_offloads = active_p_offloads
return inst
def test_record_prefix_evicts_oldest_block(proxy):
"""LRU bound on cached_blocks must evict the oldest entry once full."""
inst = proxy.InstanceState("http://x")
saved = proxy.SETTINGS.cache_capacity_blocks
proxy.SETTINGS.cache_capacity_blocks = 2
try:
block_size = proxy.BLOCK_SIZE
# Three distinct one-block prefixes; first must be evicted.
prefix_a = [1] * block_size
prefix_b = [2] * block_size
prefix_c = [3] * block_size
inst.record_prefix(prefix_a)
inst.record_prefix(prefix_b)
inst.record_prefix(prefix_c)
assert len(inst.cached_blocks) == 2
# A should have been evicted.
assert inst.estimate_cache_hit(prefix_a) == 0
assert inst.estimate_cache_hit(prefix_b) == block_size
assert inst.estimate_cache_hit(prefix_c) == block_size
finally:
proxy.SETTINGS.cache_capacity_blocks = saved
def test_estimate_cache_hit_touches_lru(proxy):
"""A cache hit must move the block to the MRU position."""
inst = proxy.InstanceState("http://x")
saved = proxy.SETTINGS.cache_capacity_blocks
proxy.SETTINGS.cache_capacity_blocks = 2
try:
block_size = proxy.BLOCK_SIZE
a = [1] * block_size
b = [2] * block_size
c = [3] * block_size
inst.record_prefix(a)
inst.record_prefix(b)
# Touch A so it becomes MRU; B is now LRU.
assert inst.estimate_cache_hit(a) == block_size
# Insert C: B should be evicted, A should remain.
inst.record_prefix(c)
assert inst.estimate_cache_hit(a) == block_size
assert inst.estimate_cache_hit(b) == 0
finally:
proxy.SETTINGS.cache_capacity_blocks = saved
def test_pick_instance_session_affinity_sticks(proxy):
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
affinity = {"sess1": 1}
chosen, idx = proxy.pick_instance(insts, None, "sess1", 1000, affinity)
assert idx == 1 and chosen is insts[1]
def test_pick_instance_session_affinity_breaks_on_overload(proxy):
"""When the pinned instance is heavily overloaded, fallback to load-aware pick."""
insts = [
_make_inst(proxy, "http://a", ongoing_tokens=100),
_make_inst(proxy, "http://b", ongoing_tokens=1_000_000),
_make_inst(proxy, "http://c", ongoing_tokens=100),
]
affinity = {"sess1": 1}
chosen, idx = proxy.pick_instance(insts, None, "sess1", 1000, affinity)
# avg ~333k; B at 1M is ~3x avg, well above OVERLOAD_FACTOR=2.0 -> fallback.
assert idx != 1
assert chosen is not insts[1]
def test_pick_instance_p_offload_penalty_steers_away(proxy):
"""Instances actively running offloaded HEAVY prefills get penalized."""
insts = [
_make_inst(proxy, "http://a", ongoing_tokens=0, active_p_offloads=2),
_make_inst(proxy, "http://b", ongoing_tokens=1000),
]
chosen, idx = proxy.pick_instance(insts, None, None, 5000, {})
# B's 1000-token load is much smaller than A's 2 * HEAVY_THRESHOLD penalty.
assert idx == 1 and chosen is insts[1]
def test_pick_instance_lmetric_picks_lowest_score(proxy):
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
insts[0].pending_prefill_tokens = 0
insts[0].num_requests = 0
insts[1].pending_prefill_tokens = 5000
insts[1].num_requests = 4
chosen, idx = proxy.pick_instance_lmetric(insts, None, None, 1000, {})
# Empty instance has score = 1000 * 0 = 0; busy one has (5000+1000)*4.
assert idx == 0 and chosen is insts[0]
def test_pick_instance_lmetric_ignores_session_affinity(proxy):
"""Review #3: pure --policy lmetric must remain affinity-free."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
# Make inst[1] look much busier than inst[0]; LMetric must still pick 0
# even though affinity points at 1.
insts[0].pending_prefill_tokens = 0
insts[0].num_requests = 0
insts[1].pending_prefill_tokens = 5000
insts[1].num_requests = 4
affinity = {"sess1": 1}
chosen, idx = proxy.pick_instance_lmetric(insts, None, "sess1", 1000, affinity)
assert idx == 0
# Picker must not mutate the affinity dict either.
assert affinity == {"sess1": 1}
def _record_n_blocks(proxy, inst, n: int) -> list[int]:
"""Record n distinct one-block prefixes on inst; return token_ids covering them."""
block_size = proxy.BLOCK_SIZE
tokens: list[int] = []
for b in range(n):
tokens.extend([1000 + b] * block_size)
inst.record_prefix(tokens)
return tokens
def test_hybrid_high_cache_session_sticks_to_affinity(proxy):
"""Hybrid: affinity instance with cache_ratio > 0.5 and no overload → stick."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
tokens = _record_n_blocks(proxy, insts[1], 2) # 2 blocks cached on inst[1]
affinity = {"sess1": 1}
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
insts, tokens, "sess1", len(tokens), affinity)
assert idx == 1 and chosen is insts[1]
assert decision["decision"] == "affinity"
assert decision["affinity_idx"] == 1
assert decision["chosen_idx"] == 1
assert decision["affinity_cache_ratio"] > 0.5
assert decision["tie_break_used"] is False
def test_hybrid_high_cache_breaks_on_overload(proxy):
"""Hybrid: affinity num_requests > avg * overload_factor → fall back to LMetric,
and with realistic new-token tail the LMetric fallback steers off the hot instance."""
insts = [
_make_inst(proxy, "http://a"),
_make_inst(proxy, "http://b"),
_make_inst(proxy, "http://c"),
]
cached = _record_n_blocks(proxy, insts[1], 2)
# Append one more uncached block so LMetric sees a real prefill cost on the
# cached instance too (BS multiplier becomes visible). Without this, the
# cached instance scores 0 * BS = 0 regardless of how loaded it is.
tokens = cached + [999_999] * proxy.BLOCK_SIZE
insts[1].num_requests = 300 # avg = 100; 300 > 100 * 2.0 ✓ breaks the gate
affinity = {"sess1": 1}
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
insts, tokens, "sess1", len(tokens), affinity)
assert decision["decision"] == "lmetric_fallback"
assert decision["affinity_idx"] == 1
assert idx != 1, "affinity instance is overloaded; fallback should steer away"
def test_hybrid_low_cache_falls_back(proxy):
"""Hybrid: cache_ratio <= 0.5 on affinity → fall back to LMetric."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
tokens = [1] * (proxy.BLOCK_SIZE * 2) # 1024 tokens, nothing cached anywhere
affinity = {"sess1": 1}
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
insts, tokens, "sess1", len(tokens), affinity)
assert decision["decision"] == "lmetric_fallback"
assert decision["affinity_cache_ratio"] == 0.0
def test_hybrid_new_session_tie_break_does_not_always_pick_index_0(proxy):
"""Review #4: when all instances tie (e.g. BS=0), tie-break must rotate."""
insts = [_make_inst(proxy, "http://a") for _ in range(3)]
seen = set()
for _ in range(12):
# No session_id, all empty → score = 0 for everyone → ties → rotate.
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
insts, None, None, 100, {})
seen.add(idx)
assert decision["decision"] == "lmetric_fallback"
assert decision["tie_break_used"] is True
assert seen == {0, 1, 2}, f"tie-breaker did not rotate; only saw {seen}"
def test_hybrid_decision_fields_populated(proxy):
"""Review #7: decision dict must carry the breakdown fields."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
_, _, decision = proxy.pick_instance_unified_hybrid(
insts, None, None, 100, {})
expected_keys = {
"decision", "affinity_idx", "chosen_idx",
"affinity_cache_hit", "affinity_cache_ratio", "affinity_num_requests",
"avg_num_requests", "fallback_score", "tie_break_used",
}
assert expected_keys.issubset(decision.keys())
def test_settings_has_runtime_knobs(proxy):
"""D5/B4/M3: Settings dataclass exposes the previously-hardcoded knobs."""
s = proxy.SETTINGS
for field in (
"heavy_threshold",
"overload_factor",
"max_offload_inflight",
"cache_gate_ratio",
"prefill_throughput",
"rdma_overhead_s",
"cache_capacity_blocks",
):
assert hasattr(s, field), f"SETTINGS missing {field}"
# Runtime mutability matters for tests + __main__ override.
saved = s.cache_gate_ratio
s.cache_gate_ratio = 0.55
assert proxy.SETTINGS.cache_gate_ratio == 0.55
s.cache_gate_ratio = saved
def test_pick_instance_load_only_picks_min_num_requests(proxy):
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b"),
_make_inst(proxy, "http://c")]
insts[0].num_requests = 5
insts[1].num_requests = 2
insts[2].num_requests = 8
chosen, idx = proxy.pick_instance_load_only(insts, None, "sess1", 1000, {})
assert idx == 1 and chosen is insts[1]
def test_pick_instance_load_only_ignores_cache_hits(proxy):
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
block_size = proxy.BLOCK_SIZE
prefix = [123] * (block_size * 4)
insts[0].record_prefix(prefix)
insts[0].num_requests = 10
insts[1].num_requests = 0
chosen, idx = proxy.pick_instance_load_only(insts, prefix, None,
len(prefix), {})
assert idx == 1, "load_only must ignore cache hit on inst[0]"
def test_pick_instance_sticky_first_turn_picks_min_load(proxy):
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
insts[0].num_requests = 10
insts[1].num_requests = 2
affinity = {}
chosen, idx = proxy.pick_instance_sticky(insts, None, "sess1", 1000, affinity)
assert idx == 1
assert affinity == {"sess1": 1}
def test_pick_instance_sticky_subsequent_never_breaks(proxy):
"""Once assigned, sticky must never re-route even under massive overload."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
affinity = {"sess1": 0}
insts[0].num_requests = 1_000_000
insts[1].num_requests = 0
chosen, idx = proxy.pick_instance_sticky(insts, None, "sess1", 1000, affinity)
assert idx == 0, "sticky must stay even when pinned instance is saturated"
def test_unified_v2_falls_through_when_new_tokens_small(proxy):
"""If post-cache new tokens < threshold, v2 should not PD-sep."""
insts = [_make_inst(proxy, f"http://h{i}") for i in range(4)]
# Tiny prompt: 2 blocks = 1024 tokens. Below 16k threshold.
tokens = [1] * (proxy.BLOCK_SIZE * 2)
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
insts, tokens, None, len(tokens), {})
assert pd_sep is None
assert decision["v2_decision"] == "local"
assert "below_threshold" in decision["v2_reason"]
def _setup_v2_scene(proxy, *, chosen_decodes: int, src_cache_blocks: int):
"""Build a 4-instance scene where inst[0] wins LMetric and inst[2]
holds optional cache. All instances have non-zero num_requests so
LMetric's bs=0 tie-break doesn't pick an empty instance.
Returns (insts, prefix_tokens).
"""
insts = [_make_inst(proxy, f"http://h{i}") for i in range(4)]
block_size = proxy.BLOCK_SIZE
prefix = []
for b in range(128):
prefix.extend([2000 + b] * block_size) # 128 × 512 = 65536 tokens
if src_cache_blocks > 0:
insts[2].record_prefix(prefix[: src_cache_blocks * block_size])
# Make inst[0] have the smallest LMetric score so chosen = inst[0].
insts[0].num_requests = max(chosen_decodes, 1)
insts[0].pending_prefill_tokens = 0
insts[0].ongoing_decode_tokens = chosen_decodes * 5000
insts[1].num_requests = 20
insts[1].pending_prefill_tokens = 200_000
insts[2].num_requests = 30 # src is busy enough not to win LMetric
insts[2].pending_prefill_tokens = 200_000
insts[3].num_requests = 20
insts[3].pending_prefill_tokens = 200_000
return insts, prefix
def test_unified_v2_falls_through_when_no_alt_cache(proxy):
"""No other instance has meaningful cache → no PD-sep."""
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=5, src_cache_blocks=0)
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
insts, prefix, None, len(prefix), {})
assert idx == 0
assert pd_sep is None
assert "src_cache" in decision["v2_reason"]
def test_unified_v2_triggers_when_src_has_meaningful_cache_and_chosen_has_decodes(proxy):
"""Classic v2-win case: big prefill, chosen has decodes, alt has cache."""
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=5, src_cache_blocks=128)
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
insts, prefix, None, len(prefix), {})
assert idx == 0
assert pd_sep is not None, (
f"expected PD-sep, got reason={decision['v2_reason']}"
)
src, src_idx = pd_sep
assert src_idx == 2
assert decision["v2_decision"] == "pd_sep"
assert decision["v2_src_cache_hit"] >= 60000
def test_unified_v2_falls_through_when_chosen_has_no_decodes(proxy):
"""No decodes on chosen → no benefit from PD-sep."""
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=0, src_cache_blocks=128)
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
insts, prefix, None, len(prefix), {})
assert pd_sep is None
assert "few_decodes" in decision["v2_reason"]
def test_estimate_transfer_cost_is_calibrated_function(proxy):
"""RDMA transfer cost grows with bytes, has a non-zero floor."""
cost_empty = proxy.estimate_transfer_cost(0)
cost_1gb = proxy.estimate_transfer_cost(1024 ** 3)
cost_10gb = proxy.estimate_transfer_cost(10 * 1024 ** 3)
assert cost_empty >= 0.2, "should have non-zero floor"
assert cost_1gb > cost_empty
assert cost_10gb > cost_1gb
# 10 GB should be roughly 0.3 + 10/2.7 ≈ 4.0 s
assert 3.0 < cost_10gb < 5.0
def test_estimate_same_worker_interference_grows_with_size(proxy):
"""Interference cost is monotone in new_tokens up to the saturation regime."""
c1 = proxy.estimate_same_worker_interference_s(2000, num_decodes=4)
c2 = proxy.estimate_same_worker_interference_s(8000, num_decodes=4)
c3 = proxy.estimate_same_worker_interference_s(20000, num_decodes=4)
c4 = proxy.estimate_same_worker_interference_s(32000, num_decodes=4)
assert c1 < c2 < c3 < c4
# Zero decodes -> zero cost regardless of size
assert proxy.estimate_same_worker_interference_s(32000, num_decodes=0) == 0.0
def test_p_offload_penalty_uses_settings_heavy_threshold(proxy):
"""M2: tweaking SETTINGS.heavy_threshold changes the P-offload penalty."""
inst = proxy.InstanceState("http://x")
inst.active_p_offloads = 3
saved = proxy.SETTINGS.heavy_threshold
try:
proxy.SETTINGS.heavy_threshold = 10000
assert proxy._p_offload_penalty(inst) == 30000
proxy.SETTINGS.heavy_threshold = 50000
assert proxy._p_offload_penalty(inst) == 150000
finally:
proxy.SETTINGS.heavy_threshold = saved