Migration correctness smoke tests: direct-read, partial-transfer, NIXL
Standalone smoke tests validating KV-migration correctness paths before trace replay: full migrate-cache, partial-prefill transfer, and a NIXL-connector variant, each with a runner. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
73
microbench/connector_tax/cache_sweep/run_smoke_nixl.sh
Normal file
73
microbench/connector_tax/cache_sweep/run_smoke_nixl.sh
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Smoke test for Nixl-based PD-sep migration (NVLink intra-node via UCX).
|
||||||
|
#
|
||||||
|
# Drops 2 vLLM kv_both NixlConnector instances on GPU 0,1 and runs
|
||||||
|
# smoke_test_migrate_cache.py against them with the kv_transfer_params
|
||||||
|
# format Nixl expects (only do_remote_decode on src; proxy must forward
|
||||||
|
# kv_transfer_params from src's response to dst).
|
||||||
|
#
|
||||||
|
# Since smoke_test_migrate_cache.py is currently hard-coded for Mooncake
|
||||||
|
# (transfer_id + remote_bootstrap_addr), we use a tiny Python in-line
|
||||||
|
# variant here that does the Nixl response-forward handshake directly.
|
||||||
|
|
||||||
|
set -uo pipefail
|
||||||
|
|
||||||
|
PROJ_DIR="${PROJ_DIR:-/home/admin/cpfs/wjh/agentic-kv}"
|
||||||
|
MODEL="${MODEL:-/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
||||||
|
VENV="$PROJ_DIR/.venv/bin"
|
||||||
|
LOGS_DIR="${LOGS_DIR:-$PROJ_DIR/outputs/smoke_nixl_$(date +%Y%m%d_%H%M%S)}"
|
||||||
|
mkdir -p "$LOGS_DIR"
|
||||||
|
|
||||||
|
cleanup() {
|
||||||
|
echo "[smoke-nixl] cleaning up vLLM..."
|
||||||
|
pkill -9 -f "vllm serve" 2>/dev/null || true
|
||||||
|
pkill -9 -f "EngineCore" 2>/dev/null || true
|
||||||
|
sleep 2
|
||||||
|
}
|
||||||
|
trap cleanup EXIT
|
||||||
|
cleanup
|
||||||
|
|
||||||
|
echo "[smoke-nixl] starting 2 vLLM kv_both NixlConnector on GPU 0,1"
|
||||||
|
for i in 0 1; do
|
||||||
|
port=$((8000 + i))
|
||||||
|
nixl_port=$((5600 + i))
|
||||||
|
master=$((29500 + i))
|
||||||
|
PYTHONHASHSEED=42 \
|
||||||
|
VLLM_NIXL_SIDE_CHANNEL_PORT=$nixl_port \
|
||||||
|
CUDA_VISIBLE_DEVICES=$i \
|
||||||
|
MASTER_PORT=$master \
|
||||||
|
nohup "$VENV/vllm" serve "$MODEL" \
|
||||||
|
--host 0.0.0.0 --port "$port" \
|
||||||
|
--tensor-parallel-size 1 \
|
||||||
|
--trust-remote-code --enable-prefix-caching \
|
||||||
|
--dtype auto --gpu-memory-utilization 0.9 \
|
||||||
|
--max-model-len 200000 \
|
||||||
|
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' \
|
||||||
|
--enable-prompt-tokens-details \
|
||||||
|
> "$LOGS_DIR/vllm_inst_${i}_gpu${i}.log" 2>&1 &
|
||||||
|
disown
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "[smoke-nixl] waiting for health on 8000 and 8001 ..."
|
||||||
|
for port in 8000 8001; do
|
||||||
|
tries=0
|
||||||
|
while ! curl -sf "http://127.0.0.1:$port/health" >/dev/null 2>&1; do
|
||||||
|
tries=$((tries+1))
|
||||||
|
if [ $tries -gt 240 ]; then
|
||||||
|
echo "[smoke-nixl] FATAL: $port not ready"; exit 1
|
||||||
|
fi
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
echo " port=$port ready"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "[smoke-nixl] running smoke_nixl_migrate.py"
|
||||||
|
"$VENV/python" "$PROJ_DIR/microbench/connector_tax/cache_sweep/smoke_nixl_migrate.py" \
|
||||||
|
--src-port 8000 --dst-port 8001 \
|
||||||
|
${EXTRA_SMOKE_ARGS:-} \
|
||||||
|
2>&1 | tee "$LOGS_DIR/smoke_output.log"
|
||||||
|
|
||||||
|
ec=${PIPESTATUS[0]}
|
||||||
|
echo "[smoke-nixl] test exit=$ec, logs at $LOGS_DIR"
|
||||||
|
exit $ec
|
||||||
68
microbench/connector_tax/cache_sweep/run_smoke_partial.sh
Normal file
68
microbench/connector_tax/cache_sweep/run_smoke_partial.sh
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Smoke test for Mechanism B (partial KV transfer):
|
||||||
|
# Start 3 vLLM kv_both Mooncake instances on GPU 0,1,2:
|
||||||
|
# - inst_0 = src (port 8000, bp 8998)
|
||||||
|
# - inst_1 = dst_warm (port 8001, bp 8999) — will be pre-warmed
|
||||||
|
# - inst_2 = dst_cold (port 8002, bp 9000) — control, no cache
|
||||||
|
#
|
||||||
|
# Then run smoke_partial_transfer.py which migrates the same prompt
|
||||||
|
# to both warm and cold dst, comparing transfer cost.
|
||||||
|
|
||||||
|
set -uo pipefail
|
||||||
|
|
||||||
|
PROJ_DIR="${PROJ_DIR:-/home/admin/cpfs/wjh/agentic-kv}"
|
||||||
|
MODEL="${MODEL:-/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
||||||
|
VENV="$PROJ_DIR/.venv/bin"
|
||||||
|
LOGS_DIR="${LOGS_DIR:-$PROJ_DIR/outputs/smoke_partial_$(date +%Y%m%d_%H%M%S)}"
|
||||||
|
mkdir -p "$LOGS_DIR"
|
||||||
|
|
||||||
|
cleanup() {
|
||||||
|
echo "[smoke-partial] cleaning up vLLM..."
|
||||||
|
pkill -9 -f "vllm serve" 2>/dev/null || true
|
||||||
|
pkill -9 -f "EngineCore" 2>/dev/null || true
|
||||||
|
sleep 2
|
||||||
|
}
|
||||||
|
trap cleanup EXIT
|
||||||
|
cleanup
|
||||||
|
|
||||||
|
echo "[smoke-partial] starting 3 vLLM kv_both Mooncake on GPU 0,1,2"
|
||||||
|
for i in 0 1 2; do
|
||||||
|
port=$((8000 + i))
|
||||||
|
bp=$((8998 + i))
|
||||||
|
master=$((29500 + i))
|
||||||
|
PYTHONHASHSEED=42 \
|
||||||
|
VLLM_MOONCAKE_BOOTSTRAP_PORT=$bp \
|
||||||
|
CUDA_VISIBLE_DEVICES=$i \
|
||||||
|
MASTER_PORT=$master \
|
||||||
|
nohup "$VENV/vllm" serve "$MODEL" \
|
||||||
|
--host 0.0.0.0 --port "$port" \
|
||||||
|
--tensor-parallel-size 1 \
|
||||||
|
--trust-remote-code --enable-prefix-caching \
|
||||||
|
--dtype auto --gpu-memory-utilization 0.9 \
|
||||||
|
--max-model-len 200000 \
|
||||||
|
--kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_both"}' \
|
||||||
|
--enable-prompt-tokens-details \
|
||||||
|
> "$LOGS_DIR/vllm_inst_${i}_gpu${i}.log" 2>&1 &
|
||||||
|
disown
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "[smoke-partial] waiting for health ..."
|
||||||
|
for port in 8000 8001 8002; do
|
||||||
|
tries=0
|
||||||
|
while ! curl -sf "http://127.0.0.1:$port/health" >/dev/null 2>&1; do
|
||||||
|
tries=$((tries+1))
|
||||||
|
if [ $tries -gt 240 ]; then echo "FATAL: $port"; exit 1; fi
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
echo " port=$port ready"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "[smoke-partial] running smoke_partial_transfer.py"
|
||||||
|
"$VENV/python" "$PROJ_DIR/microbench/connector_tax/cache_sweep/smoke_partial_transfer.py" \
|
||||||
|
${EXTRA_SMOKE_ARGS:-} \
|
||||||
|
2>&1 | tee "$LOGS_DIR/smoke_output.log"
|
||||||
|
|
||||||
|
ec=${PIPESTATUS[0]}
|
||||||
|
echo "[smoke-partial] exit=$ec, logs at $LOGS_DIR"
|
||||||
|
exit $ec
|
||||||
74
microbench/connector_tax/cache_sweep/run_smoke_sweep.sh
Normal file
74
microbench/connector_tax/cache_sweep/run_smoke_sweep.sh
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Single vLLM warmup, multiple smoke-test iterations under varying load.
|
||||||
|
#
|
||||||
|
# Each iteration uses a distinct --prefix-base to avoid prefix-cache pollution
|
||||||
|
# from prior iterations. We sweep noise levels 0, 8, 32, 64 to see at which
|
||||||
|
# point the migration cache becomes invisible to the follow-up.
|
||||||
|
|
||||||
|
set -uo pipefail
|
||||||
|
|
||||||
|
PROJ_DIR="${PROJ_DIR:-/home/admin/cpfs/wjh/agentic-kv}"
|
||||||
|
MODEL="${MODEL:-/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
||||||
|
VENV="$PROJ_DIR/.venv/bin"
|
||||||
|
LOGS_DIR="${LOGS_DIR:-$PROJ_DIR/outputs/smoke_sweep_$(date +%Y%m%d_%H%M%S)}"
|
||||||
|
mkdir -p "$LOGS_DIR"
|
||||||
|
|
||||||
|
cleanup() {
|
||||||
|
echo "[sweep] cleaning up vLLM..."
|
||||||
|
pkill -9 -f "vllm serve" 2>/dev/null || true
|
||||||
|
pkill -9 -f "EngineCore" 2>/dev/null || true
|
||||||
|
sleep 2
|
||||||
|
}
|
||||||
|
trap cleanup EXIT
|
||||||
|
cleanup
|
||||||
|
|
||||||
|
echo "[sweep] starting 2 vLLM kv_both on GPU 0,1"
|
||||||
|
for i in 0 1; do
|
||||||
|
port=$((8000 + i))
|
||||||
|
bp=$((8998 + i))
|
||||||
|
master=$((29500 + i))
|
||||||
|
PYTHONHASHSEED=42 \
|
||||||
|
VLLM_MOONCAKE_BOOTSTRAP_PORT=$bp \
|
||||||
|
CUDA_VISIBLE_DEVICES=$i \
|
||||||
|
MASTER_PORT=$master \
|
||||||
|
nohup "$VENV/vllm" serve "$MODEL" \
|
||||||
|
--host 0.0.0.0 --port "$port" \
|
||||||
|
--tensor-parallel-size 1 \
|
||||||
|
--trust-remote-code --enable-prefix-caching \
|
||||||
|
--dtype auto --gpu-memory-utilization 0.9 \
|
||||||
|
--max-model-len 200000 \
|
||||||
|
--kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_both"}' \
|
||||||
|
--enable-prompt-tokens-details \
|
||||||
|
> "$LOGS_DIR/vllm_inst_${i}_gpu${i}.log" 2>&1 &
|
||||||
|
disown
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "[sweep] waiting for health ..."
|
||||||
|
for port in 8000 8001; do
|
||||||
|
tries=0
|
||||||
|
while ! curl -sf "http://127.0.0.1:$port/health" >/dev/null 2>&1; do
|
||||||
|
tries=$((tries+1))
|
||||||
|
if [ $tries -gt 180 ]; then echo "[sweep] FATAL: $port not ready"; exit 1; fi
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
echo " port=$port ready"
|
||||||
|
done
|
||||||
|
|
||||||
|
base=100
|
||||||
|
for noise in 0 8 32 64 128; do
|
||||||
|
echo ""
|
||||||
|
echo "============================================"
|
||||||
|
echo "[sweep] iteration noise=$noise prefix_base=$base"
|
||||||
|
echo "============================================"
|
||||||
|
"$VENV/python" "$PROJ_DIR/microbench/connector_tax/cache_sweep/smoke_test_migrate_cache.py" \
|
||||||
|
--src-port 8000 --dst-port 8001 \
|
||||||
|
--src-bp 8998 --dst-bp 8999 \
|
||||||
|
--noise-reqs "$noise" \
|
||||||
|
--prefix-base "$base" \
|
||||||
|
2>&1 | tee "$LOGS_DIR/iter_noise${noise}.log" | tail -25
|
||||||
|
base=$((base + 100000))
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "[sweep] all iterations done; logs in $LOGS_DIR"
|
||||||
75
microbench/connector_tax/cache_sweep/run_smoke_test.sh
Normal file
75
microbench/connector_tax/cache_sweep/run_smoke_test.sh
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Fast iteration: start 2 vLLM kv_both, run smoke_test_migrate_cache, tear down.
|
||||||
|
#
|
||||||
|
# Usage: bash run_smoke_test.sh [WAIT_BETWEEN_S]
|
||||||
|
#
|
||||||
|
# Iteration overhead: ~3-4 min warmup + a few sec for the test. Cleanly
|
||||||
|
# tears everything down on exit so you can re-run repeatedly.
|
||||||
|
|
||||||
|
set -uo pipefail
|
||||||
|
|
||||||
|
PROJ_DIR="${PROJ_DIR:-/home/admin/cpfs/wjh/agentic-kv}"
|
||||||
|
MODEL="${MODEL:-/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
|
||||||
|
VENV="$PROJ_DIR/.venv/bin"
|
||||||
|
LOGS_DIR="${LOGS_DIR:-$PROJ_DIR/outputs/smoke_test_$(date +%Y%m%d_%H%M%S)}"
|
||||||
|
mkdir -p "$LOGS_DIR"
|
||||||
|
|
||||||
|
# MOONCAKE_PROTOCOL controls Mooncake's C++ TransferEngine transport.
|
||||||
|
# Options exposed: rdma (default), tcp, nvlink_intra (NVLink intra-node).
|
||||||
|
PROTO="${MOONCAKE_PROTOCOL:-rdma}"
|
||||||
|
echo "[smoke] using Mooncake protocol: $PROTO"
|
||||||
|
|
||||||
|
cleanup() {
|
||||||
|
echo "[smoke] cleaning up vLLM..."
|
||||||
|
pkill -9 -f "vllm serve" 2>/dev/null || true
|
||||||
|
pkill -9 -f "EngineCore" 2>/dev/null || true
|
||||||
|
sleep 2
|
||||||
|
}
|
||||||
|
trap cleanup EXIT
|
||||||
|
cleanup
|
||||||
|
|
||||||
|
echo "[smoke] starting 2 vLLM kv_both on GPU 0,1"
|
||||||
|
for i in 0 1; do
|
||||||
|
port=$((8000 + i))
|
||||||
|
bp=$((8998 + i))
|
||||||
|
master=$((29500 + i))
|
||||||
|
PYTHONHASHSEED=42 \
|
||||||
|
VLLM_MOONCAKE_BOOTSTRAP_PORT=$bp \
|
||||||
|
CUDA_VISIBLE_DEVICES=$i \
|
||||||
|
MASTER_PORT=$master \
|
||||||
|
nohup "$VENV/vllm" serve "$MODEL" \
|
||||||
|
--host 0.0.0.0 --port "$port" \
|
||||||
|
--tensor-parallel-size 1 \
|
||||||
|
--trust-remote-code --enable-prefix-caching \
|
||||||
|
--dtype auto --gpu-memory-utilization 0.9 \
|
||||||
|
--max-model-len 200000 \
|
||||||
|
--kv-transfer-config "{\"kv_connector\":\"MooncakeConnector\",\"kv_role\":\"kv_both\",\"kv_connector_extra_config\":{\"mooncake_protocol\":\"$PROTO\"}}" \
|
||||||
|
--enable-prompt-tokens-details \
|
||||||
|
> "$LOGS_DIR/vllm_inst_${i}_gpu${i}.log" 2>&1 &
|
||||||
|
disown
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "[smoke] waiting for health on 8000 and 8001 ..."
|
||||||
|
for port in 8000 8001; do
|
||||||
|
tries=0
|
||||||
|
while ! curl -sf "http://127.0.0.1:$port/health" >/dev/null 2>&1; do
|
||||||
|
tries=$((tries+1))
|
||||||
|
if [ $tries -gt 180 ]; then
|
||||||
|
echo "[smoke] FATAL: $port not ready"; exit 1
|
||||||
|
fi
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
echo " port=$port ready"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "[smoke] running migration smoke test"
|
||||||
|
"$VENV/python" "$PROJ_DIR/microbench/connector_tax/cache_sweep/smoke_test_migrate_cache.py" \
|
||||||
|
--src-port 8000 --dst-port 8001 \
|
||||||
|
--src-bp 8998 --dst-bp 8999 \
|
||||||
|
${EXTRA_SMOKE_ARGS:-} \
|
||||||
|
2>&1 | tee "$LOGS_DIR/smoke_output.log"
|
||||||
|
|
||||||
|
ec=${PIPESTATUS[0]}
|
||||||
|
echo "[smoke] test exit=$ec, logs at $LOGS_DIR"
|
||||||
|
exit $ec
|
||||||
132
microbench/connector_tax/cache_sweep/smoke_nixl_migrate.py
Normal file
132
microbench/connector_tax/cache_sweep/smoke_nixl_migrate.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Smoke test for Nixl PD-sep migration.
|
||||||
|
|
||||||
|
Nixl handshake (vs Mooncake's pre-baked engine_id):
|
||||||
|
1. POST to src with kv_transfer_params={"do_remote_decode": True},
|
||||||
|
max_tokens=1, stream=False.
|
||||||
|
2. src returns kv_transfer_params in the response body containing
|
||||||
|
remote_block_ids, remote_engine_id, remote_host, remote_port,
|
||||||
|
remote_request_id, tp_size.
|
||||||
|
3. POST to dst with the SAME kv_transfer_params dict.
|
||||||
|
4. dst pulls KV via UCX (NVLink intra-node) and decodes.
|
||||||
|
|
||||||
|
Verifies migration correctness + measures KV transfer latency on Nixl
|
||||||
|
so we can ablate vs Mooncake/RDMA on the same workload.
|
||||||
|
"""
|
||||||
|
import asyncio, argparse, json, sys, uuid, time
|
||||||
|
import httpx
|
||||||
|
import random as _r
|
||||||
|
|
||||||
|
MODEL = "/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
|
async def send(client, port, prompt, max_tokens, kv_xfer, stream):
|
||||||
|
payload = {
|
||||||
|
"model": MODEL,
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"min_tokens": max_tokens if max_tokens == 1 else 1,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
if kv_xfer is not None:
|
||||||
|
payload["kv_transfer_params"] = kv_xfer
|
||||||
|
if stream:
|
||||||
|
payload["stream_options"] = {"include_usage": True}
|
||||||
|
url = f"http://127.0.0.1:{port}/v1/completions"
|
||||||
|
if not stream:
|
||||||
|
r = await client.post(url, json=payload, timeout=300.0)
|
||||||
|
r.raise_for_status()
|
||||||
|
return r.json()
|
||||||
|
last_with_usage = None; last_any = None
|
||||||
|
async with client.stream("POST", url, json=payload, timeout=300.0) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
buf = ""
|
||||||
|
async for chunk in resp.aiter_bytes():
|
||||||
|
buf += chunk.decode("utf-8", errors="replace")
|
||||||
|
while "\n\n" in buf:
|
||||||
|
line, buf = buf.split("\n\n", 1)
|
||||||
|
if line.startswith("data: "):
|
||||||
|
s = line[6:].strip()
|
||||||
|
if s == "[DONE]": continue
|
||||||
|
try:
|
||||||
|
d = json.loads(s); last_any = d
|
||||||
|
if d.get("usage"): last_with_usage = d
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return last_with_usage or last_any or {}
|
||||||
|
|
||||||
|
|
||||||
|
def short(d):
|
||||||
|
if not d: return "no_resp"
|
||||||
|
usage = d.get("usage") or {}
|
||||||
|
details = usage.get("prompt_tokens_details") or {}
|
||||||
|
cached = details.get("cached_tokens", 0) or usage.get("cached_tokens", 0)
|
||||||
|
return (f"cached={cached}/{usage.get('prompt_tokens',0)} "
|
||||||
|
f"completion={usage.get('completion_tokens',0)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument("--src-port", type=int, default=8000)
|
||||||
|
p.add_argument("--dst-port", type=int, default=8001)
|
||||||
|
p.add_argument("--n-prefix-tokens", type=int, default=8192)
|
||||||
|
p.add_argument("--n-extension", type=int, default=32)
|
||||||
|
p.add_argument("--decode-tokens", type=int, default=16)
|
||||||
|
p.add_argument("--prefix-base", type=int, default=100)
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
rng = _r.Random(f"prefix-{args.prefix_base}")
|
||||||
|
prompt = [rng.randint(1024, 99_999) for _ in range(args.n_prefix_tokens)]
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# ----- Step 1: src prefills with do_remote_decode=True -----
|
||||||
|
t_src_start = time.monotonic()
|
||||||
|
src_resp = await send(
|
||||||
|
client, args.src_port, prompt, max_tokens=1,
|
||||||
|
kv_xfer={"do_remote_decode": True}, stream=False,
|
||||||
|
)
|
||||||
|
t_src_done = time.monotonic()
|
||||||
|
src_kv = src_resp.get("kv_transfer_params")
|
||||||
|
print(f"[1] src prefill ({(t_src_done-t_src_start)*1000:.0f}ms): {short(src_resp)}")
|
||||||
|
if not src_kv:
|
||||||
|
print(f" FAIL: src returned no kv_transfer_params")
|
||||||
|
print(f" response keys: {list(src_resp.keys())}")
|
||||||
|
sys.exit(1)
|
||||||
|
print(f" src kv_transfer_params keys: {list(src_kv.keys())}")
|
||||||
|
print(f" remote_block_ids: {len(src_kv.get('remote_block_ids', [[]])[0]) if src_kv.get('remote_block_ids') else 0} blocks")
|
||||||
|
|
||||||
|
# ----- Step 2: dst pulls KV using forwarded kv_transfer_params -----
|
||||||
|
t_dst_start = time.monotonic()
|
||||||
|
dst_resp = await send(
|
||||||
|
client, args.dst_port, prompt, max_tokens=args.decode_tokens,
|
||||||
|
kv_xfer=src_kv, stream=True,
|
||||||
|
)
|
||||||
|
t_dst_done = time.monotonic()
|
||||||
|
dst_total_ms = (t_dst_done - t_dst_start) * 1000
|
||||||
|
n_completion = (dst_resp.get("usage") or {}).get("completion_tokens", 0)
|
||||||
|
print(f"[2] dst decode ({dst_total_ms:.0f}ms, {n_completion} completion tokens): {short(dst_resp)}")
|
||||||
|
print(f" [TIMING] proto=nixl src_prefill={int((t_src_done-t_src_start)*1000)}ms "
|
||||||
|
f"dst_total={int(dst_total_ms)}ms (KV xfer + {n_completion}-token decode)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ----- Step 3: follow-up on dst (no kv_transfer_params) -----
|
||||||
|
ext = [_r.Random(f"ext-{args.prefix_base}").randint(1024, 99_999)
|
||||||
|
for _ in range(args.n_extension)]
|
||||||
|
follow_prompt = prompt + ext
|
||||||
|
fu = await send(client, args.dst_port, follow_prompt, max_tokens=4,
|
||||||
|
kv_xfer=None, stream=False)
|
||||||
|
print(f"[3] follow-up dst (cache hit test): {short(fu)}")
|
||||||
|
|
||||||
|
usage_fu = fu.get("usage") or {}
|
||||||
|
details_fu = usage_fu.get("prompt_tokens_details") or {}
|
||||||
|
cached_fu = details_fu.get("cached_tokens", 0) or usage_fu.get("cached_tokens", 0)
|
||||||
|
expected_min = int(args.n_prefix_tokens * 0.95)
|
||||||
|
verdict = "PASS" if cached_fu >= expected_min else "FAIL"
|
||||||
|
print(f"\n=== verdict: {verdict} (follow-up cached={cached_fu}, "
|
||||||
|
f"expected >= {expected_min} of {args.n_prefix_tokens}) ===")
|
||||||
|
sys.exit(0 if verdict == "PASS" else 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
166
microbench/connector_tax/cache_sweep/smoke_partial_transfer.py
Normal file
166
microbench/connector_tax/cache_sweep/smoke_partial_transfer.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Smoke test for partial KV transfer (Mechanism B).
|
||||||
|
|
||||||
|
Test if vLLM's Mooncake connector actually transfers only the
|
||||||
|
NON-OVERLAPPING portion when dst already has prefix cache.
|
||||||
|
|
||||||
|
Sequence:
|
||||||
|
step 0: warm dst with prompt P (cold prefill) — dst now has cache for [0, P]
|
||||||
|
step 1: cold prefill on src with prompt P+ext (src now has [0, P+ext])
|
||||||
|
step 2: migrate src→dst with prompt P+ext
|
||||||
|
- dst local cache should hit [0, P]
|
||||||
|
- only [P, P+ext] needs to come from src (~ext tokens)
|
||||||
|
- dst decode should be fast
|
||||||
|
step 3: control: another migrate src→dst_cold with prompt P+ext
|
||||||
|
- dst_cold has no cache, must pull all P+ext tokens
|
||||||
|
- compare with step 2
|
||||||
|
|
||||||
|
If step 2 is dramatically faster than step 3, partial transfer works
|
||||||
|
and Mechanism B is viable. If step 2 ~= step 3, partial transfer isn't
|
||||||
|
being exploited and we need to dig deeper.
|
||||||
|
"""
|
||||||
|
import asyncio, argparse, json, sys, uuid, time
|
||||||
|
import httpx
|
||||||
|
import random as _r
|
||||||
|
|
||||||
|
MODEL = "/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
|
async def send(client, port, prompt, max_tokens, kv_xfer, stream):
|
||||||
|
payload = {
|
||||||
|
"model": MODEL,
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"min_tokens": max_tokens if max_tokens == 1 else 1,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
if kv_xfer is not None:
|
||||||
|
payload["kv_transfer_params"] = kv_xfer
|
||||||
|
if stream:
|
||||||
|
payload["stream_options"] = {"include_usage": True}
|
||||||
|
url = f"http://127.0.0.1:{port}/v1/completions"
|
||||||
|
if not stream:
|
||||||
|
r = await client.post(url, json=payload, timeout=300.0)
|
||||||
|
r.raise_for_status()
|
||||||
|
return r.json()
|
||||||
|
last_w = None; last_any = None
|
||||||
|
async with client.stream("POST", url, json=payload, timeout=300.0) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
buf = ""
|
||||||
|
async for chunk in resp.aiter_bytes():
|
||||||
|
buf += chunk.decode("utf-8", errors="replace")
|
||||||
|
while "\n\n" in buf:
|
||||||
|
line, buf = buf.split("\n\n", 1)
|
||||||
|
if line.startswith("data: "):
|
||||||
|
s = line[6:].strip()
|
||||||
|
if s == "[DONE]": continue
|
||||||
|
try:
|
||||||
|
d = json.loads(s); last_any = d
|
||||||
|
if d.get("usage"): last_w = d
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return last_w or last_any or {}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_engine_id(client, bp):
|
||||||
|
r = await client.get(f"http://127.0.0.1:{bp}/query")
|
||||||
|
return r.json()["0"]["engine_id"]
|
||||||
|
|
||||||
|
|
||||||
|
def cached_of(d):
|
||||||
|
usage = d.get("usage") or {}
|
||||||
|
det = usage.get("prompt_tokens_details") or {}
|
||||||
|
return det.get("cached_tokens", 0) or usage.get("cached_tokens", 0)
|
||||||
|
|
||||||
|
|
||||||
|
async def do_migration(client, src_port, dst_port, src_bp, prompt, max_tokens, label):
|
||||||
|
"""Perform Mooncake-style PD-sep migration, returns (src_ms, dst_ms, response)."""
|
||||||
|
src_id = await get_engine_id(client, src_bp)
|
||||||
|
transfer_id = f"smoke-xfer-{uuid.uuid4().hex[:8]}"
|
||||||
|
t0 = time.monotonic()
|
||||||
|
src_resp = await send(client, src_port, prompt, max_tokens=1,
|
||||||
|
kv_xfer={"do_remote_decode": True, "do_remote_prefill": False,
|
||||||
|
"transfer_id": transfer_id}, stream=False)
|
||||||
|
t1 = time.monotonic()
|
||||||
|
bootstrap_addr = f"http://127.0.0.1:{src_bp}"
|
||||||
|
dst_resp = await send(client, dst_port, prompt, max_tokens=max_tokens,
|
||||||
|
kv_xfer={"do_remote_decode": False, "do_remote_prefill": True,
|
||||||
|
"remote_bootstrap_addr": bootstrap_addr,
|
||||||
|
"remote_engine_id": src_id,
|
||||||
|
"transfer_id": transfer_id}, stream=True)
|
||||||
|
t2 = time.monotonic()
|
||||||
|
src_ms = int((t1-t0)*1000); dst_ms = int((t2-t1)*1000)
|
||||||
|
cached = cached_of(dst_resp)
|
||||||
|
print(f" [{label}] src_prefill={src_ms}ms dst_total={dst_ms}ms cached={cached}/{len(prompt)}")
|
||||||
|
return src_ms, dst_ms, dst_resp
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument("--src-port", type=int, default=8000)
|
||||||
|
p.add_argument("--src-bp", type=int, default=8998)
|
||||||
|
p.add_argument("--dst-warm-port", type=int, default=8001) # will be pre-warmed
|
||||||
|
p.add_argument("--dst-warm-bp", type=int, default=8999)
|
||||||
|
p.add_argument("--dst-cold-port", type=int, default=8002) # cold control
|
||||||
|
p.add_argument("--dst-cold-bp", type=int, default=9000)
|
||||||
|
p.add_argument("--prefix-tokens", type=int, default=32768)
|
||||||
|
p.add_argument("--ext-tokens", type=int, default=512)
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
rng = _r.Random("partial-1")
|
||||||
|
prompt_base = [rng.randint(1024, 99_999) for _ in range(args.prefix_tokens)]
|
||||||
|
ext_tokens = [_r.Random("ext-partial").randint(1024, 99_999) for _ in range(args.ext_tokens)]
|
||||||
|
prompt_ext = prompt_base + ext_tokens
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
print(f"\n=== Setup ===")
|
||||||
|
print(f"Prompt prefix: {args.prefix_tokens} tokens, extension: {args.ext_tokens} tokens")
|
||||||
|
|
||||||
|
# Step 0: warm dst_warm with prompt_base (normal request, no kv_transfer)
|
||||||
|
print(f"\n=== Step 0: warm dst_warm (port {args.dst_warm_port}) with prompt_base ===")
|
||||||
|
t0 = time.monotonic()
|
||||||
|
r = await send(client, args.dst_warm_port, prompt_base, max_tokens=1,
|
||||||
|
kv_xfer=None, stream=False)
|
||||||
|
t1 = time.monotonic()
|
||||||
|
print(f" cold prefill on dst_warm: {int((t1-t0)*1000)}ms, cached={cached_of(r)}/{args.prefix_tokens}")
|
||||||
|
|
||||||
|
# Sanity: 2nd request to dst_warm hits local cache
|
||||||
|
print(f"\n=== Sanity: 2nd request to dst_warm with same prompt — should hit local cache ===")
|
||||||
|
t0 = time.monotonic()
|
||||||
|
r = await send(client, args.dst_warm_port, prompt_base, max_tokens=1,
|
||||||
|
kv_xfer=None, stream=False)
|
||||||
|
t1 = time.monotonic()
|
||||||
|
print(f" warm request on dst_warm: {int((t1-t0)*1000)}ms, cached={cached_of(r)}/{args.prefix_tokens}")
|
||||||
|
|
||||||
|
# Step 1: cold prefill on src with prompt_ext (also caches at src)
|
||||||
|
print(f"\n=== Step 1: cold prefill on src (port {args.src_port}) with prompt_ext ===")
|
||||||
|
t0 = time.monotonic()
|
||||||
|
r = await send(client, args.src_port, prompt_ext, max_tokens=1,
|
||||||
|
kv_xfer=None, stream=False)
|
||||||
|
t1 = time.monotonic()
|
||||||
|
print(f" cold prefill on src: {int((t1-t0)*1000)}ms, cached={cached_of(r)}/{len(prompt_ext)}")
|
||||||
|
|
||||||
|
# Step 2: MIGRATE src -> dst_warm (which has cache for prompt_base)
|
||||||
|
print(f"\n=== Step 2: MIGRATE src -> dst_warm (cache-rich) with prompt_ext ===")
|
||||||
|
s_ms_warm, d_ms_warm, _ = await do_migration(
|
||||||
|
client, args.src_port, args.dst_warm_port, args.src_bp,
|
||||||
|
prompt_ext, max_tokens=4, label="cache-rich dst")
|
||||||
|
|
||||||
|
# Step 3: MIGRATE src -> dst_cold (no cache)
|
||||||
|
print(f"\n=== Step 3: MIGRATE src -> dst_cold (port {args.dst_cold_port}, cold) with prompt_ext ===")
|
||||||
|
s_ms_cold, d_ms_cold, _ = await do_migration(
|
||||||
|
client, args.src_port, args.dst_cold_port, args.src_bp,
|
||||||
|
prompt_ext, max_tokens=4, label="cold dst (control)")
|
||||||
|
|
||||||
|
# Verdict
|
||||||
|
print(f"\n=== VERDICT ===")
|
||||||
|
print(f"Cache-rich dst (Mechanism B): dst_total={d_ms_warm}ms")
|
||||||
|
print(f"Cold dst (full transfer): dst_total={d_ms_cold}ms")
|
||||||
|
speedup = (d_ms_cold - d_ms_warm) / d_ms_cold * 100 if d_ms_cold > 0 else 0
|
||||||
|
print(f"Δ = {d_ms_cold - d_ms_warm}ms ({speedup:+.1f}% faster with cache-rich dst)")
|
||||||
|
print(f"Partial transfer {'WORKING' if speedup > 30 else 'NOT working / not exploited'}.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
255
microbench/connector_tax/cache_sweep/smoke_test_migrate_cache.py
Normal file
255
microbench/connector_tax/cache_sweep/smoke_test_migrate_cache.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Smoke test: does remote-prefill on dst leave its prefix cache discoverable
|
||||||
|
to a follow-up turn on the same instance?
|
||||||
|
|
||||||
|
Reproducer for the v3 rotation bug observed in unified_v3 (next-turn at
|
||||||
|
decode_target sees cached_tokens=0 despite migration's `cache_blocks`
|
||||||
|
supposedly running).
|
||||||
|
|
||||||
|
Expects 2 vLLM instances running on 127.0.0.1:8000 and 8001 with
|
||||||
|
--kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_both"}'
|
||||||
|
and Mooncake bootstrap servers on 8998 and 8999.
|
||||||
|
|
||||||
|
Run flow:
|
||||||
|
1. Query both bootstrap servers for engine_ids.
|
||||||
|
2. Send migration: src=8000 do_remote_decode (max_tokens=1), then
|
||||||
|
dst=8001 do_remote_prefill (pulls KV via Mooncake) with same prompt.
|
||||||
|
3. Send follow-up: same session prompt + tiny extension, hit 8001
|
||||||
|
directly (no kv_transfer_params), check cached_tokens.
|
||||||
|
|
||||||
|
A working migration with prefix-cache visibility would see ~100% cached
|
||||||
|
on the follow-up (full prefix hit). The v3 bug shows cached=0.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
async def get_engine_id(client: httpx.AsyncClient, port: int) -> str:
|
||||||
|
url = f"http://127.0.0.1:{port}/query"
|
||||||
|
r = await client.get(url)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
# data = {"0": {"engine_id": "..."}, ...}
|
||||||
|
return data["0"]["engine_id"]
|
||||||
|
|
||||||
|
|
||||||
|
async def send_completion(
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
host_port: int,
|
||||||
|
prompt: list[int],
|
||||||
|
max_tokens: int,
|
||||||
|
kv_transfer_params: dict | None = None,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
payload = {
|
||||||
|
"model": "/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct",
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"min_tokens": max_tokens if max_tokens == 1 else 1,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
if stream:
|
||||||
|
payload["stream_options"] = {"include_usage": True}
|
||||||
|
if kv_transfer_params:
|
||||||
|
payload["kv_transfer_params"] = kv_transfer_params
|
||||||
|
url = f"http://127.0.0.1:{host_port}/v1/completions"
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
r = await client.post(url, json=payload, timeout=300.0)
|
||||||
|
r.raise_for_status()
|
||||||
|
return r.json()
|
||||||
|
else:
|
||||||
|
# Stream and collect chunks; keep the LAST chunk that contains
|
||||||
|
# `usage` (with include_usage, the very last data: chunk has it).
|
||||||
|
last_with_usage = None
|
||||||
|
last_any = None
|
||||||
|
async with client.stream("POST", url, json=payload, timeout=300.0) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
buffer = ""
|
||||||
|
async for chunk in resp.aiter_bytes():
|
||||||
|
buffer += chunk.decode("utf-8", errors="replace")
|
||||||
|
while "\n\n" in buffer:
|
||||||
|
line, buffer = buffer.split("\n\n", 1)
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data_str = line[6:].strip()
|
||||||
|
if data_str == "[DONE]":
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
d = json.loads(data_str)
|
||||||
|
last_any = d
|
||||||
|
if d.get("usage"):
|
||||||
|
last_with_usage = d
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return last_with_usage or last_any or {}
|
||||||
|
|
||||||
|
|
||||||
|
def short(d: dict) -> str:
|
||||||
|
"""Pull cached_tokens out of the response usage section."""
|
||||||
|
if not d:
|
||||||
|
return "no_resp"
|
||||||
|
usage = d.get("usage") or {}
|
||||||
|
details = usage.get("prompt_tokens_details") or {}
|
||||||
|
cached = details.get("cached_tokens", 0) or usage.get("cached_tokens", 0)
|
||||||
|
return (
|
||||||
|
f"cached={cached}/{usage.get('prompt_tokens', 0)} "
|
||||||
|
f"completion={usage.get('completion_tokens', 0)} "
|
||||||
|
f"id={d.get('id', '?')[:24]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument("--src-port", type=int, default=8000)
|
||||||
|
p.add_argument("--dst-port", type=int, default=8001)
|
||||||
|
p.add_argument("--src-bp", type=int, default=8998)
|
||||||
|
p.add_argument("--dst-bp", type=int, default=8999)
|
||||||
|
p.add_argument("--n-prefix-tokens", type=int, default=8192,
|
||||||
|
help="Length of synthetic prompt prefix (tokens)")
|
||||||
|
p.add_argument("--n-extension", type=int, default=32,
|
||||||
|
help="Tokens added in the follow-up request")
|
||||||
|
p.add_argument("--noise-reqs", type=int, default=0,
|
||||||
|
help="Number of unrelated requests to send to dst between "
|
||||||
|
"migration and follow-up (eviction-pressure test)")
|
||||||
|
p.add_argument("--noise-tokens", type=int, default=16384,
|
||||||
|
help="Tokens per noise request")
|
||||||
|
p.add_argument("--noise-parallel", type=int, default=4,
|
||||||
|
help="How many noise requests in parallel")
|
||||||
|
p.add_argument("--prefix-base", type=int, default=100,
|
||||||
|
help="Token-id base for the prompt prefix (use distinct value "
|
||||||
|
"across iterations to avoid prefix-cache pollution).")
|
||||||
|
p.add_argument("--decode-tokens", type=int, default=16,
|
||||||
|
help="max_tokens on dst decode request")
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
src_id = await get_engine_id(client, args.src_bp)
|
||||||
|
dst_id = await get_engine_id(client, args.dst_bp)
|
||||||
|
print(f"src engine_id = {src_id}")
|
||||||
|
print(f"dst engine_id = {dst_id}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Build a deterministic prompt: a long enough token sequence to be
|
||||||
|
# multiple blocks. Use simple range to avoid tokenizer dependence.
|
||||||
|
# Build deterministic prompt using values in safe vocab range [1024, 100000)
|
||||||
|
# so high prefix-base values don't overflow the tokenizer vocab.
|
||||||
|
import random as _r
|
||||||
|
rng = _r.Random(f"prefix-{args.prefix_base}")
|
||||||
|
prompt = [rng.randint(1024, 99_999) for _ in range(args.n_prefix_tokens)]
|
||||||
|
|
||||||
|
# ----- Step 1: Migration. src does prefill (max_tokens=1, no stream),
|
||||||
|
# then dst pulls KV and decodes. -----
|
||||||
|
transfer_id = f"smoke-xfer-{uuid.uuid4().hex[:8]}"
|
||||||
|
print(f"[1] migration: transfer_id={transfer_id}")
|
||||||
|
|
||||||
|
# First: cold-prefill on src (no Mooncake yet, just establish src has the KV)
|
||||||
|
# The proxy convention: src sees do_remote_decode=True so it will SEND its KV
|
||||||
|
# via Mooncake later. We do this as a single-shot.
|
||||||
|
import time as _t
|
||||||
|
t_src_start = _t.monotonic()
|
||||||
|
src_resp_task = asyncio.create_task(
|
||||||
|
send_completion(
|
||||||
|
client, args.src_port, prompt, max_tokens=1,
|
||||||
|
kv_transfer_params={
|
||||||
|
"do_remote_decode": True,
|
||||||
|
"do_remote_prefill": False,
|
||||||
|
"transfer_id": transfer_id,
|
||||||
|
},
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Slight stagger: dst needs to be ready to pull. In production the
|
||||||
|
# proxy waits for src to finish before sending dst. We'll do the
|
||||||
|
# same — await src then send dst.
|
||||||
|
src_resp = await src_resp_task
|
||||||
|
t_src_done = _t.monotonic()
|
||||||
|
print(f" src prefill resp ({(t_src_done-t_src_start)*1000:.0f}ms): {short(src_resp)}")
|
||||||
|
|
||||||
|
bootstrap_addr = f"http://127.0.0.1:{args.src_bp}"
|
||||||
|
t_dst_start = _t.monotonic()
|
||||||
|
dst_resp = await send_completion(
|
||||||
|
client, args.dst_port, prompt, max_tokens=args.decode_tokens,
|
||||||
|
kv_transfer_params={
|
||||||
|
"do_remote_decode": False,
|
||||||
|
"do_remote_prefill": True,
|
||||||
|
"remote_bootstrap_addr": bootstrap_addr,
|
||||||
|
"remote_engine_id": src_id,
|
||||||
|
"transfer_id": transfer_id,
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
t_dst_done = _t.monotonic()
|
||||||
|
# dst time = KV transfer + decode of N tokens. Subtract approx decode time
|
||||||
|
# to isolate transfer cost.
|
||||||
|
usage_d = dst_resp.get("usage") or {}
|
||||||
|
n_completion = usage_d.get("completion_tokens", 0)
|
||||||
|
dst_total_ms = (t_dst_done - t_dst_start) * 1000
|
||||||
|
print(f" dst decode resp ({dst_total_ms:.0f}ms, {n_completion} completion tokens): {short(dst_resp)}")
|
||||||
|
print(f" [TIMING] proto={args.n_prefix_tokens}p src_prefill={int((t_src_done-t_src_start)*1000)}ms "
|
||||||
|
f"dst_total={int(dst_total_ms)}ms (KV xfer + {n_completion}-token decode)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ----- Step 1.5: Noise. Send unrelated requests to dst to test eviction. -----
|
||||||
|
if args.noise_reqs > 0:
|
||||||
|
print(f"[1.5] sending {args.noise_reqs} noise requests "
|
||||||
|
f"(tokens={args.noise_tokens}, parallel={args.noise_parallel}) to dst")
|
||||||
|
async def noise(idx):
|
||||||
|
rng_n = _r.Random(f"noise-{args.prefix_base}-{idx}")
|
||||||
|
p_n = [rng_n.randint(1024, 99_999) for _ in range(args.noise_tokens)]
|
||||||
|
return await send_completion(
|
||||||
|
client, args.dst_port, p_n, max_tokens=1,
|
||||||
|
kv_transfer_params=None, stream=False,
|
||||||
|
)
|
||||||
|
sem = asyncio.Semaphore(args.noise_parallel)
|
||||||
|
async def gated(idx):
|
||||||
|
async with sem:
|
||||||
|
return await noise(idx)
|
||||||
|
results = await asyncio.gather(*[gated(i) for i in range(args.noise_reqs)])
|
||||||
|
done = sum(1 for r in results if r and r.get("usage"))
|
||||||
|
print(f" noise: {done}/{args.noise_reqs} completed")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ----- Step 2: Follow-up. Same session, extended prompt, hit dst directly. -----
|
||||||
|
rng_ext = _r.Random(f"ext-{args.prefix_base}")
|
||||||
|
follow_prompt = prompt + [rng_ext.randint(1024, 99_999) for _ in range(args.n_extension)]
|
||||||
|
print(f"[2] follow-up direct to dst (no kv_transfer_params): "
|
||||||
|
f"prefix_len={args.n_prefix_tokens}, extended_len={len(follow_prompt)}")
|
||||||
|
fu = await send_completion(
|
||||||
|
client, args.dst_port, follow_prompt, max_tokens=4,
|
||||||
|
kv_transfer_params=None,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
print(f" follow-up resp: {short(fu)}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ----- Step 3: Same prompt twice on dst (sanity) -----
|
||||||
|
print(f"[3] sanity: same prompt again to dst (should see local hit "
|
||||||
|
f"from step 2's just-cached blocks)")
|
||||||
|
sanity = await send_completion(
|
||||||
|
client, args.dst_port, follow_prompt, max_tokens=4,
|
||||||
|
kv_transfer_params=None,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
print(f" sanity resp: {short(sanity)}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ----- Verdict -----
|
||||||
|
usage_fu = fu.get("usage") or {}
|
||||||
|
details_fu = usage_fu.get("prompt_tokens_details") or {}
|
||||||
|
cached_fu = details_fu.get("cached_tokens", 0) or usage_fu.get("cached_tokens", 0)
|
||||||
|
# Expect ~n_prefix_tokens (minus the last token + alignment)
|
||||||
|
expected_min = int(args.n_prefix_tokens * 0.95)
|
||||||
|
verdict = "PASS" if cached_fu >= expected_min else "FAIL"
|
||||||
|
print(f"=== verdict: {verdict} (follow-up cached={cached_fu}, "
|
||||||
|
f"expected >= {expected_min} of {args.n_prefix_tokens}) ===")
|
||||||
|
sys.exit(0 if verdict == "PASS" else 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
Reference in New Issue
Block a user