#!/usr/bin/env python3 """End-to-end smoke for the SGLang snapshot link integration. Brings up TWO SGLang workers on this node (one acts as D, the other as P) with ``SGLANG_SNAPSHOT_LINK_ENABLE=1`` and exercises the three RPCs: 1. POST {P}/_snapshot/prepare_receive → P allocates kv_pool slots 2. POST {D}/_snapshot/dump → D RDMA-pushes session KV 3. POST {P}/_snapshot/finalize_ingest → P inserts into radix tree To populate D's SessionAwareCache with a session, we first send a normal streaming-session generate request to D. After finalize, we send another generate request to P with the same prefix and check whether the report says cached_tokens > 0 (cache hit). This is a minimum-fidelity end-to-end smoke. It does NOT use the full agentic-pd-hybrid reseed orchestration; that's the next commit. Required env: MODEL default /mnt/models/Qwen/Qwen3-30B-A3B-Instruct-2507 Usage: bash scripts/setup_env.sh && uv run --no-sync python \ scripts/smoke_snapshot_sglang_integration.py """ from __future__ import annotations import argparse import json import os import signal import subprocess import sys import time from pathlib import Path from typing import Optional import httpx def _build_server_cmd(args, role: str, gpu_id: int, base_port: int, snapshot_port: int, ib_device: str) -> list: """Build the SGLang launch command for one worker (D or P).""" common = [ sys.executable, "-m", "sglang.launch_server", "--model-path", args.model, "--host", "127.0.0.1", "--port", str(base_port), "--tp-size", "1", "--mem-fraction-static", "0.6", "--disable-cuda-graph", "--disable-overlap-schedule", "--enable-streaming-session", "--disaggregation-mode", role, "--disaggregation-transfer-backend", "mooncake", "--disaggregation-bootstrap-port", str(base_port + 5000), "--disaggregation-ib-device", ib_device, ] return common def _server_env(args, gpu_id: int, snapshot_port: int, ib_device: str) -> dict: env = os.environ.copy() env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) env["SGLANG_SNAPSHOT_LINK_ENABLE"] = "1" env["SGLANG_SNAPSHOT_LINK_HOST"] = "127.0.0.1" env["SGLANG_SNAPSHOT_LINK_PORT"] = str(snapshot_port) env["SGLANG_SNAPSHOT_LINK_IB_DEVICE"] = ib_device env["MOONCAKE_PROTOCOL"] = "rdma" env["MOONCAKE_DEVICE"] = ib_device env["MC_TRANSFER_TIMEOUT"] = "1800" return env def _wait_for_ready(url: str, timeout: float = 240.0) -> bool: deadline = time.time() + timeout while time.time() < deadline: try: r = httpx.get(f"{url}/health", timeout=2.0) if r.status_code == 200: return True except Exception: pass time.sleep(2) return False def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", default=os.environ.get("MODEL", "/mnt/models/Qwen/Qwen3-30B-A3B-Instruct-2507")) ap.add_argument("--d-gpu", type=int, default=1) ap.add_argument("--p-gpu", type=int, default=0) ap.add_argument("--d-port", type=int, default=29040) ap.add_argument("--p-port", type=int, default=29041) ap.add_argument("--d-snap-port", type=int, default=29045) ap.add_argument("--p-snap-port", type=int, default=29046) ap.add_argument("--ib", default="mlx5_60") ap.add_argument("--log-dir", default="outputs/snapshot_sglang_smoke") args = ap.parse_args() log_dir = Path(args.log_dir) log_dir.mkdir(parents=True, exist_ok=True) # Spawn P first (so D can find its snapshot endpoint later via prepare_receive) p_cmd = _build_server_cmd(args, "prefill", args.p_gpu, args.p_port, args.p_snap_port, args.ib) p_env = _server_env(args, args.p_gpu, args.p_snap_port, args.ib) p_stdout = open(log_dir / "p.stdout", "w") p_stderr = open(log_dir / "p.stderr", "w") print(f"[smoke] launching P: {' '.join(p_cmd)}") p_proc = subprocess.Popen(p_cmd, env=p_env, stdout=p_stdout, stderr=p_stderr) d_cmd = _build_server_cmd(args, "decode", args.d_gpu, args.d_port, args.d_snap_port, args.ib) d_env = _server_env(args, args.d_gpu, args.d_snap_port, args.ib) d_stdout = open(log_dir / "d.stdout", "w") d_stderr = open(log_dir / "d.stderr", "w") print(f"[smoke] launching D: {' '.join(d_cmd)}") d_proc = subprocess.Popen(d_cmd, env=d_env, stdout=d_stdout, stderr=d_stderr) try: print(f"[smoke] waiting for P @ 127.0.0.1:{args.p_port} ...") if not _wait_for_ready(f"http://127.0.0.1:{args.p_port}", timeout=300): _tail_stderr(log_dir / "p.stderr") raise RuntimeError("P server did not become healthy") print(f"[smoke] waiting for D @ 127.0.0.1:{args.d_port} ...") if not _wait_for_ready(f"http://127.0.0.1:{args.d_port}", timeout=300): _tail_stderr(log_dir / "d.stderr") raise RuntimeError("D server did not become healthy") print(f"[smoke] both servers up — running RPC sanity ...") session_id = "smoke-sess-001" # NOTE: we deliberately skip seeding a session on D with a real # /generate call. Decode-mode workers crash on raw /generate without # PD-router-provided bootstrap_host (see decode.py:_bootstrap_addr). # The point of this smoke is to verify the 3 snapshot RPCs are # wired up correctly. KV correctness needs the full router stack # (covered by the end-to-end E4 sweep, not here). # 3. Probe snapshot link: prepare_receive on P num_tokens = 64 prep = httpx.post( f"http://127.0.0.1:{args.p_port}/_snapshot/prepare_receive", json={ "session_id": session_id, "num_tokens": num_tokens, "expected_bytes_per_layer_k": 0, "expected_bytes_per_layer_v": 0, }, timeout=30, ) print(f"[smoke] prepare_receive on P → {prep.status_code}: {prep.text[:500]}") if prep.status_code != 200: return 1 prep_data = prep.json() if not prep_data.get("ok"): print(f"[smoke] prepare_receive returned ok=false: {prep_data}") return 1 # 4. Dump on D — expect failure (session-not-resident), proves the # handler is reachable and exits the failure path cleanly. dump = httpx.post( f"http://127.0.0.1:{args.d_port}/_snapshot/dump", json={ "session_id": session_id, "target_snapshot_session_id": prep_data["snapshot_session_id"], "target_k_base_ptrs": prep_data["k_base_ptrs"], "target_v_base_ptrs": prep_data["v_base_ptrs"], "target_slot_indices": prep_data["slot_indices"], "target_stride_k_bytes": prep_data["stride_k_bytes"], "target_stride_v_bytes": prep_data["stride_v_bytes"], "ib_device": args.ib, }, timeout=60, ) print(f"[smoke] dump on D (expected fail) → {dump.status_code}: {dump.text[:500]}") if dump.status_code != 200: return 1 dump_data = dump.json() dump_reason = dump_data.get("reason", "") if dump_data.get("ok"): print("[smoke] unexpected dump success on a session that doesn't exist") elif dump_reason != "session-not-resident": print(f"[smoke] dump failed with wrong reason: {dump_reason}") return 1 # 5. Finalize on P with fake token_ids — radix insert should succeed prompt_ids = list(range(101, 101 + num_tokens)) # fake but unique ids fin = httpx.post( f"http://127.0.0.1:{args.p_port}/_snapshot/finalize_ingest", json={ "session_id": session_id, "token_ids": prompt_ids, "slot_indices": prep_data["slot_indices"], }, timeout=30, ) print(f"[smoke] finalize on P → {fin.status_code}: {fin.text[:500]}") if fin.status_code != 200: return 1 fin_data = fin.json() if not fin_data.get("ok"): print(f"[smoke] finalize returned ok=false: {fin_data}") return 1 print(f"[smoke] inserted_prefix_len = {fin_data.get('inserted_prefix_len')}") print("[smoke] OVERALL: PASS — all 3 RPCs reachable + handlers return expected schema") print(" (KV-correctness end-to-end check requires the full PD router stack;") print(" see scripts/sweep_e4_d_to_p_sync.sh for that)") return 0 finally: for name, proc in [("D", d_proc), ("P", p_proc)]: try: proc.send_signal(signal.SIGINT) except Exception: pass for name, proc in [("D", d_proc), ("P", p_proc)]: try: proc.wait(timeout=15) except Exception: proc.terminate() try: proc.wait(timeout=5) except Exception: proc.kill() def _tail_stderr(path: Path, n: int = 60) -> None: try: text = path.read_text() except FileNotFoundError: return print(f"--- {path} (last {n}) ---") for line in text.splitlines()[-n:]: print(f" {line}") if __name__ == "__main__": sys.exit(main())