A2: proxy worker-state snapshot and request-id passthrough

Honor incoming X-Request-Id so replayer metrics and proxy breakdown
share a join key. Each route decision now captures session_id, the
full per-worker candidate-score snapshot (ongoing/pending/num_requests
/cached_blocks plus both linear and lmetric scores), the chosen score,
and unix timestamps for first-token and done events. A separate
_worker_state_log records one row per decision and is exposed via
GET /worker_state; GET /worker_state/latest returns a live snapshot
without recording it.

Required by Batch 3 (session hot-spot proof) and Batch 5 (failure
attribution); existing breakdown.json had no per-worker state at
decision time.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-25 16:19:01 +08:00
parent d57e338366
commit fe556b5d98
2 changed files with 267 additions and 7 deletions

View File

@@ -107,6 +107,42 @@ def _p_offload_penalty(inst: InstanceState) -> int:
return inst.active_p_offloads * SETTINGS.heavy_threshold
def snapshot_workers(
instances: list[InstanceState],
token_ids: list[int] | None = None,
input_length: int = 0,
) -> list[dict]:
"""Per-worker state at route-decision time.
All routing-relevant counters plus the score each policy would
have produced for `input_length` if it were dispatched now. Cheap
enough to call on every request; B3 hot-spot analysis depends on
this being captured per decision.
"""
snap: list[dict] = []
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids) if token_ids else 0
new_prefill = max(0, input_length - cache_hit)
snap.append({
"idx": i,
"url": inst.url,
"ongoing_tokens": inst.ongoing_tokens,
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
"pending_prefill_tokens": inst.pending_prefill_tokens,
"num_requests": inst.num_requests,
"active_p_offloads": inst.active_p_offloads,
"cached_blocks": len(inst.cached_blocks),
"cache_hit": cache_hit,
"new_prefill": new_prefill,
"score_linear": (inst.ongoing_tokens
+ _p_offload_penalty(inst)
- CACHE_HIT_ALPHA * cache_hit),
"score_lmetric": (inst.pending_prefill_tokens + new_prefill)
* inst.num_requests,
})
return snap
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
@@ -308,6 +344,7 @@ session_affinity_prefill: dict[str, int] = {}
session_affinity = session_affinity_combined
is_pd_sep = False
_breakdown_log: list[dict] = []
_worker_state_log: list[dict] = []
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
@@ -445,7 +482,8 @@ async def _handle(request: Request, api: str):
raise HTTPException(status_code=503, detail="Service Unavailable")
req_data = await request.json()
request_id = str(uuid.uuid4())
incoming_rid = request.headers.get("X-Request-Id")
request_id = incoming_rid or str(uuid.uuid4())
prompt = req_data.get("prompt")
token_ids = prompt if isinstance(prompt, list) else None
input_length = len(token_ids) if token_ids else 0
@@ -490,6 +528,7 @@ async def _handle_local_request(api, req_data, headers, token_ids, input_length,
chosen.pending_prefill_tokens -= estimated_new
chosen.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token_unix"] = _time.time()
prefill_done = True
yield chunk
chosen.record_prefix(
@@ -507,6 +546,7 @@ async def _handle_local_request(api, req_data, headers, token_ids, input_length,
chosen.ongoing_tokens -= input_length
chosen.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
breakdown["t_done_unix"] = _time.time()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
@@ -527,13 +567,20 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
both regressed E2E tail). Re-enabling requires a new transfer mechanism.
"""
policy = getattr(global_args, 'policy', 'linear')
t_decision_unix = _time.time()
request_id = headers.get("X-Request-Id", "")
breakdown: dict = {
"request_id": headers.get("X-Request-Id", ""),
"request_id": request_id,
"session_id": session_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
"t_decision_unix": t_decision_unix,
"policy": policy,
}
pre_decision_workers = snapshot_workers(
combined_instances, token_ids, input_length)
if policy == "lmetric":
chosen, best_idx = pick_instance_lmetric(
combined_instances, token_ids, session_id, input_length,
@@ -550,14 +597,29 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
cache_hit = chosen.estimate_cache_hit(token_ids)
estimated_new = max(0, input_length - cache_hit)
chosen_snap = pre_decision_workers[best_idx]
cache_hit = chosen_snap["cache_hit"]
estimated_new = chosen_snap["new_prefill"]
breakdown.update({
"cache_hit": cache_hit,
"estimated_new_tokens": estimated_new,
"route_class": "LOCAL",
"routed_to": chosen.url,
"chosen_idx": best_idx,
"candidate_scores": pre_decision_workers,
"chosen_score_linear": chosen_snap["score_linear"],
"chosen_score_lmetric": chosen_snap["score_lmetric"],
})
_worker_state_log.append({
"t_decision_unix": t_decision_unix,
"request_id": request_id,
"session_id": session_id,
"policy": policy,
"chosen_idx": best_idx,
"workers": pre_decision_workers,
})
return await _handle_local_request(
api, req_data, headers, token_ids, input_length,
chosen, estimated_new, breakdown)
@@ -566,17 +628,41 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers):
"""PD-Sep mode with per-stage breakdown profiling."""
t_decision_unix = _time.time()
breakdown = {
"request_id": request_id,
"session_id": session_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
"t_decision_unix": t_decision_unix,
"policy": "pd_sep",
}
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
input_length, session_affinity_prefill)
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
pre_decision_p = snapshot_workers(prefill_instances, token_ids, input_length)
pre_decision_d = snapshot_workers(decode_instances, token_ids, input_length)
p_inst, p_idx = pick_instance(prefill_instances, token_ids, session_id,
input_length, session_affinity_prefill)
d_idx = min(range(len(decode_instances)),
key=lambda i: decode_instances[i].ongoing_tokens)
d_inst = decode_instances[d_idx]
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["candidate_scores_prefill"] = pre_decision_p
breakdown["candidate_scores_decode"] = pre_decision_d
breakdown["chosen_p_idx"] = p_idx
breakdown["chosen_d_idx"] = d_idx
_worker_state_log.append({
"t_decision_unix": t_decision_unix,
"request_id": request_id,
"session_id": session_id,
"policy": "pd_sep",
"chosen_p_idx": p_idx,
"chosen_d_idx": d_idx,
"workers_prefill": pre_decision_p,
"workers_decode": pre_decision_d,
})
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
@@ -592,15 +678,18 @@ async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
p_inst.ongoing_tokens += input_length
breakdown["t_prefill_sent"] = _time.monotonic()
breakdown["t_prefill_sent_unix"] = _time.time()
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
@@ -621,6 +710,7 @@ async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
}
breakdown["t_decode_sent"] = _time.monotonic()
breakdown["t_decode_sent_unix"] = _time.time()
async def generate():
first_token = True
@@ -635,11 +725,13 @@ async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
output_token_ids.extend(new_output_ids)
if first_token:
breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token_unix"] = _time.time()
first_token = False
yield chunk
d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids))
finally:
breakdown["t_done"] = _time.monotonic()
breakdown["t_done_unix"] = _time.time()
d_inst.ongoing_tokens -= input_length
_breakdown_log.append(breakdown)
@@ -652,6 +744,29 @@ async def get_breakdown():
return _breakdown_log
@app.get("/worker_state")
async def get_worker_state():
"""Return per-decision worker-state snapshot log (one entry per route decision)."""
return _worker_state_log
@app.get("/worker_state/latest")
async def get_worker_state_latest():
"""Return current per-worker state snapshot without recording it."""
if combined_instances:
return {
"t_unix": _time.time(),
"mode": "combined",
"workers": snapshot_workers(combined_instances),
}
return {
"t_unix": _time.time(),
"mode": "pd_sep",
"workers_prefill": snapshot_workers(prefill_instances),
"workers_decode": snapshot_workers(decode_instances),
}
@app.get("/stats")
async def get_stats():
"""Return per-instance live state for debugging."""