Approach B v2: TTFT-based migration trigger

Replace num_requests threshold with recent TTFT median as migration
trigger. Track per-instance rolling TTFT (last 8 requests) and trigger
migration when median > 5s (configurable). Target is the instance with
lowest recent TTFT, requiring > 2x improvement to justify migration.

This is more responsive than the instantaneous num_requests signal
because TTFT directly measures the user-facing impact of contention.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-24 21:54:06 +08:00
parent 45b82272c3
commit 5772149d36

View File

@@ -19,7 +19,7 @@ import os
import time as _time
import urllib.parse
import uuid
from collections import OrderedDict
from collections import OrderedDict, deque
from contextlib import asynccontextmanager
from dataclasses import dataclass
@@ -53,6 +53,8 @@ class Settings:
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
migration_discount_cap: int = 5 # max turns to discount
migration_request_factor: float = 1.5 # trigger migration when num_requests > avg * factor
migration_ttft_threshold: float = 5.0 # trigger migration when recent TTFT median > this (seconds)
migration_ttft_window: int = 8 # number of recent TTFTs to track per instance
SETTINGS = Settings()
@@ -75,6 +77,7 @@ class InstanceState:
self.dp_size = 1
# OrderedDict acts as an LRU keyed by block hash; value is unused.
self.cached_blocks: OrderedDict[int, None] = OrderedDict()
self.recent_ttfts: deque[float] = deque(maxlen=SETTINGS.migration_ttft_window)
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
if not token_ids or len(token_ids) < BLOCK_SIZE:
@@ -448,7 +451,11 @@ async def _handle_local_request(api, req_data, headers, token_ids, input_length,
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
chosen.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
t_first = _time.monotonic()
breakdown["t_first_token"] = t_first
t_recv = breakdown.get("t_proxy_recv")
if t_recv:
chosen.recent_ttfts.append(t_first - t_recv)
prefill_done = True
yield chunk
chosen.record_prefix(
@@ -517,26 +524,30 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
best_cache_idx = max(range(len(combined_instances)), key=lambda i: cache_hits[i])
best_cache_hit = cache_hits[best_cache_idx]
# Session-level migration: force-migrate overloaded sessions
# Session-level migration: force-migrate when instance has high recent TTFT
if (offload_enabled and session_id
and session_id in session_affinity_combined):
mig_src_idx = session_affinity_combined[session_id]
if mig_src_idx < len(combined_instances):
mig_src = combined_instances[mig_src_idx]
avg_reqs = max(
sum(i.num_requests for i in combined_instances)
/ len(combined_instances), 1)
src_cache_ratio = cache_hits[mig_src_idx] / max(input_length, 1)
src_ttfts = mig_src.recent_ttfts
src_ttft_med = sorted(src_ttfts)[len(src_ttfts) // 2] if len(src_ttfts) >= 3 else 0
if (mig_src.num_requests > avg_reqs * SETTINGS.migration_request_factor
if (src_ttft_med > SETTINGS.migration_ttft_threshold
and src_cache_ratio > 0.5
and input_length >= SETTINGS.heavy_threshold):
mig_tgt_idx = min(
range(len(combined_instances)),
key=lambda i: combined_instances[i].num_requests)
# Find instance with lowest recent TTFT
def _inst_ttft_score(i: int) -> float:
t = combined_instances[i].recent_ttfts
if len(t) < 2:
return 0.0
return sorted(t)[len(t) // 2]
mig_tgt_idx = min(range(len(combined_instances)), key=_inst_ttft_score)
mig_tgt = combined_instances[mig_tgt_idx]
tgt_ttft_med = _inst_ttft_score(mig_tgt_idx)
if mig_tgt.num_requests < mig_src.num_requests - 2:
if tgt_ttft_med < src_ttft_med * 0.5:
estimated_new = max(0, input_length - cache_hits[mig_src_idx])
breakdown = {
"request_id": headers.get("X-Request-Id", ""),
@@ -1004,6 +1015,8 @@ def parse_args():
help="Estimated per-request decode iteration time in seconds")
p.add_argument("--migration-request-factor", type=float, default=1.5,
help="Trigger session migration when num_requests > avg * factor")
p.add_argument("--migration-ttft-threshold", type=float, default=5.0,
help="Trigger migration when instance median TTFT > this (seconds)")
args = p.parse_args()
args.prefill = []
@@ -1027,6 +1040,7 @@ if __name__ == "__main__":
SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio
SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05)
SETTINGS.migration_request_factor = getattr(global_args, 'migration_request_factor', 1.5)
SETTINGS.migration_ttft_threshold = getattr(global_args, 'migration_ttft_threshold', 5.0)
print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % (
SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s,
getattr(global_args, 'offload', False)))