Cleanup: retire dead PUSH path + extract hybrid picker
- Delete unreachable best_needs_push block in _handle_combined and the four orphaned helpers (_handle_cached_prefill_offload, _handle_direct_read_offload, _query_bootstrap_hit, _get_bootstrap_client). Their only caller was the retired PUSH gate; see REPORT §3.9 errata for the rejected experiments (cc6e562,4c583f2). - Extract pick_instance_unified_hybrid as a pure function returning (chosen, idx, decision_dict). The decision dict carries the review #7 breakdown fields (decision, affinity_idx/chosen_idx, cache_hit/ratio, avg_num_requests, fallback_score, tie_break_used). - Add LMetric-fallback tie-breaker (primary score, then new_uncached, num_requests, round-robin) so new sessions don't all pin to inst 0 when BS=0 across the board. - Drop the lmetric-policy affinity write so --policy lmetric stays affinity-free per review #3. - Mark --max-offload-inflight / --offload-mode / --cache-gate-ratio / --decode-iteration-s as [DEPRECATED] in --help; flags remain accepted so scripts/bench.sh and legacy launchers don't break. - Revert uncommitted overload_factor 2.0->1.5 default; H7 sweep already rejected this knob (within noise). Future sweeps should go via CLI. Tests: add 6 hybrid-policy tests in tests/test_proxy_pick.py covering affinity-hit, overload break, low-cache fallback, tie-break rotation, lmetric purity, and breakdown field shape. 19/19 pass. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -51,7 +51,6 @@ class Settings:
|
||||
max_offload_inflight: int = 4
|
||||
cache_gate_ratio: float = 0.0
|
||||
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
|
||||
migration_discount_cap: int = 5 # max turns to discount
|
||||
|
||||
|
||||
SETTINGS = Settings()
|
||||
@@ -148,7 +147,9 @@ def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] |
|
||||
affinity: dict[str, int]) -> tuple[InstanceState, int]:
|
||||
"""LMetric routing: score = P_tokens × BS (OSDI'26).
|
||||
|
||||
Pure per-request load-based routing, no session affinity.
|
||||
Pure per-request load-based routing, no session affinity (the
|
||||
session_id/affinity args are accepted for signature compatibility
|
||||
with pick_instance/pick_instance_unified_hybrid but ignored).
|
||||
P = pending_prefill_tokens + (input_length - cache_hit)
|
||||
BS = num_requests (current batch size)
|
||||
"""
|
||||
@@ -166,42 +167,85 @@ def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] |
|
||||
return instances[best_idx], best_idx
|
||||
|
||||
|
||||
_bootstrap_client: httpx.AsyncClient | None = None
|
||||
|
||||
BOOTSTRAP_TIMEOUT_S = 1.0 # timeout for /estimate_hit calls
|
||||
_unified_fallback_rr_counter = 0
|
||||
|
||||
|
||||
async def _get_bootstrap_client() -> httpx.AsyncClient:
|
||||
global _bootstrap_client
|
||||
if _bootstrap_client is None:
|
||||
_bootstrap_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(BOOTSTRAP_TIMEOUT_S),
|
||||
limits=httpx.Limits(max_connections=32, max_keepalive_connections=16),
|
||||
)
|
||||
return _bootstrap_client
|
||||
def pick_instance_unified_hybrid(
|
||||
instances: list[InstanceState],
|
||||
token_ids: list[int] | None,
|
||||
session_id: str | None,
|
||||
input_length: int,
|
||||
affinity: dict[str, int],
|
||||
) -> tuple[InstanceState, int, dict]:
|
||||
"""Hybrid routing: high-cache affinity, else LMetric with tie-breaker.
|
||||
|
||||
Affinity gate (both must hold to stick):
|
||||
- affinity instance cache_hit / input_length > 0.5
|
||||
- affinity.num_requests <= avg_num_requests * SETTINGS.overload_factor
|
||||
|
||||
async def _query_bootstrap_hit(
|
||||
inst: InstanceState, token_ids: list[int],
|
||||
) -> int | None:
|
||||
"""Query bootstrap's /estimate_hit for real cache hit count.
|
||||
Fallback ordering (when affinity not used):
|
||||
primary: score = P_tokens * BS (LMetric)
|
||||
secondary: new_uncached_tokens (prefer instance with most cache)
|
||||
tertiary: num_requests (prefer least-loaded)
|
||||
quaternary: round-robin (avoid degenerate inst-0 pinning
|
||||
when BS=0 across the board)
|
||||
|
||||
Returns hit_tokens on success, None on failure (caller should fallback).
|
||||
Returns (chosen, idx, decision_dict). decision_dict carries the
|
||||
review #7 breakdown fields so the caller can merge them verbatim.
|
||||
"""
|
||||
if inst.bootstrap_port is None:
|
||||
return None
|
||||
parsed = urllib.parse.urlparse(str(inst.client.base_url))
|
||||
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/estimate_hit"
|
||||
try:
|
||||
client = await _get_bootstrap_client()
|
||||
resp = await client.post(url, json={
|
||||
"token_ids": token_ids,
|
||||
"block_size": BLOCK_SIZE,
|
||||
})
|
||||
resp.raise_for_status()
|
||||
return resp.json()["hit_tokens"]
|
||||
except Exception:
|
||||
return None
|
||||
global _unified_fallback_rr_counter
|
||||
n = len(instances)
|
||||
avg_reqs = max(sum(i.num_requests for i in instances) / n, 1.0)
|
||||
|
||||
decision: dict = {
|
||||
"decision": "lmetric_fallback",
|
||||
"affinity_idx": None,
|
||||
"chosen_idx": None,
|
||||
"affinity_cache_hit": None,
|
||||
"affinity_cache_ratio": None,
|
||||
"affinity_num_requests": None,
|
||||
"avg_num_requests": avg_reqs,
|
||||
"fallback_score": None,
|
||||
"tie_break_used": False,
|
||||
}
|
||||
|
||||
if session_id and session_id in affinity:
|
||||
a_idx = affinity[session_id]
|
||||
if a_idx < n:
|
||||
a_inst = instances[a_idx]
|
||||
a_hit = a_inst.estimate_cache_hit(token_ids)
|
||||
a_ratio = a_hit / max(input_length, 1)
|
||||
decision["affinity_idx"] = a_idx
|
||||
decision["affinity_cache_hit"] = a_hit
|
||||
decision["affinity_cache_ratio"] = a_ratio
|
||||
decision["affinity_num_requests"] = a_inst.num_requests
|
||||
if (a_ratio > 0.5
|
||||
and a_inst.num_requests <= avg_reqs * SETTINGS.overload_factor):
|
||||
decision["decision"] = "affinity"
|
||||
decision["chosen_idx"] = a_idx
|
||||
return a_inst, a_idx, decision
|
||||
|
||||
keys: list[tuple[int, int, int, int]] = []
|
||||
for i, inst in enumerate(instances):
|
||||
cache_hit = inst.estimate_cache_hit(token_ids)
|
||||
new_prefill = max(0, input_length - cache_hit)
|
||||
p_tokens = inst.pending_prefill_tokens + new_prefill
|
||||
bs = inst.num_requests
|
||||
score = p_tokens * bs
|
||||
keys.append((score, new_prefill, bs, i))
|
||||
|
||||
best_triple = min(k[:3] for k in keys)
|
||||
tied = [k for k in keys if k[:3] == best_triple]
|
||||
if len(tied) > 1:
|
||||
decision["tie_break_used"] = True
|
||||
_unified_fallback_rr_counter += 1
|
||||
winner = tied[_unified_fallback_rr_counter % len(tied)]
|
||||
else:
|
||||
winner = tied[0]
|
||||
chosen_idx = winner[3]
|
||||
decision["fallback_score"] = winner[0]
|
||||
decision["chosen_idx"] = chosen_idx
|
||||
return instances[chosen_idx], chosen_idx, decision
|
||||
|
||||
|
||||
def _extract_output_token_ids_from_sse(
|
||||
@@ -379,8 +423,6 @@ async def lifespan(app: FastAPI):
|
||||
await reconcile_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if _bootstrap_client is not None:
|
||||
await _bootstrap_client.aclose()
|
||||
for inst in combined_instances + prefill_instances + decode_instances:
|
||||
await inst.client.aclose()
|
||||
|
||||
@@ -471,314 +513,56 @@ async def _handle_local_request(api, req_data, headers, token_ids, input_length,
|
||||
|
||||
|
||||
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
|
||||
"""Unified routing: pick the instance with lowest expected latency.
|
||||
"""Route a /v1/* request among combined (PD-colocated) instances.
|
||||
|
||||
For each instance, estimate:
|
||||
latency = queue_time + prefill_time + transfer_cost
|
||||
where prefill_time depends on whether the instance has cache (local),
|
||||
can receive cache via PUSH (remote), or must do cold prefill.
|
||||
--policy options:
|
||||
linear: cache_hit-aware load score + sticky session affinity.
|
||||
lmetric: P_tokens * BS (LMetric, OSDI'26). No session affinity.
|
||||
unified: hybrid — stick to affinity instance when cache_ratio > 0.5
|
||||
and it is not overloaded; otherwise fall back to LMetric
|
||||
with a multi-key tie-breaker.
|
||||
|
||||
PD-sep offload / PUSH migration is retired (see REPORT.md §3.9 and
|
||||
commits 4c583f2 / cc6e562: relaxed-gate and forced-migration variants
|
||||
both regressed E2E tail). Re-enabling requires a new transfer mechanism.
|
||||
"""
|
||||
policy = getattr(global_args, 'policy', 'linear')
|
||||
offload_enabled = getattr(global_args, 'offload', False) and len(combined_instances) >= 2
|
||||
throughput = SETTINGS.prefill_throughput
|
||||
breakdown: dict = {
|
||||
"request_id": headers.get("X-Request-Id", ""),
|
||||
"input_length": input_length,
|
||||
"t_proxy_recv": _time.monotonic(),
|
||||
"policy": policy,
|
||||
}
|
||||
|
||||
if policy in ("linear", "lmetric"):
|
||||
if policy == "lmetric":
|
||||
chosen, best_idx = pick_instance_lmetric(
|
||||
combined_instances, token_ids, session_id, input_length,
|
||||
session_affinity_combined)
|
||||
else:
|
||||
chosen, best_idx = pick_instance(
|
||||
combined_instances, token_ids, session_id, input_length,
|
||||
session_affinity_combined)
|
||||
cache_hit = chosen.estimate_cache_hit(token_ids)
|
||||
estimated_new = max(0, input_length - cache_hit)
|
||||
breakdown = {
|
||||
"request_id": headers.get("X-Request-Id", ""),
|
||||
"input_length": input_length,
|
||||
"cache_hit": cache_hit,
|
||||
"estimated_new_tokens": estimated_new,
|
||||
"t_proxy_recv": _time.monotonic(),
|
||||
"policy": policy,
|
||||
"route_class": "LOCAL",
|
||||
"routed_to": chosen.url,
|
||||
}
|
||||
if session_id and policy == "lmetric":
|
||||
# LMetric is intentionally per-request; record last target only for
|
||||
# stats/debugging, not for future decisions.
|
||||
if policy == "lmetric":
|
||||
chosen, best_idx = pick_instance_lmetric(
|
||||
combined_instances, token_ids, session_id, input_length,
|
||||
session_affinity_combined)
|
||||
elif policy == "unified":
|
||||
chosen, best_idx, decision = pick_instance_unified_hybrid(
|
||||
combined_instances, token_ids, session_id, input_length,
|
||||
session_affinity_combined)
|
||||
breakdown.update(decision)
|
||||
if session_id:
|
||||
session_affinity_combined[session_id] = best_idx
|
||||
return await _handle_local_request(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
chosen, estimated_new, breakdown)
|
||||
|
||||
# Hybrid routing: LMetric for load balance + explicit affinity for high-cache sessions
|
||||
#
|
||||
# 1. If session has high cache on affinity instance AND instance not overloaded → stick
|
||||
# 2. Otherwise → LMetric (P × BS) for best load balance
|
||||
affinity_idx = session_affinity_combined.get(session_id) if session_id else None
|
||||
use_affinity = False
|
||||
|
||||
if affinity_idx is not None and affinity_idx < len(combined_instances):
|
||||
affinity_inst = combined_instances[affinity_idx]
|
||||
affinity_cache = affinity_inst.estimate_cache_hit(token_ids)
|
||||
cache_ratio = affinity_cache / max(input_length, 1)
|
||||
avg_reqs = max(sum(i.num_requests for i in combined_instances) / len(combined_instances), 1.0)
|
||||
|
||||
if (cache_ratio > 0.5
|
||||
and affinity_inst.num_requests <= avg_reqs * SETTINGS.overload_factor):
|
||||
use_affinity = True
|
||||
best_idx = affinity_idx
|
||||
|
||||
if not use_affinity:
|
||||
_, best_idx = pick_instance_lmetric(
|
||||
else: # linear (default)
|
||||
chosen, best_idx = pick_instance(
|
||||
combined_instances, token_ids, session_id, input_length,
|
||||
session_affinity_combined)
|
||||
|
||||
best_needs_push = False
|
||||
|
||||
chosen = combined_instances[best_idx]
|
||||
cache_hit = chosen.estimate_cache_hit(token_ids)
|
||||
estimated_new = max(0, input_length - cache_hit)
|
||||
|
||||
breakdown = {
|
||||
"request_id": headers.get("X-Request-Id", ""),
|
||||
"input_length": input_length,
|
||||
breakdown.update({
|
||||
"cache_hit": cache_hit,
|
||||
"estimated_new_tokens": estimated_new,
|
||||
"t_proxy_recv": _time.monotonic(),
|
||||
"policy": "affinity" if use_affinity else "lmetric",
|
||||
}
|
||||
|
||||
if session_id:
|
||||
session_affinity_combined[session_id] = best_idx
|
||||
|
||||
if best_needs_push:
|
||||
c_inst = combined_instances[best_cache_idx]
|
||||
d_inst = chosen
|
||||
|
||||
# Query real cache hit from bootstrap (shadow cache is inaccurate)
|
||||
real_hit = await _query_bootstrap_hit(c_inst, token_ids)
|
||||
breakdown["shadow_cache_hit"] = best_cache_hit
|
||||
breakdown["real_cache_hit"] = real_hit
|
||||
|
||||
if real_hit is not None:
|
||||
push_cache_hit = real_hit
|
||||
else:
|
||||
push_cache_hit = best_cache_hit # fallback to shadow estimate
|
||||
|
||||
# If real hit > 0, proceed with offload
|
||||
if push_cache_hit > 0:
|
||||
push_new = max(0, input_length - push_cache_hit)
|
||||
cache_ratio = push_cache_hit / max(input_length, 1)
|
||||
|
||||
if _current_offloads() >= SETTINGS.max_offload_inflight:
|
||||
breakdown["push_downgraded"] = "cap_reached"
|
||||
return await _handle_local_request(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
chosen, estimated_new, breakdown)
|
||||
if push_new < SETTINGS.heavy_threshold:
|
||||
breakdown["push_downgraded"] = "below_heavy_threshold"
|
||||
return await _handle_local_request(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
chosen, estimated_new, breakdown)
|
||||
if SETTINGS.cache_gate_ratio > 0 and cache_ratio < SETTINGS.cache_gate_ratio:
|
||||
breakdown["push_downgraded"] = "cache_gate"
|
||||
return await _handle_local_request(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
chosen, estimated_new, breakdown)
|
||||
|
||||
offload_mode = getattr(global_args, 'offload_mode', 'cached_prefill')
|
||||
breakdown["c_inst"] = c_inst.url
|
||||
breakdown["d_inst"] = d_inst.url
|
||||
breakdown["push_cache_hit"] = push_cache_hit
|
||||
|
||||
if offload_mode == "cached_prefill":
|
||||
c_inst.ongoing_tokens += input_length
|
||||
c_inst.pending_prefill_tokens += push_new
|
||||
c_inst.num_requests += 1
|
||||
c_inst.active_p_offloads += 1
|
||||
breakdown["route_class"] = "CACHED_PREFILL_OFFLOAD"
|
||||
return await _handle_cached_prefill_offload(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
c_inst, d_inst, push_cache_hit, push_new, breakdown)
|
||||
else:
|
||||
d_inst.ongoing_tokens += input_length
|
||||
d_inst.pending_prefill_tokens += push_new
|
||||
d_inst.num_requests += 1
|
||||
c_inst.active_p_offloads += 1
|
||||
breakdown["route_class"] = "PUSH_MIGRATE"
|
||||
return await _handle_direct_read_offload(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
c_inst, d_inst, push_cache_hit, push_new, breakdown)
|
||||
|
||||
# Real hit is 0 — downgrade to LOCAL
|
||||
breakdown["push_downgraded"] = True
|
||||
|
||||
# LOCAL path (also handles downgraded PUSH)
|
||||
breakdown["route_class"] = "LOCAL"
|
||||
breakdown["routed_to"] = chosen.url
|
||||
"route_class": "LOCAL",
|
||||
"routed_to": chosen.url,
|
||||
})
|
||||
return await _handle_local_request(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
chosen, estimated_new, breakdown)
|
||||
|
||||
|
||||
PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
|
||||
|
||||
|
||||
async def _handle_cached_prefill_offload(api, req_data, headers, token_ids,
|
||||
input_length, c_inst, d_inst,
|
||||
cache_hit, estimated_new, breakdown):
|
||||
"""C does fast cached prefill → KV to Mooncake → D pulls KV and decodes.
|
||||
|
||||
Unlike direct_read (D pulls blocks from C), here C's scheduler IS
|
||||
involved: C prefills (fast, because prefix is cached), pushes KV to
|
||||
Mooncake store, then D pulls and decodes. This avoids the broken
|
||||
PUSH path where D waits for RDMA transfer while occupying KV blocks.
|
||||
"""
|
||||
request_id = headers.get("X-Request-Id", "")
|
||||
|
||||
# Step 1: send blocking prefill to C
|
||||
prefill_data = req_data.copy()
|
||||
prefill_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"transfer_id": f"xfer-{request_id}",
|
||||
}
|
||||
prefill_data["stream"] = False
|
||||
prefill_data["max_tokens"] = 1
|
||||
prefill_data["min_tokens"] = 1
|
||||
prefill_data.pop("max_completion_tokens", None)
|
||||
prefill_data.pop("stream_options", None)
|
||||
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||||
|
||||
breakdown["t_prefill_sent"] = _time.monotonic()
|
||||
|
||||
try:
|
||||
resp = await c_inst.client.post(api, json=prefill_data, headers=p_headers)
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
resp.raise_for_status()
|
||||
await resp.aclose()
|
||||
c_inst.record_prefix(token_ids)
|
||||
except Exception as e:
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
breakdown["prefill_error"] = True
|
||||
_breakdown_log.append(breakdown)
|
||||
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
|
||||
c_inst.ongoing_tokens -= input_length
|
||||
c_inst.pending_prefill_tokens -= estimated_new
|
||||
c_inst.num_requests -= 1
|
||||
raise HTTPException(status_code=502, detail=f"Prefill on C failed: {e}")
|
||||
|
||||
c_inst.ongoing_tokens -= input_length
|
||||
c_inst.pending_prefill_tokens -= estimated_new
|
||||
c_inst.num_requests -= 1
|
||||
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
|
||||
|
||||
# Step 2: send decode to D (pull KV from C via Mooncake)
|
||||
d_inst.ongoing_tokens += input_length
|
||||
d_inst.num_requests += 1
|
||||
|
||||
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
|
||||
bootstrap_addr = f"http://{parsed.hostname}:{c_inst.bootstrap_port}"
|
||||
|
||||
decode_data = req_data.copy()
|
||||
decode_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": False,
|
||||
"do_remote_prefill": True,
|
||||
"remote_bootstrap_addr": bootstrap_addr,
|
||||
"remote_engine_id": c_inst.engine_id.get(0, ""),
|
||||
"transfer_id": f"xfer-{request_id}",
|
||||
}
|
||||
|
||||
breakdown["t_decode_sent"] = _time.monotonic()
|
||||
|
||||
async def generate():
|
||||
first_token = True
|
||||
sse_buffer = ""
|
||||
output_token_ids: list[int] = []
|
||||
try:
|
||||
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
|
||||
sse_buffer, chunk)
|
||||
output_token_ids.extend(new_output_ids)
|
||||
if first_token:
|
||||
d_inst.ongoing_decode_tokens += input_length
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
first_token = False
|
||||
yield chunk
|
||||
d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids))
|
||||
finally:
|
||||
if not first_token:
|
||||
d_inst.ongoing_decode_tokens -= input_length
|
||||
d_inst.ongoing_tokens -= input_length
|
||||
d_inst.num_requests -= 1
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
async def _handle_direct_read_offload(api, req_data, headers, token_ids,
|
||||
input_length, c_inst, d_inst,
|
||||
cache_hit, estimated_new, breakdown):
|
||||
"""HEAVY request: D direct-RDMA-reads cached KV from C_s, then does
|
||||
local prefill for new tokens + decode. C_s's scheduler is NOT involved.
|
||||
"""
|
||||
request_id = headers.get("X-Request-Id", "")
|
||||
|
||||
# Align cache_hit to block boundary for remote_num_tokens
|
||||
cached_tokens = (cache_hit // BLOCK_SIZE) * BLOCK_SIZE
|
||||
breakdown["t_offload_sent"] = _time.monotonic()
|
||||
|
||||
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
|
||||
bootstrap_addr = "http://%s:%s" % (parsed.hostname, c_inst.bootstrap_port)
|
||||
|
||||
# Send full prompt to D with direct_read flag
|
||||
decode_data = req_data.copy()
|
||||
decode_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": False,
|
||||
"do_remote_prefill": True,
|
||||
"direct_read": True,
|
||||
"remote_bootstrap_addr": bootstrap_addr,
|
||||
"remote_engine_id": c_inst.engine_id.get(0, ""),
|
||||
"transfer_id": "xfer-" + request_id,
|
||||
"remote_num_tokens": cached_tokens,
|
||||
}
|
||||
|
||||
async def generate():
|
||||
first_token = True
|
||||
sse_buffer = ""
|
||||
output_token_ids: list[int] = []
|
||||
try:
|
||||
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
|
||||
sse_buffer, chunk)
|
||||
output_token_ids.extend(new_output_ids)
|
||||
if first_token:
|
||||
d_inst.pending_prefill_tokens -= estimated_new
|
||||
d_inst.ongoing_decode_tokens += input_length
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
first_token = False
|
||||
yield chunk
|
||||
d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids))
|
||||
finally:
|
||||
if first_token:
|
||||
d_inst.pending_prefill_tokens -= estimated_new
|
||||
else:
|
||||
d_inst.ongoing_decode_tokens -= input_length
|
||||
d_inst.ongoing_tokens -= input_length
|
||||
d_inst.num_requests -= 1
|
||||
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
||||
session_id, headers):
|
||||
"""PD-Sep mode with per-stage breakdown profiling."""
|
||||
@@ -901,20 +685,23 @@ def parse_args():
|
||||
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
|
||||
p.add_argument("--policy", type=str, default="linear",
|
||||
choices=["linear", "lmetric", "unified"],
|
||||
help="Routing policy: linear, lmetric (P_tokens × BS), or unified cost model")
|
||||
help="Routing policy: linear (cache-aware), lmetric (P_tokens × BS), "
|
||||
"or unified (hybrid affinity + LMetric fallback)")
|
||||
p.add_argument("--overload-factor", type=float, default=2.0,
|
||||
help="Break session affinity when instance load > factor * avg")
|
||||
# The four flags below are accepted for bench.sh backward compatibility but
|
||||
# have no effect after the PD-sep offload path was retired (REPORT §3.9,
|
||||
# commits 4c583f2 / cc6e562). Removing them would break scripts/bench.sh and
|
||||
# scripts/legacy/*.sh which still pass them through.
|
||||
p.add_argument("--max-offload-inflight", type=int, default=4,
|
||||
help="Global cap on concurrent P-role offloads (M3)")
|
||||
help="[DEPRECATED] PUSH offload retired; no effect")
|
||||
p.add_argument("--offload-mode", type=str, default="cached_prefill",
|
||||
choices=["direct_read", "cached_prefill"],
|
||||
help="direct_read: D pulls KV from C (PUSH). "
|
||||
"cached_prefill: C prefills then D decodes (PD-sep style).")
|
||||
help="[DEPRECATED] PUSH offload retired; no effect")
|
||||
p.add_argument("--cache-gate-ratio", type=float, default=0.0,
|
||||
help="Min cache_hit/input ratio to allow offload "
|
||||
"(0.0 disables gate, 1.0 disables offload entirely)")
|
||||
help="[DEPRECATED] PUSH offload retired; no effect")
|
||||
p.add_argument("--decode-iteration-s", type=float, default=0.05,
|
||||
help="Estimated per-request decode iteration time in seconds")
|
||||
help="[DEPRECATED] PUSH offload retired; no effect")
|
||||
args = p.parse_args()
|
||||
|
||||
args.prefill = []
|
||||
|
||||
@@ -180,6 +180,107 @@ def test_pick_instance_lmetric_picks_lowest_score(proxy):
|
||||
assert idx == 0 and chosen is insts[0]
|
||||
|
||||
|
||||
def test_pick_instance_lmetric_ignores_session_affinity(proxy):
|
||||
"""Review #3: pure --policy lmetric must remain affinity-free."""
|
||||
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
|
||||
# Make inst[1] look much busier than inst[0]; LMetric must still pick 0
|
||||
# even though affinity points at 1.
|
||||
insts[0].pending_prefill_tokens = 0
|
||||
insts[0].num_requests = 0
|
||||
insts[1].pending_prefill_tokens = 5000
|
||||
insts[1].num_requests = 4
|
||||
affinity = {"sess1": 1}
|
||||
chosen, idx = proxy.pick_instance_lmetric(insts, None, "sess1", 1000, affinity)
|
||||
assert idx == 0
|
||||
# Picker must not mutate the affinity dict either.
|
||||
assert affinity == {"sess1": 1}
|
||||
|
||||
|
||||
def _record_n_blocks(proxy, inst, n: int) -> list[int]:
|
||||
"""Record n distinct one-block prefixes on inst; return token_ids covering them."""
|
||||
block_size = proxy.BLOCK_SIZE
|
||||
tokens: list[int] = []
|
||||
for b in range(n):
|
||||
tokens.extend([1000 + b] * block_size)
|
||||
inst.record_prefix(tokens)
|
||||
return tokens
|
||||
|
||||
|
||||
def test_hybrid_high_cache_session_sticks_to_affinity(proxy):
|
||||
"""Hybrid: affinity instance with cache_ratio > 0.5 and no overload → stick."""
|
||||
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
|
||||
tokens = _record_n_blocks(proxy, insts[1], 2) # 2 blocks cached on inst[1]
|
||||
affinity = {"sess1": 1}
|
||||
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
|
||||
insts, tokens, "sess1", len(tokens), affinity)
|
||||
assert idx == 1 and chosen is insts[1]
|
||||
assert decision["decision"] == "affinity"
|
||||
assert decision["affinity_idx"] == 1
|
||||
assert decision["chosen_idx"] == 1
|
||||
assert decision["affinity_cache_ratio"] > 0.5
|
||||
assert decision["tie_break_used"] is False
|
||||
|
||||
|
||||
def test_hybrid_high_cache_breaks_on_overload(proxy):
|
||||
"""Hybrid: affinity num_requests > avg * overload_factor → fall back to LMetric,
|
||||
and with realistic new-token tail the LMetric fallback steers off the hot instance."""
|
||||
insts = [
|
||||
_make_inst(proxy, "http://a"),
|
||||
_make_inst(proxy, "http://b"),
|
||||
_make_inst(proxy, "http://c"),
|
||||
]
|
||||
cached = _record_n_blocks(proxy, insts[1], 2)
|
||||
# Append one more uncached block so LMetric sees a real prefill cost on the
|
||||
# cached instance too (BS multiplier becomes visible). Without this, the
|
||||
# cached instance scores 0 * BS = 0 regardless of how loaded it is.
|
||||
tokens = cached + [999_999] * proxy.BLOCK_SIZE
|
||||
insts[1].num_requests = 300 # avg = 100; 300 > 100 * 2.0 ✓ breaks the gate
|
||||
affinity = {"sess1": 1}
|
||||
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
|
||||
insts, tokens, "sess1", len(tokens), affinity)
|
||||
assert decision["decision"] == "lmetric_fallback"
|
||||
assert decision["affinity_idx"] == 1
|
||||
assert idx != 1, "affinity instance is overloaded; fallback should steer away"
|
||||
|
||||
|
||||
def test_hybrid_low_cache_falls_back(proxy):
|
||||
"""Hybrid: cache_ratio <= 0.5 on affinity → fall back to LMetric."""
|
||||
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
|
||||
tokens = [1] * (proxy.BLOCK_SIZE * 2) # 1024 tokens, nothing cached anywhere
|
||||
affinity = {"sess1": 1}
|
||||
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
|
||||
insts, tokens, "sess1", len(tokens), affinity)
|
||||
assert decision["decision"] == "lmetric_fallback"
|
||||
assert decision["affinity_cache_ratio"] == 0.0
|
||||
|
||||
|
||||
def test_hybrid_new_session_tie_break_does_not_always_pick_index_0(proxy):
|
||||
"""Review #4: when all instances tie (e.g. BS=0), tie-break must rotate."""
|
||||
insts = [_make_inst(proxy, "http://a") for _ in range(3)]
|
||||
seen = set()
|
||||
for _ in range(12):
|
||||
# No session_id, all empty → score = 0 for everyone → ties → rotate.
|
||||
chosen, idx, decision = proxy.pick_instance_unified_hybrid(
|
||||
insts, None, None, 100, {})
|
||||
seen.add(idx)
|
||||
assert decision["decision"] == "lmetric_fallback"
|
||||
assert decision["tie_break_used"] is True
|
||||
assert seen == {0, 1, 2}, f"tie-breaker did not rotate; only saw {seen}"
|
||||
|
||||
|
||||
def test_hybrid_decision_fields_populated(proxy):
|
||||
"""Review #7: decision dict must carry the breakdown fields."""
|
||||
insts = [_make_inst(proxy, "http://a"), _make_inst(proxy, "http://b")]
|
||||
_, _, decision = proxy.pick_instance_unified_hybrid(
|
||||
insts, None, None, 100, {})
|
||||
expected_keys = {
|
||||
"decision", "affinity_idx", "chosen_idx",
|
||||
"affinity_cache_hit", "affinity_cache_ratio", "affinity_num_requests",
|
||||
"avg_num_requests", "fallback_score", "tie_break_used",
|
||||
}
|
||||
assert expected_keys.issubset(decision.keys())
|
||||
|
||||
|
||||
def test_settings_has_runtime_knobs(proxy):
|
||||
"""D5/B4/M3: Settings dataclass exposes the previously-hardcoded knobs."""
|
||||
s = proxy.SETTINGS
|
||||
|
||||
Reference in New Issue
Block a user