Fix multi-turn replay fidelity: track realized output tokens across all components

The replayer and proxy were building multi-turn prompts from trace tokens,
but the model generates different output tokens. Subsequent turns had wrong
prefix tokens, causing cache misses and invalid experimental measurements.

- replay.py: min_tokens=max_tokens for deterministic length, return_token_ids
  to capture actual output, _apply_realized_prefix for next-turn correction
- proxy: extract output token_ids from SSE, record prompt+output as realized
  prefix in shadow cache, extract _handle_local_request to deduplicate
- bench.sh/launch_elastic_p2p.sh: default elastic mode to unified policy
- mooncake_connector: only send prompt blocks (not stale output blocks),
  track failed_recving_block_ids for error recovery

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-24 14:47:51 +08:00
parent cc4a9c91e7
commit 9cebdb6b9b
5 changed files with 312 additions and 77 deletions

View File

@@ -249,6 +249,10 @@ class MooncakeConnector(KVConnectorBase_V1):
assert self.connector_worker is not None
return self.connector_worker.get_finished()
def get_block_ids_with_load_errors(self) -> set[int]:
assert self.connector_worker is not None
return self.connector_worker.get_block_ids_with_load_errors()
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
@@ -512,10 +516,13 @@ class MooncakeConnectorScheduler:
# TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below
delay_free_blocks = len(block_ids) > 0
block_size = self.vllm_config.cache_config.block_size
prompt_blocks = (request.num_prompt_tokens + block_size - 1) // block_size
send_block_ids = block_ids[:prompt_blocks]
delay_free_blocks = len(send_block_ids) > 0
if delay_free_blocks:
self._reqs_need_send[request.request_id] = (request, block_ids)
self._reqs_need_send[request.request_id] = (request, send_block_ids)
return delay_free_blocks, None
@@ -620,6 +627,7 @@ class MooncakeConnectorWorker:
self.finished_sending_reqs: set[ReqId] = set()
self.finished_recving_reqs: set[ReqId] = set()
self.failed_recving_block_ids: set[int] = set()
self.block_size = vllm_config.cache_config.block_size
self.model_config = vllm_config.model_config
@@ -1063,6 +1071,11 @@ class MooncakeConnectorWorker:
self.finished_recving_reqs = set()
return finished_recving_reqs
def get_block_ids_with_load_errors(self) -> set[int]:
failed = self.failed_recving_block_ids
self.failed_recving_block_ids = set()
return failed
async def fetch_finished_sending_reqs(self) -> set[ReqId]:
finished_sending_reqs = self.finished_sending_reqs
self.finished_sending_reqs = set()
@@ -1176,6 +1189,10 @@ class MooncakeConnectorWorker:
logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.")
except Exception as e:
logger.error("MooncakeXferMetadata transfer failed for %s: %s", req_ids, e)
for req_id in req_ids:
pull_meta = pull_metas[req_id]
self.failed_recving_block_ids.update(pull_meta.local_block_ids)
self.finished_recving_reqs.add(pull_meta.d_req_id)
return
def process_pulling_result(
@@ -1201,6 +1218,12 @@ class MooncakeConnectorWorker:
response.err_reqs,
response.err_msg,
)
for req_id in response.err_reqs:
pull_meta = pull_metas.get(req_id)
if pull_meta is None:
continue
self.failed_recving_block_ids.update(pull_meta.local_block_ids)
self.finished_recving_reqs.add(pull_meta.d_req_id)
async def _connect_to_prefiller_bootstrap(self, remote_bootstrap_addr: str):
url = remote_bootstrap_addr + "/query"
@@ -1322,11 +1345,13 @@ class MooncakeConnectorWorker:
logger.info("direct_push %s: %d blocks pushed from C", req_id, matched)
else:
logger.debug("direct_push %s: %d matched, pushed=%s", req_id, matched, pushed)
self.failed_recving_block_ids.update(local_block_ids)
self.finished_recving_reqs.add(req_id)
except Exception as e:
logger.error("direct_push %s failed: %s", req_id, e)
self.failed_recving_block_ids.update(pm.local_block_ids)
self.finished_recving_reqs.add(req_id)
async def _start_load_kv(