diff --git a/scripts/cache_aware_proxy.py b/scripts/cache_aware_proxy.py index 89e0591..dd36459 100644 --- a/scripts/cache_aware_proxy.py +++ b/scripts/cache_aware_proxy.py @@ -511,74 +511,33 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h api, req_data, headers, token_ids, input_length, chosen, estimated_new, breakdown) - # Compute cache hits for all instances - cache_hits = [inst.estimate_cache_hit(token_ids) for inst in combined_instances] - best_cache_idx = max(range(len(combined_instances)), key=lambda i: cache_hits[i]) - best_cache_hit = cache_hits[best_cache_idx] - def _current_offloads() -> int: - return sum(i.active_p_offloads for i in combined_instances) - - def _push_allowed(cache_hit: int) -> bool: - if _current_offloads() >= SETTINGS.max_offload_inflight: - return False - push_new = max(0, input_length - cache_hit) - if push_new < SETTINGS.heavy_threshold: - return False - if SETTINGS.cache_gate_ratio > 0: - cache_ratio = cache_hit / max(input_length, 1) - if cache_ratio < SETTINGS.cache_gate_ratio: - return False - return True - - def _instance_cost(i: int) -> tuple[float, bool]: - """Expected latency if this request goes to instance i.""" - inst = combined_instances[i] - contention = inst.num_requests * SETTINGS.decode_iteration_s - prefill_queue = inst.pending_prefill_tokens / throughput - local_hit = cache_hits[i] - local_new = max(0, input_length - local_hit) - local_cost = contention + prefill_queue + local_new / throughput - - if (offload_enabled and best_cache_hit > 0 and _push_allowed(best_cache_hit) - and i != best_cache_idx and local_hit < best_cache_hit): - push_new = max(0, input_length - best_cache_hit) - target_contention = inst.num_requests * SETTINGS.decode_iteration_s - push_cost = target_contention + push_new / throughput + SETTINGS.rdma_overhead_s - if session_id and session_id in session_affinity_combined: - turn_discount = min(SETTINGS.migration_discount_cap, 3) * SETTINGS.decode_iteration_s - push_cost -= turn_discount - if push_cost < local_cost: - return push_cost, True - return local_cost, False - - # Session affinity: prefer the last-used instance if its cost is reasonable - avg_load = max(sum(i.ongoing_tokens for i in combined_instances) / len(combined_instances), 1.0) + # Hybrid routing: LMetric for load balance + explicit affinity for high-cache sessions + # + # 1. If session has high cache on affinity instance AND instance not overloaded → stick + # 2. Otherwise → LMetric (P × BS) for best load balance affinity_idx = session_affinity_combined.get(session_id) if session_id else None + use_affinity = False + if affinity_idx is not None and affinity_idx < len(combined_instances): affinity_inst = combined_instances[affinity_idx] - # Hard gate: break affinity if instance is overloaded regardless of cache - if affinity_inst.ongoing_tokens <= avg_load * SETTINGS.overload_factor: - affinity_cost, affinity_push = _instance_cost(affinity_idx) - all_costs = [_instance_cost(i) for i in range(len(combined_instances))] - global_best_cost = min(c for c, _ in all_costs) - if affinity_cost <= global_best_cost * SETTINGS.overload_factor: - best_idx = affinity_idx - best_cost = affinity_cost - best_needs_push = affinity_push - else: - best_idx = min(range(len(combined_instances)), key=lambda i: all_costs[i][0]) - best_cost, best_needs_push = all_costs[best_idx] - else: - all_costs = [_instance_cost(i) for i in range(len(combined_instances))] - best_idx = min(range(len(combined_instances)), key=lambda i: all_costs[i][0]) - best_cost, best_needs_push = all_costs[best_idx] - else: - all_costs = [_instance_cost(i) for i in range(len(combined_instances))] - best_idx = min(range(len(combined_instances)), key=lambda i: all_costs[i][0]) - best_cost, best_needs_push = all_costs[best_idx] + affinity_cache = affinity_inst.estimate_cache_hit(token_ids) + cache_ratio = affinity_cache / max(input_length, 1) + avg_reqs = max(sum(i.num_requests for i in combined_instances) / len(combined_instances), 1.0) + + if (cache_ratio > 0.5 + and affinity_inst.num_requests <= avg_reqs * SETTINGS.overload_factor): + use_affinity = True + best_idx = affinity_idx + + if not use_affinity: + _, best_idx = pick_instance_lmetric( + combined_instances, token_ids, session_id, input_length, + session_affinity_combined) + + best_needs_push = False chosen = combined_instances[best_idx] - cache_hit = cache_hits[best_idx] + cache_hit = chosen.estimate_cache_hit(token_ids) estimated_new = max(0, input_length - cache_hit) breakdown = { @@ -587,8 +546,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h "cache_hit": cache_hit, "estimated_new_tokens": estimated_new, "t_proxy_recv": _time.monotonic(), - "policy": policy, - "chosen_cost": round(best_cost, 2), + "policy": "affinity" if use_affinity else "lmetric", } if session_id: