diff --git a/scripts/cache_aware_proxy.py b/scripts/cache_aware_proxy.py index e225ef7..012c179 100644 --- a/scripts/cache_aware_proxy.py +++ b/scripts/cache_aware_proxy.py @@ -70,41 +70,37 @@ _inst_cumulative_tokens: list[int] = [] def pick_instance(instances: list[InstanceState], token_ids: list[int] | None, session_id: str | None, input_length: int, affinity: dict[str, int]) -> tuple[InstanceState, int]: - """Session-sticky + KV-size balanced placement. + """Session-sticky with load-aware override. - Turn 2+: session affinity (sticky to same instance for KV reuse). - Turn 1 (new session): place on instance with least cumulative token load - (greedy bin packing), with cache-hit tiebreak. + Turn 2+: use session affinity UNLESS pinned instance is overloaded + (ongoing_tokens > 2x average), in which case pick least-loaded. + Turn 1: pick instance with best score (load + cache combined). """ global _inst_cumulative_tokens if not _inst_cumulative_tokens: _inst_cumulative_tokens = [0] * len(instances) - # Session affinity for turn 2+ + avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0) + OVERLOAD_FACTOR = 2.0 + + # Session affinity for turn 2+ (with load override) if session_id and session_id in affinity: idx = affinity[session_id] if idx < len(instances): - return instances[idx], idx + inst = instances[idx] + # Stick if not overloaded + if inst.ongoing_tokens <= avg_load * OVERLOAD_FACTOR: + return inst, idx + # Overloaded: fall through to score-based selection - # New session: balanced placement - # Primary: least cumulative tokens (long-term balance) - # Secondary: cache hit (tiebreak for prefix reuse) - min_load = min(_inst_cumulative_tokens) - # Candidates within 10% of min load - threshold = min_load + max(min_load * 0.1, 10000) - candidates = [i for i in range(len(instances)) - if _inst_cumulative_tokens[i] <= threshold] - - if not candidates: - candidates = list(range(len(instances))) - - # Among candidates, pick best cache hit - best_idx = candidates[0] - best_hit = 0 - for i in candidates: - hit = instances[i].estimate_cache_hit(token_ids) - if hit > best_hit: - best_hit = hit + # Score = ongoing_tokens - ALPHA * cache_hit_tokens + # Balances load (lower is better) with cache affinity (higher hit is better) + best_idx, best_score = 0, float("inf") + for i, inst in enumerate(instances): + cache_hit = inst.estimate_cache_hit(token_ids) + score = inst.ongoing_tokens - CACHE_HIT_ALPHA * cache_hit + if score < best_score: + best_score = score best_idx = i _inst_cumulative_tokens[best_idx] += input_length