Hybrid routing: LMetric for LB + explicit affinity for high-cache sessions

Replace the full unified cost model with a simpler hybrid:
- If session has >50% cache on affinity instance AND instance not overloaded
  (num_requests <= avg * overload_factor) → stick to affinity
- Otherwise → use LMetric (P × BS) for best load balance

This combines LMetric's superior load balance with explicit session
affinity for high-value sessions that have significant cache accumulation.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-25 09:05:08 +08:00
parent 448361cf83
commit 255c8e6884

View File

@@ -511,74 +511,33 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
api, req_data, headers, token_ids, input_length,
chosen, estimated_new, breakdown)
# Compute cache hits for all instances
cache_hits = [inst.estimate_cache_hit(token_ids) for inst in combined_instances]
best_cache_idx = max(range(len(combined_instances)), key=lambda i: cache_hits[i])
best_cache_hit = cache_hits[best_cache_idx]
def _current_offloads() -> int:
return sum(i.active_p_offloads for i in combined_instances)
def _push_allowed(cache_hit: int) -> bool:
if _current_offloads() >= SETTINGS.max_offload_inflight:
return False
push_new = max(0, input_length - cache_hit)
if push_new < SETTINGS.heavy_threshold:
return False
if SETTINGS.cache_gate_ratio > 0:
cache_ratio = cache_hit / max(input_length, 1)
if cache_ratio < SETTINGS.cache_gate_ratio:
return False
return True
def _instance_cost(i: int) -> tuple[float, bool]:
"""Expected latency if this request goes to instance i."""
inst = combined_instances[i]
contention = inst.num_requests * SETTINGS.decode_iteration_s
prefill_queue = inst.pending_prefill_tokens / throughput
local_hit = cache_hits[i]
local_new = max(0, input_length - local_hit)
local_cost = contention + prefill_queue + local_new / throughput
if (offload_enabled and best_cache_hit > 0 and _push_allowed(best_cache_hit)
and i != best_cache_idx and local_hit < best_cache_hit):
push_new = max(0, input_length - best_cache_hit)
target_contention = inst.num_requests * SETTINGS.decode_iteration_s
push_cost = target_contention + push_new / throughput + SETTINGS.rdma_overhead_s
if session_id and session_id in session_affinity_combined:
turn_discount = min(SETTINGS.migration_discount_cap, 3) * SETTINGS.decode_iteration_s
push_cost -= turn_discount
if push_cost < local_cost:
return push_cost, True
return local_cost, False
# Session affinity: prefer the last-used instance if its cost is reasonable
avg_load = max(sum(i.ongoing_tokens for i in combined_instances) / len(combined_instances), 1.0)
# Hybrid routing: LMetric for load balance + explicit affinity for high-cache sessions
#
# 1. If session has high cache on affinity instance AND instance not overloaded → stick
# 2. Otherwise → LMetric (P × BS) for best load balance
affinity_idx = session_affinity_combined.get(session_id) if session_id else None
use_affinity = False
if affinity_idx is not None and affinity_idx < len(combined_instances):
affinity_inst = combined_instances[affinity_idx]
# Hard gate: break affinity if instance is overloaded regardless of cache
if affinity_inst.ongoing_tokens <= avg_load * SETTINGS.overload_factor:
affinity_cost, affinity_push = _instance_cost(affinity_idx)
all_costs = [_instance_cost(i) for i in range(len(combined_instances))]
global_best_cost = min(c for c, _ in all_costs)
if affinity_cost <= global_best_cost * SETTINGS.overload_factor:
affinity_cache = affinity_inst.estimate_cache_hit(token_ids)
cache_ratio = affinity_cache / max(input_length, 1)
avg_reqs = max(sum(i.num_requests for i in combined_instances) / len(combined_instances), 1.0)
if (cache_ratio > 0.5
and affinity_inst.num_requests <= avg_reqs * SETTINGS.overload_factor):
use_affinity = True
best_idx = affinity_idx
best_cost = affinity_cost
best_needs_push = affinity_push
else:
best_idx = min(range(len(combined_instances)), key=lambda i: all_costs[i][0])
best_cost, best_needs_push = all_costs[best_idx]
else:
all_costs = [_instance_cost(i) for i in range(len(combined_instances))]
best_idx = min(range(len(combined_instances)), key=lambda i: all_costs[i][0])
best_cost, best_needs_push = all_costs[best_idx]
else:
all_costs = [_instance_cost(i) for i in range(len(combined_instances))]
best_idx = min(range(len(combined_instances)), key=lambda i: all_costs[i][0])
best_cost, best_needs_push = all_costs[best_idx]
if not use_affinity:
_, best_idx = pick_instance_lmetric(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
best_needs_push = False
chosen = combined_instances[best_idx]
cache_hit = cache_hits[best_idx]
cache_hit = chosen.estimate_cache_hit(token_ids)
estimated_new = max(0, input_length - cache_hit)
breakdown = {
@@ -587,8 +546,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
"cache_hit": cache_hit,
"estimated_new_tokens": estimated_new,
"t_proxy_recv": _time.monotonic(),
"policy": policy,
"chosen_cost": round(best_cost, 2),
"policy": "affinity" if use_affinity else "lmetric",
}
if session_id: