Fix multi-turn replay fidelity: track realized output tokens across all components

The replayer and proxy were building multi-turn prompts from trace tokens,
but the model generates different output tokens. Subsequent turns had wrong
prefix tokens, causing cache misses and invalid experimental measurements.

- replay.py: min_tokens=max_tokens for deterministic length, return_token_ids
  to capture actual output, _apply_realized_prefix for next-turn correction
- proxy: extract output token_ids from SSE, record prompt+output as realized
  prefix in shadow cache, extract _handle_local_request to deduplicate
- bench.sh/launch_elastic_p2p.sh: default elastic mode to unified policy
- mooncake_connector: only send prompt blocks (not stale output blocks),
  track failed_recving_block_ids for error recovery

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-24 14:47:51 +08:00
parent cc4a9c91e7
commit 9cebdb6b9b
5 changed files with 312 additions and 77 deletions

View File

@@ -84,6 +84,12 @@ class _SessionState:
metrics: list[RequestMetrics] = field(default_factory=list)
@dataclass
class _DispatchResult:
metric: RequestMetrics
output_token_ids: list[int]
_endpoint_counter = 0
@@ -103,14 +109,17 @@ async def _dispatch_request(
req: TraceRequest,
prompt_token_ids: list[int],
sem: asyncio.Semaphore,
) -> RequestMetrics:
) -> _DispatchResult:
"""Send one request via /v1/completions (streaming) and collect metrics."""
endpoint = _pick_endpoint(config)
target_output_tokens = max(1, req.output_length)
payload = {
"model": config.model_name,
"prompt": prompt_token_ids,
"max_tokens": max(1, req.output_length),
"max_tokens": target_output_tokens,
"min_tokens": target_output_tokens,
"temperature": 0,
"return_token_ids": True,
"stream": True,
"stream_options": {"include_usage": True},
}
@@ -122,6 +131,7 @@ async def _dispatch_request(
finish_reason = None
err = None
token_times: list[float] = []
output_token_ids: list[int] = []
req_headers = {"X-Session-Id": req.session_id}
@@ -146,14 +156,24 @@ async def _dispatch_request(
except json.JSONDecodeError:
continue
now = time.perf_counter()
if ttft_s is None:
ttft_s = now - start
choices = chunk.get("choices", [])
if choices:
now = time.perf_counter()
delta = choices[0].get("text", "")
if delta:
chunk_token_ids = choices[0].get("token_ids")
if isinstance(chunk_token_ids, list):
clean_ids = [
int(t) for t in chunk_token_ids
if isinstance(t, int)
]
if clean_ids:
if ttft_s is None:
ttft_s = now - start
output_token_ids.extend(clean_ids)
token_times.extend([now] * len(clean_ids))
elif delta:
if ttft_s is None:
ttft_s = now - start
token_times.append(now)
fr = choices[0].get("finish_reason")
if fr:
@@ -168,8 +188,15 @@ async def _dispatch_request(
end = time.perf_counter()
e2e = end - start
if n_output == 0 and token_times:
if output_token_ids:
n_output = len(output_token_ids)
elif n_output == 0 and token_times:
n_output = len(token_times)
if err is None and n_output != target_output_tokens:
err = (
"output_token_mismatch "
f"requested={target_output_tokens} actual={n_output}"
)
tpot = 0.0
if len(token_times) > 1:
@@ -177,26 +204,42 @@ async def _dispatch_request(
for i in range(len(token_times) - 1)]
tpot = sum(inter_token) / len(inter_token)
return RequestMetrics(
request_id=req.request_id,
session_id=req.session_id,
turn_id=req.turn_id,
trace_timestamp_s=req.timestamp_s,
input_length=req.input_length,
output_length=req.output_length,
request_type=req.request_type,
effective_input_length=len(prompt_token_ids),
cached_tokens=cached_tokens,
latency_s=e2e,
ttft_s=ttft_s,
tpot_s=tpot,
actual_output_tokens=n_output,
requested_output_tokens=req.output_length,
finish_reason=finish_reason,
error=err,
return _DispatchResult(
metric=RequestMetrics(
request_id=req.request_id,
session_id=req.session_id,
turn_id=req.turn_id,
trace_timestamp_s=req.timestamp_s,
input_length=req.input_length,
output_length=req.output_length,
request_type=req.request_type,
effective_input_length=len(prompt_token_ids),
cached_tokens=cached_tokens,
latency_s=e2e,
ttft_s=ttft_s,
tpot_s=tpot,
actual_output_tokens=n_output,
requested_output_tokens=req.output_length,
finish_reason=finish_reason,
error=err,
),
output_token_ids=output_token_ids,
)
def _apply_realized_prefix(
prompt_token_ids: list[int],
realized_context: list[int],
) -> list[int]:
"""Replace the reusable session prefix with engine-realized tokens."""
if not realized_context:
return prompt_token_ids
out = prompt_token_ids.copy()
n = min(len(out), len(realized_context))
out[:n] = realized_context[:n]
return out
def _extract_cached_tokens(usage: dict) -> int:
ct = 0
details = usage.get("prompt_tokens_details")
@@ -220,6 +263,7 @@ async def _run_session(
) -> list[RequestMetrics]:
if session_sem is not None:
await session_sem.acquire()
realized_context: list[int] = []
try:
for req in state.turns:
# Wait until this request's trace timestamp
@@ -228,13 +272,19 @@ async def _run_session(
if elapsed < target_wall:
await asyncio.sleep(target_wall - elapsed)
token_ids = _build_prompt_token_ids(req)
metric = await _dispatch_request(
token_ids = _apply_realized_prefix(
_build_prompt_token_ids(req),
realized_context,
)
result = await _dispatch_request(
client=client, config=config, req=req,
prompt_token_ids=token_ids, sem=request_sem,
)
metric = result.metric
state.metrics.append(metric)
await sink.append(metric)
if metric.error is None:
realized_context = token_ids + result.output_token_ids
finally:
if session_sem is not None:
session_sem.release()