From 4b50c5a08d9e6110c18059de359d6ae57c075c9c Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Sun, 24 May 2026 16:25:02 +0800 Subject: [PATCH] Fix unified cost model: include decode load in queue + hard overload gate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two bugs caused elastic to concentrate load on cached instances (10x token imbalance vs 2.7x baseline): 1. _instance_cost queue only counted pending_prefill_tokens, missing ongoing_decode_tokens entirely — instances with 50 decoding requests appeared idle to the cost model. 2. Cache hits made overloaded instances look "cheap", creating a positive feedback loop: more sessions → more cache → lower cost → more routing. Added a hard gate (ongoing_tokens > avg * overload_factor) that breaks affinity before the cost model runs, matching linear policy behavior. Result: token imbalance 10.3x → 2.6x, TTFT p90 -37% vs baseline. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/cache_aware_proxy.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/scripts/cache_aware_proxy.py b/scripts/cache_aware_proxy.py index 64f2a0d..497fa69 100644 --- a/scripts/cache_aware_proxy.py +++ b/scripts/cache_aware_proxy.py @@ -531,7 +531,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h def _instance_cost(i: int) -> tuple[float, bool]: """Expected latency if this request goes to instance i.""" inst = combined_instances[i] - queue = inst.pending_prefill_tokens / throughput + queue = (inst.pending_prefill_tokens + inst.ongoing_decode_tokens) / throughput local_hit = cache_hits[i] local_new = max(0, input_length - local_hit) local_cost = queue + local_new / throughput @@ -545,18 +545,24 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h return local_cost, False # Session affinity: prefer the last-used instance if its cost is reasonable + avg_load = max(sum(i.ongoing_tokens for i in combined_instances) / len(combined_instances), 1.0) affinity_idx = session_affinity_combined.get(session_id) if session_id else None if affinity_idx is not None and affinity_idx < len(combined_instances): - affinity_cost, affinity_push = _instance_cost(affinity_idx) - # Compare with the globally best option - all_costs = [_instance_cost(i) for i in range(len(combined_instances))] - global_best_cost = min(c for c, _ in all_costs) - # Use affinity if it's within 2x of the best option - if affinity_cost <= global_best_cost * SETTINGS.overload_factor: - best_idx = affinity_idx - best_cost = affinity_cost - best_needs_push = affinity_push + affinity_inst = combined_instances[affinity_idx] + # Hard gate: break affinity if instance is overloaded regardless of cache + if affinity_inst.ongoing_tokens <= avg_load * SETTINGS.overload_factor: + affinity_cost, affinity_push = _instance_cost(affinity_idx) + all_costs = [_instance_cost(i) for i in range(len(combined_instances))] + global_best_cost = min(c for c, _ in all_costs) + if affinity_cost <= global_best_cost * SETTINGS.overload_factor: + best_idx = affinity_idx + best_cost = affinity_cost + best_needs_push = affinity_push + else: + best_idx = min(range(len(combined_instances)), key=lambda i: all_costs[i][0]) + best_cost, best_needs_push = all_costs[best_idx] else: + all_costs = [_instance_cost(i) for i in range(len(combined_instances))] best_idx = min(range(len(combined_instances)), key=lambda i: all_costs[i][0]) best_cost, best_needs_push = all_costs[best_idx] else: