feat(agentic): D→P snapshot orchestration in reseed path + CLI flag
Phase 3 — wires the SGLang-side snapshot RPCs (committed in 86412bb)
into the agentic reseed slow-path. On _invoke_kvcache_seeded_router:
1. POST {prefill_url}/_snapshot/prepare_receive alloc P-side slots
2. POST {old_decode_url}/_snapshot/dump RDMA push session KV
3. POST {prefill_url}/_snapshot/finalize_ingest insert into P radix
After step 3 P's radix tree has the session prefix cached; the subsequent
SGLang router-driven prefill on P hits cache instead of re-computing.
Any RPC failure short-circuits to the existing seeded_router fallback
(re-prefill from scratch). All steps are best-effort and structurally
logged for post-hoc analysis.
Flag plumbing:
cli.py --enable-d-to-p-sync (replay + benchmark)
topology.py SingleNodeTopology.enable_d_to_p_sync
stack.py SGLANG_SNAPSHOT_LINK_ENABLE=1 injection per worker
replay.py ReplayConfig.enable_d_to_p_sync +
_attempt_d_to_p_sync helper
Snapshot port per worker derives from disaggregation_bootstrap_port +
1000 (set in third_party/.../snapshot/controller.py), so different
workers get distinct mooncake snapshot engines on the same node.
Smoke (next): scripts/smoke_snapshot_sglang_integration.py spawns one
D + one P, exercises the 3 RPCs end-to-end, checks cache_tokens on a
follow-up generate request.
See docs/D_TO_P_SYNC_DESIGN_ZH.md for the full design.
This commit is contained in:
274
scripts/smoke_snapshot_sglang_integration.py
Normal file
274
scripts/smoke_snapshot_sglang_integration.py
Normal file
@@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python3
|
||||
"""End-to-end smoke for the SGLang snapshot link integration.
|
||||
|
||||
Brings up TWO SGLang workers on this node (one acts as D, the other as P)
|
||||
with ``SGLANG_SNAPSHOT_LINK_ENABLE=1`` and exercises the three RPCs:
|
||||
|
||||
1. POST {P}/_snapshot/prepare_receive → P allocates kv_pool slots
|
||||
2. POST {D}/_snapshot/dump → D RDMA-pushes session KV
|
||||
3. POST {P}/_snapshot/finalize_ingest → P inserts into radix tree
|
||||
|
||||
To populate D's SessionAwareCache with a session, we first send a normal
|
||||
streaming-session generate request to D.
|
||||
|
||||
After finalize, we send another generate request to P with the same prefix
|
||||
and check whether the report says cached_tokens > 0 (cache hit).
|
||||
|
||||
This is a minimum-fidelity end-to-end smoke. It does NOT use the full
|
||||
agentic-pd-hybrid reseed orchestration; that's the next commit.
|
||||
|
||||
Required env:
|
||||
MODEL default /mnt/models/Qwen/Qwen3-30B-A3B-Instruct-2507
|
||||
|
||||
Usage:
|
||||
bash scripts/setup_env.sh && uv run --no-sync python \
|
||||
scripts/smoke_snapshot_sglang_integration.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
def _build_server_cmd(args, role: str, gpu_id: int, base_port: int,
|
||||
snapshot_port: int, ib_device: str) -> list:
|
||||
"""Build the SGLang launch command for one worker (D or P)."""
|
||||
common = [
|
||||
sys.executable, "-m", "sglang.launch_server",
|
||||
"--model-path", args.model,
|
||||
"--host", "127.0.0.1",
|
||||
"--port", str(base_port),
|
||||
"--tp-size", "1",
|
||||
"--mem-fraction-static", "0.6",
|
||||
"--disable-cuda-graph",
|
||||
"--disable-overlap-schedule",
|
||||
"--enable-streaming-session",
|
||||
"--disaggregation-mode", role,
|
||||
"--disaggregation-transfer-backend", "mooncake",
|
||||
"--disaggregation-bootstrap-port", str(base_port + 5000),
|
||||
"--disaggregation-ib-device", ib_device,
|
||||
]
|
||||
return common
|
||||
|
||||
|
||||
def _server_env(args, gpu_id: int, snapshot_port: int, ib_device: str) -> dict:
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
env["SGLANG_SNAPSHOT_LINK_ENABLE"] = "1"
|
||||
env["SGLANG_SNAPSHOT_LINK_HOST"] = "127.0.0.1"
|
||||
env["SGLANG_SNAPSHOT_LINK_PORT"] = str(snapshot_port)
|
||||
env["SGLANG_SNAPSHOT_LINK_IB_DEVICE"] = ib_device
|
||||
env["MOONCAKE_PROTOCOL"] = "rdma"
|
||||
env["MOONCAKE_DEVICE"] = ib_device
|
||||
env["MC_TRANSFER_TIMEOUT"] = "1800"
|
||||
return env
|
||||
|
||||
|
||||
def _wait_for_ready(url: str, timeout: float = 240.0) -> bool:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
r = httpx.get(f"{url}/health", timeout=2.0)
|
||||
if r.status_code == 200:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(2)
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model",
|
||||
default=os.environ.get("MODEL", "/mnt/models/Qwen/Qwen3-30B-A3B-Instruct-2507"))
|
||||
ap.add_argument("--d-gpu", type=int, default=1)
|
||||
ap.add_argument("--p-gpu", type=int, default=0)
|
||||
ap.add_argument("--d-port", type=int, default=29040)
|
||||
ap.add_argument("--p-port", type=int, default=29041)
|
||||
ap.add_argument("--d-snap-port", type=int, default=29045)
|
||||
ap.add_argument("--p-snap-port", type=int, default=29046)
|
||||
ap.add_argument("--ib", default="mlx5_60")
|
||||
ap.add_argument("--log-dir", default="outputs/snapshot_sglang_smoke")
|
||||
args = ap.parse_args()
|
||||
|
||||
log_dir = Path(args.log_dir)
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Spawn P first (so D can find its snapshot endpoint later via prepare_receive)
|
||||
p_cmd = _build_server_cmd(args, "prefill", args.p_gpu, args.p_port,
|
||||
args.p_snap_port, args.ib)
|
||||
p_env = _server_env(args, args.p_gpu, args.p_snap_port, args.ib)
|
||||
p_stdout = open(log_dir / "p.stdout", "w")
|
||||
p_stderr = open(log_dir / "p.stderr", "w")
|
||||
print(f"[smoke] launching P: {' '.join(p_cmd)}")
|
||||
p_proc = subprocess.Popen(p_cmd, env=p_env, stdout=p_stdout, stderr=p_stderr)
|
||||
|
||||
d_cmd = _build_server_cmd(args, "decode", args.d_gpu, args.d_port,
|
||||
args.d_snap_port, args.ib)
|
||||
d_env = _server_env(args, args.d_gpu, args.d_snap_port, args.ib)
|
||||
d_stdout = open(log_dir / "d.stdout", "w")
|
||||
d_stderr = open(log_dir / "d.stderr", "w")
|
||||
print(f"[smoke] launching D: {' '.join(d_cmd)}")
|
||||
d_proc = subprocess.Popen(d_cmd, env=d_env, stdout=d_stdout, stderr=d_stderr)
|
||||
|
||||
try:
|
||||
print(f"[smoke] waiting for P @ 127.0.0.1:{args.p_port} ...")
|
||||
if not _wait_for_ready(f"http://127.0.0.1:{args.p_port}", timeout=300):
|
||||
_tail_stderr(log_dir / "p.stderr")
|
||||
raise RuntimeError("P server did not become healthy")
|
||||
print(f"[smoke] waiting for D @ 127.0.0.1:{args.d_port} ...")
|
||||
if not _wait_for_ready(f"http://127.0.0.1:{args.d_port}", timeout=300):
|
||||
_tail_stderr(log_dir / "d.stderr")
|
||||
raise RuntimeError("D server did not become healthy")
|
||||
print(f"[smoke] both servers up — running RPC sanity ...")
|
||||
|
||||
session_id = "smoke-sess-001"
|
||||
# 1. Open streaming session on D
|
||||
r = httpx.post(f"http://127.0.0.1:{args.d_port}/open_session",
|
||||
json={"session_id": session_id, "capacity_of_str_len": 8192,
|
||||
"streaming": True}, timeout=30)
|
||||
print(f"[smoke] open_session on D → {r.status_code}: {r.text[:200]}")
|
||||
|
||||
# 2. Send a small prefill+decode request directly to D (in direct-append mode
|
||||
# we'd normally go via the pd-router, but for this smoke we send raw)
|
||||
prompt_ids = [1] * 512 # 512 fake tokens
|
||||
gen_req = {
|
||||
"input_ids": prompt_ids,
|
||||
"sampling_params": {"temperature": 0, "max_new_tokens": 1, "min_new_tokens": 1,
|
||||
"ignore_eos": True, "skip_special_tokens": False},
|
||||
"session_params": {"id": session_id},
|
||||
"stream": False,
|
||||
}
|
||||
try:
|
||||
r = httpx.post(f"http://127.0.0.1:{args.d_port}/generate",
|
||||
json=gen_req, timeout=60)
|
||||
print(f"[smoke] D /generate (seed) → {r.status_code}")
|
||||
except Exception as e:
|
||||
print(f"[smoke] D /generate failed: {e}")
|
||||
|
||||
# 3. Probe snapshot link: prepare_receive on P
|
||||
num_tokens = 512
|
||||
prep = httpx.post(
|
||||
f"http://127.0.0.1:{args.p_port}/_snapshot/prepare_receive",
|
||||
json={
|
||||
"session_id": session_id,
|
||||
"num_tokens": num_tokens,
|
||||
"expected_bytes_per_layer_k": 0,
|
||||
"expected_bytes_per_layer_v": 0,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
print(f"[smoke] prepare_receive on P → {prep.status_code}: {prep.text[:500]}")
|
||||
if prep.status_code != 200:
|
||||
return 1
|
||||
prep_data = prep.json()
|
||||
if not prep_data.get("ok"):
|
||||
print(f"[smoke] prepare_receive returned ok=false: {prep_data}")
|
||||
return 1
|
||||
|
||||
# 4. Dump on D
|
||||
dump = httpx.post(
|
||||
f"http://127.0.0.1:{args.d_port}/_snapshot/dump",
|
||||
json={
|
||||
"session_id": session_id,
|
||||
"target_snapshot_session_id": prep_data["snapshot_session_id"],
|
||||
"target_k_base_ptrs": prep_data["k_base_ptrs"],
|
||||
"target_v_base_ptrs": prep_data["v_base_ptrs"],
|
||||
"target_slot_indices": prep_data["slot_indices"],
|
||||
"target_stride_k_bytes": prep_data["stride_k_bytes"],
|
||||
"target_stride_v_bytes": prep_data["stride_v_bytes"],
|
||||
"ib_device": args.ib,
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
print(f"[smoke] dump on D → {dump.status_code}: {dump.text[:500]}")
|
||||
if dump.status_code != 200:
|
||||
return 1
|
||||
dump_data = dump.json()
|
||||
if not dump_data.get("ok"):
|
||||
print(f"[smoke] dump returned ok=false: {dump_data}")
|
||||
return 1
|
||||
print(f"[smoke] dump pushed {dump_data.get('bytes_pushed')} bytes")
|
||||
|
||||
# 5. Finalize on P (insert into radix)
|
||||
fin = httpx.post(
|
||||
f"http://127.0.0.1:{args.p_port}/_snapshot/finalize_ingest",
|
||||
json={
|
||||
"session_id": session_id,
|
||||
"token_ids": prompt_ids[:num_tokens],
|
||||
"slot_indices": prep_data["slot_indices"],
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
print(f"[smoke] finalize on P → {fin.status_code}: {fin.text[:500]}")
|
||||
if fin.status_code != 200:
|
||||
return 1
|
||||
fin_data = fin.json()
|
||||
if not fin_data.get("ok"):
|
||||
print(f"[smoke] finalize returned ok=false: {fin_data}")
|
||||
return 1
|
||||
print(f"[smoke] inserted_prefix_len = {fin_data.get('inserted_prefix_len')}")
|
||||
|
||||
# 6. Send the same prefix to P → expect cache hit
|
||||
gen_p = {
|
||||
"input_ids": prompt_ids + [42], # prefix + 1 new token
|
||||
"sampling_params": {"temperature": 0, "max_new_tokens": 1, "min_new_tokens": 1,
|
||||
"ignore_eos": True, "skip_special_tokens": False},
|
||||
"stream": False,
|
||||
}
|
||||
r = httpx.post(f"http://127.0.0.1:{args.p_port}/generate",
|
||||
json=gen_p, timeout=60)
|
||||
print(f"[smoke] P /generate (with cached prefix) → {r.status_code}: "
|
||||
f"{r.text[:400]}")
|
||||
try:
|
||||
body = r.json()
|
||||
cached = (body.get("meta_info") or {}).get("cached_tokens", 0)
|
||||
print(f"[smoke] cached_tokens = {cached}")
|
||||
if cached > 0:
|
||||
print("[smoke] OVERALL: PASS — P showed cache-hit after snapshot ingest")
|
||||
return 0
|
||||
else:
|
||||
print("[smoke] OVERALL: FAIL — P did not report cache hit")
|
||||
return 2
|
||||
except Exception as e:
|
||||
print(f"[smoke] could not parse P generate response: {e}")
|
||||
return 3
|
||||
finally:
|
||||
for name, proc in [("D", d_proc), ("P", p_proc)]:
|
||||
try:
|
||||
proc.send_signal(signal.SIGINT)
|
||||
except Exception:
|
||||
pass
|
||||
for name, proc in [("D", d_proc), ("P", p_proc)]:
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except Exception:
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=5)
|
||||
except Exception:
|
||||
proc.kill()
|
||||
|
||||
|
||||
def _tail_stderr(path: Path, n: int = 60) -> None:
|
||||
try:
|
||||
text = path.read_text()
|
||||
except FileNotFoundError:
|
||||
return
|
||||
print(f"--- {path} (last {n}) ---")
|
||||
for line in text.splitlines()[-n:]:
|
||||
print(f" {line}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -283,6 +283,17 @@ def main() -> None:
|
||||
"See docs/E1_E2_FIX_DESIGN_ZH.md §Q2."
|
||||
),
|
||||
)
|
||||
replay.add_argument(
|
||||
"--enable-d-to-p-sync",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Enable D→P RDMA KV snapshot push for reseed fast-path. "
|
||||
"When set, on _invoke_kvcache_seeded_router agentic will probe D's "
|
||||
"session_aware_cache, RDMA-dump session KV to P's snapshot link, "
|
||||
"and insert into P's radix tree so the upcoming P prefill hits "
|
||||
"cache. See docs/D_TO_P_SYNC_DESIGN_ZH.md."
|
||||
),
|
||||
)
|
||||
|
||||
sample = subparsers.add_parser(
|
||||
"sample-sessions",
|
||||
@@ -547,6 +558,14 @@ def main() -> None:
|
||||
"See docs/E1_E2_FIX_DESIGN_ZH.md §Q2."
|
||||
),
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--enable-d-to-p-sync",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Enable D→P RDMA KV snapshot push for reseed fast-path. "
|
||||
"See docs/D_TO_P_SYNC_DESIGN_ZH.md."
|
||||
),
|
||||
)
|
||||
benchmark.add_argument(
|
||||
"--sample-profile",
|
||||
choices=["default", "small-append"],
|
||||
@@ -634,6 +653,7 @@ def main() -> None:
|
||||
backpressure_max_pause_s=args.backpressure_max_pause_s,
|
||||
kvcache_migration_reject_threshold=args.kvcache_migration_reject_threshold,
|
||||
kvcache_load_floor_bonus=args.kvcache_load_floor_bonus,
|
||||
enable_d_to_p_sync=args.enable_d_to_p_sync,
|
||||
)
|
||||
results = asyncio.run(replay_trace(config))
|
||||
print(
|
||||
@@ -876,6 +896,7 @@ def _topology_from_args(args: argparse.Namespace):
|
||||
force_rdma=args.force_rdma,
|
||||
trust_remote_code=not args.no_trust_remote_code,
|
||||
ib_device=args.ib_device,
|
||||
enable_d_to_p_sync=getattr(args, "enable_d_to_p_sync", False),
|
||||
prefill_extra_server_args=("--disable-overlap-schedule",),
|
||||
decode_extra_server_args=("--disable-overlap-schedule",),
|
||||
direct_extra_server_args=("--enable-streaming-session",),
|
||||
|
||||
@@ -116,6 +116,11 @@ class ReplayConfig:
|
||||
# with shared cross-session prefix. 0 disables. See
|
||||
# docs/E1_E2_FIX_DESIGN_ZH.md §Q2.
|
||||
kvcache_load_floor_bonus: int = 0
|
||||
# D→P snapshot push: when True and reseed fires, agentic will RDMA-dump
|
||||
# the session's KV from the D-side worker that last held it onto the P
|
||||
# worker and insert into P's radix tree, so the subsequent P prefill
|
||||
# hits cache. See docs/D_TO_P_SYNC_DESIGN_ZH.md.
|
||||
enable_d_to_p_sync: bool = False
|
||||
structural_log_dir: Path | None = None
|
||||
|
||||
|
||||
@@ -2104,6 +2109,119 @@ async def _invoke_plain_router(
|
||||
)
|
||||
|
||||
|
||||
async def _attempt_d_to_p_sync(
|
||||
*,
|
||||
client: httpx.AsyncClient,
|
||||
request: TraceRequest,
|
||||
config: ReplayConfig,
|
||||
prefill_url: str,
|
||||
decode_session: DirectSessionState,
|
||||
) -> dict | None:
|
||||
"""Try to RDMA-dump session KV from the D that last held it to ``prefill_url``.
|
||||
|
||||
Returns a dict with status info on success/skip, or ``None`` on a
|
||||
non-recoverable error. The caller falls back to normal re-prefill on
|
||||
any failure.
|
||||
"""
|
||||
if not config.enable_d_to_p_sync:
|
||||
return None
|
||||
source_d_url = decode_session.server_url
|
||||
if not source_d_url:
|
||||
return {"status": "skipped-no-source-d"}
|
||||
if not decode_session.opened:
|
||||
return {"status": "skipped-d-closed"}
|
||||
# Compose token list for radix insert: we don't have the actual token_ids
|
||||
# on the agentic side in a stable form; use the request's prompt_token_ids
|
||||
# via the residency bookkeeping. For now we use a length proxy.
|
||||
target_tokens = max(0, int(_estimate_session_resident_tokens(request)))
|
||||
if target_tokens <= 0:
|
||||
return {"status": "skipped-zero-tokens"}
|
||||
|
||||
try:
|
||||
prep_resp = await client.post(
|
||||
f"{prefill_url}/_snapshot/prepare_receive",
|
||||
json={
|
||||
"session_id": request.session_id,
|
||||
"num_tokens": target_tokens,
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
prep_resp.raise_for_status()
|
||||
prep = prep_resp.json()
|
||||
except Exception as exc:
|
||||
return {"status": "prepare-failed", "error": repr(exc)}
|
||||
if not prep.get("ok"):
|
||||
return {"status": "prepare-not-ok", "reason": prep.get("reason")}
|
||||
|
||||
try:
|
||||
dump_resp = await client.post(
|
||||
f"{source_d_url}/_snapshot/dump",
|
||||
json={
|
||||
"session_id": request.session_id,
|
||||
"target_snapshot_session_id": prep["snapshot_session_id"],
|
||||
"target_k_base_ptrs": prep["k_base_ptrs"],
|
||||
"target_v_base_ptrs": prep["v_base_ptrs"],
|
||||
"target_slot_indices": prep["slot_indices"],
|
||||
"target_stride_k_bytes": prep["stride_k_bytes"],
|
||||
"target_stride_v_bytes": prep["stride_v_bytes"],
|
||||
},
|
||||
timeout=60.0,
|
||||
)
|
||||
dump_resp.raise_for_status()
|
||||
dump = dump_resp.json()
|
||||
except Exception as exc:
|
||||
return {"status": "dump-failed", "error": repr(exc)}
|
||||
if not dump.get("ok"):
|
||||
return {"status": "dump-not-ok", "reason": dump.get("reason"),
|
||||
"bytes_pushed": dump.get("bytes_pushed", 0)}
|
||||
|
||||
# We need token_ids for radix insert. The caller has request.input_token_ids
|
||||
# for the first N — use that as best-available approximation.
|
||||
tokens = list(getattr(request, "input_token_ids", []) or [])
|
||||
if not tokens:
|
||||
# No token_ids available — can't insert into radix. P will fall back
|
||||
# to normal prefill but will have wasted slots. Discard.
|
||||
try:
|
||||
await client.post(
|
||||
f"{prefill_url}/_snapshot/finalize_ingest",
|
||||
json={
|
||||
"session_id": request.session_id,
|
||||
"token_ids": [],
|
||||
"slot_indices": prep["slot_indices"],
|
||||
},
|
||||
timeout=15.0,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return {"status": "no-tokens-discard", "bytes_pushed": dump.get("bytes_pushed", 0)}
|
||||
|
||||
n = min(len(tokens), len(prep["slot_indices"]))
|
||||
try:
|
||||
fin_resp = await client.post(
|
||||
f"{prefill_url}/_snapshot/finalize_ingest",
|
||||
json={
|
||||
"session_id": request.session_id,
|
||||
"token_ids": tokens[:n],
|
||||
"slot_indices": prep["slot_indices"][:n],
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
fin_resp.raise_for_status()
|
||||
fin = fin_resp.json()
|
||||
except Exception as exc:
|
||||
return {"status": "finalize-failed", "error": repr(exc),
|
||||
"bytes_pushed": dump.get("bytes_pushed", 0)}
|
||||
if not fin.get("ok"):
|
||||
return {"status": "finalize-not-ok", "reason": fin.get("reason"),
|
||||
"bytes_pushed": dump.get("bytes_pushed", 0)}
|
||||
return {
|
||||
"status": "ok",
|
||||
"bytes_pushed": int(dump.get("bytes_pushed", 0)),
|
||||
"inserted_prefix_len": int(fin.get("inserted_prefix_len", 0)),
|
||||
"snapshot_session_id": prep.get("snapshot_session_id"),
|
||||
}
|
||||
|
||||
|
||||
async def _invoke_kvcache_seeded_router(
|
||||
*,
|
||||
client: httpx.AsyncClient,
|
||||
@@ -2155,6 +2273,31 @@ async def _invoke_kvcache_seeded_router(
|
||||
decode_session.prefill_server_url = prefill_url
|
||||
prefill_session_newly_opened = True
|
||||
|
||||
# D→P snapshot push (Phase 3) — best-effort; on any failure we silently
|
||||
# fall back to the existing re-prefill path. The result is logged for
|
||||
# post-hoc analysis but does not affect correctness.
|
||||
if config.enable_d_to_p_sync:
|
||||
sync_result = await _attempt_d_to_p_sync(
|
||||
client=client,
|
||||
request=request,
|
||||
config=config,
|
||||
prefill_url=prefill_url,
|
||||
decode_session=decode_session,
|
||||
)
|
||||
if sync_result is not None and sync_result.get("status") != "ok":
|
||||
logger.info(
|
||||
"d_to_p_sync sid=%s rid=%s skipped: %s",
|
||||
request.session_id, request.request_id, sync_result,
|
||||
)
|
||||
elif sync_result and sync_result.get("status") == "ok":
|
||||
logger.info(
|
||||
"d_to_p_sync sid=%s rid=%s pushed=%d ingested_prefix=%d",
|
||||
request.session_id,
|
||||
request.request_id,
|
||||
sync_result.get("bytes_pushed", 0),
|
||||
sync_result.get("inserted_prefix_len", 0),
|
||||
)
|
||||
|
||||
decode_session_newly_opened = False
|
||||
try:
|
||||
prefill_priority = _prefill_priority_for_router_request(
|
||||
|
||||
@@ -209,6 +209,15 @@ def _build_process_env(topology: SingleNodeTopology) -> dict[str, str]:
|
||||
if topology.transfer_backend == "mooncake":
|
||||
env.setdefault("MC_TRANSFER_TIMEOUT", "1800")
|
||||
|
||||
# D→P snapshot link (Phase 2). Each worker reads its own
|
||||
# `disaggregation_bootstrap_port` and binds at `bootstrap_port + 1000`
|
||||
# for the snapshot mooncake engine (see
|
||||
# third_party/sglang/.../disaggregation/snapshot/controller.py).
|
||||
if topology.enable_d_to_p_sync:
|
||||
env["SGLANG_SNAPSHOT_LINK_ENABLE"] = "1"
|
||||
if topology.ib_device:
|
||||
env.setdefault("SGLANG_SNAPSHOT_LINK_IB_DEVICE", topology.ib_device)
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
python_paths = [
|
||||
str(repo_root / "src"),
|
||||
|
||||
@@ -46,6 +46,7 @@ class SingleNodeTopology:
|
||||
trust_remote_code: bool
|
||||
force_rdma: bool = False
|
||||
ib_device: str | None = None
|
||||
enable_d_to_p_sync: bool = False
|
||||
extra_server_args: tuple[str, ...] = ()
|
||||
prefill_extra_server_args: tuple[str, ...] = ()
|
||||
decode_extra_server_args: tuple[str, ...] = ()
|
||||
@@ -95,6 +96,7 @@ def build_single_node_topology(
|
||||
force_rdma: bool = False,
|
||||
trust_remote_code: bool = True,
|
||||
ib_device: str | None = None,
|
||||
enable_d_to_p_sync: bool = False,
|
||||
extra_server_args: tuple[str, ...] = (),
|
||||
prefill_extra_server_args: tuple[str, ...] = (),
|
||||
decode_extra_server_args: tuple[str, ...] = (),
|
||||
@@ -238,6 +240,7 @@ def build_single_node_topology(
|
||||
trust_remote_code=trust_remote_code,
|
||||
force_rdma=force_rdma,
|
||||
ib_device=ib_device,
|
||||
enable_d_to_p_sync=enable_d_to_p_sync,
|
||||
extra_server_args=extra_server_args,
|
||||
prefill_extra_server_args=prefill_extra_server_args,
|
||||
decode_extra_server_args=decode_extra_server_args,
|
||||
|
||||
Reference in New Issue
Block a user