proxy: Settings dataclass + cache-ratio gate + P-pick offload penalty (B4, M2, M3, D5)
- Replace mutable module constants (HEAVY_THRESHOLD/OVERLOAD_FACTOR/ MAX_OFFLOAD_INFLIGHT/PREFILL_THROUGHPUT/RDMA_OVERHEAD_S/ CACHE_CAPACITY_BLOCKS) with a Settings dataclass + SETTINGS singleton. __main__ now mutates SETTINGS so CLI overrides survive even when the module is imported as a library (e.g. by tests/) (D5). - Add --max-offload-inflight CLI flag (M3) and read it from SETTINGS. - Add --cache-gate-ratio CLI flag and a real gate before the cost-model branch: if cache_hit/input_length < ratio, mark cache_gate_REASON and fall back to colocated. cache_ratio is no longer a write-only field (B4). - P candidate selection penalises instances already running offloaded HEAVY prefills, so back-to-back HEAVY requests don't pile onto the same P (M2). - bench.sh forwards --max-offload-inflight / --cache-gate-ratio to the proxy. - Tests cover SETTINGS knobs + the heavy_threshold-driven P-offload penalty.
This commit is contained in:
@@ -35,6 +35,8 @@ HEAVY_THRESHOLD=20000
|
||||
NO_OFFLOAD=false
|
||||
OVERLOAD_FACTOR_ARG=""
|
||||
MAX_BATCHED_TOKENS=""
|
||||
MAX_OFFLOAD_INFLIGHT=""
|
||||
CACHE_GATE_RATIO=""
|
||||
|
||||
# Parse args
|
||||
while [[ $# -gt 0 ]]; do
|
||||
@@ -49,6 +51,8 @@ while [[ $# -gt 0 ]]; do
|
||||
--no-offload) NO_OFFLOAD=true; shift ;;
|
||||
--overload-factor) OVERLOAD_FACTOR_ARG="$2"; shift 2 ;;
|
||||
--max-batched-tokens) MAX_BATCHED_TOKENS="$2"; shift 2 ;;
|
||||
--max-offload-inflight) MAX_OFFLOAD_INFLIGHT="$2"; shift 2 ;;
|
||||
--cache-gate-ratio) CACHE_GATE_RATIO="$2"; shift 2 ;;
|
||||
*) echo "Unknown: $1"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
@@ -207,6 +211,12 @@ launch_proxy() {
|
||||
if [ -n "$OVERLOAD_FACTOR_ARG" ]; then
|
||||
extra_args="$extra_args --overload-factor $OVERLOAD_FACTOR_ARG"
|
||||
fi
|
||||
if [ -n "$MAX_OFFLOAD_INFLIGHT" ]; then
|
||||
extra_args="$extra_args --max-offload-inflight $MAX_OFFLOAD_INFLIGHT"
|
||||
fi
|
||||
if [ -n "$CACHE_GATE_RATIO" ]; then
|
||||
extra_args="$extra_args --cache-gate-ratio $CACHE_GATE_RATIO"
|
||||
fi
|
||||
if [ "$MODE" = "elastic" ]; then
|
||||
local bp_list=""
|
||||
for i in $(seq 0 $((N_INSTANCES - 1))); do
|
||||
|
||||
@@ -20,6 +20,7 @@ import urllib.parse
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
@@ -28,12 +29,26 @@ from fastapi.responses import StreamingResponse
|
||||
|
||||
BLOCK_SIZE = 512
|
||||
CACHE_HIT_ALPHA = 1.0
|
||||
HEAVY_THRESHOLD = 20000 # default; overridden by --heavy-threshold
|
||||
OVERLOAD_FACTOR = 2.0 # default; overridden by --overload-factor
|
||||
MAX_OFFLOAD_INFLIGHT = 4 # cap concurrent P-role offloads
|
||||
PREFILL_THROUGHPUT = 7000 # tokens/s per GPU (from H20 measurements)
|
||||
RDMA_OVERHEAD_S = 2.0 # seconds of RDMA transfer + decode start overhead
|
||||
CACHE_CAPACITY_BLOCKS = 200000 # per-instance LRU cap on shadow cached_blocks
|
||||
|
||||
|
||||
@dataclass
|
||||
class Settings:
|
||||
"""Runtime-tunable knobs. Populated from argparse in __main__.
|
||||
|
||||
All routing/offload code reads from the SETTINGS singleton so that
|
||||
CLI overrides survive even when the module is imported as a library
|
||||
(e.g. by tests/) and __main__ does not run.
|
||||
"""
|
||||
heavy_threshold: int = 20000 # new-token cutoff for HEAVY classification
|
||||
overload_factor: float = 2.0 # break session affinity above this * avg load
|
||||
max_offload_inflight: int = 4 # global cap on concurrent P-role offloads
|
||||
cache_gate_ratio: float = 0.3 # min cache_hit/input ratio to allow offload
|
||||
prefill_throughput: float = 7000.0 # tokens/s per GPU (H20 measurement)
|
||||
rdma_overhead_s: float = 2.0 # RDMA transfer + decode-start overhead
|
||||
cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks
|
||||
|
||||
|
||||
SETTINGS = Settings()
|
||||
|
||||
|
||||
class InstanceState:
|
||||
@@ -76,7 +91,7 @@ class InstanceState:
|
||||
self.cached_blocks.move_to_end(bh)
|
||||
else:
|
||||
self.cached_blocks[bh] = None
|
||||
if len(self.cached_blocks) > CACHE_CAPACITY_BLOCKS:
|
||||
if len(self.cached_blocks) > SETTINGS.cache_capacity_blocks:
|
||||
self.cached_blocks.popitem(last=False)
|
||||
|
||||
|
||||
@@ -89,7 +104,7 @@ def _p_offload_penalty(inst: InstanceState) -> int:
|
||||
"""
|
||||
if inst.active_p_offloads <= 0:
|
||||
return 0
|
||||
return inst.active_p_offloads * HEAVY_THRESHOLD
|
||||
return inst.active_p_offloads * SETTINGS.heavy_threshold
|
||||
|
||||
|
||||
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
|
||||
@@ -109,7 +124,7 @@ def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
|
||||
idx = affinity[session_id]
|
||||
if idx < len(instances):
|
||||
inst = instances[idx]
|
||||
if (inst.ongoing_tokens <= avg_load * OVERLOAD_FACTOR
|
||||
if (inst.ongoing_tokens <= avg_load * SETTINGS.overload_factor
|
||||
and inst.active_p_offloads == 0):
|
||||
return inst, idx
|
||||
|
||||
@@ -317,32 +332,44 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
use_offload = False
|
||||
offload_reason = "offload_disabled"
|
||||
|
||||
if estimated_new >= HEAVY_THRESHOLD and offload_enabled:
|
||||
if estimated_new >= SETTINGS.heavy_threshold and offload_enabled:
|
||||
cache_ratio = cache_hit / max(input_length, 1)
|
||||
current_offloads = sum(c.active_p_offloads for c in combined_instances)
|
||||
# P candidate: least-loaded instance (excluding C_s)
|
||||
p_candidate = min((c for c in combined_instances if c is not best_inst),
|
||||
key=lambda c: c.ongoing_tokens)
|
||||
|
||||
# P candidate: least-loaded instance excluding C_s, preferring instances
|
||||
# not already shouldering an active P-role offload.
|
||||
def _p_pick_score(c: InstanceState) -> int:
|
||||
return c.ongoing_tokens + c.active_p_offloads * SETTINGS.heavy_threshold
|
||||
|
||||
p_candidate = min(
|
||||
(c for c in combined_instances if c is not best_inst),
|
||||
key=_p_pick_score,
|
||||
)
|
||||
# D candidate: least-loaded excluding both C_s and P
|
||||
remaining = [c for c in combined_instances if c is not best_inst and c is not p_candidate]
|
||||
d_candidate = min(remaining, key=lambda c: c.ongoing_tokens) if remaining else p_candidate
|
||||
|
||||
# Cost model: compare co-located vs offload expected latency
|
||||
# Co-located: queue on C_s + prefill new tokens on C_s
|
||||
cs_queue = best_inst.pending_prefill_tokens / PREFILL_THROUGHPUT
|
||||
colocated_cost = cs_queue + estimated_new / PREFILL_THROUGHPUT
|
||||
cs_queue = best_inst.pending_prefill_tokens / SETTINGS.prefill_throughput
|
||||
colocated_cost = cs_queue + estimated_new / SETTINGS.prefill_throughput
|
||||
|
||||
# Offload: prefill on P (may or may not have cache) + RDMA + decode start
|
||||
p_queue = p_candidate.pending_prefill_tokens / PREFILL_THROUGHPUT
|
||||
p_queue = p_candidate.pending_prefill_tokens / SETTINGS.prefill_throughput
|
||||
p_cache_hit = p_candidate.estimate_cache_hit(token_ids) if token_ids else 0
|
||||
p_new_tokens = max(0, input_length - p_cache_hit)
|
||||
offload_cost = p_queue + p_new_tokens / PREFILL_THROUGHPUT + RDMA_OVERHEAD_S
|
||||
offload_cost = p_queue + p_new_tokens / SETTINGS.prefill_throughput + SETTINGS.rdma_overhead_s
|
||||
|
||||
breakdown["cache_ratio"] = cache_ratio
|
||||
breakdown["colocated_cost"] = round(colocated_cost, 2)
|
||||
breakdown["offload_cost"] = round(offload_cost, 2)
|
||||
|
||||
if current_offloads >= MAX_OFFLOAD_INFLIGHT:
|
||||
# H4 cache-ratio gate: if C_s does not have a meaningful cached prefix,
|
||||
# offload pays full RDMA without saving prefill compute, so block it.
|
||||
# Set --cache-gate-ratio 0.0 to disable, 1.0 to never offload.
|
||||
if cache_ratio < SETTINGS.cache_gate_ratio:
|
||||
offload_reason = "cache_gate_%.2f<%.2f" % (cache_ratio, SETTINGS.cache_gate_ratio)
|
||||
elif current_offloads >= SETTINGS.max_offload_inflight:
|
||||
offload_reason = "cap_reached_%d" % current_offloads
|
||||
elif offload_cost < colocated_cost:
|
||||
use_offload = True
|
||||
@@ -374,7 +401,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
c_inst, d_inst, cache_hit, estimated_new, breakdown)
|
||||
else:
|
||||
if estimated_new >= HEAVY_THRESHOLD:
|
||||
if estimated_new >= SETTINGS.heavy_threshold:
|
||||
breakdown["route_class"] = "HEAVY_COLO"
|
||||
breakdown["offload_reason"] = offload_reason
|
||||
elif estimated_new < 5000:
|
||||
@@ -589,6 +616,11 @@ def parse_args():
|
||||
help="Routing policy: linear (default) or lmetric (P_tokens × BS, OSDI'26)")
|
||||
p.add_argument("--overload-factor", type=float, default=2.0,
|
||||
help="Break session affinity when instance load > factor * avg")
|
||||
p.add_argument("--max-offload-inflight", type=int, default=4,
|
||||
help="Global cap on concurrent P-role offloads (M3)")
|
||||
p.add_argument("--cache-gate-ratio", type=float, default=0.3,
|
||||
help="Min cache_hit/input ratio to allow offload "
|
||||
"(0.0 disables gate, 1.0 disables offload entirely)")
|
||||
args = p.parse_args()
|
||||
|
||||
args.prefill = []
|
||||
@@ -606,6 +638,13 @@ def parse_args():
|
||||
|
||||
if __name__ == "__main__":
|
||||
global_args = parse_args()
|
||||
HEAVY_THRESHOLD = global_args.heavy_threshold
|
||||
OVERLOAD_FACTOR = global_args.overload_factor
|
||||
SETTINGS.heavy_threshold = global_args.heavy_threshold
|
||||
SETTINGS.overload_factor = global_args.overload_factor
|
||||
SETTINGS.max_offload_inflight = global_args.max_offload_inflight
|
||||
SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio
|
||||
print(
|
||||
"SETTINGS: heavy=%d overload=%.1f max_offload=%d cache_gate=%.2f"
|
||||
% (SETTINGS.heavy_threshold, SETTINGS.overload_factor,
|
||||
SETTINGS.max_offload_inflight, SETTINGS.cache_gate_ratio)
|
||||
)
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||
|
||||
@@ -95,8 +95,8 @@ def _make_inst(proxy, url: str, ongoing_tokens: int = 0,
|
||||
def test_record_prefix_evicts_oldest_block(proxy):
|
||||
"""LRU bound on cached_blocks must evict the oldest entry once full."""
|
||||
inst = proxy.InstanceState("http://x")
|
||||
saved = proxy.CACHE_CAPACITY_BLOCKS
|
||||
proxy.CACHE_CAPACITY_BLOCKS = 2
|
||||
saved = proxy.SETTINGS.cache_capacity_blocks
|
||||
proxy.SETTINGS.cache_capacity_blocks = 2
|
||||
try:
|
||||
block_size = proxy.BLOCK_SIZE
|
||||
# Three distinct one-block prefixes; first must be evicted.
|
||||
@@ -112,14 +112,14 @@ def test_record_prefix_evicts_oldest_block(proxy):
|
||||
assert inst.estimate_cache_hit(prefix_b) == block_size
|
||||
assert inst.estimate_cache_hit(prefix_c) == block_size
|
||||
finally:
|
||||
proxy.CACHE_CAPACITY_BLOCKS = saved
|
||||
proxy.SETTINGS.cache_capacity_blocks = saved
|
||||
|
||||
|
||||
def test_estimate_cache_hit_touches_lru(proxy):
|
||||
"""A cache hit must move the block to the MRU position."""
|
||||
inst = proxy.InstanceState("http://x")
|
||||
saved = proxy.CACHE_CAPACITY_BLOCKS
|
||||
proxy.CACHE_CAPACITY_BLOCKS = 2
|
||||
saved = proxy.SETTINGS.cache_capacity_blocks
|
||||
proxy.SETTINGS.cache_capacity_blocks = 2
|
||||
try:
|
||||
block_size = proxy.BLOCK_SIZE
|
||||
a = [1] * block_size
|
||||
@@ -134,7 +134,7 @@ def test_estimate_cache_hit_touches_lru(proxy):
|
||||
assert inst.estimate_cache_hit(a) == block_size
|
||||
assert inst.estimate_cache_hit(b) == 0
|
||||
finally:
|
||||
proxy.CACHE_CAPACITY_BLOCKS = saved
|
||||
proxy.SETTINGS.cache_capacity_blocks = saved
|
||||
|
||||
|
||||
def test_pick_instance_session_affinity_sticks(proxy):
|
||||
@@ -178,3 +178,37 @@ def test_pick_instance_lmetric_picks_lowest_score(proxy):
|
||||
chosen, idx = proxy.pick_instance_lmetric(insts, None, None, 1000, {})
|
||||
# Empty instance has score = 1000 * 0 = 0; busy one has (5000+1000)*4.
|
||||
assert idx == 0 and chosen is insts[0]
|
||||
|
||||
|
||||
def test_settings_has_runtime_knobs(proxy):
|
||||
"""D5/B4/M3: Settings dataclass exposes the previously-hardcoded knobs."""
|
||||
s = proxy.SETTINGS
|
||||
for field in (
|
||||
"heavy_threshold",
|
||||
"overload_factor",
|
||||
"max_offload_inflight",
|
||||
"cache_gate_ratio",
|
||||
"prefill_throughput",
|
||||
"rdma_overhead_s",
|
||||
"cache_capacity_blocks",
|
||||
):
|
||||
assert hasattr(s, field), f"SETTINGS missing {field}"
|
||||
# Runtime mutability matters for tests + __main__ override.
|
||||
saved = s.cache_gate_ratio
|
||||
s.cache_gate_ratio = 0.55
|
||||
assert proxy.SETTINGS.cache_gate_ratio == 0.55
|
||||
s.cache_gate_ratio = saved
|
||||
|
||||
|
||||
def test_p_offload_penalty_uses_settings_heavy_threshold(proxy):
|
||||
"""M2: tweaking SETTINGS.heavy_threshold changes the P-offload penalty."""
|
||||
inst = proxy.InstanceState("http://x")
|
||||
inst.active_p_offloads = 3
|
||||
saved = proxy.SETTINGS.heavy_threshold
|
||||
try:
|
||||
proxy.SETTINGS.heavy_threshold = 10000
|
||||
assert proxy._p_offload_penalty(inst) == 30000
|
||||
proxy.SETTINGS.heavy_threshold = 50000
|
||||
assert proxy._p_offload_penalty(inst) == 150000
|
||||
finally:
|
||||
proxy.SETTINGS.heavy_threshold = saved
|
||||
|
||||
Reference in New Issue
Block a user