"""Minimal coverage for scripts/cache_aware_proxy pick_instance + cache LRU (S1).""" from __future__ import annotations import importlib.util import sys import types from pathlib import Path import pytest PROXY_PATH = Path(__file__).resolve().parent.parent / "scripts" / "cache_aware_proxy.py" def _install_stub_modules() -> None: """Provide minimal stand-ins for fastapi/uvicorn/httpx so the proxy module imports cleanly without the full server deps.""" if "uvicorn" not in sys.modules: sys.modules["uvicorn"] = types.ModuleType("uvicorn") if "fastapi" not in sys.modules: fastapi_mod = types.ModuleType("fastapi") class _FastAPI: def __init__(self, *a, **kw): self.state = types.SimpleNamespace() def post(self, *a, **kw): def deco(fn): return fn return deco def get(self, *a, **kw): def deco(fn): return fn return deco class _HTTPException(Exception): def __init__(self, status_code=500, detail=""): self.status_code = status_code self.detail = detail class _Request: # not actually instantiated by the routing tests pass fastapi_mod.FastAPI = _FastAPI fastapi_mod.HTTPException = _HTTPException fastapi_mod.Request = _Request sys.modules["fastapi"] = fastapi_mod responses_mod = types.ModuleType("fastapi.responses") class _StreamingResponse: def __init__(self, *a, **kw): pass responses_mod.StreamingResponse = _StreamingResponse sys.modules["fastapi.responses"] = responses_mod if "httpx" not in sys.modules: httpx_mod = types.ModuleType("httpx") class _AsyncClient: def __init__(self, *a, **kw): pass async def aclose(self): pass class _Limits: def __init__(self, *a, **kw): pass httpx_mod.AsyncClient = _AsyncClient httpx_mod.Limits = _Limits sys.modules["httpx"] = httpx_mod @pytest.fixture(scope="module") def proxy(): _install_stub_modules() spec = importlib.util.spec_from_file_location("cache_aware_proxy", PROXY_PATH) if spec is None or spec.loader is None: pytest.skip(f"cannot load proxy module at {PROXY_PATH}") mod = importlib.util.module_from_spec(spec) sys.modules["cache_aware_proxy"] = mod try: spec.loader.exec_module(mod) except ModuleNotFoundError as exc: pytest.skip(f"proxy dependency missing: {exc}") return mod def _make_inst(proxy, url: str, ongoing_tokens: int = 0, active_p_offloads: int = 0): inst = proxy.InstanceState(url) inst.ongoing_tokens = ongoing_tokens inst.active_p_offloads = active_p_offloads return inst def test_record_prefix_evicts_oldest_block(proxy): """LRU bound on cached_blocks must evict the oldest entry once full.""" inst = proxy.InstanceState("http://x") saved = proxy.SETTINGS.cache_capacity_blocks proxy.SETTINGS.cache_capacity_blocks = 2 try: block_size = proxy.BLOCK_SIZE # Three distinct one-block prefixes; first must be evicted. prefix_a = [1] * block_size prefix_b = [2] * block_size prefix_c = [3] * block_size inst.record_prefix(prefix_a) inst.record_prefix(prefix_b) inst.record_prefix(prefix_c) assert len(inst.cached_blocks) == 2 # A should have been evicted. assert inst.estimate_cache_hit(prefix_a) == 0 assert inst.estimate_cache_hit(prefix_b) == block_size assert inst.estimate_cache_hit(prefix_c) == block_size finally: proxy.SETTINGS.cache_capacity_blocks = saved def test_estimate_cache_hit_touches_lru(proxy): """A cache hit must move the block to the MRU position.""" inst = proxy.InstanceState("http://x") saved = proxy.SETTINGS.cache_capacity_blocks proxy.SETTINGS.cache_capacity_blocks = 2 try: block_size = proxy.BLOCK_SIZE a = [1] * block_size b = [2] * block_size c = [3] * block_size inst.record_prefix(a) inst.record_prefix(b) # Touch A so it becomes MRU; B is now LRU. assert inst.estimate_cache_hit(a) == block_size # Insert C: B should be evicted, A should remain. inst.record_prefix(c) assert inst.estimate_cache_hit(a) == block_size assert inst.estimate_cache_hit(b) == 0 finally: proxy.SETTINGS.cache_capacity_blocks = saved def test_pick_instance_session_affinity_sticks(proxy): insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")] affinity = {"sess1": 1} chosen, idx = proxy.pick_instance(insts, None, "sess1", 1000, affinity) assert idx == 1 and chosen is insts[1] def test_pick_instance_session_affinity_breaks_on_overload(proxy): """When the pinned instance is heavily overloaded, fallback to load-aware pick.""" insts = [ _make_inst(proxy, "http://a", ongoing_tokens=100), _make_inst(proxy, "http://b", ongoing_tokens=1_000_000), _make_inst(proxy, "http://c", ongoing_tokens=100), ] affinity = {"sess1": 1} chosen, idx = proxy.pick_instance(insts, None, "sess1", 1000, affinity) # avg ~333k; B at 1M is ~3x avg, well above OVERLOAD_FACTOR=2.0 -> fallback. assert idx != 1 assert chosen is not insts[1] def test_pick_instance_p_offload_penalty_steers_away(proxy): """Instances actively running offloaded HEAVY prefills get penalized.""" insts = [ _make_inst(proxy, "http://a", ongoing_tokens=0, active_p_offloads=2), _make_inst(proxy, "http://b", ongoing_tokens=1000), ] chosen, idx = proxy.pick_instance(insts, None, None, 5000, {}) # B's 1000-token load is much smaller than A's 2 * HEAVY_THRESHOLD penalty. assert idx == 1 and chosen is insts[1] def test_pick_instance_lmetric_picks_lowest_score(proxy): insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")] insts[0].pending_prefill_tokens = 0 insts[0].num_requests = 0 insts[1].pending_prefill_tokens = 5000 insts[1].num_requests = 4 chosen, idx = proxy.pick_instance_lmetric(insts, None, None, 1000, {}) # Empty instance has score = 1000 * 0 = 0; busy one has (5000+1000)*4. assert idx == 0 and chosen is insts[0] def test_pick_instance_lmetric_ignores_session_affinity(proxy): """Review #3: pure --policy lmetric must remain affinity-free.""" insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")] # Make inst[1] look much busier than inst[0]; LMetric must still pick 0 # even though affinity points at 1. insts[0].pending_prefill_tokens = 0 insts[0].num_requests = 0 insts[1].pending_prefill_tokens = 5000 insts[1].num_requests = 4 affinity = {"sess1": 1} chosen, idx = proxy.pick_instance_lmetric(insts, None, "sess1", 1000, affinity) assert idx == 0 # Picker must not mutate the affinity dict either. assert affinity == {"sess1": 1} def _record_n_blocks(proxy, inst, n: int) -> list[int]: """Record n distinct one-block prefixes on inst; return token_ids covering them.""" block_size = proxy.BLOCK_SIZE tokens: list[int] = [] for b in range(n): tokens.extend([1000 + b] * block_size) inst.record_prefix(tokens) return tokens def test_hybrid_high_cache_session_sticks_to_affinity(proxy): """Hybrid: affinity instance with cache_ratio > 0.5 and no overload → stick.""" insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")] tokens = _record_n_blocks(proxy, insts[1], 2) # 2 blocks cached on inst[1] affinity = {"sess1": 1} chosen, idx, decision = proxy.pick_instance_unified_hybrid( insts, tokens, "sess1", len(tokens), affinity) assert idx == 1 and chosen is insts[1] assert decision["decision"] == "affinity" assert decision["affinity_idx"] == 1 assert decision["chosen_idx"] == 1 assert decision["affinity_cache_ratio"] > 0.5 assert decision["tie_break_used"] is False def test_hybrid_high_cache_breaks_on_overload(proxy): """Hybrid: affinity num_requests > avg * overload_factor → fall back to LMetric, and with realistic new-token tail the LMetric fallback steers off the hot instance.""" insts = [ _make_inst(proxy, "http://a"), _make_inst(proxy, "http://b"), _make_inst(proxy, "http://c"), ] cached = _record_n_blocks(proxy, insts[1], 2) # Append one more uncached block so LMetric sees a real prefill cost on the # cached instance too (BS multiplier becomes visible). Without this, the # cached instance scores 0 * BS = 0 regardless of how loaded it is. tokens = cached + [999_999] * proxy.BLOCK_SIZE insts[1].num_requests = 300 # avg = 100; 300 > 100 * 2.0 ✓ breaks the gate affinity = {"sess1": 1} chosen, idx, decision = proxy.pick_instance_unified_hybrid( insts, tokens, "sess1", len(tokens), affinity) assert decision["decision"] == "lmetric_fallback" assert decision["affinity_idx"] == 1 assert idx != 1, "affinity instance is overloaded; fallback should steer away" def test_hybrid_low_cache_falls_back(proxy): """Hybrid: cache_ratio <= 0.5 on affinity → fall back to LMetric.""" insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")] tokens = [1] * (proxy.BLOCK_SIZE * 2) # 1024 tokens, nothing cached anywhere affinity = {"sess1": 1} chosen, idx, decision = proxy.pick_instance_unified_hybrid( insts, tokens, "sess1", len(tokens), affinity) assert decision["decision"] == "lmetric_fallback" assert decision["affinity_cache_ratio"] == 0.0 def test_hybrid_new_session_tie_break_does_not_always_pick_index_0(proxy): """Review #4: when all instances tie (e.g. BS=0), tie-break must rotate.""" insts = [_make_inst(proxy, "http://a") for _ in range(3)] seen = set() for _ in range(12): # No session_id, all empty → score = 0 for everyone → ties → rotate. chosen, idx, decision = proxy.pick_instance_unified_hybrid( insts, None, None, 100, {}) seen.add(idx) assert decision["decision"] == "lmetric_fallback" assert decision["tie_break_used"] is True assert seen == {0, 1, 2}, f"tie-breaker did not rotate; only saw {seen}" def test_hybrid_decision_fields_populated(proxy): """Review #7: decision dict must carry the breakdown fields.""" insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")] _, _, decision = proxy.pick_instance_unified_hybrid( insts, None, None, 100, {}) expected_keys = { "decision", "affinity_idx", "chosen_idx", "affinity_cache_hit", "affinity_cache_ratio", "affinity_num_requests", "avg_num_requests", "fallback_score", "tie_break_used", } assert expected_keys.issubset(decision.keys()) def test_settings_has_runtime_knobs(proxy): """D5/B4/M3: Settings dataclass exposes the previously-hardcoded knobs.""" s = proxy.SETTINGS for field in ( "heavy_threshold", "overload_factor", "max_offload_inflight", "cache_gate_ratio", "prefill_throughput", "rdma_overhead_s", "cache_capacity_blocks", ): assert hasattr(s, field), f"SETTINGS missing {field}" # Runtime mutability matters for tests + __main__ override. saved = s.cache_gate_ratio s.cache_gate_ratio = 0.55 assert proxy.SETTINGS.cache_gate_ratio == 0.55 s.cache_gate_ratio = saved def test_pick_instance_load_only_picks_min_num_requests(proxy): insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b"), _make_inst(proxy, "http://c")] insts[0].num_requests = 5 insts[1].num_requests = 2 insts[2].num_requests = 8 chosen, idx = proxy.pick_instance_load_only(insts, None, "sess1", 1000, {}) assert idx == 1 and chosen is insts[1] def test_pick_instance_load_only_ignores_cache_hits(proxy): insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")] block_size = proxy.BLOCK_SIZE prefix = [123] * (block_size * 4) insts[0].record_prefix(prefix) insts[0].num_requests = 10 insts[1].num_requests = 0 chosen, idx = proxy.pick_instance_load_only(insts, prefix, None, len(prefix), {}) assert idx == 1, "load_only must ignore cache hit on inst[0]" def test_pick_instance_sticky_first_turn_picks_min_load(proxy): insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")] insts[0].num_requests = 10 insts[1].num_requests = 2 affinity = {} chosen, idx = proxy.pick_instance_sticky(insts, None, "sess1", 1000, affinity) assert idx == 1 assert affinity == {"sess1": 1} def test_pick_instance_sticky_subsequent_never_breaks(proxy): """Once assigned, sticky must never re-route even under massive overload.""" insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")] affinity = {"sess1": 0} insts[0].num_requests = 1_000_000 insts[1].num_requests = 0 chosen, idx = proxy.pick_instance_sticky(insts, None, "sess1", 1000, affinity) assert idx == 0, "sticky must stay even when pinned instance is saturated" def test_p_offload_penalty_uses_settings_heavy_threshold(proxy): """M2: tweaking SETTINGS.heavy_threshold changes the P-offload penalty.""" inst = proxy.InstanceState("http://x") inst.active_p_offloads = 3 saved = proxy.SETTINGS.heavy_threshold try: proxy.SETTINGS.heavy_threshold = 10000 assert proxy._p_offload_penalty(inst) == 30000 proxy.SETTINGS.heavy_threshold = 50000 assert proxy._p_offload_penalty(inst) == 150000 finally: proxy.SETTINGS.heavy_threshold = saved