Approach B: session-level lazy migration trigger
When a request arrives for a session on an overloaded instance, force migration if three conditions hold: 1. Instance busy: num_requests > avg * migration_request_factor (1.5x) 2. Session has cache value: cache_ratio > 50% 3. Request is HEAVY (>= heavy_threshold) 4. A meaningfully less-loaded target exists (num_requests gap > 2) This bypasses the cost model for migration decisions — the cost model's cache-inflated costs prevented migration even when instances had 150s queue times with 99% cache hit. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -52,6 +52,7 @@ class Settings:
|
||||
cache_gate_ratio: float = 0.0
|
||||
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
|
||||
migration_discount_cap: int = 5 # max turns to discount
|
||||
migration_request_factor: float = 1.5 # trigger migration when num_requests > avg * factor
|
||||
|
||||
|
||||
SETTINGS = Settings()
|
||||
@@ -515,6 +516,50 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
cache_hits = [inst.estimate_cache_hit(token_ids) for inst in combined_instances]
|
||||
best_cache_idx = max(range(len(combined_instances)), key=lambda i: cache_hits[i])
|
||||
best_cache_hit = cache_hits[best_cache_idx]
|
||||
|
||||
# Session-level migration: force-migrate overloaded sessions
|
||||
if (offload_enabled and session_id
|
||||
and session_id in session_affinity_combined):
|
||||
mig_src_idx = session_affinity_combined[session_id]
|
||||
if mig_src_idx < len(combined_instances):
|
||||
mig_src = combined_instances[mig_src_idx]
|
||||
avg_reqs = max(
|
||||
sum(i.num_requests for i in combined_instances)
|
||||
/ len(combined_instances), 1)
|
||||
src_cache_ratio = cache_hits[mig_src_idx] / max(input_length, 1)
|
||||
|
||||
if (mig_src.num_requests > avg_reqs * SETTINGS.migration_request_factor
|
||||
and src_cache_ratio > 0.5
|
||||
and input_length >= SETTINGS.heavy_threshold):
|
||||
mig_tgt_idx = min(
|
||||
range(len(combined_instances)),
|
||||
key=lambda i: combined_instances[i].num_requests)
|
||||
mig_tgt = combined_instances[mig_tgt_idx]
|
||||
|
||||
if mig_tgt.num_requests < mig_src.num_requests - 2:
|
||||
estimated_new = max(0, input_length - cache_hits[mig_src_idx])
|
||||
breakdown = {
|
||||
"request_id": headers.get("X-Request-Id", ""),
|
||||
"input_length": input_length,
|
||||
"cache_hit": cache_hits[mig_tgt_idx],
|
||||
"estimated_new_tokens": estimated_new,
|
||||
"t_proxy_recv": _time.monotonic(),
|
||||
"policy": "session_migrate",
|
||||
"push_cache_hit": cache_hits[mig_src_idx],
|
||||
"c_inst": mig_src.url,
|
||||
"routed_to": mig_tgt.url,
|
||||
}
|
||||
session_affinity_combined[session_id] = mig_tgt_idx
|
||||
offload_mode = getattr(global_args, 'offload_mode', 'cached_prefill')
|
||||
if offload_mode == "cached_prefill":
|
||||
return await _handle_cached_prefill_offload(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
mig_tgt, mig_src, estimated_new, breakdown)
|
||||
else:
|
||||
return await _handle_direct_read_offload(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
mig_tgt, mig_src, estimated_new, breakdown)
|
||||
|
||||
def _current_offloads() -> int:
|
||||
return sum(i.active_p_offloads for i in combined_instances)
|
||||
|
||||
@@ -957,6 +1002,8 @@ def parse_args():
|
||||
"(0.0 disables gate, 1.0 disables offload entirely)")
|
||||
p.add_argument("--decode-iteration-s", type=float, default=0.05,
|
||||
help="Estimated per-request decode iteration time in seconds")
|
||||
p.add_argument("--migration-request-factor", type=float, default=1.5,
|
||||
help="Trigger session migration when num_requests > avg * factor")
|
||||
args = p.parse_args()
|
||||
|
||||
args.prefill = []
|
||||
@@ -979,6 +1026,7 @@ if __name__ == "__main__":
|
||||
SETTINGS.max_offload_inflight = global_args.max_offload_inflight
|
||||
SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio
|
||||
SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05)
|
||||
SETTINGS.migration_request_factor = getattr(global_args, 'migration_request_factor', 1.5)
|
||||
print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % (
|
||||
SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s,
|
||||
getattr(global_args, 'offload', False)))
|
||||
|
||||
Reference in New Issue
Block a user