"""Tests for A2 proxy instrumentation: worker snapshot + request_id passthrough.""" from __future__ import annotations import importlib.util import sys import types from pathlib import Path import pytest PROXY_PATH = Path(__file__).resolve().parent.parent / "scripts" / "cache_aware_proxy.py" def _install_stub_modules() -> None: if "uvicorn" not in sys.modules: sys.modules["uvicorn"] = types.ModuleType("uvicorn") if "fastapi" not in sys.modules: fastapi_mod = types.ModuleType("fastapi") class _FastAPI: def __init__(self, *a, **kw): self.state = types.SimpleNamespace() def post(self, *a, **kw): def deco(fn): return fn return deco def get(self, *a, **kw): def deco(fn): return fn return deco class _HTTPException(Exception): def __init__(self, status_code=500, detail=""): self.status_code = status_code self.detail = detail class _Request: pass fastapi_mod.FastAPI = _FastAPI fastapi_mod.HTTPException = _HTTPException fastapi_mod.Request = _Request sys.modules["fastapi"] = fastapi_mod responses_mod = types.ModuleType("fastapi.responses") class _StreamingResponse: def __init__(self, *a, **kw): pass responses_mod.StreamingResponse = _StreamingResponse sys.modules["fastapi.responses"] = responses_mod if "httpx" not in sys.modules: httpx_mod = types.ModuleType("httpx") class _AsyncClient: def __init__(self, *a, **kw): pass async def aclose(self): pass class _Limits: def __init__(self, *a, **kw): pass httpx_mod.AsyncClient = _AsyncClient httpx_mod.Limits = _Limits sys.modules["httpx"] = httpx_mod @pytest.fixture(scope="module") def proxy(): _install_stub_modules() spec = importlib.util.spec_from_file_location("cache_aware_proxy", PROXY_PATH) if spec is None or spec.loader is None: pytest.skip(f"cannot load proxy module at {PROXY_PATH}") mod = importlib.util.module_from_spec(spec) sys.modules["cache_aware_proxy"] = mod try: spec.loader.exec_module(mod) except ModuleNotFoundError as exc: pytest.skip(f"proxy dependency missing: {exc}") return mod def _make_inst(proxy, url, **kw): inst = proxy.InstanceState(url) for k, v in kw.items(): setattr(inst, k, v) return inst def test_snapshot_workers_includes_all_required_fields(proxy): insts = [ _make_inst(proxy, "http://a", ongoing_tokens=100, num_requests=2, pending_prefill_tokens=500), _make_inst(proxy, "http://b", ongoing_tokens=2000, num_requests=10, pending_prefill_tokens=8000), ] snap = proxy.snapshot_workers(insts, None, 1000) assert len(snap) == 2 required = { "idx", "url", "ongoing_tokens", "ongoing_decode_tokens", "pending_prefill_tokens", "num_requests", "active_p_offloads", "cached_blocks", "cache_hit", "new_prefill", "score_linear", "score_lmetric", } for entry in snap: assert required.issubset(entry.keys()), f"missing fields in {entry}" assert snap[0]["url"] == "http://a" assert snap[1]["url"] == "http://b" def test_snapshot_workers_lmetric_score_reflects_p_tokens_times_bs(proxy): insts = [ _make_inst(proxy, "http://a", pending_prefill_tokens=0, num_requests=0), _make_inst(proxy, "http://b", pending_prefill_tokens=4000, num_requests=5), ] snap = proxy.snapshot_workers(insts, None, 1000) assert snap[0]["score_lmetric"] == 0 assert snap[1]["score_lmetric"] == (4000 + 1000) * 5 def test_snapshot_workers_cache_hit_propagates(proxy): """When token_ids carry a cached prefix, snapshot must record the hit.""" insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")] block_size = proxy.BLOCK_SIZE prefix = [42] * block_size * 2 insts[1].record_prefix(prefix) snap = proxy.snapshot_workers(insts, prefix, len(prefix)) assert snap[0]["cache_hit"] == 0 assert snap[1]["cache_hit"] == block_size * 2 assert snap[0]["new_prefill"] == len(prefix) assert snap[1]["new_prefill"] == 0 def test_worker_state_log_is_initially_empty_and_appendable(proxy): """The proxy module exposes a global _worker_state_log list.""" assert hasattr(proxy, "_worker_state_log") assert isinstance(proxy._worker_state_log, list) snapshot_count_before = len(proxy._worker_state_log) proxy._worker_state_log.append({"sentinel": True}) try: assert len(proxy._worker_state_log) == snapshot_count_before + 1 finally: proxy._worker_state_log.pop()