Cleanup: retire dead PUSH path + extract hybrid picker

- Delete unreachable best_needs_push block in _handle_combined and the
  four orphaned helpers (_handle_cached_prefill_offload,
  _handle_direct_read_offload, _query_bootstrap_hit,
  _get_bootstrap_client). Their only caller was the retired PUSH gate;
  see REPORT §3.9 errata for the rejected experiments (cc6e562, 4c583f2).

- Extract pick_instance_unified_hybrid as a pure function returning
  (chosen, idx, decision_dict). The decision dict carries the review #7
  breakdown fields (decision, affinity_idx/chosen_idx, cache_hit/ratio,
  avg_num_requests, fallback_score, tie_break_used).

- Add LMetric-fallback tie-breaker (primary score, then new_uncached,
  num_requests, round-robin) so new sessions don't all pin to inst 0
  when BS=0 across the board.

- Drop the lmetric-policy affinity write so --policy lmetric stays
  affinity-free per review #3.

- Mark --max-offload-inflight / --offload-mode / --cache-gate-ratio /
  --decode-iteration-s as [DEPRECATED] in --help; flags remain accepted
  so scripts/bench.sh and legacy launchers don't break.

- Revert uncommitted overload_factor 2.0->1.5 default; H7 sweep already
  rejected this knob (within noise). Future sweeps should go via CLI.

Tests: add 6 hybrid-policy tests in tests/test_proxy_pick.py covering
affinity-hit, overload break, low-cache fallback, tie-break rotation,
lmetric purity, and breakdown field shape. 19/19 pass.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-25 10:46:57 +08:00
parent 255c8e6884
commit ac6534c3ff
2 changed files with 220 additions and 332 deletions

View File

@@ -51,7 +51,6 @@ class Settings:
max_offload_inflight: int = 4
cache_gate_ratio: float = 0.0
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
migration_discount_cap: int = 5 # max turns to discount
SETTINGS = Settings()
@@ -148,7 +147,9 @@ def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] |
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""LMetric routing: score = P_tokens × BS (OSDI'26).
Pure per-request load-based routing, no session affinity.
Pure per-request load-based routing, no session affinity (the
session_id/affinity args are accepted for signature compatibility
with pick_instance/pick_instance_unified_hybrid but ignored).
P = pending_prefill_tokens + (input_length - cache_hit)
BS = num_requests (current batch size)
"""
@@ -166,42 +167,85 @@ def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] |
return instances[best_idx], best_idx
_bootstrap_client: httpx.AsyncClient | None = None
BOOTSTRAP_TIMEOUT_S = 1.0 # timeout for /estimate_hit calls
_unified_fallback_rr_counter = 0
async def _get_bootstrap_client() -> httpx.AsyncClient:
global _bootstrap_client
if _bootstrap_client is None:
_bootstrap_client = httpx.AsyncClient(
timeout=httpx.Timeout(BOOTSTRAP_TIMEOUT_S),
limits=httpx.Limits(max_connections=32, max_keepalive_connections=16),
)
return _bootstrap_client
def pick_instance_unified_hybrid(
instances: list[InstanceState],
token_ids: list[int] | None,
session_id: str | None,
input_length: int,
affinity: dict[str, int],
) -> tuple[InstanceState, int, dict]:
"""Hybrid routing: high-cache affinity, else LMetric with tie-breaker.
Affinity gate (both must hold to stick):
- affinity instance cache_hit / input_length > 0.5
- affinity.num_requests <= avg_num_requests * SETTINGS.overload_factor
async def _query_bootstrap_hit(
inst: InstanceState, token_ids: list[int],
) -> int | None:
"""Query bootstrap's /estimate_hit for real cache hit count.
Fallback ordering (when affinity not used):
primary: score = P_tokens * BS (LMetric)
secondary: new_uncached_tokens (prefer instance with most cache)
tertiary: num_requests (prefer least-loaded)
quaternary: round-robin (avoid degenerate inst-0 pinning
when BS=0 across the board)
Returns hit_tokens on success, None on failure (caller should fallback).
Returns (chosen, idx, decision_dict). decision_dict carries the
review #7 breakdown fields so the caller can merge them verbatim.
"""
if inst.bootstrap_port is None:
return None
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/estimate_hit"
try:
client = await _get_bootstrap_client()
resp = await client.post(url, json={
"token_ids": token_ids,
"block_size": BLOCK_SIZE,
})
resp.raise_for_status()
return resp.json()["hit_tokens"]
except Exception:
return None
global _unified_fallback_rr_counter
n = len(instances)
avg_reqs = max(sum(i.num_requests for i in instances) / n, 1.0)
decision: dict = {
"decision": "lmetric_fallback",
"affinity_idx": None,
"chosen_idx": None,
"affinity_cache_hit": None,
"affinity_cache_ratio": None,
"affinity_num_requests": None,
"avg_num_requests": avg_reqs,
"fallback_score": None,
"tie_break_used": False,
}
if session_id and session_id in affinity:
a_idx = affinity[session_id]
if a_idx < n:
a_inst = instances[a_idx]
a_hit = a_inst.estimate_cache_hit(token_ids)
a_ratio = a_hit / max(input_length, 1)
decision["affinity_idx"] = a_idx
decision["affinity_cache_hit"] = a_hit
decision["affinity_cache_ratio"] = a_ratio
decision["affinity_num_requests"] = a_inst.num_requests
if (a_ratio > 0.5
and a_inst.num_requests <= avg_reqs * SETTINGS.overload_factor):
decision["decision"] = "affinity"
decision["chosen_idx"] = a_idx
return a_inst, a_idx, decision
keys: list[tuple[int, int, int, int]] = []
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
new_prefill = max(0, input_length - cache_hit)
p_tokens = inst.pending_prefill_tokens + new_prefill
bs = inst.num_requests
score = p_tokens * bs
keys.append((score, new_prefill, bs, i))
best_triple = min(k[:3] for k in keys)
tied = [k for k in keys if k[:3] == best_triple]
if len(tied) > 1:
decision["tie_break_used"] = True
_unified_fallback_rr_counter += 1
winner = tied[_unified_fallback_rr_counter % len(tied)]
else:
winner = tied[0]
chosen_idx = winner[3]
decision["fallback_score"] = winner[0]
decision["chosen_idx"] = chosen_idx
return instances[chosen_idx], chosen_idx, decision
def _extract_output_token_ids_from_sse(
@@ -379,8 +423,6 @@ async def lifespan(app: FastAPI):
await reconcile_task
except asyncio.CancelledError:
pass
if _bootstrap_client is not None:
await _bootstrap_client.aclose()
for inst in combined_instances + prefill_instances + decode_instances:
await inst.client.aclose()
@@ -471,314 +513,56 @@ async def _handle_local_request(api, req_data, headers, token_ids, input_length,
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
"""Unified routing: pick the instance with lowest expected latency.
"""Route a /v1/* request among combined (PD-colocated) instances.
For each instance, estimate:
latency = queue_time + prefill_time + transfer_cost
where prefill_time depends on whether the instance has cache (local),
can receive cache via PUSH (remote), or must do cold prefill.
--policy options:
linear: cache_hit-aware load score + sticky session affinity.
lmetric: P_tokens * BS (LMetric, OSDI'26). No session affinity.
unified: hybrid — stick to affinity instance when cache_ratio > 0.5
and it is not overloaded; otherwise fall back to LMetric
with a multi-key tie-breaker.
PD-sep offload / PUSH migration is retired (see REPORT.md §3.9 and
commits 4c583f2 / cc6e562: relaxed-gate and forced-migration variants
both regressed E2E tail). Re-enabling requires a new transfer mechanism.
"""
policy = getattr(global_args, 'policy', 'linear')
offload_enabled = getattr(global_args, 'offload', False) and len(combined_instances) >= 2
throughput = SETTINGS.prefill_throughput
breakdown: dict = {
"request_id": headers.get("X-Request-Id", ""),
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
"policy": policy,
}
if policy in ("linear", "lmetric"):
if policy == "lmetric":
chosen, best_idx = pick_instance_lmetric(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
else:
chosen, best_idx = pick_instance(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
cache_hit = chosen.estimate_cache_hit(token_ids)
estimated_new = max(0, input_length - cache_hit)
breakdown = {
"request_id": headers.get("X-Request-Id", ""),
"input_length": input_length,
"cache_hit": cache_hit,
"estimated_new_tokens": estimated_new,
"t_proxy_recv": _time.monotonic(),
"policy": policy,
"route_class": "LOCAL",
"routed_to": chosen.url,
}
if session_id and policy == "lmetric":
# LMetric is intentionally per-request; record last target only for
# stats/debugging, not for future decisions.
if policy == "lmetric":
chosen, best_idx = pick_instance_lmetric(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
elif policy == "unified":
chosen, best_idx, decision = pick_instance_unified_hybrid(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
breakdown.update(decision)
if session_id:
session_affinity_combined[session_id] = best_idx
return await _handle_local_request(
api, req_data, headers, token_ids, input_length,
chosen, estimated_new, breakdown)
# Hybrid routing: LMetric for load balance + explicit affinity for high-cache sessions
#
# 1. If session has high cache on affinity instance AND instance not overloaded → stick
# 2. Otherwise → LMetric (P × BS) for best load balance
affinity_idx = session_affinity_combined.get(session_id) if session_id else None
use_affinity = False
if affinity_idx is not None and affinity_idx < len(combined_instances):
affinity_inst = combined_instances[affinity_idx]
affinity_cache = affinity_inst.estimate_cache_hit(token_ids)
cache_ratio = affinity_cache / max(input_length, 1)
avg_reqs = max(sum(i.num_requests for i in combined_instances) / len(combined_instances), 1.0)
if (cache_ratio > 0.5
and affinity_inst.num_requests <= avg_reqs * SETTINGS.overload_factor):
use_affinity = True
best_idx = affinity_idx
if not use_affinity:
_, best_idx = pick_instance_lmetric(
else: # linear (default)
chosen, best_idx = pick_instance(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
best_needs_push = False
chosen = combined_instances[best_idx]
cache_hit = chosen.estimate_cache_hit(token_ids)
estimated_new = max(0, input_length - cache_hit)
breakdown = {
"request_id": headers.get("X-Request-Id", ""),
"input_length": input_length,
breakdown.update({
"cache_hit": cache_hit,
"estimated_new_tokens": estimated_new,
"t_proxy_recv": _time.monotonic(),
"policy": "affinity" if use_affinity else "lmetric",
}
if session_id:
session_affinity_combined[session_id] = best_idx
if best_needs_push:
c_inst = combined_instances[best_cache_idx]
d_inst = chosen
# Query real cache hit from bootstrap (shadow cache is inaccurate)
real_hit = await _query_bootstrap_hit(c_inst, token_ids)
breakdown["shadow_cache_hit"] = best_cache_hit
breakdown["real_cache_hit"] = real_hit
if real_hit is not None:
push_cache_hit = real_hit
else:
push_cache_hit = best_cache_hit # fallback to shadow estimate
# If real hit > 0, proceed with offload
if push_cache_hit > 0:
push_new = max(0, input_length - push_cache_hit)
cache_ratio = push_cache_hit / max(input_length, 1)
if _current_offloads() >= SETTINGS.max_offload_inflight:
breakdown["push_downgraded"] = "cap_reached"
return await _handle_local_request(
api, req_data, headers, token_ids, input_length,
chosen, estimated_new, breakdown)
if push_new < SETTINGS.heavy_threshold:
breakdown["push_downgraded"] = "below_heavy_threshold"
return await _handle_local_request(
api, req_data, headers, token_ids, input_length,
chosen, estimated_new, breakdown)
if SETTINGS.cache_gate_ratio > 0 and cache_ratio < SETTINGS.cache_gate_ratio:
breakdown["push_downgraded"] = "cache_gate"
return await _handle_local_request(
api, req_data, headers, token_ids, input_length,
chosen, estimated_new, breakdown)
offload_mode = getattr(global_args, 'offload_mode', 'cached_prefill')
breakdown["c_inst"] = c_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["push_cache_hit"] = push_cache_hit
if offload_mode == "cached_prefill":
c_inst.ongoing_tokens += input_length
c_inst.pending_prefill_tokens += push_new
c_inst.num_requests += 1
c_inst.active_p_offloads += 1
breakdown["route_class"] = "CACHED_PREFILL_OFFLOAD"
return await _handle_cached_prefill_offload(
api, req_data, headers, token_ids, input_length,
c_inst, d_inst, push_cache_hit, push_new, breakdown)
else:
d_inst.ongoing_tokens += input_length
d_inst.pending_prefill_tokens += push_new
d_inst.num_requests += 1
c_inst.active_p_offloads += 1
breakdown["route_class"] = "PUSH_MIGRATE"
return await _handle_direct_read_offload(
api, req_data, headers, token_ids, input_length,
c_inst, d_inst, push_cache_hit, push_new, breakdown)
# Real hit is 0 — downgrade to LOCAL
breakdown["push_downgraded"] = True
# LOCAL path (also handles downgraded PUSH)
breakdown["route_class"] = "LOCAL"
breakdown["routed_to"] = chosen.url
"route_class": "LOCAL",
"routed_to": chosen.url,
})
return await _handle_local_request(
api, req_data, headers, token_ids, input_length,
chosen, estimated_new, breakdown)
PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
async def _handle_cached_prefill_offload(api, req_data, headers, token_ids,
input_length, c_inst, d_inst,
cache_hit, estimated_new, breakdown):
"""C does fast cached prefill → KV to Mooncake → D pulls KV and decodes.
Unlike direct_read (D pulls blocks from C), here C's scheduler IS
involved: C prefills (fast, because prefix is cached), pushes KV to
Mooncake store, then D pulls and decodes. This avoids the broken
PUSH path where D waits for RDMA transfer while occupying KV blocks.
"""
request_id = headers.get("X-Request-Id", "")
# Step 1: send blocking prefill to C
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data["min_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
breakdown["t_prefill_sent"] = _time.monotonic()
try:
resp = await c_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
c_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
c_inst.ongoing_tokens -= input_length
c_inst.pending_prefill_tokens -= estimated_new
c_inst.num_requests -= 1
raise HTTPException(status_code=502, detail=f"Prefill on C failed: {e}")
c_inst.ongoing_tokens -= input_length
c_inst.pending_prefill_tokens -= estimated_new
c_inst.num_requests -= 1
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
# Step 2: send decode to D (pull KV from C via Mooncake)
d_inst.ongoing_tokens += input_length
d_inst.num_requests += 1
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{c_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": c_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
async def generate():
first_token = True
sse_buffer = ""
output_token_ids: list[int] = []
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
sse_buffer, chunk)
output_token_ids.extend(new_output_ids)
if first_token:
d_inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids))
finally:
if not first_token:
d_inst.ongoing_decode_tokens -= input_length
d_inst.ongoing_tokens -= input_length
d_inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_direct_read_offload(api, req_data, headers, token_ids,
input_length, c_inst, d_inst,
cache_hit, estimated_new, breakdown):
"""HEAVY request: D direct-RDMA-reads cached KV from C_s, then does
local prefill for new tokens + decode. C_s's scheduler is NOT involved.
"""
request_id = headers.get("X-Request-Id", "")
# Align cache_hit to block boundary for remote_num_tokens
cached_tokens = (cache_hit // BLOCK_SIZE) * BLOCK_SIZE
breakdown["t_offload_sent"] = _time.monotonic()
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
bootstrap_addr = "http://%s:%s" % (parsed.hostname, c_inst.bootstrap_port)
# Send full prompt to D with direct_read flag
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"direct_read": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": c_inst.engine_id.get(0, ""),
"transfer_id": "xfer-" + request_id,
"remote_num_tokens": cached_tokens,
}
async def generate():
first_token = True
sse_buffer = ""
output_token_ids: list[int] = []
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
sse_buffer, chunk)
output_token_ids.extend(new_output_ids)
if first_token:
d_inst.pending_prefill_tokens -= estimated_new
d_inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids))
finally:
if first_token:
d_inst.pending_prefill_tokens -= estimated_new
else:
d_inst.ongoing_decode_tokens -= input_length
d_inst.ongoing_tokens -= input_length
d_inst.num_requests -= 1
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers):
"""PD-Sep mode with per-stage breakdown profiling."""
@@ -901,20 +685,23 @@ def parse_args():
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
p.add_argument("--policy", type=str, default="linear",
choices=["linear", "lmetric", "unified"],
help="Routing policy: linear, lmetric (P_tokens × BS), or unified cost model")
help="Routing policy: linear (cache-aware), lmetric (P_tokens × BS), "
"or unified (hybrid affinity + LMetric fallback)")
p.add_argument("--overload-factor", type=float, default=2.0,
help="Break session affinity when instance load > factor * avg")
# The four flags below are accepted for bench.sh backward compatibility but
# have no effect after the PD-sep offload path was retired (REPORT §3.9,
# commits 4c583f2 / cc6e562). Removing them would break scripts/bench.sh and
# scripts/legacy/*.sh which still pass them through.
p.add_argument("--max-offload-inflight", type=int, default=4,
help="Global cap on concurrent P-role offloads (M3)")
help="[DEPRECATED] PUSH offload retired; no effect")
p.add_argument("--offload-mode", type=str, default="cached_prefill",
choices=["direct_read", "cached_prefill"],
help="direct_read: D pulls KV from C (PUSH). "
"cached_prefill: C prefills then D decodes (PD-sep style).")
help="[DEPRECATED] PUSH offload retired; no effect")
p.add_argument("--cache-gate-ratio", type=float, default=0.0,
help="Min cache_hit/input ratio to allow offload "
"(0.0 disables gate, 1.0 disables offload entirely)")
help="[DEPRECATED] PUSH offload retired; no effect")
p.add_argument("--decode-iteration-s", type=float, default=0.05,
help="Estimated per-request decode iteration time in seconds")
help="[DEPRECATED] PUSH offload retired; no effect")
args = p.parse_args()
args.prefill = []

View File

@@ -180,6 +180,107 @@ def test_pick_instance_lmetric_picks_lowest_score(proxy):
assert idx == 0 and chosen is insts[0]
def test_pick_instance_lmetric_ignores_session_affinity(proxy):
"""Review #3: pure --policy lmetric must remain affinity-free."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
# Make inst[1] look much busier than inst[0]; LMetric must still pick 0
# even though affinity points at 1.
insts[0].pending_prefill_tokens = 0
insts[0].num_requests = 0
insts[1].pending_prefill_tokens = 5000
insts[1].num_requests = 4
affinity = {"sess1": 1}
chosen, idx = proxy.pick_instance_lmetric(insts, None, "sess1", 1000, affinity)
assert idx == 0
# Picker must not mutate the affinity dict either.
assert affinity == {"sess1": 1}
def _record_n_blocks(proxy, inst, n: int) -> list[int]:
"""Record n distinct one-block prefixes on inst; return token_ids covering them."""
block_size = proxy.BLOCK_SIZE
tokens: list[int] = []
for b in range(n):
tokens.extend([1000 + b] * block_size)
inst.record_prefix(tokens)
return tokens
def test_hybrid_high_cache_session_sticks_to_affinity(proxy):
"""Hybrid: affinity instance with cache_ratio > 0.5 and no overload → stick."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
tokens = _record_n_blocks(proxy, insts[1], 2) # 2 blocks cached on inst[1]
affinity = {"sess1": 1}
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
insts, tokens, "sess1", len(tokens), affinity)
assert idx == 1 and chosen is insts[1]
assert decision["decision"] == "affinity"
assert decision["affinity_idx"] == 1
assert decision["chosen_idx"] == 1
assert decision["affinity_cache_ratio"] > 0.5
assert decision["tie_break_used"] is False
def test_hybrid_high_cache_breaks_on_overload(proxy):
"""Hybrid: affinity num_requests > avg * overload_factor → fall back to LMetric,
and with realistic new-token tail the LMetric fallback steers off the hot instance."""
insts = [
_make_inst(proxy, "http://a"),
_make_inst(proxy, "http://b"),
_make_inst(proxy, "http://c"),
]
cached = _record_n_blocks(proxy, insts[1], 2)
# Append one more uncached block so LMetric sees a real prefill cost on the
# cached instance too (BS multiplier becomes visible). Without this, the
# cached instance scores 0 * BS = 0 regardless of how loaded it is.
tokens = cached + [999_999] * proxy.BLOCK_SIZE
insts[1].num_requests = 300 # avg = 100; 300 > 100 * 2.0 ✓ breaks the gate
affinity = {"sess1": 1}
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
insts, tokens, "sess1", len(tokens), affinity)
assert decision["decision"] == "lmetric_fallback"
assert decision["affinity_idx"] == 1
assert idx != 1, "affinity instance is overloaded; fallback should steer away"
def test_hybrid_low_cache_falls_back(proxy):
"""Hybrid: cache_ratio <= 0.5 on affinity → fall back to LMetric."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
tokens = [1] * (proxy.BLOCK_SIZE * 2) # 1024 tokens, nothing cached anywhere
affinity = {"sess1": 1}
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
insts, tokens, "sess1", len(tokens), affinity)
assert decision["decision"] == "lmetric_fallback"
assert decision["affinity_cache_ratio"] == 0.0
def test_hybrid_new_session_tie_break_does_not_always_pick_index_0(proxy):
"""Review #4: when all instances tie (e.g. BS=0), tie-break must rotate."""
insts = [_make_inst(proxy, "http://a") for _ in range(3)]
seen = set()
for _ in range(12):
# No session_id, all empty → score = 0 for everyone → ties → rotate.
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
insts, None, None, 100, {})
seen.add(idx)
assert decision["decision"] == "lmetric_fallback"
assert decision["tie_break_used"] is True
assert seen == {0, 1, 2}, f"tie-breaker did not rotate; only saw {seen}"
def test_hybrid_decision_fields_populated(proxy):
"""Review #7: decision dict must carry the breakdown fields."""
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
_, _, decision = proxy.pick_instance_unified_hybrid(
insts, None, None, 100, {})
expected_keys = {
"decision", "affinity_idx", "chosen_idx",
"affinity_cache_hit", "affinity_cache_ratio", "affinity_num_requests",
"avg_num_requests", "fallback_score", "tie_break_used",
}
assert expected_keys.issubset(decision.keys())
def test_settings_has_runtime_knobs(proxy):
"""D5/B4/M3: Settings dataclass exposes the previously-hardcoded knobs."""
s = proxy.SETTINGS