A2: proxy worker-state snapshot and request-id passthrough
Honor incoming X-Request-Id so replayer metrics and proxy breakdown share a join key. Each route decision now captures session_id, the full per-worker candidate-score snapshot (ongoing/pending/num_requests /cached_blocks plus both linear and lmetric scores), the chosen score, and unix timestamps for first-token and done events. A separate _worker_state_log records one row per decision and is exposed via GET /worker_state; GET /worker_state/latest returns a live snapshot without recording it. Required by Batch 3 (session hot-spot proof) and Batch 5 (failure attribution); existing breakdown.json had no per-worker state at decision time. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -107,6 +107,42 @@ def _p_offload_penalty(inst: InstanceState) -> int:
|
||||
return inst.active_p_offloads * SETTINGS.heavy_threshold
|
||||
|
||||
|
||||
def snapshot_workers(
|
||||
instances: list[InstanceState],
|
||||
token_ids: list[int] | None = None,
|
||||
input_length: int = 0,
|
||||
) -> list[dict]:
|
||||
"""Per-worker state at route-decision time.
|
||||
|
||||
All routing-relevant counters plus the score each policy would
|
||||
have produced for `input_length` if it were dispatched now. Cheap
|
||||
enough to call on every request; B3 hot-spot analysis depends on
|
||||
this being captured per decision.
|
||||
"""
|
||||
snap: list[dict] = []
|
||||
for i, inst in enumerate(instances):
|
||||
cache_hit = inst.estimate_cache_hit(token_ids) if token_ids else 0
|
||||
new_prefill = max(0, input_length - cache_hit)
|
||||
snap.append({
|
||||
"idx": i,
|
||||
"url": inst.url,
|
||||
"ongoing_tokens": inst.ongoing_tokens,
|
||||
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
|
||||
"pending_prefill_tokens": inst.pending_prefill_tokens,
|
||||
"num_requests": inst.num_requests,
|
||||
"active_p_offloads": inst.active_p_offloads,
|
||||
"cached_blocks": len(inst.cached_blocks),
|
||||
"cache_hit": cache_hit,
|
||||
"new_prefill": new_prefill,
|
||||
"score_linear": (inst.ongoing_tokens
|
||||
+ _p_offload_penalty(inst)
|
||||
- CACHE_HIT_ALPHA * cache_hit),
|
||||
"score_lmetric": (inst.pending_prefill_tokens + new_prefill)
|
||||
* inst.num_requests,
|
||||
})
|
||||
return snap
|
||||
|
||||
|
||||
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
|
||||
session_id: str | None, input_length: int,
|
||||
affinity: dict[str, int]) -> tuple[InstanceState, int]:
|
||||
@@ -308,6 +344,7 @@ session_affinity_prefill: dict[str, int] = {}
|
||||
session_affinity = session_affinity_combined
|
||||
is_pd_sep = False
|
||||
_breakdown_log: list[dict] = []
|
||||
_worker_state_log: list[dict] = []
|
||||
|
||||
|
||||
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
|
||||
@@ -445,7 +482,8 @@ async def _handle(request: Request, api: str):
|
||||
raise HTTPException(status_code=503, detail="Service Unavailable")
|
||||
|
||||
req_data = await request.json()
|
||||
request_id = str(uuid.uuid4())
|
||||
incoming_rid = request.headers.get("X-Request-Id")
|
||||
request_id = incoming_rid or str(uuid.uuid4())
|
||||
prompt = req_data.get("prompt")
|
||||
token_ids = prompt if isinstance(prompt, list) else None
|
||||
input_length = len(token_ids) if token_ids else 0
|
||||
@@ -490,6 +528,7 @@ async def _handle_local_request(api, req_data, headers, token_ids, input_length,
|
||||
chosen.pending_prefill_tokens -= estimated_new
|
||||
chosen.ongoing_decode_tokens += input_length
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
breakdown["t_first_token_unix"] = _time.time()
|
||||
prefill_done = True
|
||||
yield chunk
|
||||
chosen.record_prefix(
|
||||
@@ -507,6 +546,7 @@ async def _handle_local_request(api, req_data, headers, token_ids, input_length,
|
||||
chosen.ongoing_tokens -= input_length
|
||||
chosen.num_requests -= 1
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
breakdown["t_done_unix"] = _time.time()
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
@@ -527,13 +567,20 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
both regressed E2E tail). Re-enabling requires a new transfer mechanism.
|
||||
"""
|
||||
policy = getattr(global_args, 'policy', 'linear')
|
||||
t_decision_unix = _time.time()
|
||||
request_id = headers.get("X-Request-Id", "")
|
||||
breakdown: dict = {
|
||||
"request_id": headers.get("X-Request-Id", ""),
|
||||
"request_id": request_id,
|
||||
"session_id": session_id,
|
||||
"input_length": input_length,
|
||||
"t_proxy_recv": _time.monotonic(),
|
||||
"t_decision_unix": t_decision_unix,
|
||||
"policy": policy,
|
||||
}
|
||||
|
||||
pre_decision_workers = snapshot_workers(
|
||||
combined_instances, token_ids, input_length)
|
||||
|
||||
if policy == "lmetric":
|
||||
chosen, best_idx = pick_instance_lmetric(
|
||||
combined_instances, token_ids, session_id, input_length,
|
||||
@@ -550,14 +597,29 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
combined_instances, token_ids, session_id, input_length,
|
||||
session_affinity_combined)
|
||||
|
||||
cache_hit = chosen.estimate_cache_hit(token_ids)
|
||||
estimated_new = max(0, input_length - cache_hit)
|
||||
chosen_snap = pre_decision_workers[best_idx]
|
||||
cache_hit = chosen_snap["cache_hit"]
|
||||
estimated_new = chosen_snap["new_prefill"]
|
||||
breakdown.update({
|
||||
"cache_hit": cache_hit,
|
||||
"estimated_new_tokens": estimated_new,
|
||||
"route_class": "LOCAL",
|
||||
"routed_to": chosen.url,
|
||||
"chosen_idx": best_idx,
|
||||
"candidate_scores": pre_decision_workers,
|
||||
"chosen_score_linear": chosen_snap["score_linear"],
|
||||
"chosen_score_lmetric": chosen_snap["score_lmetric"],
|
||||
})
|
||||
|
||||
_worker_state_log.append({
|
||||
"t_decision_unix": t_decision_unix,
|
||||
"request_id": request_id,
|
||||
"session_id": session_id,
|
||||
"policy": policy,
|
||||
"chosen_idx": best_idx,
|
||||
"workers": pre_decision_workers,
|
||||
})
|
||||
|
||||
return await _handle_local_request(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
chosen, estimated_new, breakdown)
|
||||
@@ -566,17 +628,41 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
||||
session_id, headers):
|
||||
"""PD-Sep mode with per-stage breakdown profiling."""
|
||||
t_decision_unix = _time.time()
|
||||
breakdown = {
|
||||
"request_id": request_id,
|
||||
"session_id": session_id,
|
||||
"input_length": input_length,
|
||||
"t_proxy_recv": _time.monotonic(),
|
||||
"t_decision_unix": t_decision_unix,
|
||||
"policy": "pd_sep",
|
||||
}
|
||||
|
||||
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
|
||||
input_length, session_affinity_prefill)
|
||||
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
|
||||
pre_decision_p = snapshot_workers(prefill_instances, token_ids, input_length)
|
||||
pre_decision_d = snapshot_workers(decode_instances, token_ids, input_length)
|
||||
|
||||
p_inst, p_idx = pick_instance(prefill_instances, token_ids, session_id,
|
||||
input_length, session_affinity_prefill)
|
||||
d_idx = min(range(len(decode_instances)),
|
||||
key=lambda i: decode_instances[i].ongoing_tokens)
|
||||
d_inst = decode_instances[d_idx]
|
||||
breakdown["p_inst"] = p_inst.url
|
||||
breakdown["d_inst"] = d_inst.url
|
||||
breakdown["candidate_scores_prefill"] = pre_decision_p
|
||||
breakdown["candidate_scores_decode"] = pre_decision_d
|
||||
breakdown["chosen_p_idx"] = p_idx
|
||||
breakdown["chosen_d_idx"] = d_idx
|
||||
|
||||
_worker_state_log.append({
|
||||
"t_decision_unix": t_decision_unix,
|
||||
"request_id": request_id,
|
||||
"session_id": session_id,
|
||||
"policy": "pd_sep",
|
||||
"chosen_p_idx": p_idx,
|
||||
"chosen_d_idx": d_idx,
|
||||
"workers_prefill": pre_decision_p,
|
||||
"workers_decode": pre_decision_d,
|
||||
})
|
||||
|
||||
prefill_data = req_data.copy()
|
||||
prefill_data["kv_transfer_params"] = {
|
||||
@@ -592,15 +678,18 @@ async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
||||
|
||||
p_inst.ongoing_tokens += input_length
|
||||
breakdown["t_prefill_sent"] = _time.monotonic()
|
||||
breakdown["t_prefill_sent_unix"] = _time.time()
|
||||
|
||||
try:
|
||||
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
breakdown["t_prefill_done_unix"] = _time.time()
|
||||
resp.raise_for_status()
|
||||
await resp.aclose()
|
||||
p_inst.record_prefix(token_ids)
|
||||
except Exception as e:
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
breakdown["t_prefill_done_unix"] = _time.time()
|
||||
breakdown["prefill_error"] = True
|
||||
_breakdown_log.append(breakdown)
|
||||
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
|
||||
@@ -621,6 +710,7 @@ async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
||||
}
|
||||
|
||||
breakdown["t_decode_sent"] = _time.monotonic()
|
||||
breakdown["t_decode_sent_unix"] = _time.time()
|
||||
|
||||
async def generate():
|
||||
first_token = True
|
||||
@@ -635,11 +725,13 @@ async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
||||
output_token_ids.extend(new_output_ids)
|
||||
if first_token:
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
breakdown["t_first_token_unix"] = _time.time()
|
||||
first_token = False
|
||||
yield chunk
|
||||
d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids))
|
||||
finally:
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
breakdown["t_done_unix"] = _time.time()
|
||||
d_inst.ongoing_tokens -= input_length
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
@@ -652,6 +744,29 @@ async def get_breakdown():
|
||||
return _breakdown_log
|
||||
|
||||
|
||||
@app.get("/worker_state")
|
||||
async def get_worker_state():
|
||||
"""Return per-decision worker-state snapshot log (one entry per route decision)."""
|
||||
return _worker_state_log
|
||||
|
||||
|
||||
@app.get("/worker_state/latest")
|
||||
async def get_worker_state_latest():
|
||||
"""Return current per-worker state snapshot without recording it."""
|
||||
if combined_instances:
|
||||
return {
|
||||
"t_unix": _time.time(),
|
||||
"mode": "combined",
|
||||
"workers": snapshot_workers(combined_instances),
|
||||
}
|
||||
return {
|
||||
"t_unix": _time.time(),
|
||||
"mode": "pd_sep",
|
||||
"workers_prefill": snapshot_workers(prefill_instances),
|
||||
"workers_decode": snapshot_workers(decode_instances),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
async def get_stats():
|
||||
"""Return per-instance live state for debugging."""
|
||||
|
||||
Reference in New Issue
Block a user