Files
agentic-kvc/tests/test_proxy_pick.py
Gahow Wang 4b833d33b7 unified_v2.1: relax gates + add unified_kv_both isolation control
v2.0 ran on B3 and triggered PD-sep only 2 / 1214 times (0.2%). The
gates were too conservative; the v2-vs-v1 latency gap (TTFT p90
7.35 -> 8.96 s) is therefore probably attributable to kv_both
always-on overhead, not to the PD-sep mechanism itself. v2.1 has two
fixes plus an isolation control.

Bug fix:
- The "chosen has live decodes worth protecting" gate combined
  num_requests and ongoing_decode_tokens with AND, falling through
  when EITHER was small. Under agentic workloads each worker rarely
  stacks more than 1-2 concurrent requests, so the gate killed 84%
  of v2.0 candidates that reached it. Replace with a pure
  ongoing_decode_tokens == 0 check ("chosen_no_active_decode") —
  same semantic, much higher recall.

Threshold relaxation (B2 microbench is the calibration source):
- pd_sep_min_new_tokens: 16000 -> 8000 (B2 TPOT idx 1.9x already
  at 8k, TTFT idx 12x — strictly worth migrating)
- pd_sep_min_decodes_protected: 2 -> 1
- pd_sep_min_src_cache_tokens: 8000 -> 4000
- pd_sep_min_extra_cache_tokens: 4000 -> 2000

Isolation control:
- New --policy unified_kv_both option. Uses the exact same picker as
  --policy unified but the vLLMs are launched in kv_role=kv_both
  (the same launch mode unified_v2 requires). PD-sep never fires.
  Compares against unified_v2 to attribute any v2 effect to the
  PD-sep branch alone, not the kv_both always-on overhead.
- Both unified_kv_both and unified_v2 auto-enable kv_both launch in
  b3_isolated_policy.sh.

Tests:
- Updated the existing "chosen has no decodes" test for the new
  gate name and semantic.
- All 24 proxy tests pass.

Refs: window_1_results/v2_breakdown analysis (88.7% of candidates
caught by old new_local_below_threshold; 84% of the remainder
caught by the old few_decodes gate).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-26 10:40:57 +08:00

456 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Minimal coverage for scripts/cache_aware_proxy pick_instance + cache LRU (S1)."""
from __future__ import annotations
import importlib.util
import sys
import types
from pathlib import Path
import pytest
PROXY_PATH = Path(__file__).resolve().parent.parent / "scripts" / "cache_aware_proxy.py"
def _install_stub_modules() -> None:
"""Provide minimal stand-ins for fastapi/uvicorn/httpx so the proxy
module imports cleanly without the full server deps."""
if "uvicorn" not in sys.modules:
sys.modules["uvicorn"] = types.ModuleType("uvicorn")
if "fastapi" not in sys.modules:
fastapi_mod = types.ModuleType("fastapi")
class _FastAPI:
def __init__(self, *a, **kw):
self.state = types.SimpleNamespace()
def post(self, *a, **kw):
def deco(fn): return fn
return deco
def get(self, *a, **kw):
def deco(fn): return fn
return deco
class _HTTPException(Exception):
def __init__(self, status_code=500, detail=""):
self.status_code = status_code
self.detail = detail
class _Request: # not actually instantiated by the routing tests
pass
fastapi_mod.FastAPI = _FastAPI
fastapi_mod.HTTPException = _HTTPException
fastapi_mod.Request = _Request
sys.modules["fastapi"] = fastapi_mod
responses_mod = types.ModuleType("fastapi.responses")
class _StreamingResponse:
def __init__(self, *a, **kw): pass
responses_mod.StreamingResponse = _StreamingResponse
sys.modules["fastapi.responses"] = responses_mod
if "httpx" not in sys.modules:
httpx_mod = types.ModuleType("httpx")
class _AsyncClient:
def __init__(self, *a, **kw): pass
async def aclose(self): pass
class _Limits:
def __init__(self, *a, **kw): pass
httpx_mod.AsyncClient = _AsyncClient
httpx_mod.Limits = _Limits
sys.modules["httpx"] = httpx_mod
@pytest.fixture(scope="module")
def proxy():
_install_stub_modules()
spec = importlib.util.spec_from_file_location("cache_aware_proxy", PROXY_PATH)
if spec is None or spec.loader is None:
pytest.skip(f"cannot load proxy module at {PROXY_PATH}")
mod = importlib.util.module_from_spec(spec)
sys.modules["cache_aware_proxy"] = mod
try:
spec.loader.exec_module(mod)
except ModuleNotFoundError as exc:
pytest.skip(f"proxy dependency missing: {exc}")
return mod
def _make_inst(proxy, url: str, ongoing_tokens: int = 0,
active_p_offloads: int = 0):
inst = proxy.InstanceState(url)
inst.ongoing_tokens = ongoing_tokens
inst.active_p_offloads = active_p_offloads
return inst
def test_record_prefix_evicts_oldest_block(proxy):
"""LRU bound on cached_blocks must evict the oldest entry once full."""
inst = proxy.InstanceState("http://x")
saved = proxy.SETTINGS.cache_capacity_blocks
proxy.SETTINGS.cache_capacity_blocks = 2
try:
block_size = proxy.BLOCK_SIZE
# Three distinct one-block prefixes; first must be evicted.
prefix_a = [1] * block_size
prefix_b = [2] * block_size
prefix_c = [3] * block_size
inst.record_prefix(prefix_a)
inst.record_prefix(prefix_b)
inst.record_prefix(prefix_c)
assert len(inst.cached_blocks) == 2
# A should have been evicted.
assert inst.estimate_cache_hit(prefix_a) == 0
assert inst.estimate_cache_hit(prefix_b) == block_size
assert inst.estimate_cache_hit(prefix_c) == block_size
finally:
proxy.SETTINGS.cache_capacity_blocks = saved
def test_estimate_cache_hit_touches_lru(proxy):
"""A cache hit must move the block to the MRU position."""
inst = proxy.InstanceState("http://x")
saved = proxy.SETTINGS.cache_capacity_blocks
proxy.SETTINGS.cache_capacity_blocks = 2
try:
block_size = proxy.BLOCK_SIZE
a = [1] * block_size
b = [2] * block_size
c = [3] * block_size
inst.record_prefix(a)
inst.record_prefix(b)
# Touch A so it becomes MRU; B is now LRU.
assert inst.estimate_cache_hit(a) == block_size
# Insert C: B should be evicted, A should remain.
inst.record_prefix(c)
assert inst.estimate_cache_hit(a) == block_size
assert inst.estimate_cache_hit(b) == 0
finally:
proxy.SETTINGS.cache_capacity_blocks = saved
def test_pick_instance_session_affinity_sticks(proxy):
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
affinity = {"sess1": 1}
chosen, idx = proxy.pick_instance(insts, None, "sess1", 1000, affinity)
assert idx == 1 and chosen is insts[1]
def test_pick_instance_session_affinity_breaks_on_overload(proxy):
"""When the pinned instance is heavily overloaded, fallback to load-aware pick."""
insts = [
_make_inst(proxy, "http://a", ongoing_tokens=100),
_make_inst(proxy, "http://b", ongoing_tokens=1_000_000),
_make_inst(proxy, "http://c", ongoing_tokens=100),
]
affinity = {"sess1": 1}
chosen, idx = proxy.pick_instance(insts, None, "sess1", 1000, affinity)
# avg ~333k; B at 1M is ~3x avg, well above OVERLOAD_FACTOR=2.0 -> fallback.
assert idx != 1
assert chosen is not insts[1]
def test_pick_instance_p_offload_penalty_steers_away(proxy):
"""Instances actively running offloaded HEAVY prefills get penalized."""
insts = [
_make_inst(proxy, "http://a", ongoing_tokens=0, active_p_offloads=2),
_make_inst(proxy, "http://b", ongoing_tokens=1000),
]
chosen, idx = proxy.pick_instance(insts, None, None, 5000, {})
# B's 1000-token load is much smaller than A's 2 * HEAVY_THRESHOLD penalty.
assert idx == 1 and chosen is insts[1]
def test_pick_instance_lmetric_picks_lowest_score(proxy):
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
insts[0].pending_prefill_tokens = 0
insts[0].num_requests = 0
insts[1].pending_prefill_tokens = 5000
insts[1].num_requests = 4
chosen, idx = proxy.pick_instance_lmetric(insts, None, None, 1000, {})
# Empty instance has score = 1000 * 0 = 0; busy one has (5000+1000)*4.
assert idx == 0 and chosen is insts[0]
def test_pick_instance_lmetric_ignores_session_affinity(proxy):
"""Review #3: pure --policy lmetric must remain affinity-free."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
# Make inst[1] look much busier than inst[0]; LMetric must still pick 0
# even though affinity points at 1.
insts[0].pending_prefill_tokens = 0
insts[0].num_requests = 0
insts[1].pending_prefill_tokens = 5000
insts[1].num_requests = 4
affinity = {"sess1": 1}
chosen, idx = proxy.pick_instance_lmetric(insts, None, "sess1", 1000, affinity)
assert idx == 0
# Picker must not mutate the affinity dict either.
assert affinity == {"sess1": 1}
def _record_n_blocks(proxy, inst, n: int) -> list[int]:
"""Record n distinct one-block prefixes on inst; return token_ids covering them."""
block_size = proxy.BLOCK_SIZE
tokens: list[int] = []
for b in range(n):
tokens.extend([1000 + b] * block_size)
inst.record_prefix(tokens)
return tokens
def test_hybrid_high_cache_session_sticks_to_affinity(proxy):
"""Hybrid: affinity instance with cache_ratio > 0.5 and no overload → stick."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
tokens = _record_n_blocks(proxy, insts[1], 2) # 2 blocks cached on inst[1]
affinity = {"sess1": 1}
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
insts, tokens, "sess1", len(tokens), affinity)
assert idx == 1 and chosen is insts[1]
assert decision["decision"] == "affinity"
assert decision["affinity_idx"] == 1
assert decision["chosen_idx"] == 1
assert decision["affinity_cache_ratio"] > 0.5
assert decision["tie_break_used"] is False
def test_hybrid_high_cache_breaks_on_overload(proxy):
"""Hybrid: affinity num_requests > avg * overload_factor → fall back to LMetric,
and with realistic new-token tail the LMetric fallback steers off the hot instance."""
insts = [
_make_inst(proxy, "http://a"),
_make_inst(proxy, "http://b"),
_make_inst(proxy, "http://c"),
]
cached = _record_n_blocks(proxy, insts[1], 2)
# Append one more uncached block so LMetric sees a real prefill cost on the
# cached instance too (BS multiplier becomes visible). Without this, the
# cached instance scores 0 * BS = 0 regardless of how loaded it is.
tokens = cached + [999_999] * proxy.BLOCK_SIZE
insts[1].num_requests = 300 # avg = 100; 300 > 100 * 2.0 ✓ breaks the gate
affinity = {"sess1": 1}
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
insts, tokens, "sess1", len(tokens), affinity)
assert decision["decision"] == "lmetric_fallback"
assert decision["affinity_idx"] == 1
assert idx != 1, "affinity instance is overloaded; fallback should steer away"
def test_hybrid_low_cache_falls_back(proxy):
"""Hybrid: cache_ratio <= 0.5 on affinity → fall back to LMetric."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
tokens = [1] * (proxy.BLOCK_SIZE * 2) # 1024 tokens, nothing cached anywhere
affinity = {"sess1": 1}
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
insts, tokens, "sess1", len(tokens), affinity)
assert decision["decision"] == "lmetric_fallback"
assert decision["affinity_cache_ratio"] == 0.0
def test_hybrid_new_session_tie_break_does_not_always_pick_index_0(proxy):
"""Review #4: when all instances tie (e.g. BS=0), tie-break must rotate."""
insts = [_make_inst(proxy, "http://a") for _ in range(3)]
seen = set()
for _ in range(12):
# No session_id, all empty → score = 0 for everyone → ties → rotate.
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
insts, None, None, 100, {})
seen.add(idx)
assert decision["decision"] == "lmetric_fallback"
assert decision["tie_break_used"] is True
assert seen == {0, 1, 2}, f"tie-breaker did not rotate; only saw {seen}"
def test_hybrid_decision_fields_populated(proxy):
"""Review #7: decision dict must carry the breakdown fields."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
_, _, decision = proxy.pick_instance_unified_hybrid(
insts, None, None, 100, {})
expected_keys = {
"decision", "affinity_idx", "chosen_idx",
"affinity_cache_hit", "affinity_cache_ratio", "affinity_num_requests",
"avg_num_requests", "fallback_score", "tie_break_used",
}
assert expected_keys.issubset(decision.keys())
def test_settings_has_runtime_knobs(proxy):
"""D5/B4/M3: Settings dataclass exposes the previously-hardcoded knobs."""
s = proxy.SETTINGS
for field in (
"heavy_threshold",
"overload_factor",
"max_offload_inflight",
"cache_gate_ratio",
"prefill_throughput",
"rdma_overhead_s",
"cache_capacity_blocks",
):
assert hasattr(s, field), f"SETTINGS missing {field}"
# Runtime mutability matters for tests + __main__ override.
saved = s.cache_gate_ratio
s.cache_gate_ratio = 0.55
assert proxy.SETTINGS.cache_gate_ratio == 0.55
s.cache_gate_ratio = saved
def test_pick_instance_load_only_picks_min_num_requests(proxy):
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b"),
_make_inst(proxy, "http://c")]
insts[0].num_requests = 5
insts[1].num_requests = 2
insts[2].num_requests = 8
chosen, idx = proxy.pick_instance_load_only(insts, None, "sess1", 1000, {})
assert idx == 1 and chosen is insts[1]
def test_pick_instance_load_only_ignores_cache_hits(proxy):
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
block_size = proxy.BLOCK_SIZE
prefix = [123] * (block_size * 4)
insts[0].record_prefix(prefix)
insts[0].num_requests = 10
insts[1].num_requests = 0
chosen, idx = proxy.pick_instance_load_only(insts, prefix, None,
len(prefix), {})
assert idx == 1, "load_only must ignore cache hit on inst[0]"
def test_pick_instance_sticky_first_turn_picks_min_load(proxy):
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
insts[0].num_requests = 10
insts[1].num_requests = 2
affinity = {}
chosen, idx = proxy.pick_instance_sticky(insts, None, "sess1", 1000, affinity)
assert idx == 1
assert affinity == {"sess1": 1}
def test_pick_instance_sticky_subsequent_never_breaks(proxy):
"""Once assigned, sticky must never re-route even under massive overload."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
affinity = {"sess1": 0}
insts[0].num_requests = 1_000_000
insts[1].num_requests = 0
chosen, idx = proxy.pick_instance_sticky(insts, None, "sess1", 1000, affinity)
assert idx == 0, "sticky must stay even when pinned instance is saturated"
def test_unified_v2_falls_through_when_new_tokens_small(proxy):
"""If post-cache new tokens < threshold, v2 should not PD-sep."""
insts = [_make_inst(proxy, f"http://h{i}") for i in range(4)]
# Tiny prompt: 2 blocks = 1024 tokens. Below 16k threshold.
tokens = [1] * (proxy.BLOCK_SIZE * 2)
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
insts, tokens, None, len(tokens), {})
assert pd_sep is None
assert decision["v2_decision"] == "local"
assert "below_threshold" in decision["v2_reason"]
def _setup_v2_scene(proxy, *, chosen_decodes: int, src_cache_blocks: int):
"""Build a 4-instance scene where inst[0] wins LMetric and inst[2]
holds optional cache. All instances have non-zero num_requests so
LMetric's bs=0 tie-break doesn't pick an empty instance.
Returns (insts, prefix_tokens).
"""
insts = [_make_inst(proxy, f"http://h{i}") for i in range(4)]
block_size = proxy.BLOCK_SIZE
prefix = []
for b in range(128):
prefix.extend([2000 + b] * block_size) # 128 × 512 = 65536 tokens
if src_cache_blocks > 0:
insts[2].record_prefix(prefix[: src_cache_blocks * block_size])
# Make inst[0] have the smallest LMetric score so chosen = inst[0].
insts[0].num_requests = max(chosen_decodes, 1)
insts[0].pending_prefill_tokens = 0
insts[0].ongoing_decode_tokens = chosen_decodes * 5000
insts[1].num_requests = 20
insts[1].pending_prefill_tokens = 200_000
insts[2].num_requests = 30 # src is busy enough not to win LMetric
insts[2].pending_prefill_tokens = 200_000
insts[3].num_requests = 20
insts[3].pending_prefill_tokens = 200_000
return insts, prefix
def test_unified_v2_falls_through_when_no_alt_cache(proxy):
"""No other instance has meaningful cache → no PD-sep."""
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=5, src_cache_blocks=0)
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
insts, prefix, None, len(prefix), {})
assert idx == 0
assert pd_sep is None
assert "src_cache" in decision["v2_reason"]
def test_unified_v2_triggers_when_src_has_meaningful_cache_and_chosen_has_decodes(proxy):
"""Classic v2-win case: big prefill, chosen has decodes, alt has cache."""
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=5, src_cache_blocks=128)
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
insts, prefix, None, len(prefix), {})
assert idx == 0
assert pd_sep is not None, (
f"expected PD-sep, got reason={decision['v2_reason']}"
)
src, src_idx = pd_sep
assert src_idx == 2
assert decision["v2_decision"] == "pd_sep"
assert decision["v2_src_cache_hit"] >= 60000
def test_unified_v2_falls_through_when_chosen_has_no_decodes(proxy):
"""No decoding work on chosen → no benefit from PD-sep."""
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=0, src_cache_blocks=128)
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
insts, prefix, None, len(prefix), {})
assert pd_sep is None
assert "no_active_decode" in decision["v2_reason"]
def test_estimate_transfer_cost_is_calibrated_function(proxy):
"""RDMA transfer cost grows with bytes, has a non-zero floor."""
cost_empty = proxy.estimate_transfer_cost(0)
cost_1gb = proxy.estimate_transfer_cost(1024 ** 3)
cost_10gb = proxy.estimate_transfer_cost(10 * 1024 ** 3)
assert cost_empty >= 0.2, "should have non-zero floor"
assert cost_1gb > cost_empty
assert cost_10gb > cost_1gb
# 10 GB should be roughly 0.3 + 10/2.7 ≈ 4.0 s
assert 3.0 < cost_10gb < 5.0
def test_estimate_same_worker_interference_grows_with_size(proxy):
"""Interference cost is monotone in new_tokens up to the saturation regime."""
c1 = proxy.estimate_same_worker_interference_s(2000, num_decodes=4)
c2 = proxy.estimate_same_worker_interference_s(8000, num_decodes=4)
c3 = proxy.estimate_same_worker_interference_s(20000, num_decodes=4)
c4 = proxy.estimate_same_worker_interference_s(32000, num_decodes=4)
assert c1 < c2 < c3 < c4
# Zero decodes -> zero cost regardless of size
assert proxy.estimate_same_worker_interference_s(32000, num_decodes=0) == 0.0
def test_p_offload_penalty_uses_settings_heavy_threshold(proxy):
"""M2: tweaking SETTINGS.heavy_threshold changes the P-offload penalty."""
inst = proxy.InstanceState("http://x")
inst.active_p_offloads = 3
saved = proxy.SETTINGS.heavy_threshold
try:
proxy.SETTINGS.heavy_threshold = 10000
assert proxy._p_offload_penalty(inst) == 30000
proxy.SETTINGS.heavy_threshold = 50000
assert proxy._p_offload_penalty(inst) == 150000
finally:
proxy.SETTINGS.heavy_threshold = saved