Revert Approach B (session migration): overhead exceeds LB benefit
Reverts 3 commits:e991960,5772149,5b1d360. 57 migrations triggered but PD-sep overhead (C queue + KV transfer + D cold start) caused HEAVY TTFT p90 to regress from 15.9s to 59.1s. Migration mechanism needs fundamental rework before it can help. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -19,7 +19,7 @@ import os
|
|||||||
import time as _time
|
import time as _time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import uuid
|
import uuid
|
||||||
from collections import OrderedDict, deque
|
from collections import OrderedDict
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@@ -52,9 +52,6 @@ class Settings:
|
|||||||
cache_gate_ratio: float = 0.0
|
cache_gate_ratio: float = 0.0
|
||||||
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
|
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
|
||||||
migration_discount_cap: int = 5 # max turns to discount
|
migration_discount_cap: int = 5 # max turns to discount
|
||||||
migration_request_factor: float = 1.5 # trigger migration when num_requests > avg * factor
|
|
||||||
migration_ttft_threshold: float = 5.0 # trigger migration when recent TTFT median > this (seconds)
|
|
||||||
migration_ttft_window: int = 8 # number of recent TTFTs to track per instance
|
|
||||||
|
|
||||||
|
|
||||||
SETTINGS = Settings()
|
SETTINGS = Settings()
|
||||||
@@ -77,7 +74,6 @@ class InstanceState:
|
|||||||
self.dp_size = 1
|
self.dp_size = 1
|
||||||
# OrderedDict acts as an LRU keyed by block hash; value is unused.
|
# OrderedDict acts as an LRU keyed by block hash; value is unused.
|
||||||
self.cached_blocks: OrderedDict[int, None] = OrderedDict()
|
self.cached_blocks: OrderedDict[int, None] = OrderedDict()
|
||||||
self.recent_ttfts: deque[float] = deque(maxlen=SETTINGS.migration_ttft_window)
|
|
||||||
|
|
||||||
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
|
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
|
||||||
if not token_ids or len(token_ids) < BLOCK_SIZE:
|
if not token_ids or len(token_ids) < BLOCK_SIZE:
|
||||||
@@ -451,11 +447,7 @@ async def _handle_local_request(api, req_data, headers, token_ids, input_length,
|
|||||||
if not prefill_done:
|
if not prefill_done:
|
||||||
chosen.pending_prefill_tokens -= estimated_new
|
chosen.pending_prefill_tokens -= estimated_new
|
||||||
chosen.ongoing_decode_tokens += input_length
|
chosen.ongoing_decode_tokens += input_length
|
||||||
t_first = _time.monotonic()
|
breakdown["t_first_token"] = _time.monotonic()
|
||||||
breakdown["t_first_token"] = t_first
|
|
||||||
t_recv = breakdown.get("t_proxy_recv")
|
|
||||||
if t_recv:
|
|
||||||
chosen.recent_ttfts.append(t_first - t_recv)
|
|
||||||
prefill_done = True
|
prefill_done = True
|
||||||
yield chunk
|
yield chunk
|
||||||
chosen.record_prefix(
|
chosen.record_prefix(
|
||||||
@@ -523,57 +515,6 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
|||||||
cache_hits = [inst.estimate_cache_hit(token_ids) for inst in combined_instances]
|
cache_hits = [inst.estimate_cache_hit(token_ids) for inst in combined_instances]
|
||||||
best_cache_idx = max(range(len(combined_instances)), key=lambda i: cache_hits[i])
|
best_cache_idx = max(range(len(combined_instances)), key=lambda i: cache_hits[i])
|
||||||
best_cache_hit = cache_hits[best_cache_idx]
|
best_cache_hit = cache_hits[best_cache_idx]
|
||||||
|
|
||||||
# Session-level migration: force-migrate when instance has high recent TTFT
|
|
||||||
if (offload_enabled and session_id
|
|
||||||
and session_id in session_affinity_combined):
|
|
||||||
mig_src_idx = session_affinity_combined[session_id]
|
|
||||||
if mig_src_idx < len(combined_instances):
|
|
||||||
mig_src = combined_instances[mig_src_idx]
|
|
||||||
src_cache_ratio = cache_hits[mig_src_idx] / max(input_length, 1)
|
|
||||||
src_ttfts = mig_src.recent_ttfts
|
|
||||||
src_ttft_med = sorted(src_ttfts)[len(src_ttfts) // 2] if len(src_ttfts) >= 3 else 0
|
|
||||||
|
|
||||||
if (src_ttft_med > SETTINGS.migration_ttft_threshold
|
|
||||||
and src_cache_ratio > 0.5
|
|
||||||
and input_length >= SETTINGS.heavy_threshold):
|
|
||||||
# Find instance with lowest recent TTFT
|
|
||||||
def _inst_ttft_score(i: int) -> float:
|
|
||||||
t = combined_instances[i].recent_ttfts
|
|
||||||
if len(t) < 2:
|
|
||||||
return 0.0
|
|
||||||
return sorted(t)[len(t) // 2]
|
|
||||||
mig_tgt_idx = min(range(len(combined_instances)), key=_inst_ttft_score)
|
|
||||||
mig_tgt = combined_instances[mig_tgt_idx]
|
|
||||||
tgt_ttft_med = _inst_ttft_score(mig_tgt_idx)
|
|
||||||
|
|
||||||
if tgt_ttft_med < src_ttft_med * 0.5:
|
|
||||||
estimated_new = max(0, input_length - cache_hits[mig_src_idx])
|
|
||||||
breakdown = {
|
|
||||||
"request_id": headers.get("X-Request-Id", ""),
|
|
||||||
"input_length": input_length,
|
|
||||||
"cache_hit": cache_hits[mig_tgt_idx],
|
|
||||||
"estimated_new_tokens": estimated_new,
|
|
||||||
"t_proxy_recv": _time.monotonic(),
|
|
||||||
"policy": "session_migrate",
|
|
||||||
"push_cache_hit": cache_hits[mig_src_idx],
|
|
||||||
"c_inst": mig_src.url,
|
|
||||||
"routed_to": mig_tgt.url,
|
|
||||||
}
|
|
||||||
session_affinity_combined[session_id] = mig_tgt_idx
|
|
||||||
offload_mode = getattr(global_args, 'offload_mode', 'cached_prefill')
|
|
||||||
push_cache_hit = cache_hits[mig_src_idx]
|
|
||||||
if offload_mode == "cached_prefill":
|
|
||||||
return await _handle_cached_prefill_offload(
|
|
||||||
api, req_data, headers, token_ids, input_length,
|
|
||||||
mig_src, mig_tgt, push_cache_hit, estimated_new,
|
|
||||||
breakdown)
|
|
||||||
else:
|
|
||||||
return await _handle_direct_read_offload(
|
|
||||||
api, req_data, headers, token_ids, input_length,
|
|
||||||
mig_src, mig_tgt, push_cache_hit, estimated_new,
|
|
||||||
breakdown)
|
|
||||||
|
|
||||||
def _current_offloads() -> int:
|
def _current_offloads() -> int:
|
||||||
return sum(i.active_p_offloads for i in combined_instances)
|
return sum(i.active_p_offloads for i in combined_instances)
|
||||||
|
|
||||||
@@ -1016,10 +957,6 @@ def parse_args():
|
|||||||
"(0.0 disables gate, 1.0 disables offload entirely)")
|
"(0.0 disables gate, 1.0 disables offload entirely)")
|
||||||
p.add_argument("--decode-iteration-s", type=float, default=0.05,
|
p.add_argument("--decode-iteration-s", type=float, default=0.05,
|
||||||
help="Estimated per-request decode iteration time in seconds")
|
help="Estimated per-request decode iteration time in seconds")
|
||||||
p.add_argument("--migration-request-factor", type=float, default=1.5,
|
|
||||||
help="Trigger session migration when num_requests > avg * factor")
|
|
||||||
p.add_argument("--migration-ttft-threshold", type=float, default=5.0,
|
|
||||||
help="Trigger migration when instance median TTFT > this (seconds)")
|
|
||||||
args = p.parse_args()
|
args = p.parse_args()
|
||||||
|
|
||||||
args.prefill = []
|
args.prefill = []
|
||||||
@@ -1042,8 +979,6 @@ if __name__ == "__main__":
|
|||||||
SETTINGS.max_offload_inflight = global_args.max_offload_inflight
|
SETTINGS.max_offload_inflight = global_args.max_offload_inflight
|
||||||
SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio
|
SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio
|
||||||
SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05)
|
SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05)
|
||||||
SETTINGS.migration_request_factor = getattr(global_args, 'migration_request_factor', 1.5)
|
|
||||||
SETTINGS.migration_ttft_threshold = getattr(global_args, 'migration_ttft_threshold', 5.0)
|
|
||||||
print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % (
|
print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % (
|
||||||
SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s,
|
SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s,
|
||||||
getattr(global_args, 'offload', False)))
|
getattr(global_args, 'offload', False)))
|
||||||
|
|||||||
Reference in New Issue
Block a user