A2: proxy worker-state snapshot and request-id passthrough

Honor incoming X-Request-Id so replayer metrics and proxy breakdown
share a join key. Each route decision now captures session_id, the
full per-worker candidate-score snapshot (ongoing/pending/num_requests
/cached_blocks plus both linear and lmetric scores), the chosen score,
and unix timestamps for first-token and done events. A separate
_worker_state_log records one row per decision and is exposed via
GET /worker_state; GET /worker_state/latest returns a live snapshot
without recording it.

Required by Batch 3 (session hot-spot proof) and Batch 5 (failure
attribution); existing breakdown.json had no per-worker state at
decision time.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-25 16:19:01 +08:00
parent d57e338366
commit fe556b5d98
2 changed files with 267 additions and 7 deletions

View File

@@ -107,6 +107,42 @@ def _p_offload_penalty(inst: InstanceState) -> int:
return inst.active_p_offloads * SETTINGS.heavy_threshold return inst.active_p_offloads * SETTINGS.heavy_threshold
def snapshot_workers(
instances: list[InstanceState],
token_ids: list[int] | None = None,
input_length: int = 0,
) -> list[dict]:
"""Per-worker state at route-decision time.
All routing-relevant counters plus the score each policy would
have produced for `input_length` if it were dispatched now. Cheap
enough to call on every request; B3 hot-spot analysis depends on
this being captured per decision.
"""
snap: list[dict] = []
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids) if token_ids else 0
new_prefill = max(0, input_length - cache_hit)
snap.append({
"idx": i,
"url": inst.url,
"ongoing_tokens": inst.ongoing_tokens,
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
"pending_prefill_tokens": inst.pending_prefill_tokens,
"num_requests": inst.num_requests,
"active_p_offloads": inst.active_p_offloads,
"cached_blocks": len(inst.cached_blocks),
"cache_hit": cache_hit,
"new_prefill": new_prefill,
"score_linear": (inst.ongoing_tokens
+ _p_offload_penalty(inst)
- CACHE_HIT_ALPHA * cache_hit),
"score_lmetric": (inst.pending_prefill_tokens + new_prefill)
* inst.num_requests,
})
return snap
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None, def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int, session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]: affinity: dict[str, int]) -> tuple[InstanceState, int]:
@@ -308,6 +344,7 @@ session_affinity_prefill: dict[str, int] = {}
session_affinity = session_affinity_combined session_affinity = session_affinity_combined
is_pd_sep = False is_pd_sep = False
_breakdown_log: list[dict] = [] _breakdown_log: list[dict] = []
_worker_state_log: list[dict] = []
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event): async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
@@ -445,7 +482,8 @@ async def _handle(request: Request, api: str):
raise HTTPException(status_code=503, detail="Service Unavailable") raise HTTPException(status_code=503, detail="Service Unavailable")
req_data = await request.json() req_data = await request.json()
request_id = str(uuid.uuid4()) incoming_rid = request.headers.get("X-Request-Id")
request_id = incoming_rid or str(uuid.uuid4())
prompt = req_data.get("prompt") prompt = req_data.get("prompt")
token_ids = prompt if isinstance(prompt, list) else None token_ids = prompt if isinstance(prompt, list) else None
input_length = len(token_ids) if token_ids else 0 input_length = len(token_ids) if token_ids else 0
@@ -490,6 +528,7 @@ async def _handle_local_request(api, req_data, headers, token_ids, input_length,
chosen.pending_prefill_tokens -= estimated_new chosen.pending_prefill_tokens -= estimated_new
chosen.ongoing_decode_tokens += input_length chosen.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic() breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token_unix"] = _time.time()
prefill_done = True prefill_done = True
yield chunk yield chunk
chosen.record_prefix( chosen.record_prefix(
@@ -507,6 +546,7 @@ async def _handle_local_request(api, req_data, headers, token_ids, input_length,
chosen.ongoing_tokens -= input_length chosen.ongoing_tokens -= input_length
chosen.num_requests -= 1 chosen.num_requests -= 1
breakdown["t_done"] = _time.monotonic() breakdown["t_done"] = _time.monotonic()
breakdown["t_done_unix"] = _time.time()
_breakdown_log.append(breakdown) _breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream") return StreamingResponse(generate(), media_type="text/event-stream")
@@ -527,13 +567,20 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
both regressed E2E tail). Re-enabling requires a new transfer mechanism. both regressed E2E tail). Re-enabling requires a new transfer mechanism.
""" """
policy = getattr(global_args, 'policy', 'linear') policy = getattr(global_args, 'policy', 'linear')
t_decision_unix = _time.time()
request_id = headers.get("X-Request-Id", "")
breakdown: dict = { breakdown: dict = {
"request_id": headers.get("X-Request-Id", ""), "request_id": request_id,
"session_id": session_id,
"input_length": input_length, "input_length": input_length,
"t_proxy_recv": _time.monotonic(), "t_proxy_recv": _time.monotonic(),
"t_decision_unix": t_decision_unix,
"policy": policy, "policy": policy,
} }
pre_decision_workers = snapshot_workers(
combined_instances, token_ids, input_length)
if policy == "lmetric": if policy == "lmetric":
chosen, best_idx = pick_instance_lmetric( chosen, best_idx = pick_instance_lmetric(
combined_instances, token_ids, session_id, input_length, combined_instances, token_ids, session_id, input_length,
@@ -550,14 +597,29 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
combined_instances, token_ids, session_id, input_length, combined_instances, token_ids, session_id, input_length,
session_affinity_combined) session_affinity_combined)
cache_hit = chosen.estimate_cache_hit(token_ids) chosen_snap = pre_decision_workers[best_idx]
estimated_new = max(0, input_length - cache_hit) cache_hit = chosen_snap["cache_hit"]
estimated_new = chosen_snap["new_prefill"]
breakdown.update({ breakdown.update({
"cache_hit": cache_hit, "cache_hit": cache_hit,
"estimated_new_tokens": estimated_new, "estimated_new_tokens": estimated_new,
"route_class": "LOCAL", "route_class": "LOCAL",
"routed_to": chosen.url, "routed_to": chosen.url,
"chosen_idx": best_idx,
"candidate_scores": pre_decision_workers,
"chosen_score_linear": chosen_snap["score_linear"],
"chosen_score_lmetric": chosen_snap["score_lmetric"],
}) })
_worker_state_log.append({
"t_decision_unix": t_decision_unix,
"request_id": request_id,
"session_id": session_id,
"policy": policy,
"chosen_idx": best_idx,
"workers": pre_decision_workers,
})
return await _handle_local_request( return await _handle_local_request(
api, req_data, headers, token_ids, input_length, api, req_data, headers, token_ids, input_length,
chosen, estimated_new, breakdown) chosen, estimated_new, breakdown)
@@ -566,17 +628,41 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length, async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers): session_id, headers):
"""PD-Sep mode with per-stage breakdown profiling.""" """PD-Sep mode with per-stage breakdown profiling."""
t_decision_unix = _time.time()
breakdown = { breakdown = {
"request_id": request_id, "request_id": request_id,
"session_id": session_id,
"input_length": input_length, "input_length": input_length,
"t_proxy_recv": _time.monotonic(), "t_proxy_recv": _time.monotonic(),
"t_decision_unix": t_decision_unix,
"policy": "pd_sep",
} }
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id, pre_decision_p = snapshot_workers(prefill_instances, token_ids, input_length)
input_length, session_affinity_prefill) pre_decision_d = snapshot_workers(decode_instances, token_ids, input_length)
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
p_inst, p_idx = pick_instance(prefill_instances, token_ids, session_id,
input_length, session_affinity_prefill)
d_idx = min(range(len(decode_instances)),
key=lambda i: decode_instances[i].ongoing_tokens)
d_inst = decode_instances[d_idx]
breakdown["p_inst"] = p_inst.url breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url breakdown["d_inst"] = d_inst.url
breakdown["candidate_scores_prefill"] = pre_decision_p
breakdown["candidate_scores_decode"] = pre_decision_d
breakdown["chosen_p_idx"] = p_idx
breakdown["chosen_d_idx"] = d_idx
_worker_state_log.append({
"t_decision_unix": t_decision_unix,
"request_id": request_id,
"session_id": session_id,
"policy": "pd_sep",
"chosen_p_idx": p_idx,
"chosen_d_idx": d_idx,
"workers_prefill": pre_decision_p,
"workers_decode": pre_decision_d,
})
prefill_data = req_data.copy() prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = { prefill_data["kv_transfer_params"] = {
@@ -592,15 +678,18 @@ async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
p_inst.ongoing_tokens += input_length p_inst.ongoing_tokens += input_length
breakdown["t_prefill_sent"] = _time.monotonic() breakdown["t_prefill_sent"] = _time.monotonic()
breakdown["t_prefill_sent_unix"] = _time.time()
try: try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers) resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic() breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
resp.raise_for_status() resp.raise_for_status()
await resp.aclose() await resp.aclose()
p_inst.record_prefix(token_ids) p_inst.record_prefix(token_ids)
except Exception as e: except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic() breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
breakdown["prefill_error"] = True breakdown["prefill_error"] = True
_breakdown_log.append(breakdown) _breakdown_log.append(breakdown)
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}") raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
@@ -621,6 +710,7 @@ async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
} }
breakdown["t_decode_sent"] = _time.monotonic() breakdown["t_decode_sent"] = _time.monotonic()
breakdown["t_decode_sent_unix"] = _time.time()
async def generate(): async def generate():
first_token = True first_token = True
@@ -635,11 +725,13 @@ async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
output_token_ids.extend(new_output_ids) output_token_ids.extend(new_output_ids)
if first_token: if first_token:
breakdown["t_first_token"] = _time.monotonic() breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token_unix"] = _time.time()
first_token = False first_token = False
yield chunk yield chunk
d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids)) d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids))
finally: finally:
breakdown["t_done"] = _time.monotonic() breakdown["t_done"] = _time.monotonic()
breakdown["t_done_unix"] = _time.time()
d_inst.ongoing_tokens -= input_length d_inst.ongoing_tokens -= input_length
_breakdown_log.append(breakdown) _breakdown_log.append(breakdown)
@@ -652,6 +744,29 @@ async def get_breakdown():
return _breakdown_log return _breakdown_log
@app.get("/worker_state")
async def get_worker_state():
"""Return per-decision worker-state snapshot log (one entry per route decision)."""
return _worker_state_log
@app.get("/worker_state/latest")
async def get_worker_state_latest():
"""Return current per-worker state snapshot without recording it."""
if combined_instances:
return {
"t_unix": _time.time(),
"mode": "combined",
"workers": snapshot_workers(combined_instances),
}
return {
"t_unix": _time.time(),
"mode": "pd_sep",
"workers_prefill": snapshot_workers(prefill_instances),
"workers_decode": snapshot_workers(decode_instances),
}
@app.get("/stats") @app.get("/stats")
async def get_stats(): async def get_stats():
"""Return per-instance live state for debugging.""" """Return per-instance live state for debugging."""

View File

@@ -0,0 +1,145 @@
"""Tests for A2 proxy instrumentation: worker snapshot + request_id passthrough."""
from __future__ import annotations
import importlib.util
import sys
import types
from pathlib import Path
import pytest
PROXY_PATH = Path(__file__).resolve().parent.parent / "scripts" / "cache_aware_proxy.py"
def _install_stub_modules() -> None:
if "uvicorn" not in sys.modules:
sys.modules["uvicorn"] = types.ModuleType("uvicorn")
if "fastapi" not in sys.modules:
fastapi_mod = types.ModuleType("fastapi")
class _FastAPI:
def __init__(self, *a, **kw):
self.state = types.SimpleNamespace()
def post(self, *a, **kw):
def deco(fn): return fn
return deco
def get(self, *a, **kw):
def deco(fn): return fn
return deco
class _HTTPException(Exception):
def __init__(self, status_code=500, detail=""):
self.status_code = status_code
self.detail = detail
class _Request:
pass
fastapi_mod.FastAPI = _FastAPI
fastapi_mod.HTTPException = _HTTPException
fastapi_mod.Request = _Request
sys.modules["fastapi"] = fastapi_mod
responses_mod = types.ModuleType("fastapi.responses")
class _StreamingResponse:
def __init__(self, *a, **kw): pass
responses_mod.StreamingResponse = _StreamingResponse
sys.modules["fastapi.responses"] = responses_mod
if "httpx" not in sys.modules:
httpx_mod = types.ModuleType("httpx")
class _AsyncClient:
def __init__(self, *a, **kw): pass
async def aclose(self): pass
class _Limits:
def __init__(self, *a, **kw): pass
httpx_mod.AsyncClient = _AsyncClient
httpx_mod.Limits = _Limits
sys.modules["httpx"] = httpx_mod
@pytest.fixture(scope="module")
def proxy():
_install_stub_modules()
spec = importlib.util.spec_from_file_location("cache_aware_proxy", PROXY_PATH)
if spec is None or spec.loader is None:
pytest.skip(f"cannot load proxy module at {PROXY_PATH}")
mod = importlib.util.module_from_spec(spec)
sys.modules["cache_aware_proxy"] = mod
try:
spec.loader.exec_module(mod)
except ModuleNotFoundError as exc:
pytest.skip(f"proxy dependency missing: {exc}")
return mod
def _make_inst(proxy, url, **kw):
inst = proxy.InstanceState(url)
for k, v in kw.items():
setattr(inst, k, v)
return inst
def test_snapshot_workers_includes_all_required_fields(proxy):
insts = [
_make_inst(proxy, "http://a", ongoing_tokens=100, num_requests=2,
pending_prefill_tokens=500),
_make_inst(proxy, "http://b", ongoing_tokens=2000, num_requests=10,
pending_prefill_tokens=8000),
]
snap = proxy.snapshot_workers(insts, None, 1000)
assert len(snap) == 2
required = {
"idx", "url", "ongoing_tokens", "ongoing_decode_tokens",
"pending_prefill_tokens", "num_requests", "active_p_offloads",
"cached_blocks", "cache_hit", "new_prefill",
"score_linear", "score_lmetric",
}
for entry in snap:
assert required.issubset(entry.keys()), f"missing fields in {entry}"
assert snap[0]["url"] == "http://a"
assert snap[1]["url"] == "http://b"
def test_snapshot_workers_lmetric_score_reflects_p_tokens_times_bs(proxy):
insts = [
_make_inst(proxy, "http://a", pending_prefill_tokens=0, num_requests=0),
_make_inst(proxy, "http://b", pending_prefill_tokens=4000, num_requests=5),
]
snap = proxy.snapshot_workers(insts, None, 1000)
assert snap[0]["score_lmetric"] == 0
assert snap[1]["score_lmetric"] == (4000 + 1000) * 5
def test_snapshot_workers_cache_hit_propagates(proxy):
"""When token_ids carry a cached prefix, snapshot must record the hit."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
block_size = proxy.BLOCK_SIZE
prefix = [42] * block_size * 2
insts[1].record_prefix(prefix)
snap = proxy.snapshot_workers(insts, prefix, len(prefix))
assert snap[0]["cache_hit"] == 0
assert snap[1]["cache_hit"] == block_size * 2
assert snap[0]["new_prefill"] == len(prefix)
assert snap[1]["new_prefill"] == 0
def test_worker_state_log_is_initially_empty_and_appendable(proxy):
"""The proxy module exposes a global _worker_state_log list."""
assert hasattr(proxy, "_worker_state_log")
assert isinstance(proxy._worker_state_log, list)
snapshot_count_before = len(proxy._worker_state_log)
proxy._worker_state_log.append({"sentinel": True})
try:
assert len(proxy._worker_state_log) == snapshot_count_before + 1
finally:
proxy._worker_state_log.pop()