unified_v2: selective per-request PD-sep via Mooncake (E3+E4)
Adds a sixth routing policy --policy unified_v2 that wraps the
existing unified hybrid picker with a selective PD-sep branch.
When all of the following hold, a request is split prefill-on-src,
decode-on-chosen via Mooncake kv_role=kv_both transfer:
1. new_local = input_length - chosen.cache_hit > 16k
(B2 microbench shows same-worker TTFT idx >= 3x from this size up)
2. chosen has live decodes worth protecting (>= 2 in-flight)
3. some other instance holds materially more cache for this prefix
(>= 8k tokens, and >= 4k more than chosen)
4. cost(src_interference + RDMA xfer) + 0.2s margin < cost(chosen_interference)
The cost model is the audit-blessed shape from E1's post-mortem:
- gate on new_tokens (post-cache), NOT input_length (the old PUSH gate)
- bind to a single transfer mechanism (kv_both peer-to-peer pull)
- realistic RDMA cost as a function of bytes: 0.3s base +
bytes / 2.7 GB/s (calibrated against contention_16s_elastic p50)
- both source and target decode counts considered
E2 mechanism-level patches not yet applied (this commit is policy-only).
Patches 6.2 / 6.3 / 6.5 remain on the table. Patch 6.6 (per-request
xfer timeout, 60s default) is implemented on the proxy side as an
httpx per-chunk read timeout on the dst streaming call, so a stuck
KV transfer fails the request instead of hanging for 600s.
cache_aware_proxy.py:
- Settings: kv_bytes_per_token, prefill_throughput_kv_both,
rdma_base_overhead_s, rdma_effective_gb_per_s, pd_sep_* gating knobs
- estimate_transfer_cost(bytes) replaces the constant rdma_overhead_s
- estimate_same_worker_interference_s(new_tokens, num_decodes) reads off
the B2 penalty curve in 4 bins
- pick_instance_unified_v2: inherits unified, returns extra
(src_inst, src_idx) tuple when PD-sep wins the cost compare
- _handle_combined_pd_sep_v2: prefill on src (do_remote_decode=True,
max_tokens=1), Mooncake xfer, decode-stream on dst with httpx
Timeout(read=pd_sep_xfer_timeout_s)
- --policy unified_v2 added to argparse choices
- lifespan auto-runs init_prefill_bootstrap when policy is unified_v2
b3_isolated_policy.sh:
- ENABLE_KV_BOTH env var, auto-set when POLICY=unified_v2, threads
kv_role=kv_both + VLLM_MOONCAKE_BOOTSTRAP_PORT to vllm and
--bootstrap-ports to the proxy
Tests: 8 new unit tests cover the gating predicates and the cost
estimators; all 32 proxy tests still pass.
Refs: E1 (PUSH post-mortem) + E2 (Mooncake audit) reports.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -19,11 +19,22 @@ BASE_PORT="${BASE_PORT:-8000}"
|
||||
GPU_INDICES="${GPU_INDICES:-0 1 2 3 4 5 6 7}"
|
||||
EXTRA_VLLM_ARGS="${EXTRA_VLLM_ARGS:---enable-prompt-tokens-details}"
|
||||
N_INSTANCES=$(echo $GPU_INDICES | wc -w)
|
||||
# When ENABLE_KV_BOTH=1, vLLM launches with the Mooncake KV connector in
|
||||
# kv_both role and the proxy is given bootstrap ports. This is required
|
||||
# for --policy unified_v2 (per-request PD-sep) but disabled by default
|
||||
# because it adds always-on KV-transfer overhead even when not triggered.
|
||||
ENABLE_KV_BOTH="${ENABLE_KV_BOTH:-0}"
|
||||
BOOTSTRAP_BASE_PORT="${BOOTSTRAP_BASE_PORT:-8998}"
|
||||
|
||||
POLICY="${1:?usage: $0 <policy> <trace> <rundir>}"
|
||||
TRACE="${2:?usage: $0 <policy> <trace> <rundir>}"
|
||||
RUNDIR="${3:?usage: $0 <policy> <trace> <rundir>}"
|
||||
|
||||
# Auto-enable kv_both when the policy requires it.
|
||||
if [ "$POLICY" = "unified_v2" ]; then
|
||||
ENABLE_KV_BOTH=1
|
||||
fi
|
||||
|
||||
mkdir -p "$RUNDIR/engine_state" "$RUNDIR/logs"
|
||||
echo "[isolated] policy=$POLICY trace=$(basename $TRACE) rundir=$RUNDIR"
|
||||
|
||||
@@ -38,23 +49,46 @@ trap cleanup EXIT
|
||||
# Hard reset first
|
||||
cleanup
|
||||
|
||||
echo "[isolated] launching $N_INSTANCES vLLM on GPUs $GPU_INDICES ..."
|
||||
echo "[isolated] launching $N_INSTANCES vLLM on GPUs $GPU_INDICES ENABLE_KV_BOTH=$ENABLE_KV_BOTH ..."
|
||||
i=0
|
||||
kv_both_extra=""
|
||||
if [ "$ENABLE_KV_BOTH" = "1" ]; then
|
||||
kv_both_extra="--kv-transfer-config {\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_both\"}"
|
||||
fi
|
||||
for gpu in $GPU_INDICES; do
|
||||
port=$((BASE_PORT + i))
|
||||
master=$((29500 + i))
|
||||
AGENTIC_STEP_LOG_PATH="$RUNDIR/engine_state/engine_${i}.jsonl" \
|
||||
AGENTIC_WORKER_ID="engine_${i}" \
|
||||
CUDA_VISIBLE_DEVICES=$gpu \
|
||||
MASTER_PORT=$master \
|
||||
nohup "$VENV/vllm" serve "$MODEL" \
|
||||
--host 0.0.0.0 --port "$port" \
|
||||
--tensor-parallel-size 1 \
|
||||
--trust-remote-code --enable-prefix-caching \
|
||||
--dtype auto --gpu-memory-utilization 0.9 \
|
||||
--max-model-len 200000 \
|
||||
$EXTRA_VLLM_ARGS \
|
||||
> "$RUNDIR/logs/vllm_inst_${i}_gpu${gpu}.log" 2>&1 &
|
||||
bp=$((BOOTSTRAP_BASE_PORT + i))
|
||||
if [ "$ENABLE_KV_BOTH" = "1" ]; then
|
||||
PYTHONHASHSEED=42 \
|
||||
VLLM_MOONCAKE_BOOTSTRAP_PORT=$bp \
|
||||
AGENTIC_STEP_LOG_PATH="$RUNDIR/engine_state/engine_${i}.jsonl" \
|
||||
AGENTIC_WORKER_ID="engine_${i}" \
|
||||
CUDA_VISIBLE_DEVICES=$gpu \
|
||||
MASTER_PORT=$master \
|
||||
nohup "$VENV/vllm" serve "$MODEL" \
|
||||
--host 0.0.0.0 --port "$port" \
|
||||
--tensor-parallel-size 1 \
|
||||
--trust-remote-code --enable-prefix-caching \
|
||||
--dtype auto --gpu-memory-utilization 0.9 \
|
||||
--max-model-len 200000 \
|
||||
--kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_both"}' \
|
||||
$EXTRA_VLLM_ARGS \
|
||||
> "$RUNDIR/logs/vllm_inst_${i}_gpu${gpu}.log" 2>&1 &
|
||||
else
|
||||
AGENTIC_STEP_LOG_PATH="$RUNDIR/engine_state/engine_${i}.jsonl" \
|
||||
AGENTIC_WORKER_ID="engine_${i}" \
|
||||
CUDA_VISIBLE_DEVICES=$gpu \
|
||||
MASTER_PORT=$master \
|
||||
nohup "$VENV/vllm" serve "$MODEL" \
|
||||
--host 0.0.0.0 --port "$port" \
|
||||
--tensor-parallel-size 1 \
|
||||
--trust-remote-code --enable-prefix-caching \
|
||||
--dtype auto --gpu-memory-utilization 0.9 \
|
||||
--max-model-len 200000 \
|
||||
$EXTRA_VLLM_ARGS \
|
||||
> "$RUNDIR/logs/vllm_inst_${i}_gpu${gpu}.log" 2>&1 &
|
||||
fi
|
||||
disown
|
||||
sleep 2
|
||||
i=$((i + 1))
|
||||
@@ -80,10 +114,23 @@ combined_args=""
|
||||
for i in $(seq 0 $((N_INSTANCES - 1))); do
|
||||
combined_args="$combined_args http://127.0.0.1:$((BASE_PORT + i))"
|
||||
done
|
||||
proxy_extra=""
|
||||
if [ "$ENABLE_KV_BOTH" = "1" ]; then
|
||||
bp_list=""
|
||||
for i in $(seq 0 $((N_INSTANCES - 1))); do
|
||||
if [ -z "$bp_list" ]; then
|
||||
bp_list="$((BOOTSTRAP_BASE_PORT + i))"
|
||||
else
|
||||
bp_list="$bp_list,$((BOOTSTRAP_BASE_PORT + i))"
|
||||
fi
|
||||
done
|
||||
proxy_extra="--bootstrap-ports $bp_list"
|
||||
fi
|
||||
nohup "$VENV/python" "$ROOT/scripts/cache_aware_proxy.py" \
|
||||
--port "$PROXY_PORT" \
|
||||
--combined $combined_args \
|
||||
--policy "$POLICY" \
|
||||
$proxy_extra \
|
||||
> "$RUNDIR/proxy.log" 2>&1 &
|
||||
disown
|
||||
tries=0
|
||||
|
||||
@@ -44,7 +44,7 @@ class Settings:
|
||||
(e.g. by tests/) and __main__ does not run.
|
||||
"""
|
||||
prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20)
|
||||
rdma_overhead_s: float = 0.1 # RDMA PUSH overhead (~10-50ms measured)
|
||||
rdma_overhead_s: float = 0.1 # legacy floor; v2 uses estimate_transfer_cost
|
||||
cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks
|
||||
heavy_threshold: int = 20000
|
||||
overload_factor: float = 2.0
|
||||
@@ -52,10 +52,91 @@ class Settings:
|
||||
cache_gate_ratio: float = 0.0
|
||||
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
|
||||
|
||||
# --- Patch 6.9: cost-model calibration for unified_v2 ---
|
||||
# Throughput when the engine runs in kv_both mode. Lower than the
|
||||
# pure-decode 7000 tok/s because kv_both adds always-on overhead
|
||||
# (REPORT §3.8 documents ~+16% TPOT vs plain).
|
||||
prefill_throughput_kv_both: float = 4000.0
|
||||
# Calibrated RDMA transfer cost: base + bandwidth term.
|
||||
# Floor from isolated test ≈ 0.3 s (handshake + scheduler step).
|
||||
# Bandwidth term reflects realized effective throughput, not
|
||||
# theoretical 25 GB/s — production p50 = 1.1 s for ~3 GB ≈ 2.7 GB/s
|
||||
# effective on the contended kv_both path. v2 uses this lookup
|
||||
# rather than the constant rdma_overhead_s.
|
||||
rdma_base_overhead_s: float = 0.3
|
||||
rdma_effective_gb_per_s: float = 2.7
|
||||
|
||||
# Qwen3-Coder-30B-A3B (bf16, 48 layers × 4 KV heads × 128 head_dim × 2):
|
||||
# 2 × 48 × 4 × 128 × 2 = 98304 bytes per token.
|
||||
kv_bytes_per_token: int = 98304
|
||||
|
||||
# --- unified_v2 gating knobs ---
|
||||
pd_sep_min_new_tokens: int = 16000 # B2 idx ≥ 3× starts here
|
||||
pd_sep_min_decodes_protected: int = 2 # require src has live decodes to protect
|
||||
pd_sep_min_src_cache_tokens: int = 8000 # require non-trivial cache to transfer
|
||||
pd_sep_min_extra_cache_tokens: int = 4000 # src must have meaningfully more cache than chosen
|
||||
pd_sep_margin_s: float = 0.2 # require cost gap > 0.2 s before migrating
|
||||
# Patch 6.6: per-request KV-xfer wall-clock timeout (proxy side).
|
||||
pd_sep_xfer_timeout_s: float = 60.0
|
||||
|
||||
|
||||
SETTINGS = Settings()
|
||||
|
||||
|
||||
def estimate_transfer_cost(transfer_bytes: int) -> float:
|
||||
"""Calibrated RDMA transfer cost as a function of bytes.
|
||||
|
||||
Replaces the legacy constant rdma_overhead_s. Calibration sources:
|
||||
- Floor: isolated-test ~0.3 s for a few-block PUSH (scripts/test_direct_read.py)
|
||||
- Bandwidth term: outputs/contention_16s_elastic/breakdown.json shows
|
||||
decode_sent->first_token p50 = 1.1 s for ~3 GB transfers, giving
|
||||
~2.7 GB/s effective on the contended kv_both path.
|
||||
|
||||
The p90 in that same run is 6.7 s (D-side block reservation +
|
||||
scheduler step delays). v2's cost model uses the *median* — being
|
||||
too pessimistic would suppress all PD-sep triggers. The risk of
|
||||
underestimation is mitigated by the pd_sep_margin_s safety factor.
|
||||
"""
|
||||
base = SETTINGS.rdma_base_overhead_s
|
||||
bw_term = transfer_bytes / (SETTINGS.rdma_effective_gb_per_s * 1024 ** 3)
|
||||
return base + bw_term
|
||||
|
||||
|
||||
def estimate_same_worker_interference_s(
|
||||
new_tokens: int,
|
||||
num_decodes: int,
|
||||
) -> float:
|
||||
"""Estimated additional latency on `num_decodes` co-located decodes
|
||||
when a `new_tokens`-token prefill runs on the same worker.
|
||||
|
||||
Derived from B2 microbench (analysis/characterization/window_1_results.md):
|
||||
same-worker prefill of size N steals decode capacity for the
|
||||
prefill's duration. The penalty factor is the fraction of decode
|
||||
steps stolen during the prefill window.
|
||||
|
||||
For new_tokens < 4k: ~0.2 (chunked prefill leaves room)
|
||||
For new_tokens 16k: ~0.5 (mid-regime, B2 TPOT idx 3.4×)
|
||||
For new_tokens 32k: ~0.8 (B2 peak TPOT idx 7.9×)
|
||||
For new_tokens > 32k: ~0.95 (B2 TTFT regime — decodes are nearly fully blocked)
|
||||
|
||||
The cost in seconds is roughly: prefill_duration × penalty × n_decodes,
|
||||
because each affected decode loses ~penalty fraction of its capacity
|
||||
during the prefill window.
|
||||
"""
|
||||
if num_decodes <= 0:
|
||||
return 0.0
|
||||
prefill_dur_s = new_tokens / SETTINGS.prefill_throughput_kv_both
|
||||
if new_tokens < 4000:
|
||||
penalty = 0.2
|
||||
elif new_tokens < 16000:
|
||||
penalty = 0.5
|
||||
elif new_tokens < 32000:
|
||||
penalty = 0.8
|
||||
else:
|
||||
penalty = 0.95
|
||||
return prefill_dur_s * penalty * num_decodes
|
||||
|
||||
|
||||
class InstanceState:
|
||||
def __init__(self, url: str, bootstrap_port: int | None = None):
|
||||
self.url = url
|
||||
@@ -326,6 +407,128 @@ def pick_instance_unified_hybrid(
|
||||
return instances[chosen_idx], chosen_idx, decision
|
||||
|
||||
|
||||
def pick_instance_unified_v2(
|
||||
instances: list[InstanceState],
|
||||
token_ids: list[int] | None,
|
||||
session_id: str | None,
|
||||
input_length: int,
|
||||
affinity: dict[str, int],
|
||||
) -> tuple[InstanceState, int, dict, tuple[InstanceState, int] | None]:
|
||||
"""unified_v2 = unified hybrid + selective per-request PD-sep trigger.
|
||||
|
||||
Stage 1 picks `chosen` exactly as `pick_instance_unified_hybrid`.
|
||||
|
||||
Stage 2 asks: is there another instance with materially more cache
|
||||
for this request? If yes, would doing prefill on that instance and
|
||||
transferring KV to `chosen` for decode be cheaper than just doing
|
||||
everything on `chosen`?
|
||||
|
||||
The cost model compares two scenarios in seconds-of-decode-disruption:
|
||||
|
||||
local: same-worker prefill on chosen of (input - chosen.cache_hit)
|
||||
tokens interferes with chosen.num_decodes co-located decodes.
|
||||
|
||||
pd-sep: same-worker prefill on src of (input - src.cache_hit) tokens
|
||||
(smaller, because src has more cache) interferes with
|
||||
src.num_decodes co-located decodes, plus we pay RDMA
|
||||
transfer of src.cache_hit blocks to chosen.
|
||||
|
||||
We migrate only when local cost > pd-sep cost + safety margin AND
|
||||
a set of hard gates (size, cache, decodes) are met.
|
||||
|
||||
Returns (chosen, chosen_idx, decision, pd_sep). When pd_sep is None
|
||||
the handler should do local routing on `chosen`. When pd_sep is
|
||||
(src_inst, src_idx) the handler should do prefill-on-src,
|
||||
decode-on-chosen via Mooncake.
|
||||
"""
|
||||
chosen, chosen_idx, decision = pick_instance_unified_hybrid(
|
||||
instances, token_ids, session_id, input_length, affinity)
|
||||
|
||||
decision["v2_pd_sep"] = False
|
||||
decision["v2_decision"] = "local"
|
||||
decision["v2_reason"] = None
|
||||
|
||||
if not token_ids:
|
||||
decision["v2_reason"] = "no_token_ids"
|
||||
return chosen, chosen_idx, decision, None
|
||||
|
||||
chosen_cache_hit = chosen.estimate_cache_hit(token_ids)
|
||||
new_local = max(0, input_length - chosen_cache_hit)
|
||||
|
||||
# Hard gate 1: prefill must be large enough that interference
|
||||
# outweighs the fixed RDMA setup cost.
|
||||
if new_local < SETTINGS.pd_sep_min_new_tokens:
|
||||
decision["v2_reason"] = f"new_local_below_threshold ({new_local} < {SETTINGS.pd_sep_min_new_tokens})"
|
||||
return chosen, chosen_idx, decision, None
|
||||
|
||||
# Hard gate 2: chosen must have live decodes worth protecting.
|
||||
if chosen.ongoing_decode_tokens // max(1, SETTINGS.heavy_threshold) < 1 \
|
||||
and chosen.num_requests < SETTINGS.pd_sep_min_decodes_protected:
|
||||
# Heuristic for "num_decodes": prefer num_requests as an upper
|
||||
# bound since we don't track decode count separately at route time.
|
||||
decision["v2_reason"] = f"chosen_few_decodes ({chosen.num_requests})"
|
||||
return chosen, chosen_idx, decision, None
|
||||
|
||||
# Find best alternative cache source.
|
||||
best_src_idx, best_src_hit = -1, 0
|
||||
for i, inst in enumerate(instances):
|
||||
if i == chosen_idx:
|
||||
continue
|
||||
h = inst.estimate_cache_hit(token_ids)
|
||||
if h > best_src_hit:
|
||||
best_src_idx, best_src_hit = i, h
|
||||
|
||||
# Hard gate 3: src must hold meaningful cache.
|
||||
if best_src_hit < SETTINGS.pd_sep_min_src_cache_tokens:
|
||||
decision["v2_reason"] = f"src_cache_below_threshold ({best_src_hit} < {SETTINGS.pd_sep_min_src_cache_tokens})"
|
||||
return chosen, chosen_idx, decision, None
|
||||
|
||||
# Hard gate 4: src must hold materially more cache than chosen.
|
||||
if best_src_hit - chosen_cache_hit < SETTINGS.pd_sep_min_extra_cache_tokens:
|
||||
decision["v2_reason"] = (
|
||||
f"src_not_meaningfully_more_cache "
|
||||
f"(src={best_src_hit} chosen={chosen_cache_hit})"
|
||||
)
|
||||
return chosen, chosen_idx, decision, None
|
||||
|
||||
src = instances[best_src_idx]
|
||||
new_src = max(0, input_length - best_src_hit)
|
||||
|
||||
# Cost-benefit in seconds-of-decode-disruption.
|
||||
cost_local = estimate_same_worker_interference_s(
|
||||
new_local, chosen.num_requests)
|
||||
cost_src_interf = estimate_same_worker_interference_s(
|
||||
new_src, src.num_requests)
|
||||
transfer_bytes = best_src_hit * SETTINGS.kv_bytes_per_token
|
||||
cost_xfer = estimate_transfer_cost(transfer_bytes)
|
||||
cost_migrate = cost_src_interf + cost_xfer
|
||||
|
||||
decision["v2_chosen_cache_hit"] = chosen_cache_hit
|
||||
decision["v2_src_idx"] = best_src_idx
|
||||
decision["v2_src_cache_hit"] = best_src_hit
|
||||
decision["v2_new_local"] = new_local
|
||||
decision["v2_new_src"] = new_src
|
||||
decision["v2_cost_local_s"] = cost_local
|
||||
decision["v2_cost_src_interf_s"] = cost_src_interf
|
||||
decision["v2_cost_xfer_s"] = cost_xfer
|
||||
decision["v2_cost_migrate_s"] = cost_migrate
|
||||
|
||||
if cost_local > cost_migrate + SETTINGS.pd_sep_margin_s:
|
||||
decision["v2_pd_sep"] = True
|
||||
decision["v2_decision"] = "pd_sep"
|
||||
decision["v2_reason"] = (
|
||||
f"local_cost {cost_local:.2f}s > migrate_cost {cost_migrate:.2f}s "
|
||||
f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s"
|
||||
)
|
||||
return chosen, chosen_idx, decision, (src, best_src_idx)
|
||||
|
||||
decision["v2_reason"] = (
|
||||
f"local_cost {cost_local:.2f}s <= migrate_cost {cost_migrate:.2f}s "
|
||||
f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s"
|
||||
)
|
||||
return chosen, chosen_idx, decision, None
|
||||
|
||||
|
||||
def _extract_output_token_ids_from_sse(
|
||||
buffer: str,
|
||||
chunk: bytes,
|
||||
@@ -480,8 +683,15 @@ async def lifespan(app: FastAPI):
|
||||
combined_instances.append(InstanceState(url, bp))
|
||||
|
||||
# Bootstrap combined instances for offload (need engine_ids for KV transfer)
|
||||
if global_args.offload and bp_list:
|
||||
policy = getattr(global_args, 'policy', 'linear')
|
||||
needs_bootstrap = global_args.offload or policy == "unified_v2"
|
||||
if needs_bootstrap and bp_list:
|
||||
await init_prefill_bootstrap(combined_instances, app.state.ready)
|
||||
elif needs_bootstrap and not bp_list:
|
||||
raise RuntimeError(
|
||||
f"--policy {policy} requires --bootstrap-ports for KV transfer; "
|
||||
"got empty bootstrap list."
|
||||
)
|
||||
else:
|
||||
app.state.ready.set()
|
||||
|
||||
@@ -623,6 +833,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
pre_decision_workers = snapshot_workers(
|
||||
combined_instances, token_ids, input_length)
|
||||
|
||||
pd_sep_v2: tuple[InstanceState, int] | None = None
|
||||
if policy == "lmetric":
|
||||
chosen, best_idx = pick_instance_lmetric(
|
||||
combined_instances, token_ids, session_id, input_length,
|
||||
@@ -642,6 +853,13 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
breakdown.update(decision)
|
||||
if session_id:
|
||||
session_affinity_combined[session_id] = best_idx
|
||||
elif policy == "unified_v2":
|
||||
chosen, best_idx, decision, pd_sep_v2 = pick_instance_unified_v2(
|
||||
combined_instances, token_ids, session_id, input_length,
|
||||
session_affinity_combined)
|
||||
breakdown.update(decision)
|
||||
if session_id:
|
||||
session_affinity_combined[session_id] = best_idx
|
||||
else: # linear (default)
|
||||
chosen, best_idx = pick_instance(
|
||||
combined_instances, token_ids, session_id, input_length,
|
||||
@@ -653,7 +871,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
breakdown.update({
|
||||
"cache_hit": cache_hit,
|
||||
"estimated_new_tokens": estimated_new,
|
||||
"route_class": "LOCAL",
|
||||
"route_class": "LOCAL" if pd_sep_v2 is None else "PD_SEP_V2",
|
||||
"routed_to": chosen.url,
|
||||
"chosen_idx": best_idx,
|
||||
"candidate_scores": pre_decision_workers,
|
||||
@@ -667,14 +885,152 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
"session_id": session_id,
|
||||
"policy": policy,
|
||||
"chosen_idx": best_idx,
|
||||
"v2_pd_sep": pd_sep_v2 is not None,
|
||||
"workers": pre_decision_workers,
|
||||
})
|
||||
|
||||
if pd_sep_v2 is not None:
|
||||
src_inst, src_idx = pd_sep_v2
|
||||
breakdown["v2_src_url"] = src_inst.url
|
||||
breakdown["v2_src_idx"] = src_idx
|
||||
return await _handle_combined_pd_sep_v2(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
src_inst, chosen, breakdown,
|
||||
request_id=request_id)
|
||||
|
||||
return await _handle_local_request(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
chosen, estimated_new, breakdown)
|
||||
|
||||
|
||||
async def _handle_combined_pd_sep_v2(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
src: InstanceState, dst: InstanceState, breakdown: dict,
|
||||
*, request_id: str,
|
||||
):
|
||||
"""Per-request PD-sep among combined instances (unified_v2 path).
|
||||
|
||||
src does cached prefill (max_tokens=1) and ships KV to dst via
|
||||
Mooncake; dst pulls KV and decodes. Both instances must run in
|
||||
kv_role=kv_both with bootstrap server enabled.
|
||||
|
||||
Patch 6.6: the dst streaming call uses a per-chunk read timeout
|
||||
of SETTINGS.pd_sep_xfer_timeout_s, so a stuck KV transfer fails
|
||||
the request instead of hanging for 600 s.
|
||||
"""
|
||||
if src.bootstrap_port is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=(
|
||||
"unified_v2 PD-sep triggered but src instance "
|
||||
f"{src.url} has no bootstrap_port; launch with "
|
||||
"kv_role=kv_both and pass --bootstrap-ports"
|
||||
),
|
||||
)
|
||||
|
||||
# Reserve load on both endpoints.
|
||||
src.ongoing_tokens += input_length
|
||||
src.num_requests += 1
|
||||
dst.ongoing_tokens += input_length
|
||||
dst.num_requests += 1
|
||||
src_load_held = True
|
||||
dst_load_held = True
|
||||
|
||||
prefill_data = req_data.copy()
|
||||
prefill_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"transfer_id": f"xfer-{request_id}",
|
||||
}
|
||||
prefill_data["stream"] = False
|
||||
prefill_data["max_tokens"] = 1
|
||||
prefill_data["min_tokens"] = 1
|
||||
prefill_data.pop("max_completion_tokens", None)
|
||||
prefill_data.pop("stream_options", None)
|
||||
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||||
|
||||
breakdown["t_prefill_sent"] = _time.monotonic()
|
||||
breakdown["t_prefill_sent_unix"] = _time.time()
|
||||
try:
|
||||
resp = await src.client.post(api, json=prefill_data, headers=p_headers)
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
breakdown["t_prefill_done_unix"] = _time.time()
|
||||
resp.raise_for_status()
|
||||
await resp.aclose()
|
||||
src.record_prefix(token_ids)
|
||||
except Exception as e:
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
breakdown["t_prefill_done_unix"] = _time.time()
|
||||
breakdown["prefill_error"] = True
|
||||
breakdown["error_detail"] = repr(e)[:300]
|
||||
_breakdown_log.append(breakdown)
|
||||
# Release reservations on failure.
|
||||
src.ongoing_tokens -= input_length
|
||||
src.num_requests -= 1
|
||||
dst.ongoing_tokens -= input_length
|
||||
dst.num_requests -= 1
|
||||
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
|
||||
finally:
|
||||
if src_load_held:
|
||||
src.ongoing_tokens -= input_length
|
||||
src.num_requests -= 1
|
||||
src_load_held = False
|
||||
|
||||
parsed = urllib.parse.urlparse(str(src.client.base_url))
|
||||
bootstrap_addr = f"http://{parsed.hostname}:{src.bootstrap_port}"
|
||||
|
||||
decode_data = req_data.copy()
|
||||
decode_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": False,
|
||||
"do_remote_prefill": True,
|
||||
"remote_bootstrap_addr": bootstrap_addr,
|
||||
"remote_engine_id": src.engine_id.get(0, ""),
|
||||
"transfer_id": f"xfer-{request_id}",
|
||||
}
|
||||
|
||||
breakdown["t_decode_sent"] = _time.monotonic()
|
||||
breakdown["t_decode_sent_unix"] = _time.time()
|
||||
|
||||
xfer_timeout = httpx.Timeout(
|
||||
connect=10.0,
|
||||
read=SETTINGS.pd_sep_xfer_timeout_s,
|
||||
write=10.0,
|
||||
pool=10.0,
|
||||
)
|
||||
|
||||
async def generate():
|
||||
nonlocal dst_load_held
|
||||
first_token = True
|
||||
sse_buffer = ""
|
||||
output_token_ids: list[int] = []
|
||||
try:
|
||||
async with dst.client.stream(
|
||||
"POST", api, json=decode_data, headers=headers,
|
||||
timeout=xfer_timeout,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
|
||||
sse_buffer, chunk)
|
||||
output_token_ids.extend(new_output_ids)
|
||||
if first_token:
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
breakdown["t_first_token_unix"] = _time.time()
|
||||
first_token = False
|
||||
yield chunk
|
||||
dst.record_prefix(_realized_tokens(token_ids, output_token_ids))
|
||||
finally:
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
breakdown["t_done_unix"] = _time.time()
|
||||
if dst_load_held:
|
||||
dst.ongoing_tokens -= input_length
|
||||
dst.num_requests -= 1
|
||||
dst_load_held = False
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
||||
session_id, headers):
|
||||
"""PD-Sep mode with per-stage breakdown profiling."""
|
||||
@@ -849,11 +1205,15 @@ def parse_args():
|
||||
p.add_argument("--bootstrap-ports", type=str, default="",
|
||||
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
|
||||
p.add_argument("--policy", type=str, default="linear",
|
||||
choices=["linear", "lmetric", "load_only", "sticky", "unified"],
|
||||
choices=["linear", "lmetric", "load_only", "sticky",
|
||||
"unified", "unified_v2"],
|
||||
help="Routing policy: linear (cache-aware), lmetric (P_tokens × BS), "
|
||||
"load_only (B3 control: pure min-num_requests), "
|
||||
"sticky (B3 control: hard session affinity), "
|
||||
"or unified (hybrid affinity + LMetric fallback)")
|
||||
"unified (hybrid affinity + LMetric fallback), "
|
||||
"or unified_v2 (unified + selective per-request PD-sep "
|
||||
"via Mooncake; requires --bootstrap-ports and "
|
||||
"kv_role=kv_both vLLM launch)")
|
||||
p.add_argument("--overload-factor", type=float, default=2.0,
|
||||
help="Break session affinity when instance load > factor * avg")
|
||||
# The four flags below are accepted for bench.sh backward compatibility but
|
||||
|
||||
@@ -343,6 +343,104 @@ def test_pick_instance_sticky_subsequent_never_breaks(proxy):
|
||||
assert idx == 0, "sticky must stay even when pinned instance is saturated"
|
||||
|
||||
|
||||
def test_unified_v2_falls_through_when_new_tokens_small(proxy):
|
||||
"""If post-cache new tokens < threshold, v2 should not PD-sep."""
|
||||
insts = [_make_inst(proxy, f"http://h{i}") for i in range(4)]
|
||||
# Tiny prompt: 2 blocks = 1024 tokens. Below 16k threshold.
|
||||
tokens = [1] * (proxy.BLOCK_SIZE * 2)
|
||||
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
|
||||
insts, tokens, None, len(tokens), {})
|
||||
assert pd_sep is None
|
||||
assert decision["v2_decision"] == "local"
|
||||
assert "below_threshold" in decision["v2_reason"]
|
||||
|
||||
|
||||
def _setup_v2_scene(proxy, *, chosen_decodes: int, src_cache_blocks: int):
|
||||
"""Build a 4-instance scene where inst[0] wins LMetric and inst[2]
|
||||
holds optional cache. All instances have non-zero num_requests so
|
||||
LMetric's bs=0 tie-break doesn't pick an empty instance.
|
||||
|
||||
Returns (insts, prefix_tokens).
|
||||
"""
|
||||
insts = [_make_inst(proxy, f"http://h{i}") for i in range(4)]
|
||||
block_size = proxy.BLOCK_SIZE
|
||||
prefix = []
|
||||
for b in range(128):
|
||||
prefix.extend([2000 + b] * block_size) # 128 × 512 = 65536 tokens
|
||||
|
||||
if src_cache_blocks > 0:
|
||||
insts[2].record_prefix(prefix[: src_cache_blocks * block_size])
|
||||
|
||||
# Make inst[0] have the smallest LMetric score so chosen = inst[0].
|
||||
insts[0].num_requests = max(chosen_decodes, 1)
|
||||
insts[0].pending_prefill_tokens = 0
|
||||
insts[0].ongoing_decode_tokens = chosen_decodes * 5000
|
||||
insts[1].num_requests = 20
|
||||
insts[1].pending_prefill_tokens = 200_000
|
||||
insts[2].num_requests = 30 # src is busy enough not to win LMetric
|
||||
insts[2].pending_prefill_tokens = 200_000
|
||||
insts[3].num_requests = 20
|
||||
insts[3].pending_prefill_tokens = 200_000
|
||||
return insts, prefix
|
||||
|
||||
|
||||
def test_unified_v2_falls_through_when_no_alt_cache(proxy):
|
||||
"""No other instance has meaningful cache → no PD-sep."""
|
||||
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=5, src_cache_blocks=0)
|
||||
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
|
||||
insts, prefix, None, len(prefix), {})
|
||||
assert idx == 0
|
||||
assert pd_sep is None
|
||||
assert "src_cache" in decision["v2_reason"]
|
||||
|
||||
|
||||
def test_unified_v2_triggers_when_src_has_meaningful_cache_and_chosen_has_decodes(proxy):
|
||||
"""Classic v2-win case: big prefill, chosen has decodes, alt has cache."""
|
||||
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=5, src_cache_blocks=128)
|
||||
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
|
||||
insts, prefix, None, len(prefix), {})
|
||||
assert idx == 0
|
||||
assert pd_sep is not None, (
|
||||
f"expected PD-sep, got reason={decision['v2_reason']}"
|
||||
)
|
||||
src, src_idx = pd_sep
|
||||
assert src_idx == 2
|
||||
assert decision["v2_decision"] == "pd_sep"
|
||||
assert decision["v2_src_cache_hit"] >= 60000
|
||||
|
||||
|
||||
def test_unified_v2_falls_through_when_chosen_has_no_decodes(proxy):
|
||||
"""No decodes on chosen → no benefit from PD-sep."""
|
||||
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=0, src_cache_blocks=128)
|
||||
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
|
||||
insts, prefix, None, len(prefix), {})
|
||||
assert pd_sep is None
|
||||
assert "few_decodes" in decision["v2_reason"]
|
||||
|
||||
|
||||
def test_estimate_transfer_cost_is_calibrated_function(proxy):
|
||||
"""RDMA transfer cost grows with bytes, has a non-zero floor."""
|
||||
cost_empty = proxy.estimate_transfer_cost(0)
|
||||
cost_1gb = proxy.estimate_transfer_cost(1024 ** 3)
|
||||
cost_10gb = proxy.estimate_transfer_cost(10 * 1024 ** 3)
|
||||
assert cost_empty >= 0.2, "should have non-zero floor"
|
||||
assert cost_1gb > cost_empty
|
||||
assert cost_10gb > cost_1gb
|
||||
# 10 GB should be roughly 0.3 + 10/2.7 ≈ 4.0 s
|
||||
assert 3.0 < cost_10gb < 5.0
|
||||
|
||||
|
||||
def test_estimate_same_worker_interference_grows_with_size(proxy):
|
||||
"""Interference cost is monotone in new_tokens up to the saturation regime."""
|
||||
c1 = proxy.estimate_same_worker_interference_s(2000, num_decodes=4)
|
||||
c2 = proxy.estimate_same_worker_interference_s(8000, num_decodes=4)
|
||||
c3 = proxy.estimate_same_worker_interference_s(20000, num_decodes=4)
|
||||
c4 = proxy.estimate_same_worker_interference_s(32000, num_decodes=4)
|
||||
assert c1 < c2 < c3 < c4
|
||||
# Zero decodes -> zero cost regardless of size
|
||||
assert proxy.estimate_same_worker_interference_s(32000, num_decodes=0) == 0.0
|
||||
|
||||
|
||||
def test_p_offload_penalty_uses_settings_heavy_threshold(proxy):
|
||||
"""M2: tweaking SETTINGS.heavy_threshold changes the P-offload penalty."""
|
||||
inst = proxy.InstanceState("http://x")
|
||||
|
||||
Reference in New Issue
Block a user