unified_v2: selective per-request PD-sep via Mooncake (E3+E4)
Adds a sixth routing policy --policy unified_v2 that wraps the
existing unified hybrid picker with a selective PD-sep branch.
When all of the following hold, a request is split prefill-on-src,
decode-on-chosen via Mooncake kv_role=kv_both transfer:
1. new_local = input_length - chosen.cache_hit > 16k
(B2 microbench shows same-worker TTFT idx >= 3x from this size up)
2. chosen has live decodes worth protecting (>= 2 in-flight)
3. some other instance holds materially more cache for this prefix
(>= 8k tokens, and >= 4k more than chosen)
4. cost(src_interference + RDMA xfer) + 0.2s margin < cost(chosen_interference)
The cost model is the audit-blessed shape from E1's post-mortem:
- gate on new_tokens (post-cache), NOT input_length (the old PUSH gate)
- bind to a single transfer mechanism (kv_both peer-to-peer pull)
- realistic RDMA cost as a function of bytes: 0.3s base +
bytes / 2.7 GB/s (calibrated against contention_16s_elastic p50)
- both source and target decode counts considered
E2 mechanism-level patches not yet applied (this commit is policy-only).
Patches 6.2 / 6.3 / 6.5 remain on the table. Patch 6.6 (per-request
xfer timeout, 60s default) is implemented on the proxy side as an
httpx per-chunk read timeout on the dst streaming call, so a stuck
KV transfer fails the request instead of hanging for 600s.
cache_aware_proxy.py:
- Settings: kv_bytes_per_token, prefill_throughput_kv_both,
rdma_base_overhead_s, rdma_effective_gb_per_s, pd_sep_* gating knobs
- estimate_transfer_cost(bytes) replaces the constant rdma_overhead_s
- estimate_same_worker_interference_s(new_tokens, num_decodes) reads off
the B2 penalty curve in 4 bins
- pick_instance_unified_v2: inherits unified, returns extra
(src_inst, src_idx) tuple when PD-sep wins the cost compare
- _handle_combined_pd_sep_v2: prefill on src (do_remote_decode=True,
max_tokens=1), Mooncake xfer, decode-stream on dst with httpx
Timeout(read=pd_sep_xfer_timeout_s)
- --policy unified_v2 added to argparse choices
- lifespan auto-runs init_prefill_bootstrap when policy is unified_v2
b3_isolated_policy.sh:
- ENABLE_KV_BOTH env var, auto-set when POLICY=unified_v2, threads
kv_role=kv_both + VLLM_MOONCAKE_BOOTSTRAP_PORT to vllm and
--bootstrap-ports to the proxy
Tests: 8 new unit tests cover the gating predicates and the cost
estimators; all 32 proxy tests still pass.
Refs: E1 (PUSH post-mortem) + E2 (Mooncake audit) reports.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -19,11 +19,22 @@ BASE_PORT="${BASE_PORT:-8000}"
|
|||||||
GPU_INDICES="${GPU_INDICES:-0 1 2 3 4 5 6 7}"
|
GPU_INDICES="${GPU_INDICES:-0 1 2 3 4 5 6 7}"
|
||||||
EXTRA_VLLM_ARGS="${EXTRA_VLLM_ARGS:---enable-prompt-tokens-details}"
|
EXTRA_VLLM_ARGS="${EXTRA_VLLM_ARGS:---enable-prompt-tokens-details}"
|
||||||
N_INSTANCES=$(echo $GPU_INDICES | wc -w)
|
N_INSTANCES=$(echo $GPU_INDICES | wc -w)
|
||||||
|
# When ENABLE_KV_BOTH=1, vLLM launches with the Mooncake KV connector in
|
||||||
|
# kv_both role and the proxy is given bootstrap ports. This is required
|
||||||
|
# for --policy unified_v2 (per-request PD-sep) but disabled by default
|
||||||
|
# because it adds always-on KV-transfer overhead even when not triggered.
|
||||||
|
ENABLE_KV_BOTH="${ENABLE_KV_BOTH:-0}"
|
||||||
|
BOOTSTRAP_BASE_PORT="${BOOTSTRAP_BASE_PORT:-8998}"
|
||||||
|
|
||||||
POLICY="${1:?usage: $0 <policy> <trace> <rundir>}"
|
POLICY="${1:?usage: $0 <policy> <trace> <rundir>}"
|
||||||
TRACE="${2:?usage: $0 <policy> <trace> <rundir>}"
|
TRACE="${2:?usage: $0 <policy> <trace> <rundir>}"
|
||||||
RUNDIR="${3:?usage: $0 <policy> <trace> <rundir>}"
|
RUNDIR="${3:?usage: $0 <policy> <trace> <rundir>}"
|
||||||
|
|
||||||
|
# Auto-enable kv_both when the policy requires it.
|
||||||
|
if [ "$POLICY" = "unified_v2" ]; then
|
||||||
|
ENABLE_KV_BOTH=1
|
||||||
|
fi
|
||||||
|
|
||||||
mkdir -p "$RUNDIR/engine_state" "$RUNDIR/logs"
|
mkdir -p "$RUNDIR/engine_state" "$RUNDIR/logs"
|
||||||
echo "[isolated] policy=$POLICY trace=$(basename $TRACE) rundir=$RUNDIR"
|
echo "[isolated] policy=$POLICY trace=$(basename $TRACE) rundir=$RUNDIR"
|
||||||
|
|
||||||
@@ -38,23 +49,46 @@ trap cleanup EXIT
|
|||||||
# Hard reset first
|
# Hard reset first
|
||||||
cleanup
|
cleanup
|
||||||
|
|
||||||
echo "[isolated] launching $N_INSTANCES vLLM on GPUs $GPU_INDICES ..."
|
echo "[isolated] launching $N_INSTANCES vLLM on GPUs $GPU_INDICES ENABLE_KV_BOTH=$ENABLE_KV_BOTH ..."
|
||||||
i=0
|
i=0
|
||||||
|
kv_both_extra=""
|
||||||
|
if [ "$ENABLE_KV_BOTH" = "1" ]; then
|
||||||
|
kv_both_extra="--kv-transfer-config {\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_both\"}"
|
||||||
|
fi
|
||||||
for gpu in $GPU_INDICES; do
|
for gpu in $GPU_INDICES; do
|
||||||
port=$((BASE_PORT + i))
|
port=$((BASE_PORT + i))
|
||||||
master=$((29500 + i))
|
master=$((29500 + i))
|
||||||
AGENTIC_STEP_LOG_PATH="$RUNDIR/engine_state/engine_${i}.jsonl" \
|
bp=$((BOOTSTRAP_BASE_PORT + i))
|
||||||
AGENTIC_WORKER_ID="engine_${i}" \
|
if [ "$ENABLE_KV_BOTH" = "1" ]; then
|
||||||
CUDA_VISIBLE_DEVICES=$gpu \
|
PYTHONHASHSEED=42 \
|
||||||
MASTER_PORT=$master \
|
VLLM_MOONCAKE_BOOTSTRAP_PORT=$bp \
|
||||||
nohup "$VENV/vllm" serve "$MODEL" \
|
AGENTIC_STEP_LOG_PATH="$RUNDIR/engine_state/engine_${i}.jsonl" \
|
||||||
--host 0.0.0.0 --port "$port" \
|
AGENTIC_WORKER_ID="engine_${i}" \
|
||||||
--tensor-parallel-size 1 \
|
CUDA_VISIBLE_DEVICES=$gpu \
|
||||||
--trust-remote-code --enable-prefix-caching \
|
MASTER_PORT=$master \
|
||||||
--dtype auto --gpu-memory-utilization 0.9 \
|
nohup "$VENV/vllm" serve "$MODEL" \
|
||||||
--max-model-len 200000 \
|
--host 0.0.0.0 --port "$port" \
|
||||||
$EXTRA_VLLM_ARGS \
|
--tensor-parallel-size 1 \
|
||||||
> "$RUNDIR/logs/vllm_inst_${i}_gpu${gpu}.log" 2>&1 &
|
--trust-remote-code --enable-prefix-caching \
|
||||||
|
--dtype auto --gpu-memory-utilization 0.9 \
|
||||||
|
--max-model-len 200000 \
|
||||||
|
--kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_both"}' \
|
||||||
|
$EXTRA_VLLM_ARGS \
|
||||||
|
> "$RUNDIR/logs/vllm_inst_${i}_gpu${gpu}.log" 2>&1 &
|
||||||
|
else
|
||||||
|
AGENTIC_STEP_LOG_PATH="$RUNDIR/engine_state/engine_${i}.jsonl" \
|
||||||
|
AGENTIC_WORKER_ID="engine_${i}" \
|
||||||
|
CUDA_VISIBLE_DEVICES=$gpu \
|
||||||
|
MASTER_PORT=$master \
|
||||||
|
nohup "$VENV/vllm" serve "$MODEL" \
|
||||||
|
--host 0.0.0.0 --port "$port" \
|
||||||
|
--tensor-parallel-size 1 \
|
||||||
|
--trust-remote-code --enable-prefix-caching \
|
||||||
|
--dtype auto --gpu-memory-utilization 0.9 \
|
||||||
|
--max-model-len 200000 \
|
||||||
|
$EXTRA_VLLM_ARGS \
|
||||||
|
> "$RUNDIR/logs/vllm_inst_${i}_gpu${gpu}.log" 2>&1 &
|
||||||
|
fi
|
||||||
disown
|
disown
|
||||||
sleep 2
|
sleep 2
|
||||||
i=$((i + 1))
|
i=$((i + 1))
|
||||||
@@ -80,10 +114,23 @@ combined_args=""
|
|||||||
for i in $(seq 0 $((N_INSTANCES - 1))); do
|
for i in $(seq 0 $((N_INSTANCES - 1))); do
|
||||||
combined_args="$combined_args http://127.0.0.1:$((BASE_PORT + i))"
|
combined_args="$combined_args http://127.0.0.1:$((BASE_PORT + i))"
|
||||||
done
|
done
|
||||||
|
proxy_extra=""
|
||||||
|
if [ "$ENABLE_KV_BOTH" = "1" ]; then
|
||||||
|
bp_list=""
|
||||||
|
for i in $(seq 0 $((N_INSTANCES - 1))); do
|
||||||
|
if [ -z "$bp_list" ]; then
|
||||||
|
bp_list="$((BOOTSTRAP_BASE_PORT + i))"
|
||||||
|
else
|
||||||
|
bp_list="$bp_list,$((BOOTSTRAP_BASE_PORT + i))"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
proxy_extra="--bootstrap-ports $bp_list"
|
||||||
|
fi
|
||||||
nohup "$VENV/python" "$ROOT/scripts/cache_aware_proxy.py" \
|
nohup "$VENV/python" "$ROOT/scripts/cache_aware_proxy.py" \
|
||||||
--port "$PROXY_PORT" \
|
--port "$PROXY_PORT" \
|
||||||
--combined $combined_args \
|
--combined $combined_args \
|
||||||
--policy "$POLICY" \
|
--policy "$POLICY" \
|
||||||
|
$proxy_extra \
|
||||||
> "$RUNDIR/proxy.log" 2>&1 &
|
> "$RUNDIR/proxy.log" 2>&1 &
|
||||||
disown
|
disown
|
||||||
tries=0
|
tries=0
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class Settings:
|
|||||||
(e.g. by tests/) and __main__ does not run.
|
(e.g. by tests/) and __main__ does not run.
|
||||||
"""
|
"""
|
||||||
prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20)
|
prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20)
|
||||||
rdma_overhead_s: float = 0.1 # RDMA PUSH overhead (~10-50ms measured)
|
rdma_overhead_s: float = 0.1 # legacy floor; v2 uses estimate_transfer_cost
|
||||||
cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks
|
cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks
|
||||||
heavy_threshold: int = 20000
|
heavy_threshold: int = 20000
|
||||||
overload_factor: float = 2.0
|
overload_factor: float = 2.0
|
||||||
@@ -52,10 +52,91 @@ class Settings:
|
|||||||
cache_gate_ratio: float = 0.0
|
cache_gate_ratio: float = 0.0
|
||||||
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
|
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
|
||||||
|
|
||||||
|
# --- Patch 6.9: cost-model calibration for unified_v2 ---
|
||||||
|
# Throughput when the engine runs in kv_both mode. Lower than the
|
||||||
|
# pure-decode 7000 tok/s because kv_both adds always-on overhead
|
||||||
|
# (REPORT §3.8 documents ~+16% TPOT vs plain).
|
||||||
|
prefill_throughput_kv_both: float = 4000.0
|
||||||
|
# Calibrated RDMA transfer cost: base + bandwidth term.
|
||||||
|
# Floor from isolated test ≈ 0.3 s (handshake + scheduler step).
|
||||||
|
# Bandwidth term reflects realized effective throughput, not
|
||||||
|
# theoretical 25 GB/s — production p50 = 1.1 s for ~3 GB ≈ 2.7 GB/s
|
||||||
|
# effective on the contended kv_both path. v2 uses this lookup
|
||||||
|
# rather than the constant rdma_overhead_s.
|
||||||
|
rdma_base_overhead_s: float = 0.3
|
||||||
|
rdma_effective_gb_per_s: float = 2.7
|
||||||
|
|
||||||
|
# Qwen3-Coder-30B-A3B (bf16, 48 layers × 4 KV heads × 128 head_dim × 2):
|
||||||
|
# 2 × 48 × 4 × 128 × 2 = 98304 bytes per token.
|
||||||
|
kv_bytes_per_token: int = 98304
|
||||||
|
|
||||||
|
# --- unified_v2 gating knobs ---
|
||||||
|
pd_sep_min_new_tokens: int = 16000 # B2 idx ≥ 3× starts here
|
||||||
|
pd_sep_min_decodes_protected: int = 2 # require src has live decodes to protect
|
||||||
|
pd_sep_min_src_cache_tokens: int = 8000 # require non-trivial cache to transfer
|
||||||
|
pd_sep_min_extra_cache_tokens: int = 4000 # src must have meaningfully more cache than chosen
|
||||||
|
pd_sep_margin_s: float = 0.2 # require cost gap > 0.2 s before migrating
|
||||||
|
# Patch 6.6: per-request KV-xfer wall-clock timeout (proxy side).
|
||||||
|
pd_sep_xfer_timeout_s: float = 60.0
|
||||||
|
|
||||||
|
|
||||||
SETTINGS = Settings()
|
SETTINGS = Settings()
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_transfer_cost(transfer_bytes: int) -> float:
|
||||||
|
"""Calibrated RDMA transfer cost as a function of bytes.
|
||||||
|
|
||||||
|
Replaces the legacy constant rdma_overhead_s. Calibration sources:
|
||||||
|
- Floor: isolated-test ~0.3 s for a few-block PUSH (scripts/test_direct_read.py)
|
||||||
|
- Bandwidth term: outputs/contention_16s_elastic/breakdown.json shows
|
||||||
|
decode_sent->first_token p50 = 1.1 s for ~3 GB transfers, giving
|
||||||
|
~2.7 GB/s effective on the contended kv_both path.
|
||||||
|
|
||||||
|
The p90 in that same run is 6.7 s (D-side block reservation +
|
||||||
|
scheduler step delays). v2's cost model uses the *median* — being
|
||||||
|
too pessimistic would suppress all PD-sep triggers. The risk of
|
||||||
|
underestimation is mitigated by the pd_sep_margin_s safety factor.
|
||||||
|
"""
|
||||||
|
base = SETTINGS.rdma_base_overhead_s
|
||||||
|
bw_term = transfer_bytes / (SETTINGS.rdma_effective_gb_per_s * 1024 ** 3)
|
||||||
|
return base + bw_term
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_same_worker_interference_s(
|
||||||
|
new_tokens: int,
|
||||||
|
num_decodes: int,
|
||||||
|
) -> float:
|
||||||
|
"""Estimated additional latency on `num_decodes` co-located decodes
|
||||||
|
when a `new_tokens`-token prefill runs on the same worker.
|
||||||
|
|
||||||
|
Derived from B2 microbench (analysis/characterization/window_1_results.md):
|
||||||
|
same-worker prefill of size N steals decode capacity for the
|
||||||
|
prefill's duration. The penalty factor is the fraction of decode
|
||||||
|
steps stolen during the prefill window.
|
||||||
|
|
||||||
|
For new_tokens < 4k: ~0.2 (chunked prefill leaves room)
|
||||||
|
For new_tokens 16k: ~0.5 (mid-regime, B2 TPOT idx 3.4×)
|
||||||
|
For new_tokens 32k: ~0.8 (B2 peak TPOT idx 7.9×)
|
||||||
|
For new_tokens > 32k: ~0.95 (B2 TTFT regime — decodes are nearly fully blocked)
|
||||||
|
|
||||||
|
The cost in seconds is roughly: prefill_duration × penalty × n_decodes,
|
||||||
|
because each affected decode loses ~penalty fraction of its capacity
|
||||||
|
during the prefill window.
|
||||||
|
"""
|
||||||
|
if num_decodes <= 0:
|
||||||
|
return 0.0
|
||||||
|
prefill_dur_s = new_tokens / SETTINGS.prefill_throughput_kv_both
|
||||||
|
if new_tokens < 4000:
|
||||||
|
penalty = 0.2
|
||||||
|
elif new_tokens < 16000:
|
||||||
|
penalty = 0.5
|
||||||
|
elif new_tokens < 32000:
|
||||||
|
penalty = 0.8
|
||||||
|
else:
|
||||||
|
penalty = 0.95
|
||||||
|
return prefill_dur_s * penalty * num_decodes
|
||||||
|
|
||||||
|
|
||||||
class InstanceState:
|
class InstanceState:
|
||||||
def __init__(self, url: str, bootstrap_port: int | None = None):
|
def __init__(self, url: str, bootstrap_port: int | None = None):
|
||||||
self.url = url
|
self.url = url
|
||||||
@@ -326,6 +407,128 @@ def pick_instance_unified_hybrid(
|
|||||||
return instances[chosen_idx], chosen_idx, decision
|
return instances[chosen_idx], chosen_idx, decision
|
||||||
|
|
||||||
|
|
||||||
|
def pick_instance_unified_v2(
|
||||||
|
instances: list[InstanceState],
|
||||||
|
token_ids: list[int] | None,
|
||||||
|
session_id: str | None,
|
||||||
|
input_length: int,
|
||||||
|
affinity: dict[str, int],
|
||||||
|
) -> tuple[InstanceState, int, dict, tuple[InstanceState, int] | None]:
|
||||||
|
"""unified_v2 = unified hybrid + selective per-request PD-sep trigger.
|
||||||
|
|
||||||
|
Stage 1 picks `chosen` exactly as `pick_instance_unified_hybrid`.
|
||||||
|
|
||||||
|
Stage 2 asks: is there another instance with materially more cache
|
||||||
|
for this request? If yes, would doing prefill on that instance and
|
||||||
|
transferring KV to `chosen` for decode be cheaper than just doing
|
||||||
|
everything on `chosen`?
|
||||||
|
|
||||||
|
The cost model compares two scenarios in seconds-of-decode-disruption:
|
||||||
|
|
||||||
|
local: same-worker prefill on chosen of (input - chosen.cache_hit)
|
||||||
|
tokens interferes with chosen.num_decodes co-located decodes.
|
||||||
|
|
||||||
|
pd-sep: same-worker prefill on src of (input - src.cache_hit) tokens
|
||||||
|
(smaller, because src has more cache) interferes with
|
||||||
|
src.num_decodes co-located decodes, plus we pay RDMA
|
||||||
|
transfer of src.cache_hit blocks to chosen.
|
||||||
|
|
||||||
|
We migrate only when local cost > pd-sep cost + safety margin AND
|
||||||
|
a set of hard gates (size, cache, decodes) are met.
|
||||||
|
|
||||||
|
Returns (chosen, chosen_idx, decision, pd_sep). When pd_sep is None
|
||||||
|
the handler should do local routing on `chosen`. When pd_sep is
|
||||||
|
(src_inst, src_idx) the handler should do prefill-on-src,
|
||||||
|
decode-on-chosen via Mooncake.
|
||||||
|
"""
|
||||||
|
chosen, chosen_idx, decision = pick_instance_unified_hybrid(
|
||||||
|
instances, token_ids, session_id, input_length, affinity)
|
||||||
|
|
||||||
|
decision["v2_pd_sep"] = False
|
||||||
|
decision["v2_decision"] = "local"
|
||||||
|
decision["v2_reason"] = None
|
||||||
|
|
||||||
|
if not token_ids:
|
||||||
|
decision["v2_reason"] = "no_token_ids"
|
||||||
|
return chosen, chosen_idx, decision, None
|
||||||
|
|
||||||
|
chosen_cache_hit = chosen.estimate_cache_hit(token_ids)
|
||||||
|
new_local = max(0, input_length - chosen_cache_hit)
|
||||||
|
|
||||||
|
# Hard gate 1: prefill must be large enough that interference
|
||||||
|
# outweighs the fixed RDMA setup cost.
|
||||||
|
if new_local < SETTINGS.pd_sep_min_new_tokens:
|
||||||
|
decision["v2_reason"] = f"new_local_below_threshold ({new_local} < {SETTINGS.pd_sep_min_new_tokens})"
|
||||||
|
return chosen, chosen_idx, decision, None
|
||||||
|
|
||||||
|
# Hard gate 2: chosen must have live decodes worth protecting.
|
||||||
|
if chosen.ongoing_decode_tokens // max(1, SETTINGS.heavy_threshold) < 1 \
|
||||||
|
and chosen.num_requests < SETTINGS.pd_sep_min_decodes_protected:
|
||||||
|
# Heuristic for "num_decodes": prefer num_requests as an upper
|
||||||
|
# bound since we don't track decode count separately at route time.
|
||||||
|
decision["v2_reason"] = f"chosen_few_decodes ({chosen.num_requests})"
|
||||||
|
return chosen, chosen_idx, decision, None
|
||||||
|
|
||||||
|
# Find best alternative cache source.
|
||||||
|
best_src_idx, best_src_hit = -1, 0
|
||||||
|
for i, inst in enumerate(instances):
|
||||||
|
if i == chosen_idx:
|
||||||
|
continue
|
||||||
|
h = inst.estimate_cache_hit(token_ids)
|
||||||
|
if h > best_src_hit:
|
||||||
|
best_src_idx, best_src_hit = i, h
|
||||||
|
|
||||||
|
# Hard gate 3: src must hold meaningful cache.
|
||||||
|
if best_src_hit < SETTINGS.pd_sep_min_src_cache_tokens:
|
||||||
|
decision["v2_reason"] = f"src_cache_below_threshold ({best_src_hit} < {SETTINGS.pd_sep_min_src_cache_tokens})"
|
||||||
|
return chosen, chosen_idx, decision, None
|
||||||
|
|
||||||
|
# Hard gate 4: src must hold materially more cache than chosen.
|
||||||
|
if best_src_hit - chosen_cache_hit < SETTINGS.pd_sep_min_extra_cache_tokens:
|
||||||
|
decision["v2_reason"] = (
|
||||||
|
f"src_not_meaningfully_more_cache "
|
||||||
|
f"(src={best_src_hit} chosen={chosen_cache_hit})"
|
||||||
|
)
|
||||||
|
return chosen, chosen_idx, decision, None
|
||||||
|
|
||||||
|
src = instances[best_src_idx]
|
||||||
|
new_src = max(0, input_length - best_src_hit)
|
||||||
|
|
||||||
|
# Cost-benefit in seconds-of-decode-disruption.
|
||||||
|
cost_local = estimate_same_worker_interference_s(
|
||||||
|
new_local, chosen.num_requests)
|
||||||
|
cost_src_interf = estimate_same_worker_interference_s(
|
||||||
|
new_src, src.num_requests)
|
||||||
|
transfer_bytes = best_src_hit * SETTINGS.kv_bytes_per_token
|
||||||
|
cost_xfer = estimate_transfer_cost(transfer_bytes)
|
||||||
|
cost_migrate = cost_src_interf + cost_xfer
|
||||||
|
|
||||||
|
decision["v2_chosen_cache_hit"] = chosen_cache_hit
|
||||||
|
decision["v2_src_idx"] = best_src_idx
|
||||||
|
decision["v2_src_cache_hit"] = best_src_hit
|
||||||
|
decision["v2_new_local"] = new_local
|
||||||
|
decision["v2_new_src"] = new_src
|
||||||
|
decision["v2_cost_local_s"] = cost_local
|
||||||
|
decision["v2_cost_src_interf_s"] = cost_src_interf
|
||||||
|
decision["v2_cost_xfer_s"] = cost_xfer
|
||||||
|
decision["v2_cost_migrate_s"] = cost_migrate
|
||||||
|
|
||||||
|
if cost_local > cost_migrate + SETTINGS.pd_sep_margin_s:
|
||||||
|
decision["v2_pd_sep"] = True
|
||||||
|
decision["v2_decision"] = "pd_sep"
|
||||||
|
decision["v2_reason"] = (
|
||||||
|
f"local_cost {cost_local:.2f}s > migrate_cost {cost_migrate:.2f}s "
|
||||||
|
f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s"
|
||||||
|
)
|
||||||
|
return chosen, chosen_idx, decision, (src, best_src_idx)
|
||||||
|
|
||||||
|
decision["v2_reason"] = (
|
||||||
|
f"local_cost {cost_local:.2f}s <= migrate_cost {cost_migrate:.2f}s "
|
||||||
|
f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s"
|
||||||
|
)
|
||||||
|
return chosen, chosen_idx, decision, None
|
||||||
|
|
||||||
|
|
||||||
def _extract_output_token_ids_from_sse(
|
def _extract_output_token_ids_from_sse(
|
||||||
buffer: str,
|
buffer: str,
|
||||||
chunk: bytes,
|
chunk: bytes,
|
||||||
@@ -480,8 +683,15 @@ async def lifespan(app: FastAPI):
|
|||||||
combined_instances.append(InstanceState(url, bp))
|
combined_instances.append(InstanceState(url, bp))
|
||||||
|
|
||||||
# Bootstrap combined instances for offload (need engine_ids for KV transfer)
|
# Bootstrap combined instances for offload (need engine_ids for KV transfer)
|
||||||
if global_args.offload and bp_list:
|
policy = getattr(global_args, 'policy', 'linear')
|
||||||
|
needs_bootstrap = global_args.offload or policy == "unified_v2"
|
||||||
|
if needs_bootstrap and bp_list:
|
||||||
await init_prefill_bootstrap(combined_instances, app.state.ready)
|
await init_prefill_bootstrap(combined_instances, app.state.ready)
|
||||||
|
elif needs_bootstrap and not bp_list:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"--policy {policy} requires --bootstrap-ports for KV transfer; "
|
||||||
|
"got empty bootstrap list."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
app.state.ready.set()
|
app.state.ready.set()
|
||||||
|
|
||||||
@@ -623,6 +833,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
|||||||
pre_decision_workers = snapshot_workers(
|
pre_decision_workers = snapshot_workers(
|
||||||
combined_instances, token_ids, input_length)
|
combined_instances, token_ids, input_length)
|
||||||
|
|
||||||
|
pd_sep_v2: tuple[InstanceState, int] | None = None
|
||||||
if policy == "lmetric":
|
if policy == "lmetric":
|
||||||
chosen, best_idx = pick_instance_lmetric(
|
chosen, best_idx = pick_instance_lmetric(
|
||||||
combined_instances, token_ids, session_id, input_length,
|
combined_instances, token_ids, session_id, input_length,
|
||||||
@@ -642,6 +853,13 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
|||||||
breakdown.update(decision)
|
breakdown.update(decision)
|
||||||
if session_id:
|
if session_id:
|
||||||
session_affinity_combined[session_id] = best_idx
|
session_affinity_combined[session_id] = best_idx
|
||||||
|
elif policy == "unified_v2":
|
||||||
|
chosen, best_idx, decision, pd_sep_v2 = pick_instance_unified_v2(
|
||||||
|
combined_instances, token_ids, session_id, input_length,
|
||||||
|
session_affinity_combined)
|
||||||
|
breakdown.update(decision)
|
||||||
|
if session_id:
|
||||||
|
session_affinity_combined[session_id] = best_idx
|
||||||
else: # linear (default)
|
else: # linear (default)
|
||||||
chosen, best_idx = pick_instance(
|
chosen, best_idx = pick_instance(
|
||||||
combined_instances, token_ids, session_id, input_length,
|
combined_instances, token_ids, session_id, input_length,
|
||||||
@@ -653,7 +871,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
|||||||
breakdown.update({
|
breakdown.update({
|
||||||
"cache_hit": cache_hit,
|
"cache_hit": cache_hit,
|
||||||
"estimated_new_tokens": estimated_new,
|
"estimated_new_tokens": estimated_new,
|
||||||
"route_class": "LOCAL",
|
"route_class": "LOCAL" if pd_sep_v2 is None else "PD_SEP_V2",
|
||||||
"routed_to": chosen.url,
|
"routed_to": chosen.url,
|
||||||
"chosen_idx": best_idx,
|
"chosen_idx": best_idx,
|
||||||
"candidate_scores": pre_decision_workers,
|
"candidate_scores": pre_decision_workers,
|
||||||
@@ -667,14 +885,152 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
|||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"policy": policy,
|
"policy": policy,
|
||||||
"chosen_idx": best_idx,
|
"chosen_idx": best_idx,
|
||||||
|
"v2_pd_sep": pd_sep_v2 is not None,
|
||||||
"workers": pre_decision_workers,
|
"workers": pre_decision_workers,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if pd_sep_v2 is not None:
|
||||||
|
src_inst, src_idx = pd_sep_v2
|
||||||
|
breakdown["v2_src_url"] = src_inst.url
|
||||||
|
breakdown["v2_src_idx"] = src_idx
|
||||||
|
return await _handle_combined_pd_sep_v2(
|
||||||
|
api, req_data, headers, token_ids, input_length,
|
||||||
|
src_inst, chosen, breakdown,
|
||||||
|
request_id=request_id)
|
||||||
|
|
||||||
return await _handle_local_request(
|
return await _handle_local_request(
|
||||||
api, req_data, headers, token_ids, input_length,
|
api, req_data, headers, token_ids, input_length,
|
||||||
chosen, estimated_new, breakdown)
|
chosen, estimated_new, breakdown)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_combined_pd_sep_v2(
|
||||||
|
api, req_data, headers, token_ids, input_length,
|
||||||
|
src: InstanceState, dst: InstanceState, breakdown: dict,
|
||||||
|
*, request_id: str,
|
||||||
|
):
|
||||||
|
"""Per-request PD-sep among combined instances (unified_v2 path).
|
||||||
|
|
||||||
|
src does cached prefill (max_tokens=1) and ships KV to dst via
|
||||||
|
Mooncake; dst pulls KV and decodes. Both instances must run in
|
||||||
|
kv_role=kv_both with bootstrap server enabled.
|
||||||
|
|
||||||
|
Patch 6.6: the dst streaming call uses a per-chunk read timeout
|
||||||
|
of SETTINGS.pd_sep_xfer_timeout_s, so a stuck KV transfer fails
|
||||||
|
the request instead of hanging for 600 s.
|
||||||
|
"""
|
||||||
|
if src.bootstrap_port is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=(
|
||||||
|
"unified_v2 PD-sep triggered but src instance "
|
||||||
|
f"{src.url} has no bootstrap_port; launch with "
|
||||||
|
"kv_role=kv_both and pass --bootstrap-ports"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reserve load on both endpoints.
|
||||||
|
src.ongoing_tokens += input_length
|
||||||
|
src.num_requests += 1
|
||||||
|
dst.ongoing_tokens += input_length
|
||||||
|
dst.num_requests += 1
|
||||||
|
src_load_held = True
|
||||||
|
dst_load_held = True
|
||||||
|
|
||||||
|
prefill_data = req_data.copy()
|
||||||
|
prefill_data["kv_transfer_params"] = {
|
||||||
|
"do_remote_decode": True,
|
||||||
|
"do_remote_prefill": False,
|
||||||
|
"transfer_id": f"xfer-{request_id}",
|
||||||
|
}
|
||||||
|
prefill_data["stream"] = False
|
||||||
|
prefill_data["max_tokens"] = 1
|
||||||
|
prefill_data["min_tokens"] = 1
|
||||||
|
prefill_data.pop("max_completion_tokens", None)
|
||||||
|
prefill_data.pop("stream_options", None)
|
||||||
|
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||||||
|
|
||||||
|
breakdown["t_prefill_sent"] = _time.monotonic()
|
||||||
|
breakdown["t_prefill_sent_unix"] = _time.time()
|
||||||
|
try:
|
||||||
|
resp = await src.client.post(api, json=prefill_data, headers=p_headers)
|
||||||
|
breakdown["t_prefill_done"] = _time.monotonic()
|
||||||
|
breakdown["t_prefill_done_unix"] = _time.time()
|
||||||
|
resp.raise_for_status()
|
||||||
|
await resp.aclose()
|
||||||
|
src.record_prefix(token_ids)
|
||||||
|
except Exception as e:
|
||||||
|
breakdown["t_prefill_done"] = _time.monotonic()
|
||||||
|
breakdown["t_prefill_done_unix"] = _time.time()
|
||||||
|
breakdown["prefill_error"] = True
|
||||||
|
breakdown["error_detail"] = repr(e)[:300]
|
||||||
|
_breakdown_log.append(breakdown)
|
||||||
|
# Release reservations on failure.
|
||||||
|
src.ongoing_tokens -= input_length
|
||||||
|
src.num_requests -= 1
|
||||||
|
dst.ongoing_tokens -= input_length
|
||||||
|
dst.num_requests -= 1
|
||||||
|
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
|
||||||
|
finally:
|
||||||
|
if src_load_held:
|
||||||
|
src.ongoing_tokens -= input_length
|
||||||
|
src.num_requests -= 1
|
||||||
|
src_load_held = False
|
||||||
|
|
||||||
|
parsed = urllib.parse.urlparse(str(src.client.base_url))
|
||||||
|
bootstrap_addr = f"http://{parsed.hostname}:{src.bootstrap_port}"
|
||||||
|
|
||||||
|
decode_data = req_data.copy()
|
||||||
|
decode_data["kv_transfer_params"] = {
|
||||||
|
"do_remote_decode": False,
|
||||||
|
"do_remote_prefill": True,
|
||||||
|
"remote_bootstrap_addr": bootstrap_addr,
|
||||||
|
"remote_engine_id": src.engine_id.get(0, ""),
|
||||||
|
"transfer_id": f"xfer-{request_id}",
|
||||||
|
}
|
||||||
|
|
||||||
|
breakdown["t_decode_sent"] = _time.monotonic()
|
||||||
|
breakdown["t_decode_sent_unix"] = _time.time()
|
||||||
|
|
||||||
|
xfer_timeout = httpx.Timeout(
|
||||||
|
connect=10.0,
|
||||||
|
read=SETTINGS.pd_sep_xfer_timeout_s,
|
||||||
|
write=10.0,
|
||||||
|
pool=10.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate():
|
||||||
|
nonlocal dst_load_held
|
||||||
|
first_token = True
|
||||||
|
sse_buffer = ""
|
||||||
|
output_token_ids: list[int] = []
|
||||||
|
try:
|
||||||
|
async with dst.client.stream(
|
||||||
|
"POST", api, json=decode_data, headers=headers,
|
||||||
|
timeout=xfer_timeout,
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
async for chunk in resp.aiter_bytes():
|
||||||
|
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
|
||||||
|
sse_buffer, chunk)
|
||||||
|
output_token_ids.extend(new_output_ids)
|
||||||
|
if first_token:
|
||||||
|
breakdown["t_first_token"] = _time.monotonic()
|
||||||
|
breakdown["t_first_token_unix"] = _time.time()
|
||||||
|
first_token = False
|
||||||
|
yield chunk
|
||||||
|
dst.record_prefix(_realized_tokens(token_ids, output_token_ids))
|
||||||
|
finally:
|
||||||
|
breakdown["t_done"] = _time.monotonic()
|
||||||
|
breakdown["t_done_unix"] = _time.time()
|
||||||
|
if dst_load_held:
|
||||||
|
dst.ongoing_tokens -= input_length
|
||||||
|
dst.num_requests -= 1
|
||||||
|
dst_load_held = False
|
||||||
|
_breakdown_log.append(breakdown)
|
||||||
|
|
||||||
|
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
||||||
session_id, headers):
|
session_id, headers):
|
||||||
"""PD-Sep mode with per-stage breakdown profiling."""
|
"""PD-Sep mode with per-stage breakdown profiling."""
|
||||||
@@ -849,11 +1205,15 @@ def parse_args():
|
|||||||
p.add_argument("--bootstrap-ports", type=str, default="",
|
p.add_argument("--bootstrap-ports", type=str, default="",
|
||||||
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
|
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
|
||||||
p.add_argument("--policy", type=str, default="linear",
|
p.add_argument("--policy", type=str, default="linear",
|
||||||
choices=["linear", "lmetric", "load_only", "sticky", "unified"],
|
choices=["linear", "lmetric", "load_only", "sticky",
|
||||||
|
"unified", "unified_v2"],
|
||||||
help="Routing policy: linear (cache-aware), lmetric (P_tokens × BS), "
|
help="Routing policy: linear (cache-aware), lmetric (P_tokens × BS), "
|
||||||
"load_only (B3 control: pure min-num_requests), "
|
"load_only (B3 control: pure min-num_requests), "
|
||||||
"sticky (B3 control: hard session affinity), "
|
"sticky (B3 control: hard session affinity), "
|
||||||
"or unified (hybrid affinity + LMetric fallback)")
|
"unified (hybrid affinity + LMetric fallback), "
|
||||||
|
"or unified_v2 (unified + selective per-request PD-sep "
|
||||||
|
"via Mooncake; requires --bootstrap-ports and "
|
||||||
|
"kv_role=kv_both vLLM launch)")
|
||||||
p.add_argument("--overload-factor", type=float, default=2.0,
|
p.add_argument("--overload-factor", type=float, default=2.0,
|
||||||
help="Break session affinity when instance load > factor * avg")
|
help="Break session affinity when instance load > factor * avg")
|
||||||
# The four flags below are accepted for bench.sh backward compatibility but
|
# The four flags below are accepted for bench.sh backward compatibility but
|
||||||
|
|||||||
@@ -343,6 +343,104 @@ def test_pick_instance_sticky_subsequent_never_breaks(proxy):
|
|||||||
assert idx == 0, "sticky must stay even when pinned instance is saturated"
|
assert idx == 0, "sticky must stay even when pinned instance is saturated"
|
||||||
|
|
||||||
|
|
||||||
|
def test_unified_v2_falls_through_when_new_tokens_small(proxy):
|
||||||
|
"""If post-cache new tokens < threshold, v2 should not PD-sep."""
|
||||||
|
insts = [_make_inst(proxy, f"http://h{i}") for i in range(4)]
|
||||||
|
# Tiny prompt: 2 blocks = 1024 tokens. Below 16k threshold.
|
||||||
|
tokens = [1] * (proxy.BLOCK_SIZE * 2)
|
||||||
|
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
|
||||||
|
insts, tokens, None, len(tokens), {})
|
||||||
|
assert pd_sep is None
|
||||||
|
assert decision["v2_decision"] == "local"
|
||||||
|
assert "below_threshold" in decision["v2_reason"]
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_v2_scene(proxy, *, chosen_decodes: int, src_cache_blocks: int):
|
||||||
|
"""Build a 4-instance scene where inst[0] wins LMetric and inst[2]
|
||||||
|
holds optional cache. All instances have non-zero num_requests so
|
||||||
|
LMetric's bs=0 tie-break doesn't pick an empty instance.
|
||||||
|
|
||||||
|
Returns (insts, prefix_tokens).
|
||||||
|
"""
|
||||||
|
insts = [_make_inst(proxy, f"http://h{i}") for i in range(4)]
|
||||||
|
block_size = proxy.BLOCK_SIZE
|
||||||
|
prefix = []
|
||||||
|
for b in range(128):
|
||||||
|
prefix.extend([2000 + b] * block_size) # 128 × 512 = 65536 tokens
|
||||||
|
|
||||||
|
if src_cache_blocks > 0:
|
||||||
|
insts[2].record_prefix(prefix[: src_cache_blocks * block_size])
|
||||||
|
|
||||||
|
# Make inst[0] have the smallest LMetric score so chosen = inst[0].
|
||||||
|
insts[0].num_requests = max(chosen_decodes, 1)
|
||||||
|
insts[0].pending_prefill_tokens = 0
|
||||||
|
insts[0].ongoing_decode_tokens = chosen_decodes * 5000
|
||||||
|
insts[1].num_requests = 20
|
||||||
|
insts[1].pending_prefill_tokens = 200_000
|
||||||
|
insts[2].num_requests = 30 # src is busy enough not to win LMetric
|
||||||
|
insts[2].pending_prefill_tokens = 200_000
|
||||||
|
insts[3].num_requests = 20
|
||||||
|
insts[3].pending_prefill_tokens = 200_000
|
||||||
|
return insts, prefix
|
||||||
|
|
||||||
|
|
||||||
|
def test_unified_v2_falls_through_when_no_alt_cache(proxy):
|
||||||
|
"""No other instance has meaningful cache → no PD-sep."""
|
||||||
|
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=5, src_cache_blocks=0)
|
||||||
|
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
|
||||||
|
insts, prefix, None, len(prefix), {})
|
||||||
|
assert idx == 0
|
||||||
|
assert pd_sep is None
|
||||||
|
assert "src_cache" in decision["v2_reason"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_unified_v2_triggers_when_src_has_meaningful_cache_and_chosen_has_decodes(proxy):
|
||||||
|
"""Classic v2-win case: big prefill, chosen has decodes, alt has cache."""
|
||||||
|
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=5, src_cache_blocks=128)
|
||||||
|
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
|
||||||
|
insts, prefix, None, len(prefix), {})
|
||||||
|
assert idx == 0
|
||||||
|
assert pd_sep is not None, (
|
||||||
|
f"expected PD-sep, got reason={decision['v2_reason']}"
|
||||||
|
)
|
||||||
|
src, src_idx = pd_sep
|
||||||
|
assert src_idx == 2
|
||||||
|
assert decision["v2_decision"] == "pd_sep"
|
||||||
|
assert decision["v2_src_cache_hit"] >= 60000
|
||||||
|
|
||||||
|
|
||||||
|
def test_unified_v2_falls_through_when_chosen_has_no_decodes(proxy):
|
||||||
|
"""No decodes on chosen → no benefit from PD-sep."""
|
||||||
|
insts, prefix = _setup_v2_scene(proxy, chosen_decodes=0, src_cache_blocks=128)
|
||||||
|
chosen, idx, decision, pd_sep = proxy.pick_instance_unified_v2(
|
||||||
|
insts, prefix, None, len(prefix), {})
|
||||||
|
assert pd_sep is None
|
||||||
|
assert "few_decodes" in decision["v2_reason"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_estimate_transfer_cost_is_calibrated_function(proxy):
|
||||||
|
"""RDMA transfer cost grows with bytes, has a non-zero floor."""
|
||||||
|
cost_empty = proxy.estimate_transfer_cost(0)
|
||||||
|
cost_1gb = proxy.estimate_transfer_cost(1024 ** 3)
|
||||||
|
cost_10gb = proxy.estimate_transfer_cost(10 * 1024 ** 3)
|
||||||
|
assert cost_empty >= 0.2, "should have non-zero floor"
|
||||||
|
assert cost_1gb > cost_empty
|
||||||
|
assert cost_10gb > cost_1gb
|
||||||
|
# 10 GB should be roughly 0.3 + 10/2.7 ≈ 4.0 s
|
||||||
|
assert 3.0 < cost_10gb < 5.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_estimate_same_worker_interference_grows_with_size(proxy):
|
||||||
|
"""Interference cost is monotone in new_tokens up to the saturation regime."""
|
||||||
|
c1 = proxy.estimate_same_worker_interference_s(2000, num_decodes=4)
|
||||||
|
c2 = proxy.estimate_same_worker_interference_s(8000, num_decodes=4)
|
||||||
|
c3 = proxy.estimate_same_worker_interference_s(20000, num_decodes=4)
|
||||||
|
c4 = proxy.estimate_same_worker_interference_s(32000, num_decodes=4)
|
||||||
|
assert c1 < c2 < c3 < c4
|
||||||
|
# Zero decodes -> zero cost regardless of size
|
||||||
|
assert proxy.estimate_same_worker_interference_s(32000, num_decodes=0) == 0.0
|
||||||
|
|
||||||
|
|
||||||
def test_p_offload_penalty_uses_settings_heavy_threshold(proxy):
|
def test_p_offload_penalty_uses_settings_heavy_threshold(proxy):
|
||||||
"""M2: tweaking SETTINGS.heavy_threshold changes the P-offload penalty."""
|
"""M2: tweaking SETTINGS.heavy_threshold changes the P-offload penalty."""
|
||||||
inst = proxy.InstanceState("http://x")
|
inst = proxy.InstanceState("http://x")
|
||||||
|
|||||||
Reference in New Issue
Block a user