diff --git a/scripts/smoke_snapshot_sglang_integration.py b/scripts/smoke_snapshot_sglang_integration.py new file mode 100644 index 0000000..9a42088 --- /dev/null +++ b/scripts/smoke_snapshot_sglang_integration.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +"""End-to-end smoke for the SGLang snapshot link integration. + +Brings up TWO SGLang workers on this node (one acts as D, the other as P) +with ``SGLANG_SNAPSHOT_LINK_ENABLE=1`` and exercises the three RPCs: + + 1. POST {P}/_snapshot/prepare_receive → P allocates kv_pool slots + 2. POST {D}/_snapshot/dump → D RDMA-pushes session KV + 3. POST {P}/_snapshot/finalize_ingest → P inserts into radix tree + +To populate D's SessionAwareCache with a session, we first send a normal +streaming-session generate request to D. + +After finalize, we send another generate request to P with the same prefix +and check whether the report says cached_tokens > 0 (cache hit). + +This is a minimum-fidelity end-to-end smoke. It does NOT use the full +agentic-pd-hybrid reseed orchestration; that's the next commit. + +Required env: + MODEL default /mnt/models/Qwen/Qwen3-30B-A3B-Instruct-2507 + +Usage: + bash scripts/setup_env.sh && uv run --no-sync python \ + scripts/smoke_snapshot_sglang_integration.py +""" + +from __future__ import annotations + +import argparse +import json +import os +import signal +import subprocess +import sys +import time +from pathlib import Path +from typing import Optional + +import httpx + + +def _build_server_cmd(args, role: str, gpu_id: int, base_port: int, + snapshot_port: int, ib_device: str) -> list: + """Build the SGLang launch command for one worker (D or P).""" + common = [ + sys.executable, "-m", "sglang.launch_server", + "--model-path", args.model, + "--host", "127.0.0.1", + "--port", str(base_port), + "--tp-size", "1", + "--mem-fraction-static", "0.6", + "--disable-cuda-graph", + "--disable-overlap-schedule", + "--enable-streaming-session", + "--disaggregation-mode", role, + "--disaggregation-transfer-backend", "mooncake", + "--disaggregation-bootstrap-port", str(base_port + 5000), + "--disaggregation-ib-device", ib_device, + ] + return common + + +def _server_env(args, gpu_id: int, snapshot_port: int, ib_device: str) -> dict: + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + env["SGLANG_SNAPSHOT_LINK_ENABLE"] = "1" + env["SGLANG_SNAPSHOT_LINK_HOST"] = "127.0.0.1" + env["SGLANG_SNAPSHOT_LINK_PORT"] = str(snapshot_port) + env["SGLANG_SNAPSHOT_LINK_IB_DEVICE"] = ib_device + env["MOONCAKE_PROTOCOL"] = "rdma" + env["MOONCAKE_DEVICE"] = ib_device + env["MC_TRANSFER_TIMEOUT"] = "1800" + return env + + +def _wait_for_ready(url: str, timeout: float = 240.0) -> bool: + deadline = time.time() + timeout + while time.time() < deadline: + try: + r = httpx.get(f"{url}/health", timeout=2.0) + if r.status_code == 200: + return True + except Exception: + pass + time.sleep(2) + return False + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--model", + default=os.environ.get("MODEL", "/mnt/models/Qwen/Qwen3-30B-A3B-Instruct-2507")) + ap.add_argument("--d-gpu", type=int, default=1) + ap.add_argument("--p-gpu", type=int, default=0) + ap.add_argument("--d-port", type=int, default=29040) + ap.add_argument("--p-port", type=int, default=29041) + ap.add_argument("--d-snap-port", type=int, default=29045) + ap.add_argument("--p-snap-port", type=int, default=29046) + ap.add_argument("--ib", default="mlx5_60") + ap.add_argument("--log-dir", default="outputs/snapshot_sglang_smoke") + args = ap.parse_args() + + log_dir = Path(args.log_dir) + log_dir.mkdir(parents=True, exist_ok=True) + + # Spawn P first (so D can find its snapshot endpoint later via prepare_receive) + p_cmd = _build_server_cmd(args, "prefill", args.p_gpu, args.p_port, + args.p_snap_port, args.ib) + p_env = _server_env(args, args.p_gpu, args.p_snap_port, args.ib) + p_stdout = open(log_dir / "p.stdout", "w") + p_stderr = open(log_dir / "p.stderr", "w") + print(f"[smoke] launching P: {' '.join(p_cmd)}") + p_proc = subprocess.Popen(p_cmd, env=p_env, stdout=p_stdout, stderr=p_stderr) + + d_cmd = _build_server_cmd(args, "decode", args.d_gpu, args.d_port, + args.d_snap_port, args.ib) + d_env = _server_env(args, args.d_gpu, args.d_snap_port, args.ib) + d_stdout = open(log_dir / "d.stdout", "w") + d_stderr = open(log_dir / "d.stderr", "w") + print(f"[smoke] launching D: {' '.join(d_cmd)}") + d_proc = subprocess.Popen(d_cmd, env=d_env, stdout=d_stdout, stderr=d_stderr) + + try: + print(f"[smoke] waiting for P @ 127.0.0.1:{args.p_port} ...") + if not _wait_for_ready(f"http://127.0.0.1:{args.p_port}", timeout=300): + _tail_stderr(log_dir / "p.stderr") + raise RuntimeError("P server did not become healthy") + print(f"[smoke] waiting for D @ 127.0.0.1:{args.d_port} ...") + if not _wait_for_ready(f"http://127.0.0.1:{args.d_port}", timeout=300): + _tail_stderr(log_dir / "d.stderr") + raise RuntimeError("D server did not become healthy") + print(f"[smoke] both servers up — running RPC sanity ...") + + session_id = "smoke-sess-001" + # 1. Open streaming session on D + r = httpx.post(f"http://127.0.0.1:{args.d_port}/open_session", + json={"session_id": session_id, "capacity_of_str_len": 8192, + "streaming": True}, timeout=30) + print(f"[smoke] open_session on D → {r.status_code}: {r.text[:200]}") + + # 2. Send a small prefill+decode request directly to D (in direct-append mode + # we'd normally go via the pd-router, but for this smoke we send raw) + prompt_ids = [1] * 512 # 512 fake tokens + gen_req = { + "input_ids": prompt_ids, + "sampling_params": {"temperature": 0, "max_new_tokens": 1, "min_new_tokens": 1, + "ignore_eos": True, "skip_special_tokens": False}, + "session_params": {"id": session_id}, + "stream": False, + } + try: + r = httpx.post(f"http://127.0.0.1:{args.d_port}/generate", + json=gen_req, timeout=60) + print(f"[smoke] D /generate (seed) → {r.status_code}") + except Exception as e: + print(f"[smoke] D /generate failed: {e}") + + # 3. Probe snapshot link: prepare_receive on P + num_tokens = 512 + prep = httpx.post( + f"http://127.0.0.1:{args.p_port}/_snapshot/prepare_receive", + json={ + "session_id": session_id, + "num_tokens": num_tokens, + "expected_bytes_per_layer_k": 0, + "expected_bytes_per_layer_v": 0, + }, + timeout=30, + ) + print(f"[smoke] prepare_receive on P → {prep.status_code}: {prep.text[:500]}") + if prep.status_code != 200: + return 1 + prep_data = prep.json() + if not prep_data.get("ok"): + print(f"[smoke] prepare_receive returned ok=false: {prep_data}") + return 1 + + # 4. Dump on D + dump = httpx.post( + f"http://127.0.0.1:{args.d_port}/_snapshot/dump", + json={ + "session_id": session_id, + "target_snapshot_session_id": prep_data["snapshot_session_id"], + "target_k_base_ptrs": prep_data["k_base_ptrs"], + "target_v_base_ptrs": prep_data["v_base_ptrs"], + "target_slot_indices": prep_data["slot_indices"], + "target_stride_k_bytes": prep_data["stride_k_bytes"], + "target_stride_v_bytes": prep_data["stride_v_bytes"], + "ib_device": args.ib, + }, + timeout=60, + ) + print(f"[smoke] dump on D → {dump.status_code}: {dump.text[:500]}") + if dump.status_code != 200: + return 1 + dump_data = dump.json() + if not dump_data.get("ok"): + print(f"[smoke] dump returned ok=false: {dump_data}") + return 1 + print(f"[smoke] dump pushed {dump_data.get('bytes_pushed')} bytes") + + # 5. Finalize on P (insert into radix) + fin = httpx.post( + f"http://127.0.0.1:{args.p_port}/_snapshot/finalize_ingest", + json={ + "session_id": session_id, + "token_ids": prompt_ids[:num_tokens], + "slot_indices": prep_data["slot_indices"], + }, + timeout=30, + ) + print(f"[smoke] finalize on P → {fin.status_code}: {fin.text[:500]}") + if fin.status_code != 200: + return 1 + fin_data = fin.json() + if not fin_data.get("ok"): + print(f"[smoke] finalize returned ok=false: {fin_data}") + return 1 + print(f"[smoke] inserted_prefix_len = {fin_data.get('inserted_prefix_len')}") + + # 6. Send the same prefix to P → expect cache hit + gen_p = { + "input_ids": prompt_ids + [42], # prefix + 1 new token + "sampling_params": {"temperature": 0, "max_new_tokens": 1, "min_new_tokens": 1, + "ignore_eos": True, "skip_special_tokens": False}, + "stream": False, + } + r = httpx.post(f"http://127.0.0.1:{args.p_port}/generate", + json=gen_p, timeout=60) + print(f"[smoke] P /generate (with cached prefix) → {r.status_code}: " + f"{r.text[:400]}") + try: + body = r.json() + cached = (body.get("meta_info") or {}).get("cached_tokens", 0) + print(f"[smoke] cached_tokens = {cached}") + if cached > 0: + print("[smoke] OVERALL: PASS — P showed cache-hit after snapshot ingest") + return 0 + else: + print("[smoke] OVERALL: FAIL — P did not report cache hit") + return 2 + except Exception as e: + print(f"[smoke] could not parse P generate response: {e}") + return 3 + finally: + for name, proc in [("D", d_proc), ("P", p_proc)]: + try: + proc.send_signal(signal.SIGINT) + except Exception: + pass + for name, proc in [("D", d_proc), ("P", p_proc)]: + try: + proc.wait(timeout=15) + except Exception: + proc.terminate() + try: + proc.wait(timeout=5) + except Exception: + proc.kill() + + +def _tail_stderr(path: Path, n: int = 60) -> None: + try: + text = path.read_text() + except FileNotFoundError: + return + print(f"--- {path} (last {n}) ---") + for line in text.splitlines()[-n:]: + print(f" {line}") + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/agentic_pd_hybrid/cli.py b/src/agentic_pd_hybrid/cli.py index 2c37d73..e4f1663 100644 --- a/src/agentic_pd_hybrid/cli.py +++ b/src/agentic_pd_hybrid/cli.py @@ -283,6 +283,17 @@ def main() -> None: "See docs/E1_E2_FIX_DESIGN_ZH.md §Q2." ), ) + replay.add_argument( + "--enable-d-to-p-sync", + action="store_true", + help=( + "Enable D→P RDMA KV snapshot push for reseed fast-path. " + "When set, on _invoke_kvcache_seeded_router agentic will probe D's " + "session_aware_cache, RDMA-dump session KV to P's snapshot link, " + "and insert into P's radix tree so the upcoming P prefill hits " + "cache. See docs/D_TO_P_SYNC_DESIGN_ZH.md." + ), + ) sample = subparsers.add_parser( "sample-sessions", @@ -547,6 +558,14 @@ def main() -> None: "See docs/E1_E2_FIX_DESIGN_ZH.md §Q2." ), ) + benchmark.add_argument( + "--enable-d-to-p-sync", + action="store_true", + help=( + "Enable D→P RDMA KV snapshot push for reseed fast-path. " + "See docs/D_TO_P_SYNC_DESIGN_ZH.md." + ), + ) benchmark.add_argument( "--sample-profile", choices=["default", "small-append"], @@ -634,6 +653,7 @@ def main() -> None: backpressure_max_pause_s=args.backpressure_max_pause_s, kvcache_migration_reject_threshold=args.kvcache_migration_reject_threshold, kvcache_load_floor_bonus=args.kvcache_load_floor_bonus, + enable_d_to_p_sync=args.enable_d_to_p_sync, ) results = asyncio.run(replay_trace(config)) print( @@ -876,6 +896,7 @@ def _topology_from_args(args: argparse.Namespace): force_rdma=args.force_rdma, trust_remote_code=not args.no_trust_remote_code, ib_device=args.ib_device, + enable_d_to_p_sync=getattr(args, "enable_d_to_p_sync", False), prefill_extra_server_args=("--disable-overlap-schedule",), decode_extra_server_args=("--disable-overlap-schedule",), direct_extra_server_args=("--enable-streaming-session",), diff --git a/src/agentic_pd_hybrid/replay.py b/src/agentic_pd_hybrid/replay.py index 4c0be51..ed3f9f0 100644 --- a/src/agentic_pd_hybrid/replay.py +++ b/src/agentic_pd_hybrid/replay.py @@ -116,6 +116,11 @@ class ReplayConfig: # with shared cross-session prefix. 0 disables. See # docs/E1_E2_FIX_DESIGN_ZH.md §Q2. kvcache_load_floor_bonus: int = 0 + # D→P snapshot push: when True and reseed fires, agentic will RDMA-dump + # the session's KV from the D-side worker that last held it onto the P + # worker and insert into P's radix tree, so the subsequent P prefill + # hits cache. See docs/D_TO_P_SYNC_DESIGN_ZH.md. + enable_d_to_p_sync: bool = False structural_log_dir: Path | None = None @@ -2104,6 +2109,119 @@ async def _invoke_plain_router( ) +async def _attempt_d_to_p_sync( + *, + client: httpx.AsyncClient, + request: TraceRequest, + config: ReplayConfig, + prefill_url: str, + decode_session: DirectSessionState, +) -> dict | None: + """Try to RDMA-dump session KV from the D that last held it to ``prefill_url``. + + Returns a dict with status info on success/skip, or ``None`` on a + non-recoverable error. The caller falls back to normal re-prefill on + any failure. + """ + if not config.enable_d_to_p_sync: + return None + source_d_url = decode_session.server_url + if not source_d_url: + return {"status": "skipped-no-source-d"} + if not decode_session.opened: + return {"status": "skipped-d-closed"} + # Compose token list for radix insert: we don't have the actual token_ids + # on the agentic side in a stable form; use the request's prompt_token_ids + # via the residency bookkeeping. For now we use a length proxy. + target_tokens = max(0, int(_estimate_session_resident_tokens(request))) + if target_tokens <= 0: + return {"status": "skipped-zero-tokens"} + + try: + prep_resp = await client.post( + f"{prefill_url}/_snapshot/prepare_receive", + json={ + "session_id": request.session_id, + "num_tokens": target_tokens, + }, + timeout=30.0, + ) + prep_resp.raise_for_status() + prep = prep_resp.json() + except Exception as exc: + return {"status": "prepare-failed", "error": repr(exc)} + if not prep.get("ok"): + return {"status": "prepare-not-ok", "reason": prep.get("reason")} + + try: + dump_resp = await client.post( + f"{source_d_url}/_snapshot/dump", + json={ + "session_id": request.session_id, + "target_snapshot_session_id": prep["snapshot_session_id"], + "target_k_base_ptrs": prep["k_base_ptrs"], + "target_v_base_ptrs": prep["v_base_ptrs"], + "target_slot_indices": prep["slot_indices"], + "target_stride_k_bytes": prep["stride_k_bytes"], + "target_stride_v_bytes": prep["stride_v_bytes"], + }, + timeout=60.0, + ) + dump_resp.raise_for_status() + dump = dump_resp.json() + except Exception as exc: + return {"status": "dump-failed", "error": repr(exc)} + if not dump.get("ok"): + return {"status": "dump-not-ok", "reason": dump.get("reason"), + "bytes_pushed": dump.get("bytes_pushed", 0)} + + # We need token_ids for radix insert. The caller has request.input_token_ids + # for the first N — use that as best-available approximation. + tokens = list(getattr(request, "input_token_ids", []) or []) + if not tokens: + # No token_ids available — can't insert into radix. P will fall back + # to normal prefill but will have wasted slots. Discard. + try: + await client.post( + f"{prefill_url}/_snapshot/finalize_ingest", + json={ + "session_id": request.session_id, + "token_ids": [], + "slot_indices": prep["slot_indices"], + }, + timeout=15.0, + ) + except Exception: + pass + return {"status": "no-tokens-discard", "bytes_pushed": dump.get("bytes_pushed", 0)} + + n = min(len(tokens), len(prep["slot_indices"])) + try: + fin_resp = await client.post( + f"{prefill_url}/_snapshot/finalize_ingest", + json={ + "session_id": request.session_id, + "token_ids": tokens[:n], + "slot_indices": prep["slot_indices"][:n], + }, + timeout=30.0, + ) + fin_resp.raise_for_status() + fin = fin_resp.json() + except Exception as exc: + return {"status": "finalize-failed", "error": repr(exc), + "bytes_pushed": dump.get("bytes_pushed", 0)} + if not fin.get("ok"): + return {"status": "finalize-not-ok", "reason": fin.get("reason"), + "bytes_pushed": dump.get("bytes_pushed", 0)} + return { + "status": "ok", + "bytes_pushed": int(dump.get("bytes_pushed", 0)), + "inserted_prefix_len": int(fin.get("inserted_prefix_len", 0)), + "snapshot_session_id": prep.get("snapshot_session_id"), + } + + async def _invoke_kvcache_seeded_router( *, client: httpx.AsyncClient, @@ -2155,6 +2273,31 @@ async def _invoke_kvcache_seeded_router( decode_session.prefill_server_url = prefill_url prefill_session_newly_opened = True + # D→P snapshot push (Phase 3) — best-effort; on any failure we silently + # fall back to the existing re-prefill path. The result is logged for + # post-hoc analysis but does not affect correctness. + if config.enable_d_to_p_sync: + sync_result = await _attempt_d_to_p_sync( + client=client, + request=request, + config=config, + prefill_url=prefill_url, + decode_session=decode_session, + ) + if sync_result is not None and sync_result.get("status") != "ok": + logger.info( + "d_to_p_sync sid=%s rid=%s skipped: %s", + request.session_id, request.request_id, sync_result, + ) + elif sync_result and sync_result.get("status") == "ok": + logger.info( + "d_to_p_sync sid=%s rid=%s pushed=%d ingested_prefix=%d", + request.session_id, + request.request_id, + sync_result.get("bytes_pushed", 0), + sync_result.get("inserted_prefix_len", 0), + ) + decode_session_newly_opened = False try: prefill_priority = _prefill_priority_for_router_request( diff --git a/src/agentic_pd_hybrid/stack.py b/src/agentic_pd_hybrid/stack.py index c6bea65..1c99b59 100644 --- a/src/agentic_pd_hybrid/stack.py +++ b/src/agentic_pd_hybrid/stack.py @@ -209,6 +209,15 @@ def _build_process_env(topology: SingleNodeTopology) -> dict[str, str]: if topology.transfer_backend == "mooncake": env.setdefault("MC_TRANSFER_TIMEOUT", "1800") + # D→P snapshot link (Phase 2). Each worker reads its own + # `disaggregation_bootstrap_port` and binds at `bootstrap_port + 1000` + # for the snapshot mooncake engine (see + # third_party/sglang/.../disaggregation/snapshot/controller.py). + if topology.enable_d_to_p_sync: + env["SGLANG_SNAPSHOT_LINK_ENABLE"] = "1" + if topology.ib_device: + env.setdefault("SGLANG_SNAPSHOT_LINK_IB_DEVICE", topology.ib_device) + repo_root = Path(__file__).resolve().parents[2] python_paths = [ str(repo_root / "src"), diff --git a/src/agentic_pd_hybrid/topology.py b/src/agentic_pd_hybrid/topology.py index 56c2957..13b08df 100644 --- a/src/agentic_pd_hybrid/topology.py +++ b/src/agentic_pd_hybrid/topology.py @@ -46,6 +46,7 @@ class SingleNodeTopology: trust_remote_code: bool force_rdma: bool = False ib_device: str | None = None + enable_d_to_p_sync: bool = False extra_server_args: tuple[str, ...] = () prefill_extra_server_args: tuple[str, ...] = () decode_extra_server_args: tuple[str, ...] = () @@ -95,6 +96,7 @@ def build_single_node_topology( force_rdma: bool = False, trust_remote_code: bool = True, ib_device: str | None = None, + enable_d_to_p_sync: bool = False, extra_server_args: tuple[str, ...] = (), prefill_extra_server_args: tuple[str, ...] = (), decode_extra_server_args: tuple[str, ...] = (), @@ -238,6 +240,7 @@ def build_single_node_topology( trust_remote_code=trust_remote_code, force_rdma=force_rdma, ib_device=ib_device, + enable_d_to_p_sync=enable_d_to_p_sync, extra_server_args=extra_server_args, prefill_extra_server_args=prefill_extra_server_args, decode_extra_server_args=decode_extra_server_args,