Fix multi-turn replay fidelity: track realized output tokens across all components
The replayer and proxy were building multi-turn prompts from trace tokens, but the model generates different output tokens. Subsequent turns had wrong prefix tokens, causing cache misses and invalid experimental measurements. - replay.py: min_tokens=max_tokens for deterministic length, return_token_ids to capture actual output, _apply_realized_prefix for next-turn correction - proxy: extract output token_ids from SSE, record prompt+output as realized prefix in shadow cache, extract _handle_local_request to deduplicate - bench.sh/launch_elastic_p2p.sh: default elastic mode to unified policy - mooncake_connector: only send prompt blocks (not stale output blocks), track failed_recving_block_ids for error recovery Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -84,6 +84,12 @@ class _SessionState:
|
||||
metrics: list[RequestMetrics] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DispatchResult:
|
||||
metric: RequestMetrics
|
||||
output_token_ids: list[int]
|
||||
|
||||
|
||||
_endpoint_counter = 0
|
||||
|
||||
|
||||
@@ -103,14 +109,17 @@ async def _dispatch_request(
|
||||
req: TraceRequest,
|
||||
prompt_token_ids: list[int],
|
||||
sem: asyncio.Semaphore,
|
||||
) -> RequestMetrics:
|
||||
) -> _DispatchResult:
|
||||
"""Send one request via /v1/completions (streaming) and collect metrics."""
|
||||
endpoint = _pick_endpoint(config)
|
||||
target_output_tokens = max(1, req.output_length)
|
||||
payload = {
|
||||
"model": config.model_name,
|
||||
"prompt": prompt_token_ids,
|
||||
"max_tokens": max(1, req.output_length),
|
||||
"max_tokens": target_output_tokens,
|
||||
"min_tokens": target_output_tokens,
|
||||
"temperature": 0,
|
||||
"return_token_ids": True,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
@@ -122,6 +131,7 @@ async def _dispatch_request(
|
||||
finish_reason = None
|
||||
err = None
|
||||
token_times: list[float] = []
|
||||
output_token_ids: list[int] = []
|
||||
|
||||
req_headers = {"X-Session-Id": req.session_id}
|
||||
|
||||
@@ -146,14 +156,24 @@ async def _dispatch_request(
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
now = time.perf_counter()
|
||||
if ttft_s is None:
|
||||
ttft_s = now - start
|
||||
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
now = time.perf_counter()
|
||||
delta = choices[0].get("text", "")
|
||||
if delta:
|
||||
chunk_token_ids = choices[0].get("token_ids")
|
||||
if isinstance(chunk_token_ids, list):
|
||||
clean_ids = [
|
||||
int(t) for t in chunk_token_ids
|
||||
if isinstance(t, int)
|
||||
]
|
||||
if clean_ids:
|
||||
if ttft_s is None:
|
||||
ttft_s = now - start
|
||||
output_token_ids.extend(clean_ids)
|
||||
token_times.extend([now] * len(clean_ids))
|
||||
elif delta:
|
||||
if ttft_s is None:
|
||||
ttft_s = now - start
|
||||
token_times.append(now)
|
||||
fr = choices[0].get("finish_reason")
|
||||
if fr:
|
||||
@@ -168,8 +188,15 @@ async def _dispatch_request(
|
||||
|
||||
end = time.perf_counter()
|
||||
e2e = end - start
|
||||
if n_output == 0 and token_times:
|
||||
if output_token_ids:
|
||||
n_output = len(output_token_ids)
|
||||
elif n_output == 0 and token_times:
|
||||
n_output = len(token_times)
|
||||
if err is None and n_output != target_output_tokens:
|
||||
err = (
|
||||
"output_token_mismatch "
|
||||
f"requested={target_output_tokens} actual={n_output}"
|
||||
)
|
||||
|
||||
tpot = 0.0
|
||||
if len(token_times) > 1:
|
||||
@@ -177,26 +204,42 @@ async def _dispatch_request(
|
||||
for i in range(len(token_times) - 1)]
|
||||
tpot = sum(inter_token) / len(inter_token)
|
||||
|
||||
return RequestMetrics(
|
||||
request_id=req.request_id,
|
||||
session_id=req.session_id,
|
||||
turn_id=req.turn_id,
|
||||
trace_timestamp_s=req.timestamp_s,
|
||||
input_length=req.input_length,
|
||||
output_length=req.output_length,
|
||||
request_type=req.request_type,
|
||||
effective_input_length=len(prompt_token_ids),
|
||||
cached_tokens=cached_tokens,
|
||||
latency_s=e2e,
|
||||
ttft_s=ttft_s,
|
||||
tpot_s=tpot,
|
||||
actual_output_tokens=n_output,
|
||||
requested_output_tokens=req.output_length,
|
||||
finish_reason=finish_reason,
|
||||
error=err,
|
||||
return _DispatchResult(
|
||||
metric=RequestMetrics(
|
||||
request_id=req.request_id,
|
||||
session_id=req.session_id,
|
||||
turn_id=req.turn_id,
|
||||
trace_timestamp_s=req.timestamp_s,
|
||||
input_length=req.input_length,
|
||||
output_length=req.output_length,
|
||||
request_type=req.request_type,
|
||||
effective_input_length=len(prompt_token_ids),
|
||||
cached_tokens=cached_tokens,
|
||||
latency_s=e2e,
|
||||
ttft_s=ttft_s,
|
||||
tpot_s=tpot,
|
||||
actual_output_tokens=n_output,
|
||||
requested_output_tokens=req.output_length,
|
||||
finish_reason=finish_reason,
|
||||
error=err,
|
||||
),
|
||||
output_token_ids=output_token_ids,
|
||||
)
|
||||
|
||||
|
||||
def _apply_realized_prefix(
|
||||
prompt_token_ids: list[int],
|
||||
realized_context: list[int],
|
||||
) -> list[int]:
|
||||
"""Replace the reusable session prefix with engine-realized tokens."""
|
||||
if not realized_context:
|
||||
return prompt_token_ids
|
||||
out = prompt_token_ids.copy()
|
||||
n = min(len(out), len(realized_context))
|
||||
out[:n] = realized_context[:n]
|
||||
return out
|
||||
|
||||
|
||||
def _extract_cached_tokens(usage: dict) -> int:
|
||||
ct = 0
|
||||
details = usage.get("prompt_tokens_details")
|
||||
@@ -220,6 +263,7 @@ async def _run_session(
|
||||
) -> list[RequestMetrics]:
|
||||
if session_sem is not None:
|
||||
await session_sem.acquire()
|
||||
realized_context: list[int] = []
|
||||
try:
|
||||
for req in state.turns:
|
||||
# Wait until this request's trace timestamp
|
||||
@@ -228,13 +272,19 @@ async def _run_session(
|
||||
if elapsed < target_wall:
|
||||
await asyncio.sleep(target_wall - elapsed)
|
||||
|
||||
token_ids = _build_prompt_token_ids(req)
|
||||
metric = await _dispatch_request(
|
||||
token_ids = _apply_realized_prefix(
|
||||
_build_prompt_token_ids(req),
|
||||
realized_context,
|
||||
)
|
||||
result = await _dispatch_request(
|
||||
client=client, config=config, req=req,
|
||||
prompt_token_ids=token_ids, sem=request_sem,
|
||||
)
|
||||
metric = result.metric
|
||||
state.metrics.append(metric)
|
||||
await sink.append(metric)
|
||||
if metric.error is None:
|
||||
realized_context = token_ids + result.output_token_ids
|
||||
finally:
|
||||
if session_sem is not None:
|
||||
session_sem.release()
|
||||
|
||||
Reference in New Issue
Block a user