Revert Approach B (session migration): overhead exceeds LB benefit

Reverts 3 commits: e991960, 5772149, 5b1d360.

57 migrations triggered but PD-sep overhead (C queue + KV transfer + D
cold start) caused HEAVY TTFT p90 to regress from 15.9s to 59.1s.
Migration mechanism needs fundamental rework before it can help.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-24 23:43:47 +08:00
parent 5b1d36080a
commit cc6e5625bb

View File

@@ -19,7 +19,7 @@ import os
import time as _time import time as _time
import urllib.parse import urllib.parse
import uuid import uuid
from collections import OrderedDict, deque from collections import OrderedDict
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
@@ -52,9 +52,6 @@ class Settings:
cache_gate_ratio: float = 0.0 cache_gate_ratio: float = 0.0
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20) decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
migration_discount_cap: int = 5 # max turns to discount migration_discount_cap: int = 5 # max turns to discount
migration_request_factor: float = 1.5 # trigger migration when num_requests > avg * factor
migration_ttft_threshold: float = 5.0 # trigger migration when recent TTFT median > this (seconds)
migration_ttft_window: int = 8 # number of recent TTFTs to track per instance
SETTINGS = Settings() SETTINGS = Settings()
@@ -77,7 +74,6 @@ class InstanceState:
self.dp_size = 1 self.dp_size = 1
# OrderedDict acts as an LRU keyed by block hash; value is unused. # OrderedDict acts as an LRU keyed by block hash; value is unused.
self.cached_blocks: OrderedDict[int, None] = OrderedDict() self.cached_blocks: OrderedDict[int, None] = OrderedDict()
self.recent_ttfts: deque[float] = deque(maxlen=SETTINGS.migration_ttft_window)
def estimate_cache_hit(self, token_ids: list[int] | None) -> int: def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
if not token_ids or len(token_ids) < BLOCK_SIZE: if not token_ids or len(token_ids) < BLOCK_SIZE:
@@ -451,11 +447,7 @@ async def _handle_local_request(api, req_data, headers, token_ids, input_length,
if not prefill_done: if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new chosen.pending_prefill_tokens -= estimated_new
chosen.ongoing_decode_tokens += input_length chosen.ongoing_decode_tokens += input_length
t_first = _time.monotonic() breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token"] = t_first
t_recv = breakdown.get("t_proxy_recv")
if t_recv:
chosen.recent_ttfts.append(t_first - t_recv)
prefill_done = True prefill_done = True
yield chunk yield chunk
chosen.record_prefix( chosen.record_prefix(
@@ -523,57 +515,6 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
cache_hits = [inst.estimate_cache_hit(token_ids) for inst in combined_instances] cache_hits = [inst.estimate_cache_hit(token_ids) for inst in combined_instances]
best_cache_idx = max(range(len(combined_instances)), key=lambda i: cache_hits[i]) best_cache_idx = max(range(len(combined_instances)), key=lambda i: cache_hits[i])
best_cache_hit = cache_hits[best_cache_idx] best_cache_hit = cache_hits[best_cache_idx]
# Session-level migration: force-migrate when instance has high recent TTFT
if (offload_enabled and session_id
and session_id in session_affinity_combined):
mig_src_idx = session_affinity_combined[session_id]
if mig_src_idx < len(combined_instances):
mig_src = combined_instances[mig_src_idx]
src_cache_ratio = cache_hits[mig_src_idx] / max(input_length, 1)
src_ttfts = mig_src.recent_ttfts
src_ttft_med = sorted(src_ttfts)[len(src_ttfts) // 2] if len(src_ttfts) >= 3 else 0
if (src_ttft_med > SETTINGS.migration_ttft_threshold
and src_cache_ratio > 0.5
and input_length >= SETTINGS.heavy_threshold):
# Find instance with lowest recent TTFT
def _inst_ttft_score(i: int) -> float:
t = combined_instances[i].recent_ttfts
if len(t) < 2:
return 0.0
return sorted(t)[len(t) // 2]
mig_tgt_idx = min(range(len(combined_instances)), key=_inst_ttft_score)
mig_tgt = combined_instances[mig_tgt_idx]
tgt_ttft_med = _inst_ttft_score(mig_tgt_idx)
if tgt_ttft_med < src_ttft_med * 0.5:
estimated_new = max(0, input_length - cache_hits[mig_src_idx])
breakdown = {
"request_id": headers.get("X-Request-Id", ""),
"input_length": input_length,
"cache_hit": cache_hits[mig_tgt_idx],
"estimated_new_tokens": estimated_new,
"t_proxy_recv": _time.monotonic(),
"policy": "session_migrate",
"push_cache_hit": cache_hits[mig_src_idx],
"c_inst": mig_src.url,
"routed_to": mig_tgt.url,
}
session_affinity_combined[session_id] = mig_tgt_idx
offload_mode = getattr(global_args, 'offload_mode', 'cached_prefill')
push_cache_hit = cache_hits[mig_src_idx]
if offload_mode == "cached_prefill":
return await _handle_cached_prefill_offload(
api, req_data, headers, token_ids, input_length,
mig_src, mig_tgt, push_cache_hit, estimated_new,
breakdown)
else:
return await _handle_direct_read_offload(
api, req_data, headers, token_ids, input_length,
mig_src, mig_tgt, push_cache_hit, estimated_new,
breakdown)
def _current_offloads() -> int: def _current_offloads() -> int:
return sum(i.active_p_offloads for i in combined_instances) return sum(i.active_p_offloads for i in combined_instances)
@@ -1016,10 +957,6 @@ def parse_args():
"(0.0 disables gate, 1.0 disables offload entirely)") "(0.0 disables gate, 1.0 disables offload entirely)")
p.add_argument("--decode-iteration-s", type=float, default=0.05, p.add_argument("--decode-iteration-s", type=float, default=0.05,
help="Estimated per-request decode iteration time in seconds") help="Estimated per-request decode iteration time in seconds")
p.add_argument("--migration-request-factor", type=float, default=1.5,
help="Trigger session migration when num_requests > avg * factor")
p.add_argument("--migration-ttft-threshold", type=float, default=5.0,
help="Trigger migration when instance median TTFT > this (seconds)")
args = p.parse_args() args = p.parse_args()
args.prefill = [] args.prefill = []
@@ -1042,8 +979,6 @@ if __name__ == "__main__":
SETTINGS.max_offload_inflight = global_args.max_offload_inflight SETTINGS.max_offload_inflight = global_args.max_offload_inflight
SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio
SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05) SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05)
SETTINGS.migration_request_factor = getattr(global_args, 'migration_request_factor', 1.5)
SETTINGS.migration_ttft_threshold = getattr(global_args, 'migration_ttft_threshold', 5.0)
print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % ( print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % (
SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s, SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s,
getattr(global_args, 'offload', False))) getattr(global_args, 'offload', False)))