Approach A: contention-aware cost model with migration discount

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-24 17:24:27 +08:00
parent e13391eeab
commit e06de5144b

View File

@@ -50,6 +50,8 @@ class Settings:
overload_factor: float = 2.0
max_offload_inflight: int = 4
cache_gate_ratio: float = 0.0
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
migration_discount_cap: int = 5 # max turns to discount
SETTINGS = Settings()
@@ -531,15 +533,20 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
def _instance_cost(i: int) -> tuple[float, bool]:
"""Expected latency if this request goes to instance i."""
inst = combined_instances[i]
queue = (inst.pending_prefill_tokens + inst.ongoing_decode_tokens) / throughput
contention = inst.num_requests * SETTINGS.decode_iteration_s
prefill_queue = inst.pending_prefill_tokens / throughput
local_hit = cache_hits[i]
local_new = max(0, input_length - local_hit)
local_cost = queue + local_new / throughput
local_cost = contention + prefill_queue + local_new / throughput
if (offload_enabled and best_cache_hit > 0 and _push_allowed(best_cache_hit)
and i != best_cache_idx and local_hit < best_cache_hit):
push_new = max(0, input_length - best_cache_hit)
push_cost = queue + push_new / throughput + SETTINGS.rdma_overhead_s
target_contention = inst.num_requests * SETTINGS.decode_iteration_s
push_cost = target_contention + push_new / throughput + SETTINGS.rdma_overhead_s
if session_id and session_id in session_affinity_combined:
turn_discount = min(SETTINGS.migration_discount_cap, 3) * SETTINGS.decode_iteration_s
push_cost -= turn_discount
if push_cost < local_cost:
return push_cost, True
return local_cost, False
@@ -948,6 +955,8 @@ def parse_args():
p.add_argument("--cache-gate-ratio", type=float, default=0.0,
help="Min cache_hit/input ratio to allow offload "
"(0.0 disables gate, 1.0 disables offload entirely)")
p.add_argument("--decode-iteration-s", type=float, default=0.05,
help="Estimated per-request decode iteration time in seconds")
args = p.parse_args()
args.prefill = []
@@ -969,6 +978,7 @@ if __name__ == "__main__":
SETTINGS.overload_factor = global_args.overload_factor
SETTINGS.max_offload_inflight = global_args.max_offload_inflight
SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio
SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05)
print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % (
SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s,
getattr(global_args, 'offload', False)))