proxy: Settings dataclass + cache-ratio gate + P-pick offload penalty (B4, M2, M3, D5)

- Replace mutable module constants (HEAVY_THRESHOLD/OVERLOAD_FACTOR/
  MAX_OFFLOAD_INFLIGHT/PREFILL_THROUGHPUT/RDMA_OVERHEAD_S/
  CACHE_CAPACITY_BLOCKS) with a Settings dataclass + SETTINGS singleton.
  __main__ now mutates SETTINGS so CLI overrides survive even when the
  module is imported as a library (e.g. by tests/) (D5).
- Add --max-offload-inflight CLI flag (M3) and read it from SETTINGS.
- Add --cache-gate-ratio CLI flag and a real gate before the cost-model
  branch: if cache_hit/input_length < ratio, mark cache_gate_REASON and
  fall back to colocated. cache_ratio is no longer a write-only field
  (B4).
- P candidate selection penalises instances already running offloaded
  HEAVY prefills, so back-to-back HEAVY requests don't pile onto the
  same P (M2).
- bench.sh forwards --max-offload-inflight / --cache-gate-ratio to the
  proxy.
- Tests cover SETTINGS knobs + the heavy_threshold-driven P-offload
  penalty.
This commit is contained in:
2026-05-23 21:11:17 +08:00
parent 0701f84c00
commit c843f2e3db
3 changed files with 110 additions and 27 deletions

View File

@@ -35,6 +35,8 @@ HEAVY_THRESHOLD=20000
NO_OFFLOAD=false
OVERLOAD_FACTOR_ARG=""
MAX_BATCHED_TOKENS=""
MAX_OFFLOAD_INFLIGHT=""
CACHE_GATE_RATIO=""
# Parse args
while [[ $# -gt 0 ]]; do
@@ -49,6 +51,8 @@ while [[ $# -gt 0 ]]; do
--no-offload) NO_OFFLOAD=true; shift ;;
--overload-factor) OVERLOAD_FACTOR_ARG="$2"; shift 2 ;;
--max-batched-tokens) MAX_BATCHED_TOKENS="$2"; shift 2 ;;
--max-offload-inflight) MAX_OFFLOAD_INFLIGHT="$2"; shift 2 ;;
--cache-gate-ratio) CACHE_GATE_RATIO="$2"; shift 2 ;;
*) echo "Unknown: $1"; exit 1 ;;
esac
done
@@ -207,6 +211,12 @@ launch_proxy() {
if [ -n "$OVERLOAD_FACTOR_ARG" ]; then
extra_args="$extra_args --overload-factor $OVERLOAD_FACTOR_ARG"
fi
if [ -n "$MAX_OFFLOAD_INFLIGHT" ]; then
extra_args="$extra_args --max-offload-inflight $MAX_OFFLOAD_INFLIGHT"
fi
if [ -n "$CACHE_GATE_RATIO" ]; then
extra_args="$extra_args --cache-gate-ratio $CACHE_GATE_RATIO"
fi
if [ "$MODE" = "elastic" ]; then
local bp_list=""
for i in $(seq 0 $((N_INSTANCES - 1))); do

View File

@@ -20,6 +20,7 @@ import urllib.parse
import uuid
from collections import OrderedDict
from contextlib import asynccontextmanager
from dataclasses import dataclass
import httpx
import uvicorn
@@ -28,12 +29,26 @@ from fastapi.responses import StreamingResponse
BLOCK_SIZE = 512
CACHE_HIT_ALPHA = 1.0
HEAVY_THRESHOLD = 20000 # default; overridden by --heavy-threshold
OVERLOAD_FACTOR = 2.0 # default; overridden by --overload-factor
MAX_OFFLOAD_INFLIGHT = 4 # cap concurrent P-role offloads
PREFILL_THROUGHPUT = 7000 # tokens/s per GPU (from H20 measurements)
RDMA_OVERHEAD_S = 2.0 # seconds of RDMA transfer + decode start overhead
CACHE_CAPACITY_BLOCKS = 200000 # per-instance LRU cap on shadow cached_blocks
@dataclass
class Settings:
"""Runtime-tunable knobs. Populated from argparse in __main__.
All routing/offload code reads from the SETTINGS singleton so that
CLI overrides survive even when the module is imported as a library
(e.g. by tests/) and __main__ does not run.
"""
heavy_threshold: int = 20000 # new-token cutoff for HEAVY classification
overload_factor: float = 2.0 # break session affinity above this * avg load
max_offload_inflight: int = 4 # global cap on concurrent P-role offloads
cache_gate_ratio: float = 0.3 # min cache_hit/input ratio to allow offload
prefill_throughput: float = 7000.0 # tokens/s per GPU (H20 measurement)
rdma_overhead_s: float = 2.0 # RDMA transfer + decode-start overhead
cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks
SETTINGS = Settings()
class InstanceState:
@@ -76,7 +91,7 @@ class InstanceState:
self.cached_blocks.move_to_end(bh)
else:
self.cached_blocks[bh] = None
if len(self.cached_blocks) > CACHE_CAPACITY_BLOCKS:
if len(self.cached_blocks) > SETTINGS.cache_capacity_blocks:
self.cached_blocks.popitem(last=False)
@@ -89,7 +104,7 @@ def _p_offload_penalty(inst: InstanceState) -> int:
"""
if inst.active_p_offloads <= 0:
return 0
return inst.active_p_offloads * HEAVY_THRESHOLD
return inst.active_p_offloads * SETTINGS.heavy_threshold
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
@@ -109,7 +124,7 @@ def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
idx = affinity[session_id]
if idx < len(instances):
inst = instances[idx]
if (inst.ongoing_tokens <= avg_load * OVERLOAD_FACTOR
if (inst.ongoing_tokens <= avg_load * SETTINGS.overload_factor
and inst.active_p_offloads == 0):
return inst, idx
@@ -317,32 +332,44 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
use_offload = False
offload_reason = "offload_disabled"
if estimated_new >= HEAVY_THRESHOLD and offload_enabled:
if estimated_new >= SETTINGS.heavy_threshold and offload_enabled:
cache_ratio = cache_hit / max(input_length, 1)
current_offloads = sum(c.active_p_offloads for c in combined_instances)
# P candidate: least-loaded instance (excluding C_s)
p_candidate = min((c for c in combined_instances if c is not best_inst),
key=lambda c: c.ongoing_tokens)
# P candidate: least-loaded instance excluding C_s, preferring instances
# not already shouldering an active P-role offload.
def _p_pick_score(c: InstanceState) -> int:
return c.ongoing_tokens + c.active_p_offloads * SETTINGS.heavy_threshold
p_candidate = min(
(c for c in combined_instances if c is not best_inst),
key=_p_pick_score,
)
# D candidate: least-loaded excluding both C_s and P
remaining = [c for c in combined_instances if c is not best_inst and c is not p_candidate]
d_candidate = min(remaining, key=lambda c: c.ongoing_tokens) if remaining else p_candidate
# Cost model: compare co-located vs offload expected latency
# Co-located: queue on C_s + prefill new tokens on C_s
cs_queue = best_inst.pending_prefill_tokens / PREFILL_THROUGHPUT
colocated_cost = cs_queue + estimated_new / PREFILL_THROUGHPUT
cs_queue = best_inst.pending_prefill_tokens / SETTINGS.prefill_throughput
colocated_cost = cs_queue + estimated_new / SETTINGS.prefill_throughput
# Offload: prefill on P (may or may not have cache) + RDMA + decode start
p_queue = p_candidate.pending_prefill_tokens / PREFILL_THROUGHPUT
p_queue = p_candidate.pending_prefill_tokens / SETTINGS.prefill_throughput
p_cache_hit = p_candidate.estimate_cache_hit(token_ids) if token_ids else 0
p_new_tokens = max(0, input_length - p_cache_hit)
offload_cost = p_queue + p_new_tokens / PREFILL_THROUGHPUT + RDMA_OVERHEAD_S
offload_cost = p_queue + p_new_tokens / SETTINGS.prefill_throughput + SETTINGS.rdma_overhead_s
breakdown["cache_ratio"] = cache_ratio
breakdown["colocated_cost"] = round(colocated_cost, 2)
breakdown["offload_cost"] = round(offload_cost, 2)
if current_offloads >= MAX_OFFLOAD_INFLIGHT:
# H4 cache-ratio gate: if C_s does not have a meaningful cached prefix,
# offload pays full RDMA without saving prefill compute, so block it.
# Set --cache-gate-ratio 0.0 to disable, 1.0 to never offload.
if cache_ratio < SETTINGS.cache_gate_ratio:
offload_reason = "cache_gate_%.2f<%.2f" % (cache_ratio, SETTINGS.cache_gate_ratio)
elif current_offloads >= SETTINGS.max_offload_inflight:
offload_reason = "cap_reached_%d" % current_offloads
elif offload_cost < colocated_cost:
use_offload = True
@@ -374,7 +401,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
api, req_data, headers, token_ids, input_length,
c_inst, d_inst, cache_hit, estimated_new, breakdown)
else:
if estimated_new >= HEAVY_THRESHOLD:
if estimated_new >= SETTINGS.heavy_threshold:
breakdown["route_class"] = "HEAVY_COLO"
breakdown["offload_reason"] = offload_reason
elif estimated_new < 5000:
@@ -589,6 +616,11 @@ def parse_args():
help="Routing policy: linear (default) or lmetric (P_tokens × BS, OSDI'26)")
p.add_argument("--overload-factor", type=float, default=2.0,
help="Break session affinity when instance load > factor * avg")
p.add_argument("--max-offload-inflight", type=int, default=4,
help="Global cap on concurrent P-role offloads (M3)")
p.add_argument("--cache-gate-ratio", type=float, default=0.3,
help="Min cache_hit/input ratio to allow offload "
"(0.0 disables gate, 1.0 disables offload entirely)")
args = p.parse_args()
args.prefill = []
@@ -606,6 +638,13 @@ def parse_args():
if __name__ == "__main__":
global_args = parse_args()
HEAVY_THRESHOLD = global_args.heavy_threshold
OVERLOAD_FACTOR = global_args.overload_factor
SETTINGS.heavy_threshold = global_args.heavy_threshold
SETTINGS.overload_factor = global_args.overload_factor
SETTINGS.max_offload_inflight = global_args.max_offload_inflight
SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio
print(
"SETTINGS: heavy=%d overload=%.1f max_offload=%d cache_gate=%.2f"
% (SETTINGS.heavy_threshold, SETTINGS.overload_factor,
SETTINGS.max_offload_inflight, SETTINGS.cache_gate_ratio)
)
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@@ -95,8 +95,8 @@ def _make_inst(proxy, url: str, ongoing_tokens: int = 0,
def test_record_prefix_evicts_oldest_block(proxy):
"""LRU bound on cached_blocks must evict the oldest entry once full."""
inst = proxy.InstanceState("http://x")
saved = proxy.CACHE_CAPACITY_BLOCKS
proxy.CACHE_CAPACITY_BLOCKS = 2
saved = proxy.SETTINGS.cache_capacity_blocks
proxy.SETTINGS.cache_capacity_blocks = 2
try:
block_size = proxy.BLOCK_SIZE
# Three distinct one-block prefixes; first must be evicted.
@@ -112,14 +112,14 @@ def test_record_prefix_evicts_oldest_block(proxy):
assert inst.estimate_cache_hit(prefix_b) == block_size
assert inst.estimate_cache_hit(prefix_c) == block_size
finally:
proxy.CACHE_CAPACITY_BLOCKS = saved
proxy.SETTINGS.cache_capacity_blocks = saved
def test_estimate_cache_hit_touches_lru(proxy):
"""A cache hit must move the block to the MRU position."""
inst = proxy.InstanceState("http://x")
saved = proxy.CACHE_CAPACITY_BLOCKS
proxy.CACHE_CAPACITY_BLOCKS = 2
saved = proxy.SETTINGS.cache_capacity_blocks
proxy.SETTINGS.cache_capacity_blocks = 2
try:
block_size = proxy.BLOCK_SIZE
a = [1] * block_size
@@ -134,7 +134,7 @@ def test_estimate_cache_hit_touches_lru(proxy):
assert inst.estimate_cache_hit(a) == block_size
assert inst.estimate_cache_hit(b) == 0
finally:
proxy.CACHE_CAPACITY_BLOCKS = saved
proxy.SETTINGS.cache_capacity_blocks = saved
def test_pick_instance_session_affinity_sticks(proxy):
@@ -178,3 +178,37 @@ def test_pick_instance_lmetric_picks_lowest_score(proxy):
chosen, idx = proxy.pick_instance_lmetric(insts, None, None, 1000, {})
# Empty instance has score = 1000 * 0 = 0; busy one has (5000+1000)*4.
assert idx == 0 and chosen is insts[0]
def test_settings_has_runtime_knobs(proxy):
"""D5/B4/M3: Settings dataclass exposes the previously-hardcoded knobs."""
s = proxy.SETTINGS
for field in (
"heavy_threshold",
"overload_factor",
"max_offload_inflight",
"cache_gate_ratio",
"prefill_throughput",
"rdma_overhead_s",
"cache_capacity_blocks",
):
assert hasattr(s, field), f"SETTINGS missing {field}"
# Runtime mutability matters for tests + __main__ override.
saved = s.cache_gate_ratio
s.cache_gate_ratio = 0.55
assert proxy.SETTINGS.cache_gate_ratio == 0.55
s.cache_gate_ratio = saved
def test_p_offload_penalty_uses_settings_heavy_threshold(proxy):
"""M2: tweaking SETTINGS.heavy_threshold changes the P-offload penalty."""
inst = proxy.InstanceState("http://x")
inst.active_p_offloads = 3
saved = proxy.SETTINGS.heavy_threshold
try:
proxy.SETTINGS.heavy_threshold = 10000
assert proxy._p_offload_penalty(inst) == 30000
proxy.SETTINGS.heavy_threshold = 50000
assert proxy._p_offload_penalty(inst) == 150000
finally:
proxy.SETTINGS.heavy_threshold = saved