Unified routing: single argmin(expected_latency) over all instances

Replace two-phase routing (pick_instance → offload gate) with a single
cost function evaluated per instance:

  latency(D) = queue(D) + prefill_time(D) + transfer_cost(D)

  - If D has local cache: prefill = (input - local_hit) / throughput
  - If D can receive PUSH from cache source: prefill = (input - push_hit) / throughput + rdma
  - Otherwise: prefill = input / throughput (cold)

Choose argmin(latency). If the winner needs PUSH → trigger migration.

Removed:
- WARM/MEDIUM/HEAVY classification (no routing purpose)
- heavy_threshold, overload_factor, max_offload_inflight, cache_gate_ratio
- Interference penalty magic number (0.3)
- Separate pick_instance + offload gate stages

Only 2 measured parameters remain:
- prefill_throughput = 7000 tokens/s (H20 measured)
- rdma_overhead_s = 0.1s (RDMA PUSH measured)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-24 02:21:34 +08:00
parent 1cd0a18e2c
commit 6b255fad91

View File

@@ -42,12 +42,8 @@ class Settings:
CLI overrides survive even when the module is imported as a library
(e.g. by tests/) and __main__ does not run.
"""
heavy_threshold: int = 20000 # new-token cutoff for HEAVY classification
overload_factor: float = 2.0 # break session affinity above this * avg load
max_offload_inflight: int = 4 # global cap on concurrent P-role offloads
cache_gate_ratio: float = 0.3 # min cache_hit/input ratio to allow offload
prefill_throughput: float = 7000.0 # tokens/s per GPU (H20 measurement)
rdma_overhead_s: float = 0.1 # direct RDMA read overhead (raw memory read ~10-50ms)
prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20)
rdma_overhead_s: float = 0.1 # RDMA PUSH overhead (~10-50ms measured)
cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks
@@ -99,15 +95,10 @@ class InstanceState:
def _p_offload_penalty(inst: InstanceState) -> int:
"""Penalty for instances currently doing P-role offloaded prefills.
When an instance is busy with offloaded HEAVY prefills for other
instances, we want to steer WARM/MEDIUM requests away from it so
its GPU is dedicated to prefill (soft PD separation).
"""
"""Penalty for PD-sep mode routing (legacy)."""
if inst.active_p_offloads <= 0:
return 0
return inst.active_p_offloads * SETTINGS.heavy_threshold
return inst.active_p_offloads * 20000
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
@@ -127,7 +118,7 @@ def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
idx = affinity[session_id]
if idx < len(instances):
inst = instances[idx]
if (inst.ongoing_tokens <= avg_load * SETTINGS.overload_factor
if (inst.ongoing_tokens <= avg_load * 2.0
and inst.active_p_offloads == 0):
return inst, idx
@@ -337,116 +328,99 @@ async def _handle(request: Request, api: str):
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
"""Combined mode with V2 P2P offload.
"""Unified routing: pick the instance with lowest expected latency.
WARM/MEDIUM: route to best instance, co-located P+D (no KV transfer).
HEAVY: C_s (session-sticky, has cache) does FAST prefill,
D (least-loaded C, D != C_s) pulls KV via Mooncake and decodes.
Offload only when D is meaningfully less loaded than C_s.
For each instance, estimate:
latency = queue_time + prefill_time + transfer_cost
where prefill_time depends on whether the instance has cache (local),
can receive cache via PUSH (remote), or must do cold prefill.
"""
policy = getattr(global_args, 'policy', 'linear') if global_args else 'linear'
picker = pick_instance_lmetric if policy == 'lmetric' else pick_instance
best_inst, best_idx = picker(combined_instances, token_ids, session_id,
input_length, session_affinity_combined)
cache_hit = best_inst.estimate_cache_hit(token_ids)
offload_enabled = getattr(global_args, 'offload', False) and len(combined_instances) >= 2
throughput = SETTINGS.prefill_throughput
# Find the best cache source (instance with highest prefix cache hit)
cache_hits = []
for i, inst in enumerate(combined_instances):
hit = inst.estimate_cache_hit(token_ids)
cache_hits.append(hit)
best_cache_idx = max(range(len(combined_instances)), key=lambda i: cache_hits[i])
best_cache_hit = cache_hits[best_cache_idx]
# Score each instance by expected latency
best_idx = 0
best_cost = float("inf")
best_needs_push = False
costs = []
for i, inst in enumerate(combined_instances):
queue = inst.pending_prefill_tokens / throughput
local_hit = cache_hits[i]
local_new = max(0, input_length - local_hit)
if offload_enabled and best_cache_hit > 0 and i != best_cache_idx:
# This instance could receive cached blocks via PUSH
push_new = max(0, input_length - best_cache_hit)
push_cost = queue + push_new / throughput + SETTINGS.rdma_overhead_s
local_cost = queue + local_new / throughput
# Use whichever is cheaper (push vs local cache)
if push_cost < local_cost:
cost = push_cost
needs_push = True
else:
cost = local_cost
needs_push = False
else:
cost = queue + local_new / throughput
needs_push = False
costs.append((cost, needs_push))
if cost < best_cost:
best_cost = cost
best_idx = i
best_needs_push = needs_push
chosen = combined_instances[best_idx]
cache_hit = cache_hits[best_idx]
estimated_new = max(0, input_length - cache_hit)
breakdown = {
"request_id": headers.get("X-Request-Id", ""),
"input_length": input_length,
"estimated_new_tokens": estimated_new,
"cache_hit": cache_hit,
"estimated_new_tokens": estimated_new,
"t_proxy_recv": _time.monotonic(),
"chosen_cost": round(best_cost, 2),
}
# Runtime cost-model offload gate: compare co-located vs offload latency
# Co-located = queue(C_s) + prefill(new_tokens)
# Offload = queue(P) + prefill(P_new_tokens) + RDMA_overhead
offload_enabled = getattr(global_args, 'offload', False) and len(combined_instances) >= 2
use_offload = False
offload_reason = "offload_disabled"
if session_id:
session_affinity_combined[session_id] = best_idx
if estimated_new >= SETTINGS.heavy_threshold and offload_enabled:
cache_ratio = cache_hit / max(input_length, 1)
current_offloads = sum(c.active_p_offloads for c in combined_instances)
# P candidate: least-loaded instance excluding C_s, preferring instances
# not already shouldering an active P-role offload.
def _p_pick_score(c: InstanceState) -> int:
return c.ongoing_tokens + c.active_p_offloads * SETTINGS.heavy_threshold
p_candidate = min(
(c for c in combined_instances if c is not best_inst),
key=_p_pick_score,
)
# D candidate: least-loaded excluding both C_s and P
remaining = [c for c in combined_instances if c is not best_inst and c is not p_candidate]
d_candidate = min(remaining, key=lambda c: c.ongoing_tokens) if remaining else p_candidate
# Cost model: compare co-located vs direct-RDMA-read offload
# Co-located cost includes interference: heavy prefill on C_s blocks
# its ongoing decode requests, degrading their TPOT.
cs_queue = best_inst.pending_prefill_tokens / SETTINGS.prefill_throughput
prefill_time = estimated_new / SETTINGS.prefill_throughput
# Interference penalty: if C_s has decode requests, heavy prefill disrupts them
interference = prefill_time * min(best_inst.num_requests, 3) * 0.3
colocated_cost = cs_queue + prefill_time + interference
# Direct RDMA read: D reads cached blocks + prefills new tokens locally
# C_s is not involved → zero interference on C_s's decode
d_queue = d_candidate.pending_prefill_tokens / SETTINGS.prefill_throughput
offload_cost = d_queue + SETTINGS.rdma_overhead_s + prefill_time
breakdown["cache_ratio"] = cache_ratio
breakdown["colocated_cost"] = round(colocated_cost, 2)
breakdown["offload_cost"] = round(offload_cost, 2)
if current_offloads >= SETTINGS.max_offload_inflight:
offload_reason = "cap_reached_%d" % current_offloads
elif offload_cost < colocated_cost:
use_offload = True
offload_reason = "cost_model_%.1fvs%.1f" % (offload_cost, colocated_cost)
else:
offload_reason = "colocated_cheaper_%.1fvs%.1f" % (colocated_cost, offload_cost)
if use_offload:
# Direct RDMA read: D reads cached KV from C_s's GPU, no request to C_s
c_inst = best_inst # has cache (not doing any work)
d_inst = d_candidate
d_idx = combined_instances.index(d_inst)
if best_needs_push:
c_inst = combined_instances[best_cache_idx]
d_inst = chosen
push_cache_hit = best_cache_hit
push_new = max(0, input_length - push_cache_hit)
d_inst.ongoing_tokens += input_length
d_inst.pending_prefill_tokens += estimated_new
d_inst.pending_prefill_tokens += push_new
d_inst.num_requests += 1
c_inst.active_p_offloads += 1
breakdown["route_class"] = "HEAVY_OFFLOAD"
breakdown["offload_reason"] = offload_reason
breakdown["route_class"] = "PUSH_MIGRATE"
breakdown["c_inst"] = c_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["cache_hit_tokens"] = cache_hit
if session_id:
session_affinity_combined[session_id] = d_idx
breakdown["push_cache_hit"] = push_cache_hit
return await _handle_direct_read_offload(
api, req_data, headers, token_ids, input_length,
c_inst, d_inst, cache_hit, estimated_new, breakdown)
c_inst, d_inst, push_cache_hit, push_new, breakdown)
else:
if estimated_new >= SETTINGS.heavy_threshold:
breakdown["route_class"] = "HEAVY_COLO"
breakdown["offload_reason"] = offload_reason
elif estimated_new < 5000:
breakdown["route_class"] = "WARM"
else:
breakdown["route_class"] = "MEDIUM"
breakdown["route_class"] = "LOCAL"
breakdown["routed_to"] = chosen.url
inst = best_inst
breakdown["routed_to"] = inst.url
breakdown["policy"] = policy
inst.ongoing_tokens += input_length
inst.pending_prefill_tokens += estimated_new
inst.num_requests += 1
chosen.ongoing_tokens += input_length
chosen.pending_prefill_tokens += estimated_new
chosen.num_requests += 1
async def generate():
prefill_done = False
@@ -677,13 +651,7 @@ def parse_args():
if __name__ == "__main__":
global_args = parse_args()
SETTINGS.heavy_threshold = global_args.heavy_threshold
SETTINGS.overload_factor = global_args.overload_factor
SETTINGS.max_offload_inflight = global_args.max_offload_inflight
SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio
print(
"SETTINGS: heavy=%d overload=%.1f max_offload=%d cache_gate=%.2f"
% (SETTINGS.heavy_threshold, SETTINGS.overload_factor,
SETTINGS.max_offload_inflight, SETTINGS.cache_gate_ratio)
)
print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % (
SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s,
getattr(global_args, 'offload', False)))
uvicorn.run(app, host=global_args.host, port=global_args.port)