Files
replaysim/patches/frontier-vllm-v1-prefix-cache-chunked-prefill.patch

439 lines
21 KiB
Diff

diff --git a/frontier/entities/request.py b/frontier/entities/request.py
index a173caf..eee588b 100644
--- a/frontier/entities/request.py
+++ b/frontier/entities/request.py
@@ -301,7 +301,7 @@ class Request(BaseEntity):
@property
@check_scheduled
def e2e_time_normalized(self) -> float:
- return self.e2e_time / self.num_decode_tokens
+ return self.e2e_time / self.user_facing_num_decode_tokens
@property
@check_scheduled
@@ -315,7 +315,7 @@ class Request(BaseEntity):
@property
@check_scheduled
def execution_time_normalized(self) -> float:
- return self.execution_time / self.num_decode_tokens
+ return self.execution_time / self.user_facing_num_decode_tokens
@property
@check_scheduled
@@ -329,7 +329,7 @@ class Request(BaseEntity):
@property
@check_scheduled
def model_execution_time_normalized(self) -> float:
- return self.model_execution_time / self.num_decode_tokens
+ return self.model_execution_time / self.user_facing_num_decode_tokens
@property
def arrived_at(self) -> float:
@@ -886,10 +886,13 @@ class Request(BaseEntity):
Average time per output token excluding the first token.
Returns 0 if there's only one or no decode tokens.
"""
- if self._num_decode_tokens <= 1 or self._first_decode_token_completed_at == 0:
+ if (
+ self._user_facing_num_decode_tokens <= 1
+ or self._first_decode_token_completed_at == 0
+ ):
return 0
total_decode_time = self._completed_at - self._first_decode_token_completed_at
- return total_decode_time / (self._num_decode_tokens - 1)
+ return total_decode_time / (self._user_facing_num_decode_tokens - 1)
def on_kv_cache_transfer_start(self, transfer_start_time: float) -> None:
"""Record the earliest request-level KV transfer start timestamp."""
diff --git a/frontier/metrics/metrics_store.py b/frontier/metrics/metrics_store.py
index 422e2c2..eaa6308 100644
--- a/frontier/metrics/metrics_store.py
+++ b/frontier/metrics/metrics_store.py
@@ -2374,8 +2374,8 @@ class MetricsStore:
"cluster_type": cluster_type.name,
"arrived_at": float(request.arrived_at),
"arrived_at_ms": float(request.arrived_at) * 1000.0,
- "num_prefill_tokens": int(request.num_prefill_tokens),
- "num_decode_tokens": int(request.num_decode_tokens),
+ "num_prefill_tokens": int(request.user_facing_num_prefill_tokens),
+ "num_decode_tokens": int(request.user_facing_num_decode_tokens),
}
)
@@ -2558,8 +2558,8 @@ class MetricsStore:
* 1000.0,
"completed_at": float(request.completed_at),
"completed_at_ms": float(request.completed_at) * 1000.0,
- "num_prefill_tokens": int(request.num_prefill_tokens),
- "num_decode_tokens": int(request.num_decode_tokens),
+ "num_prefill_tokens": int(request.user_facing_num_prefill_tokens),
+ "num_decode_tokens": int(request.user_facing_num_decode_tokens),
"request_e2e_time_s": float(request.e2e_time),
"request_e2e_time_ms": float(request.e2e_time) * 1000.0,
"request_waiting_time_total_s": float(request_waiting_time_total),
@@ -2675,7 +2675,7 @@ class MetricsStore:
RequestMetricsTimeDistributions.REQUEST_EXECUTION_PLUS_PREEMPTION_TIME_NORMALIZED
].put(
request.id,
- request.execution_time / request.num_decode_tokens,
+ request.execution_time / request.user_facing_num_decode_tokens,
)
if request.is_prefill_complete:
@@ -2688,23 +2688,23 @@ class MetricsStore:
# Guard against division by zero (defensive programming)
# Normal requests should always have num_prefill_tokens >= 1, but
# this protects against edge cases or synthetic test requests
- if request.num_prefill_tokens > 0:
+ if request.user_facing_num_prefill_tokens > 0:
self._request_metrics_time_distributions[
RequestMetricsTimeDistributions.PREFILL_EXECUTION_PLUS_PREEMPTION_PER_TOKEN
].put(
request.id,
(request.prefill_completed_at - request.scheduled_at)
- / request.num_prefill_tokens,
+ / request.user_facing_num_prefill_tokens,
)
#
# Guard against division by zero for decode tokens
- if request.num_decode_tokens > 0:
+ if request.user_facing_num_decode_tokens > 0:
self._request_metrics_time_distributions[
RequestMetricsTimeDistributions.DECODE_E2E_TIME_PER_TOKEN
].put(
request.id,
(request.completed_at - request.prefill_completed_at)
- / request.num_decode_tokens,
+ / request.user_facing_num_decode_tokens,
)
self._request_metrics_histogram[
@@ -2958,7 +2958,7 @@ class MetricsStore:
RequestMetricsTimeDistributions.TTFT_DECODE_FIRST_TOKEN
].put(request.id, max(0, ttft_decode_first)) # Ensure non-negative
- if request.num_decode_tokens > 1:
+ if request.user_facing_num_decode_tokens > 1:
if request.first_decode_token_completed_at <= 0:
raise ValueError(
f"Missing first token completion timestamp for request_id={request.id}"
@@ -2983,7 +2983,7 @@ class MetricsStore:
# TPOT (Time Per Output Token) metrics
if (
- request.num_decode_tokens > 1
+ request.user_facing_num_decode_tokens > 1
and request.first_decode_token_completed_at > 0
):
self._request_metrics_time_distributions[
@@ -2997,7 +2997,7 @@ class MetricsStore:
)
# M2N transfer time per token (excluding first token)
tpot_transfer = request.total_m2n_transfer_time / (
- request.num_decode_tokens - 1
+ request.user_facing_num_decode_tokens - 1
)
self._request_metrics_time_distributions[
RequestMetricsTimeDistributions.TPOT_TRANSFER
diff --git a/frontier/scheduler/replica_scheduler/base_replica_scheduler.py b/frontier/scheduler/replica_scheduler/base_replica_scheduler.py
index c0c4085..2d362bb 100644
--- a/frontier/scheduler/replica_scheduler/base_replica_scheduler.py
+++ b/frontier/scheduler/replica_scheduler/base_replica_scheduler.py
@@ -599,6 +599,9 @@ class BaseReplicaScheduler(ABC):
"running_requests": self._debug_request_collection_state(
getattr(self, "_running_requests", None)
),
+ "active_batch_request_counts": dict(
+ getattr(self, "_active_batch_request_counts", {})
+ ),
"allocation_map": self._debug_allocation_map_state(
self._allocation_map
),
diff --git a/frontier/scheduler/replica_scheduler/vllm_v1_engine_replica_scheduler.py b/frontier/scheduler/replica_scheduler/vllm_v1_engine_replica_scheduler.py
index ac2e062..f61e088 100644
--- a/frontier/scheduler/replica_scheduler/vllm_v1_engine_replica_scheduler.py
+++ b/frontier/scheduler/replica_scheduler/vllm_v1_engine_replica_scheduler.py
@@ -2569,10 +2569,9 @@ class VLLMv1EngineReplicaScheduler(BaseReplicaScheduler):
else -1
)
- # Record preemption statistics in the request entity
- # This must be done BEFORE resetting num_processed_tokens
+ # Record preemption statistics in the request entity.
+ # This must be done before mutating the request's token accounting.
victim.record_preemption(self._cluster_type, num_computed_tokens_before)
- victim.advance_runtime_epoch()
# Remove from running requests
if victim in self._running_requests:
@@ -2582,9 +2581,33 @@ class VLLMv1EngineReplicaScheduler(BaseReplicaScheduler):
if victim.id in self._allocation_map:
self._free_request_resources(victim)
- # Mark as preempted and reset computed tokens
+ # Mark as preempted and reset computed tokens. Decode-phase preemption
+ # must replay the already-produced output tokens as part of the next
+ # prefill, matching vLLM's recompute path. Otherwise the request stays
+ # prefill-complete with zero processed tokens and waiting admission can
+ # compute a zero-token next step.
+ if victim.is_prefill_complete:
+ total_tokens = int(victim.total_tokens)
+ replay_prefill_tokens = int(victim.num_processed_tokens)
+ victim.advance_runtime_epoch()
+ victim._num_prefill_tokens = replay_prefill_tokens
+ victim._num_decode_tokens = max(total_tokens - replay_prefill_tokens, 0)
+ victim._num_processed_tokens = 0
+ victim._num_prefill_tokens_cached = 0
+ victim._scheduled = False
+ victim._completed = False
+ victim._is_prefill_complete = False
+ victim._current_decode_token_index = 1
+ victim._completed_layer_count = 0
+ victim._af_roundtrip_inflight = False
+ victim._num_restarts += 1
+ else:
+ victim.advance_runtime_epoch()
+ victim._num_processed_tokens = 0
+ victim._num_prefill_tokens_cached = 0
+ victim._scheduled = False
+ victim._is_prefill_complete = False
victim._preempted = True
- victim._num_processed_tokens = 0 # Reset computed tokens as in vLLM v1
self._scheduled_num_computed_tokens_by_request.pop(victim.id, None)
# Record re-entry to waiting queue for waiting time tracking
@@ -3139,6 +3162,19 @@ class VLLMv1EngineReplicaScheduler(BaseReplicaScheduler):
if num_new_tokens <= 0:
waiting_queue.popleft()
+ if not request.completed:
+ skipped_waiting_requests.append(request)
+ logger.warning(
+ "[WAITING-ZERO-TOKEN-SKIP] Preserving unfinished req=%s "
+ "with num_new_tokens=%s, preempted=%s, "
+ "is_prefill_complete=%s, processed=%s/%s",
+ request.id,
+ num_new_tokens,
+ getattr(request, "preempted", False),
+ getattr(request, "is_prefill_complete", False),
+ getattr(request, "num_processed_tokens", None),
+ getattr(request, "total_tokens", None),
+ )
continue
# Try to allocate (no preemption for waiting requests in Phase 2)
@@ -4360,12 +4396,14 @@ class VLLMv1EngineReplicaScheduler(BaseReplicaScheduler):
)
waiting_len = len(self._waiting_requests)
running_len = len(self._running_requests)
+ active_batch_request_len = len(self._get_active_batch_request_counts())
logger.info(
f"[RS-IDLE-CHECK][replica={self._replica_id}][dp={self._dp_id}] "
f"num_pending_requests={self.num_pending_requests}, waiting_requests={waiting_len}, "
f"running_requests={running_len}, allocated_blocks={len(self._allocation_map)}, "
- f"num_running_batches={self._num_running_batches}, stages_empty={stages_empty}, af_immediate_len={af_len}"
+ f"num_running_batches={self._num_running_batches}, stages_empty={stages_empty}, "
+ f"af_immediate_len={af_len}, active_batch_requests={active_batch_request_len}"
)
# If AF immediate queue has pending batches, the replica is not idle
if af_len > 0:
@@ -4374,6 +4412,7 @@ class VLLMv1EngineReplicaScheduler(BaseReplicaScheduler):
self.num_pending_requests == 0
and waiting_len == 0
and running_len == 0
+ and active_batch_request_len == 0
and len(self._allocation_map) == 0
and self._num_running_batches == 0
and stages_empty
diff --git a/frontier/scheduler/replica_stage_scheduler/replica_stage_schduler.py b/frontier/scheduler/replica_stage_scheduler/replica_stage_schduler.py
index 2344fe3..22c29d4 100644
--- a/frontier/scheduler/replica_stage_scheduler/replica_stage_schduler.py
+++ b/frontier/scheduler/replica_stage_scheduler/replica_stage_schduler.py
@@ -48,7 +48,7 @@ class ReplicaStageScheduler:
return self._is_last_stage
def is_empty(self) -> bool:
- return len(self._batch_queue) == 0
+ return len(self._batch_queue) == 0 and not self._is_busy
def get_debug_state(self) -> dict:
"""Return scheduler state for fail-fast sequential-end diagnostics."""
diff --git a/frontier/simulator.py b/frontier/simulator.py
index b1e14fd..083dcba 100644
--- a/frontier/simulator.py
+++ b/frontier/simulator.py
@@ -543,6 +543,146 @@ class Simulator:
+ json.dumps(payload, indent=2, sort_keys=True, default=str)
)
+ def _build_sequential_incomplete_request_report(
+ self,
+ *,
+ completed_requests: int,
+ total_requests: int,
+ ) -> str:
+ """Build a structured report for drained sequential runs with open requests."""
+ if not hasattr(self._global_scheduler, "_cluster_schedulers"):
+ raise RuntimeError(
+ "Global scheduler missing _cluster_schedulers for sequential diagnostics"
+ )
+
+ cluster_states = []
+ for cluster_type, cluster_scheduler in sorted(
+ self._global_scheduler._cluster_schedulers.items(),
+ key=lambda item: item[0].name,
+ ):
+ if not hasattr(cluster_scheduler, "get_debug_state"):
+ raise RuntimeError(
+ f"Cluster scheduler {cluster_type.name} missing get_debug_state()"
+ )
+ cluster_states.append(
+ {
+ "cluster_key": cluster_type.name,
+ "state": cluster_scheduler.get_debug_state(),
+ }
+ )
+
+ completed_ids = set(getattr(self._metric_store, "_completed_request_ids", set()))
+ generated_requests = list(getattr(self, "_all_requests", []))
+ missing_requests = [
+ request for request in generated_requests if request.id not in completed_ids
+ ]
+ missing_request_summaries = [
+ self._build_request_debug_snapshot(request) for request in missing_requests
+ ]
+ payload = {
+ "message": (
+ "Sequential simulation drained before all requests completed"
+ ),
+ "simulation_time": self._time,
+ "terminate": self._terminate,
+ "event_queue_length": len(self._event_queue),
+ "global_scheduler_is_empty": self._global_scheduler.is_empty,
+ "completed_requests": completed_requests,
+ "total_requests": total_requests,
+ "missing_request_count": total_requests - completed_requests,
+ "completed_request_ids": sorted(completed_ids),
+ "missing_request_ids": [request.id for request in missing_requests],
+ "missing_requests": missing_request_summaries,
+ "clusters": cluster_states,
+ }
+ return (
+ "Sequential simulation drained before all requests completed:\n"
+ + json.dumps(payload, indent=2, sort_keys=True, default=str)
+ )
+
+ def _build_request_debug_snapshot(self, request) -> dict:
+ """Return request state without calling checked metric properties."""
+ return {
+ "id": request.id,
+ "session_id": getattr(request, "session_id", None),
+ "arrived_at": getattr(request, "arrived_at", None),
+ "num_prefill_tokens": getattr(request, "num_prefill_tokens", None),
+ "num_decode_tokens": getattr(request, "num_decode_tokens", None),
+ "total_tokens": getattr(request, "total_tokens", None),
+ "num_processed_tokens": getattr(request, "num_processed_tokens", None),
+ "num_processed_prefill_tokens": getattr(
+ request, "num_processed_prefill_tokens", None
+ ),
+ "num_processed_decode_tokens": getattr(
+ request, "num_processed_decode_tokens", None
+ ),
+ "remaining_decode_tokens": getattr(request, "remaining_decode_tokens", None),
+ "num_prefill_tokens_cached": getattr(
+ request, "num_prefill_tokens_cached", None
+ ),
+ "is_prefill_complete": getattr(request, "is_prefill_complete", None),
+ "scheduled": getattr(request, "scheduled", None),
+ "preempted": getattr(request, "preempted", None),
+ "completed": getattr(request, "completed", None),
+ "completed_at": getattr(request, "_completed_at", None),
+ "prefill_completed_at": getattr(request, "_prefill_completed_at", None),
+ "latest_stage_scheduled_at": getattr(
+ request, "_latest_stage_scheduled_at", None
+ ),
+ "latest_stage_completed_at": getattr(
+ request, "_latest_stage_completed_at", None
+ ),
+ "latest_iteration_scheduled_at": getattr(
+ request, "_latest_iteration_scheduled_at", None
+ ),
+ "latest_iteration_completed_at": getattr(
+ request, "_latest_iteration_completed_at", None
+ ),
+ "current_decode_token_index": getattr(
+ request, "_current_decode_token_index", None
+ ),
+ "completed_layer_count": getattr(request, "_completed_layer_count", None),
+ "runtime_epoch": getattr(request, "runtime_epoch", None),
+ "execution_epoch": getattr(request, "execution_epoch", None),
+ "num_restarts": getattr(request, "num_restarts", None),
+ "cluster_arrival_times": {
+ cluster_type.name: list(times)
+ for cluster_type, times in getattr(
+ request, "_cluster_arrival_times", {}
+ ).items()
+ },
+ "cluster_scheduled_at": {
+ cluster_type.name: list(times)
+ for cluster_type, times in getattr(
+ request, "_scheduled_at", {}
+ ).items()
+ },
+ "cluster_scheduling_delay": {
+ cluster_type.name: list(times)
+ for cluster_type, times in getattr(
+ request, "_scheduling_delay", {}
+ ).items()
+ },
+ "cluster_execution_time": {
+ cluster_type.name: list(times)
+ for cluster_type, times in getattr(
+ request, "_execution_time", {}
+ ).items()
+ },
+ "preemption_count": {
+ cluster_type.name: count
+ for cluster_type, count in getattr(
+ request, "_preemption_count", {}
+ ).items()
+ },
+ "tokens_at_preemption": {
+ cluster_type.name: list(tokens)
+ for cluster_type, tokens in getattr(
+ request, "_tokens_at_preemption", {}
+ ).items()
+ },
+ }
+
def _try_promote_terminal_pdaf_scheduler_work(self) -> bool:
"""Promote terminal PD-AF DECODE_FFN groups when no event can fill them.
@@ -708,6 +848,25 @@ class Simulator:
self._trace_store.close()
raise RuntimeError(report)
+ total_requests = self._metric_store.get_total_requests()
+ completed_requests = self._metric_store.get_completed_requests()
+ if (
+ total_requests > 0
+ and completed_requests < total_requests
+ and not self._terminate
+ ):
+ report = self._build_sequential_incomplete_request_report(
+ completed_requests=completed_requests,
+ total_requests=total_requests,
+ )
+ logger.error(report)
+ if self._sequential_event_loggers:
+ for logger_instance in self._sequential_event_loggers.values():
+ logger_instance.write_summary()
+ if self._trace_store:
+ self._trace_store.close()
+ raise RuntimeError(report)
+
if self._sequential_event_loggers:
for logger_instance in self._sequential_event_loggers.values():
logger_instance.write_summary()