Honor incoming X-Request-Id so replayer metrics and proxy breakdown share a join key. Each route decision now captures session_id, the full per-worker candidate-score snapshot (ongoing/pending/num_requests /cached_blocks plus both linear and lmetric scores), the chosen score, and unix timestamps for first-token and done events. A separate _worker_state_log records one row per decision and is exposed via GET /worker_state; GET /worker_state/latest returns a live snapshot without recording it. Required by Batch 3 (session hot-spot proof) and Batch 5 (failure attribution); existing breakdown.json had no per-worker state at decision time. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
146 lines
4.7 KiB
Python
146 lines
4.7 KiB
Python
"""Tests for A2 proxy instrumentation: worker snapshot + request_id passthrough."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
import sys
|
|
import types
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
PROXY_PATH = Path(__file__).resolve().parent.parent / "scripts" / "cache_aware_proxy.py"
|
|
|
|
|
|
def _install_stub_modules() -> None:
|
|
if "uvicorn" not in sys.modules:
|
|
sys.modules["uvicorn"] = types.ModuleType("uvicorn")
|
|
|
|
if "fastapi" not in sys.modules:
|
|
fastapi_mod = types.ModuleType("fastapi")
|
|
|
|
class _FastAPI:
|
|
def __init__(self, *a, **kw):
|
|
self.state = types.SimpleNamespace()
|
|
|
|
def post(self, *a, **kw):
|
|
def deco(fn): return fn
|
|
return deco
|
|
|
|
def get(self, *a, **kw):
|
|
def deco(fn): return fn
|
|
return deco
|
|
|
|
class _HTTPException(Exception):
|
|
def __init__(self, status_code=500, detail=""):
|
|
self.status_code = status_code
|
|
self.detail = detail
|
|
|
|
class _Request:
|
|
pass
|
|
|
|
fastapi_mod.FastAPI = _FastAPI
|
|
fastapi_mod.HTTPException = _HTTPException
|
|
fastapi_mod.Request = _Request
|
|
sys.modules["fastapi"] = fastapi_mod
|
|
|
|
responses_mod = types.ModuleType("fastapi.responses")
|
|
|
|
class _StreamingResponse:
|
|
def __init__(self, *a, **kw): pass
|
|
|
|
responses_mod.StreamingResponse = _StreamingResponse
|
|
sys.modules["fastapi.responses"] = responses_mod
|
|
|
|
if "httpx" not in sys.modules:
|
|
httpx_mod = types.ModuleType("httpx")
|
|
|
|
class _AsyncClient:
|
|
def __init__(self, *a, **kw): pass
|
|
async def aclose(self): pass
|
|
|
|
class _Limits:
|
|
def __init__(self, *a, **kw): pass
|
|
|
|
httpx_mod.AsyncClient = _AsyncClient
|
|
httpx_mod.Limits = _Limits
|
|
sys.modules["httpx"] = httpx_mod
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def proxy():
|
|
_install_stub_modules()
|
|
spec = importlib.util.spec_from_file_location("cache_aware_proxy", PROXY_PATH)
|
|
if spec is None or spec.loader is None:
|
|
pytest.skip(f"cannot load proxy module at {PROXY_PATH}")
|
|
mod = importlib.util.module_from_spec(spec)
|
|
sys.modules["cache_aware_proxy"] = mod
|
|
try:
|
|
spec.loader.exec_module(mod)
|
|
except ModuleNotFoundError as exc:
|
|
pytest.skip(f"proxy dependency missing: {exc}")
|
|
return mod
|
|
|
|
|
|
def _make_inst(proxy, url, **kw):
|
|
inst = proxy.InstanceState(url)
|
|
for k, v in kw.items():
|
|
setattr(inst, k, v)
|
|
return inst
|
|
|
|
|
|
def test_snapshot_workers_includes_all_required_fields(proxy):
|
|
insts = [
|
|
_make_inst(proxy, "http://a", ongoing_tokens=100, num_requests=2,
|
|
pending_prefill_tokens=500),
|
|
_make_inst(proxy, "http://b", ongoing_tokens=2000, num_requests=10,
|
|
pending_prefill_tokens=8000),
|
|
]
|
|
snap = proxy.snapshot_workers(insts, None, 1000)
|
|
assert len(snap) == 2
|
|
required = {
|
|
"idx", "url", "ongoing_tokens", "ongoing_decode_tokens",
|
|
"pending_prefill_tokens", "num_requests", "active_p_offloads",
|
|
"cached_blocks", "cache_hit", "new_prefill",
|
|
"score_linear", "score_lmetric",
|
|
}
|
|
for entry in snap:
|
|
assert required.issubset(entry.keys()), f"missing fields in {entry}"
|
|
assert snap[0]["url"] == "http://a"
|
|
assert snap[1]["url"] == "http://b"
|
|
|
|
|
|
def test_snapshot_workers_lmetric_score_reflects_p_tokens_times_bs(proxy):
|
|
insts = [
|
|
_make_inst(proxy, "http://a", pending_prefill_tokens=0, num_requests=0),
|
|
_make_inst(proxy, "http://b", pending_prefill_tokens=4000, num_requests=5),
|
|
]
|
|
snap = proxy.snapshot_workers(insts, None, 1000)
|
|
assert snap[0]["score_lmetric"] == 0
|
|
assert snap[1]["score_lmetric"] == (4000 + 1000) * 5
|
|
|
|
|
|
def test_snapshot_workers_cache_hit_propagates(proxy):
|
|
"""When token_ids carry a cached prefix, snapshot must record the hit."""
|
|
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
|
|
block_size = proxy.BLOCK_SIZE
|
|
prefix = [42] * block_size * 2
|
|
insts[1].record_prefix(prefix)
|
|
snap = proxy.snapshot_workers(insts, prefix, len(prefix))
|
|
assert snap[0]["cache_hit"] == 0
|
|
assert snap[1]["cache_hit"] == block_size * 2
|
|
assert snap[0]["new_prefill"] == len(prefix)
|
|
assert snap[1]["new_prefill"] == 0
|
|
|
|
|
|
def test_worker_state_log_is_initially_empty_and_appendable(proxy):
|
|
"""The proxy module exposes a global _worker_state_log list."""
|
|
assert hasattr(proxy, "_worker_state_log")
|
|
assert isinstance(proxy._worker_state_log, list)
|
|
snapshot_count_before = len(proxy._worker_state_log)
|
|
proxy._worker_state_log.append({"sentinel": True})
|
|
try:
|
|
assert len(proxy._worker_state_log) == snapshot_count_before + 1
|
|
finally:
|
|
proxy._worker_state_log.pop()
|