From e9919605af6de5723a75e8629e1e113fa6b70975 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Sun, 24 May 2026 17:34:06 +0800 Subject: [PATCH] Approach B: session-level lazy migration trigger MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a request arrives for a session on an overloaded instance, force migration if three conditions hold: 1. Instance busy: num_requests > avg * migration_request_factor (1.5x) 2. Session has cache value: cache_ratio > 50% 3. Request is HEAVY (>= heavy_threshold) 4. A meaningfully less-loaded target exists (num_requests gap > 2) This bypasses the cost model for migration decisions — the cost model's cache-inflated costs prevented migration even when instances had 150s queue times with 99% cache hit. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/cache_aware_proxy.py | 48 ++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/scripts/cache_aware_proxy.py b/scripts/cache_aware_proxy.py index 89e0591..493b9a4 100644 --- a/scripts/cache_aware_proxy.py +++ b/scripts/cache_aware_proxy.py @@ -52,6 +52,7 @@ class Settings: cache_gate_ratio: float = 0.0 decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20) migration_discount_cap: int = 5 # max turns to discount + migration_request_factor: float = 1.5 # trigger migration when num_requests > avg * factor SETTINGS = Settings() @@ -515,6 +516,50 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h cache_hits = [inst.estimate_cache_hit(token_ids) for inst in combined_instances] best_cache_idx = max(range(len(combined_instances)), key=lambda i: cache_hits[i]) best_cache_hit = cache_hits[best_cache_idx] + + # Session-level migration: force-migrate overloaded sessions + if (offload_enabled and session_id + and session_id in session_affinity_combined): + mig_src_idx = session_affinity_combined[session_id] + if mig_src_idx < len(combined_instances): + mig_src = combined_instances[mig_src_idx] + avg_reqs = max( + sum(i.num_requests for i in combined_instances) + / len(combined_instances), 1) + src_cache_ratio = cache_hits[mig_src_idx] / max(input_length, 1) + + if (mig_src.num_requests > avg_reqs * SETTINGS.migration_request_factor + and src_cache_ratio > 0.5 + and input_length >= SETTINGS.heavy_threshold): + mig_tgt_idx = min( + range(len(combined_instances)), + key=lambda i: combined_instances[i].num_requests) + mig_tgt = combined_instances[mig_tgt_idx] + + if mig_tgt.num_requests < mig_src.num_requests - 2: + estimated_new = max(0, input_length - cache_hits[mig_src_idx]) + breakdown = { + "request_id": headers.get("X-Request-Id", ""), + "input_length": input_length, + "cache_hit": cache_hits[mig_tgt_idx], + "estimated_new_tokens": estimated_new, + "t_proxy_recv": _time.monotonic(), + "policy": "session_migrate", + "push_cache_hit": cache_hits[mig_src_idx], + "c_inst": mig_src.url, + "routed_to": mig_tgt.url, + } + session_affinity_combined[session_id] = mig_tgt_idx + offload_mode = getattr(global_args, 'offload_mode', 'cached_prefill') + if offload_mode == "cached_prefill": + return await _handle_cached_prefill_offload( + api, req_data, headers, token_ids, input_length, + mig_tgt, mig_src, estimated_new, breakdown) + else: + return await _handle_direct_read_offload( + api, req_data, headers, token_ids, input_length, + mig_tgt, mig_src, estimated_new, breakdown) + def _current_offloads() -> int: return sum(i.active_p_offloads for i in combined_instances) @@ -957,6 +1002,8 @@ def parse_args(): "(0.0 disables gate, 1.0 disables offload entirely)") p.add_argument("--decode-iteration-s", type=float, default=0.05, help="Estimated per-request decode iteration time in seconds") + p.add_argument("--migration-request-factor", type=float, default=1.5, + help="Trigger session migration when num_requests > avg * factor") args = p.parse_args() args.prefill = [] @@ -979,6 +1026,7 @@ if __name__ == "__main__": SETTINGS.max_offload_inflight = global_args.max_offload_inflight SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05) + SETTINGS.migration_request_factor = getattr(global_args, 'migration_request_factor', 1.5) print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % ( SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s, getattr(global_args, 'offload', False)))