diff --git a/scripts/cache_aware_proxy.py b/scripts/cache_aware_proxy.py index dd36459..3006099 100644 --- a/scripts/cache_aware_proxy.py +++ b/scripts/cache_aware_proxy.py @@ -51,7 +51,6 @@ class Settings: max_offload_inflight: int = 4 cache_gate_ratio: float = 0.0 decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20) - migration_discount_cap: int = 5 # max turns to discount SETTINGS = Settings() @@ -148,7 +147,9 @@ def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | affinity: dict[str, int]) -> tuple[InstanceState, int]: """LMetric routing: score = P_tokens × BS (OSDI'26). - Pure per-request load-based routing, no session affinity. + Pure per-request load-based routing, no session affinity (the + session_id/affinity args are accepted for signature compatibility + with pick_instance/pick_instance_unified_hybrid but ignored). P = pending_prefill_tokens + (input_length - cache_hit) BS = num_requests (current batch size) """ @@ -166,42 +167,85 @@ def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | return instances[best_idx], best_idx -_bootstrap_client: httpx.AsyncClient | None = None - -BOOTSTRAP_TIMEOUT_S = 1.0 # timeout for /estimate_hit calls +_unified_fallback_rr_counter = 0 -async def _get_bootstrap_client() -> httpx.AsyncClient: - global _bootstrap_client - if _bootstrap_client is None: - _bootstrap_client = httpx.AsyncClient( - timeout=httpx.Timeout(BOOTSTRAP_TIMEOUT_S), - limits=httpx.Limits(max_connections=32, max_keepalive_connections=16), - ) - return _bootstrap_client +def pick_instance_unified_hybrid( + instances: list[InstanceState], + token_ids: list[int] | None, + session_id: str | None, + input_length: int, + affinity: dict[str, int], +) -> tuple[InstanceState, int, dict]: + """Hybrid routing: high-cache affinity, else LMetric with tie-breaker. + Affinity gate (both must hold to stick): + - affinity instance cache_hit / input_length > 0.5 + - affinity.num_requests <= avg_num_requests * SETTINGS.overload_factor -async def _query_bootstrap_hit( - inst: InstanceState, token_ids: list[int], -) -> int | None: - """Query bootstrap's /estimate_hit for real cache hit count. + Fallback ordering (when affinity not used): + primary: score = P_tokens * BS (LMetric) + secondary: new_uncached_tokens (prefer instance with most cache) + tertiary: num_requests (prefer least-loaded) + quaternary: round-robin (avoid degenerate inst-0 pinning + when BS=0 across the board) - Returns hit_tokens on success, None on failure (caller should fallback). + Returns (chosen, idx, decision_dict). decision_dict carries the + review #7 breakdown fields so the caller can merge them verbatim. """ - if inst.bootstrap_port is None: - return None - parsed = urllib.parse.urlparse(str(inst.client.base_url)) - url = f"http://{parsed.hostname}:{inst.bootstrap_port}/estimate_hit" - try: - client = await _get_bootstrap_client() - resp = await client.post(url, json={ - "token_ids": token_ids, - "block_size": BLOCK_SIZE, - }) - resp.raise_for_status() - return resp.json()["hit_tokens"] - except Exception: - return None + global _unified_fallback_rr_counter + n = len(instances) + avg_reqs = max(sum(i.num_requests for i in instances) / n, 1.0) + + decision: dict = { + "decision": "lmetric_fallback", + "affinity_idx": None, + "chosen_idx": None, + "affinity_cache_hit": None, + "affinity_cache_ratio": None, + "affinity_num_requests": None, + "avg_num_requests": avg_reqs, + "fallback_score": None, + "tie_break_used": False, + } + + if session_id and session_id in affinity: + a_idx = affinity[session_id] + if a_idx < n: + a_inst = instances[a_idx] + a_hit = a_inst.estimate_cache_hit(token_ids) + a_ratio = a_hit / max(input_length, 1) + decision["affinity_idx"] = a_idx + decision["affinity_cache_hit"] = a_hit + decision["affinity_cache_ratio"] = a_ratio + decision["affinity_num_requests"] = a_inst.num_requests + if (a_ratio > 0.5 + and a_inst.num_requests <= avg_reqs * SETTINGS.overload_factor): + decision["decision"] = "affinity" + decision["chosen_idx"] = a_idx + return a_inst, a_idx, decision + + keys: list[tuple[int, int, int, int]] = [] + for i, inst in enumerate(instances): + cache_hit = inst.estimate_cache_hit(token_ids) + new_prefill = max(0, input_length - cache_hit) + p_tokens = inst.pending_prefill_tokens + new_prefill + bs = inst.num_requests + score = p_tokens * bs + keys.append((score, new_prefill, bs, i)) + + best_triple = min(k[:3] for k in keys) + tied = [k for k in keys if k[:3] == best_triple] + if len(tied) > 1: + decision["tie_break_used"] = True + _unified_fallback_rr_counter += 1 + winner = tied[_unified_fallback_rr_counter % len(tied)] + else: + winner = tied[0] + chosen_idx = winner[3] + decision["fallback_score"] = winner[0] + decision["chosen_idx"] = chosen_idx + return instances[chosen_idx], chosen_idx, decision def _extract_output_token_ids_from_sse( @@ -379,8 +423,6 @@ async def lifespan(app: FastAPI): await reconcile_task except asyncio.CancelledError: pass - if _bootstrap_client is not None: - await _bootstrap_client.aclose() for inst in combined_instances + prefill_instances + decode_instances: await inst.client.aclose() @@ -471,314 +513,56 @@ async def _handle_local_request(api, req_data, headers, token_ids, input_length, async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers): - """Unified routing: pick the instance with lowest expected latency. + """Route a /v1/* request among combined (PD-colocated) instances. - For each instance, estimate: - latency = queue_time + prefill_time + transfer_cost - where prefill_time depends on whether the instance has cache (local), - can receive cache via PUSH (remote), or must do cold prefill. + --policy options: + linear: cache_hit-aware load score + sticky session affinity. + lmetric: P_tokens * BS (LMetric, OSDI'26). No session affinity. + unified: hybrid — stick to affinity instance when cache_ratio > 0.5 + and it is not overloaded; otherwise fall back to LMetric + with a multi-key tie-breaker. + + PD-sep offload / PUSH migration is retired (see REPORT.md §3.9 and + commits 4c583f2 / cc6e562: relaxed-gate and forced-migration variants + both regressed E2E tail). Re-enabling requires a new transfer mechanism. """ policy = getattr(global_args, 'policy', 'linear') - offload_enabled = getattr(global_args, 'offload', False) and len(combined_instances) >= 2 - throughput = SETTINGS.prefill_throughput + breakdown: dict = { + "request_id": headers.get("X-Request-Id", ""), + "input_length": input_length, + "t_proxy_recv": _time.monotonic(), + "policy": policy, + } - if policy in ("linear", "lmetric"): - if policy == "lmetric": - chosen, best_idx = pick_instance_lmetric( - combined_instances, token_ids, session_id, input_length, - session_affinity_combined) - else: - chosen, best_idx = pick_instance( - combined_instances, token_ids, session_id, input_length, - session_affinity_combined) - cache_hit = chosen.estimate_cache_hit(token_ids) - estimated_new = max(0, input_length - cache_hit) - breakdown = { - "request_id": headers.get("X-Request-Id", ""), - "input_length": input_length, - "cache_hit": cache_hit, - "estimated_new_tokens": estimated_new, - "t_proxy_recv": _time.monotonic(), - "policy": policy, - "route_class": "LOCAL", - "routed_to": chosen.url, - } - if session_id and policy == "lmetric": - # LMetric is intentionally per-request; record last target only for - # stats/debugging, not for future decisions. + if policy == "lmetric": + chosen, best_idx = pick_instance_lmetric( + combined_instances, token_ids, session_id, input_length, + session_affinity_combined) + elif policy == "unified": + chosen, best_idx, decision = pick_instance_unified_hybrid( + combined_instances, token_ids, session_id, input_length, + session_affinity_combined) + breakdown.update(decision) + if session_id: session_affinity_combined[session_id] = best_idx - return await _handle_local_request( - api, req_data, headers, token_ids, input_length, - chosen, estimated_new, breakdown) - - # Hybrid routing: LMetric for load balance + explicit affinity for high-cache sessions - # - # 1. If session has high cache on affinity instance AND instance not overloaded → stick - # 2. Otherwise → LMetric (P × BS) for best load balance - affinity_idx = session_affinity_combined.get(session_id) if session_id else None - use_affinity = False - - if affinity_idx is not None and affinity_idx < len(combined_instances): - affinity_inst = combined_instances[affinity_idx] - affinity_cache = affinity_inst.estimate_cache_hit(token_ids) - cache_ratio = affinity_cache / max(input_length, 1) - avg_reqs = max(sum(i.num_requests for i in combined_instances) / len(combined_instances), 1.0) - - if (cache_ratio > 0.5 - and affinity_inst.num_requests <= avg_reqs * SETTINGS.overload_factor): - use_affinity = True - best_idx = affinity_idx - - if not use_affinity: - _, best_idx = pick_instance_lmetric( + else: # linear (default) + chosen, best_idx = pick_instance( combined_instances, token_ids, session_id, input_length, session_affinity_combined) - best_needs_push = False - - chosen = combined_instances[best_idx] cache_hit = chosen.estimate_cache_hit(token_ids) estimated_new = max(0, input_length - cache_hit) - - breakdown = { - "request_id": headers.get("X-Request-Id", ""), - "input_length": input_length, + breakdown.update({ "cache_hit": cache_hit, "estimated_new_tokens": estimated_new, - "t_proxy_recv": _time.monotonic(), - "policy": "affinity" if use_affinity else "lmetric", - } - - if session_id: - session_affinity_combined[session_id] = best_idx - - if best_needs_push: - c_inst = combined_instances[best_cache_idx] - d_inst = chosen - - # Query real cache hit from bootstrap (shadow cache is inaccurate) - real_hit = await _query_bootstrap_hit(c_inst, token_ids) - breakdown["shadow_cache_hit"] = best_cache_hit - breakdown["real_cache_hit"] = real_hit - - if real_hit is not None: - push_cache_hit = real_hit - else: - push_cache_hit = best_cache_hit # fallback to shadow estimate - - # If real hit > 0, proceed with offload - if push_cache_hit > 0: - push_new = max(0, input_length - push_cache_hit) - cache_ratio = push_cache_hit / max(input_length, 1) - - if _current_offloads() >= SETTINGS.max_offload_inflight: - breakdown["push_downgraded"] = "cap_reached" - return await _handle_local_request( - api, req_data, headers, token_ids, input_length, - chosen, estimated_new, breakdown) - if push_new < SETTINGS.heavy_threshold: - breakdown["push_downgraded"] = "below_heavy_threshold" - return await _handle_local_request( - api, req_data, headers, token_ids, input_length, - chosen, estimated_new, breakdown) - if SETTINGS.cache_gate_ratio > 0 and cache_ratio < SETTINGS.cache_gate_ratio: - breakdown["push_downgraded"] = "cache_gate" - return await _handle_local_request( - api, req_data, headers, token_ids, input_length, - chosen, estimated_new, breakdown) - - offload_mode = getattr(global_args, 'offload_mode', 'cached_prefill') - breakdown["c_inst"] = c_inst.url - breakdown["d_inst"] = d_inst.url - breakdown["push_cache_hit"] = push_cache_hit - - if offload_mode == "cached_prefill": - c_inst.ongoing_tokens += input_length - c_inst.pending_prefill_tokens += push_new - c_inst.num_requests += 1 - c_inst.active_p_offloads += 1 - breakdown["route_class"] = "CACHED_PREFILL_OFFLOAD" - return await _handle_cached_prefill_offload( - api, req_data, headers, token_ids, input_length, - c_inst, d_inst, push_cache_hit, push_new, breakdown) - else: - d_inst.ongoing_tokens += input_length - d_inst.pending_prefill_tokens += push_new - d_inst.num_requests += 1 - c_inst.active_p_offloads += 1 - breakdown["route_class"] = "PUSH_MIGRATE" - return await _handle_direct_read_offload( - api, req_data, headers, token_ids, input_length, - c_inst, d_inst, push_cache_hit, push_new, breakdown) - - # Real hit is 0 — downgrade to LOCAL - breakdown["push_downgraded"] = True - - # LOCAL path (also handles downgraded PUSH) - breakdown["route_class"] = "LOCAL" - breakdown["routed_to"] = chosen.url + "route_class": "LOCAL", + "routed_to": chosen.url, + }) return await _handle_local_request( api, req_data, headers, token_ids, input_length, chosen, estimated_new, breakdown) -PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill - - -async def _handle_cached_prefill_offload(api, req_data, headers, token_ids, - input_length, c_inst, d_inst, - cache_hit, estimated_new, breakdown): - """C does fast cached prefill → KV to Mooncake → D pulls KV and decodes. - - Unlike direct_read (D pulls blocks from C), here C's scheduler IS - involved: C prefills (fast, because prefix is cached), pushes KV to - Mooncake store, then D pulls and decodes. This avoids the broken - PUSH path where D waits for RDMA transfer while occupying KV blocks. - """ - request_id = headers.get("X-Request-Id", "") - - # Step 1: send blocking prefill to C - prefill_data = req_data.copy() - prefill_data["kv_transfer_params"] = { - "do_remote_decode": True, - "do_remote_prefill": False, - "transfer_id": f"xfer-{request_id}", - } - prefill_data["stream"] = False - prefill_data["max_tokens"] = 1 - prefill_data["min_tokens"] = 1 - prefill_data.pop("max_completion_tokens", None) - prefill_data.pop("stream_options", None) - p_headers = {**headers, "X-data-parallel-rank": "0"} - - breakdown["t_prefill_sent"] = _time.monotonic() - - try: - resp = await c_inst.client.post(api, json=prefill_data, headers=p_headers) - breakdown["t_prefill_done"] = _time.monotonic() - resp.raise_for_status() - await resp.aclose() - c_inst.record_prefix(token_ids) - except Exception as e: - breakdown["t_prefill_done"] = _time.monotonic() - breakdown["prefill_error"] = True - _breakdown_log.append(breakdown) - c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1) - c_inst.ongoing_tokens -= input_length - c_inst.pending_prefill_tokens -= estimated_new - c_inst.num_requests -= 1 - raise HTTPException(status_code=502, detail=f"Prefill on C failed: {e}") - - c_inst.ongoing_tokens -= input_length - c_inst.pending_prefill_tokens -= estimated_new - c_inst.num_requests -= 1 - c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1) - - # Step 2: send decode to D (pull KV from C via Mooncake) - d_inst.ongoing_tokens += input_length - d_inst.num_requests += 1 - - parsed = urllib.parse.urlparse(str(c_inst.client.base_url)) - bootstrap_addr = f"http://{parsed.hostname}:{c_inst.bootstrap_port}" - - decode_data = req_data.copy() - decode_data["kv_transfer_params"] = { - "do_remote_decode": False, - "do_remote_prefill": True, - "remote_bootstrap_addr": bootstrap_addr, - "remote_engine_id": c_inst.engine_id.get(0, ""), - "transfer_id": f"xfer-{request_id}", - } - - breakdown["t_decode_sent"] = _time.monotonic() - - async def generate(): - first_token = True - sse_buffer = "" - output_token_ids: list[int] = [] - try: - async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp: - resp.raise_for_status() - async for chunk in resp.aiter_bytes(): - sse_buffer, new_output_ids = _extract_output_token_ids_from_sse( - sse_buffer, chunk) - output_token_ids.extend(new_output_ids) - if first_token: - d_inst.ongoing_decode_tokens += input_length - breakdown["t_first_token"] = _time.monotonic() - first_token = False - yield chunk - d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids)) - finally: - if not first_token: - d_inst.ongoing_decode_tokens -= input_length - d_inst.ongoing_tokens -= input_length - d_inst.num_requests -= 1 - breakdown["t_done"] = _time.monotonic() - _breakdown_log.append(breakdown) - - return StreamingResponse(generate(), media_type="text/event-stream") - - -async def _handle_direct_read_offload(api, req_data, headers, token_ids, - input_length, c_inst, d_inst, - cache_hit, estimated_new, breakdown): - """HEAVY request: D direct-RDMA-reads cached KV from C_s, then does - local prefill for new tokens + decode. C_s's scheduler is NOT involved. - """ - request_id = headers.get("X-Request-Id", "") - - # Align cache_hit to block boundary for remote_num_tokens - cached_tokens = (cache_hit // BLOCK_SIZE) * BLOCK_SIZE - breakdown["t_offload_sent"] = _time.monotonic() - - parsed = urllib.parse.urlparse(str(c_inst.client.base_url)) - bootstrap_addr = "http://%s:%s" % (parsed.hostname, c_inst.bootstrap_port) - - # Send full prompt to D with direct_read flag - decode_data = req_data.copy() - decode_data["kv_transfer_params"] = { - "do_remote_decode": False, - "do_remote_prefill": True, - "direct_read": True, - "remote_bootstrap_addr": bootstrap_addr, - "remote_engine_id": c_inst.engine_id.get(0, ""), - "transfer_id": "xfer-" + request_id, - "remote_num_tokens": cached_tokens, - } - - async def generate(): - first_token = True - sse_buffer = "" - output_token_ids: list[int] = [] - try: - async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp: - resp.raise_for_status() - async for chunk in resp.aiter_bytes(): - sse_buffer, new_output_ids = _extract_output_token_ids_from_sse( - sse_buffer, chunk) - output_token_ids.extend(new_output_ids) - if first_token: - d_inst.pending_prefill_tokens -= estimated_new - d_inst.ongoing_decode_tokens += input_length - breakdown["t_first_token"] = _time.monotonic() - first_token = False - yield chunk - d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids)) - finally: - if first_token: - d_inst.pending_prefill_tokens -= estimated_new - else: - d_inst.ongoing_decode_tokens -= input_length - d_inst.ongoing_tokens -= input_length - d_inst.num_requests -= 1 - c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1) - breakdown["t_done"] = _time.monotonic() - _breakdown_log.append(breakdown) - - return StreamingResponse(generate(), media_type="text/event-stream") - - async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length, session_id, headers): """PD-Sep mode with per-stage breakdown profiling.""" @@ -901,20 +685,23 @@ def parse_args(): help="Comma-separated bootstrap ports for combined instances (for offload mode)") p.add_argument("--policy", type=str, default="linear", choices=["linear", "lmetric", "unified"], - help="Routing policy: linear, lmetric (P_tokens × BS), or unified cost model") + help="Routing policy: linear (cache-aware), lmetric (P_tokens × BS), " + "or unified (hybrid affinity + LMetric fallback)") p.add_argument("--overload-factor", type=float, default=2.0, help="Break session affinity when instance load > factor * avg") + # The four flags below are accepted for bench.sh backward compatibility but + # have no effect after the PD-sep offload path was retired (REPORT §3.9, + # commits 4c583f2 / cc6e562). Removing them would break scripts/bench.sh and + # scripts/legacy/*.sh which still pass them through. p.add_argument("--max-offload-inflight", type=int, default=4, - help="Global cap on concurrent P-role offloads (M3)") + help="[DEPRECATED] PUSH offload retired; no effect") p.add_argument("--offload-mode", type=str, default="cached_prefill", choices=["direct_read", "cached_prefill"], - help="direct_read: D pulls KV from C (PUSH). " - "cached_prefill: C prefills then D decodes (PD-sep style).") + help="[DEPRECATED] PUSH offload retired; no effect") p.add_argument("--cache-gate-ratio", type=float, default=0.0, - help="Min cache_hit/input ratio to allow offload " - "(0.0 disables gate, 1.0 disables offload entirely)") + help="[DEPRECATED] PUSH offload retired; no effect") p.add_argument("--decode-iteration-s", type=float, default=0.05, - help="Estimated per-request decode iteration time in seconds") + help="[DEPRECATED] PUSH offload retired; no effect") args = p.parse_args() args.prefill = [] diff --git a/tests/test_proxy_pick.py b/tests/test_proxy_pick.py index 539f703..e076753 100644 --- a/tests/test_proxy_pick.py +++ b/tests/test_proxy_pick.py @@ -180,6 +180,107 @@ def test_pick_instance_lmetric_picks_lowest_score(proxy): assert idx == 0 and chosen is insts[0] +def test_pick_instance_lmetric_ignores_session_affinity(proxy): + """Review #3: pure --policy lmetric must remain affinity-free.""" + insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")] + # Make inst[1] look much busier than inst[0]; LMetric must still pick 0 + # even though affinity points at 1. + insts[0].pending_prefill_tokens = 0 + insts[0].num_requests = 0 + insts[1].pending_prefill_tokens = 5000 + insts[1].num_requests = 4 + affinity = {"sess1": 1} + chosen, idx = proxy.pick_instance_lmetric(insts, None, "sess1", 1000, affinity) + assert idx == 0 + # Picker must not mutate the affinity dict either. + assert affinity == {"sess1": 1} + + +def _record_n_blocks(proxy, inst, n: int) -> list[int]: + """Record n distinct one-block prefixes on inst; return token_ids covering them.""" + block_size = proxy.BLOCK_SIZE + tokens: list[int] = [] + for b in range(n): + tokens.extend([1000 + b] * block_size) + inst.record_prefix(tokens) + return tokens + + +def test_hybrid_high_cache_session_sticks_to_affinity(proxy): + """Hybrid: affinity instance with cache_ratio > 0.5 and no overload → stick.""" + insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")] + tokens = _record_n_blocks(proxy, insts[1], 2) # 2 blocks cached on inst[1] + affinity = {"sess1": 1} + chosen, idx, decision = proxy.pick_instance_unified_hybrid( + insts, tokens, "sess1", len(tokens), affinity) + assert idx == 1 and chosen is insts[1] + assert decision["decision"] == "affinity" + assert decision["affinity_idx"] == 1 + assert decision["chosen_idx"] == 1 + assert decision["affinity_cache_ratio"] > 0.5 + assert decision["tie_break_used"] is False + + +def test_hybrid_high_cache_breaks_on_overload(proxy): + """Hybrid: affinity num_requests > avg * overload_factor → fall back to LMetric, + and with realistic new-token tail the LMetric fallback steers off the hot instance.""" + insts = [ + _make_inst(proxy, "http://a"), + _make_inst(proxy, "http://b"), + _make_inst(proxy, "http://c"), + ] + cached = _record_n_blocks(proxy, insts[1], 2) + # Append one more uncached block so LMetric sees a real prefill cost on the + # cached instance too (BS multiplier becomes visible). Without this, the + # cached instance scores 0 * BS = 0 regardless of how loaded it is. + tokens = cached + [999_999] * proxy.BLOCK_SIZE + insts[1].num_requests = 300 # avg = 100; 300 > 100 * 2.0 ✓ breaks the gate + affinity = {"sess1": 1} + chosen, idx, decision = proxy.pick_instance_unified_hybrid( + insts, tokens, "sess1", len(tokens), affinity) + assert decision["decision"] == "lmetric_fallback" + assert decision["affinity_idx"] == 1 + assert idx != 1, "affinity instance is overloaded; fallback should steer away" + + +def test_hybrid_low_cache_falls_back(proxy): + """Hybrid: cache_ratio <= 0.5 on affinity → fall back to LMetric.""" + insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")] + tokens = [1] * (proxy.BLOCK_SIZE * 2) # 1024 tokens, nothing cached anywhere + affinity = {"sess1": 1} + chosen, idx, decision = proxy.pick_instance_unified_hybrid( + insts, tokens, "sess1", len(tokens), affinity) + assert decision["decision"] == "lmetric_fallback" + assert decision["affinity_cache_ratio"] == 0.0 + + +def test_hybrid_new_session_tie_break_does_not_always_pick_index_0(proxy): + """Review #4: when all instances tie (e.g. BS=0), tie-break must rotate.""" + insts = [_make_inst(proxy, "http://a") for _ in range(3)] + seen = set() + for _ in range(12): + # No session_id, all empty → score = 0 for everyone → ties → rotate. + chosen, idx, decision = proxy.pick_instance_unified_hybrid( + insts, None, None, 100, {}) + seen.add(idx) + assert decision["decision"] == "lmetric_fallback" + assert decision["tie_break_used"] is True + assert seen == {0, 1, 2}, f"tie-breaker did not rotate; only saw {seen}" + + +def test_hybrid_decision_fields_populated(proxy): + """Review #7: decision dict must carry the breakdown fields.""" + insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")] + _, _, decision = proxy.pick_instance_unified_hybrid( + insts, None, None, 100, {}) + expected_keys = { + "decision", "affinity_idx", "chosen_idx", + "affinity_cache_hit", "affinity_cache_ratio", "affinity_num_requests", + "avg_num_requests", "fallback_score", "tie_break_used", + } + assert expected_keys.issubset(decision.keys()) + + def test_settings_has_runtime_knobs(proxy): """D5/B4/M3: Settings dataclass exposes the previously-hardcoded knobs.""" s = proxy.SETTINGS