Files
agentic-pd-hybrid/src/agentic_pd_hybrid/stack.py

223 lines
6.7 KiB
Python

from __future__ import annotations
import os
import signal
import subprocess
import time
from dataclasses import dataclass
from pathlib import Path
import httpx
from agentic_pd_hybrid.launcher import build_launch_plan
from agentic_pd_hybrid.topology import SingleNodeTopology
@dataclass
class ManagedProcess:
name: str
command: tuple[str, ...]
process: subprocess.Popen[bytes]
log_path: Path
@dataclass
class ManagedPdStack:
topology: SingleNodeTopology
run_dir: Path
prefill_processes: list[ManagedProcess]
decode_processes: list[ManagedProcess]
direct_processes: list[ManagedProcess]
router_process: ManagedProcess | None
@property
def router_url(self) -> str:
return self.topology.router_url
def stop(self) -> None:
processes = (
([self.router_process] if self.router_process is not None else [])
+ self.direct_processes
+ self.decode_processes
+ self.prefill_processes
)
for managed in processes:
if managed.process.poll() is None:
os.killpg(os.getpgid(managed.process.pid), signal.SIGTERM)
deadline = time.time() + 20
for managed in processes:
if managed.process.poll() is not None:
continue
remaining = max(0.0, deadline - time.time())
try:
managed.process.wait(timeout=remaining)
except subprocess.TimeoutExpired:
if managed.process.poll() is None:
os.killpg(os.getpgid(managed.process.pid), signal.SIGKILL)
managed.process.wait(timeout=5)
def launch_pd_stack(
*,
topology: SingleNodeTopology,
run_dir: Path,
prefill_policy: str,
decode_policy: str,
timeout_s: float = 1200.0,
include_router: bool = True,
) -> ManagedPdStack:
run_dir.mkdir(parents=True, exist_ok=True)
logs_dir = run_dir / "logs"
logs_dir.mkdir(parents=True, exist_ok=True)
plan = build_launch_plan(
topology,
prefill_policy=prefill_policy,
decode_policy=decode_policy,
include_router=include_router,
)
prefill_processes = [
_spawn_process(
name=f"prefill-{idx}",
command=command,
log_path=logs_dir / f"prefill-{idx}.log",
topology=topology,
)
for idx, command in enumerate(plan.prefill_commands)
]
decode_processes = [
_spawn_process(
name=f"decode-{idx}",
command=command,
log_path=logs_dir / f"decode-{idx}.log",
topology=topology,
)
for idx, command in enumerate(plan.decode_commands)
]
direct_processes = [
_spawn_process(
name=f"direct-{idx}",
command=command,
log_path=logs_dir / f"direct-{idx}.log",
topology=topology,
)
for idx, command in enumerate(plan.direct_commands)
]
router_process: ManagedProcess | None = None
try:
for worker in topology.prefill_workers:
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
for worker in topology.decode_workers:
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
for worker in topology.direct_workers:
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
if plan.router_command is not None:
router_process = _spawn_process(
name="router",
command=plan.router_command,
log_path=logs_dir / "router.log",
topology=topology,
)
_wait_for_ready_endpoint(f"{topology.router_url}/health", timeout_s=timeout_s)
except Exception:
stack = ManagedPdStack(
topology=topology,
run_dir=run_dir,
prefill_processes=prefill_processes,
decode_processes=decode_processes,
direct_processes=direct_processes,
router_process=router_process,
)
stack.stop()
raise
return ManagedPdStack(
topology=topology,
run_dir=run_dir,
prefill_processes=prefill_processes,
decode_processes=decode_processes,
direct_processes=direct_processes,
router_process=router_process,
)
def _spawn_process(
*,
name: str,
command: tuple[str, ...],
log_path: Path,
topology: SingleNodeTopology,
) -> ManagedProcess:
log_handle = log_path.open("wb")
env = _build_process_env(topology)
process = subprocess.Popen(
command,
stdout=log_handle,
stderr=subprocess.STDOUT,
env=env,
preexec_fn=os.setsid,
)
return ManagedProcess(
name=name,
command=command,
process=process,
log_path=log_path,
)
def _build_process_env(topology: SingleNodeTopology) -> dict[str, str]:
env = os.environ.copy()
env["PYTHONDONTWRITEBYTECODE"] = "1"
env["PYTHONUNBUFFERED"] = "1"
# SGLang's PD bootstrap path uses `requests`; force localhost traffic to stay local.
for key in (
"http_proxy",
"https_proxy",
"all_proxy",
"HTTP_PROXY",
"HTTPS_PROXY",
"ALL_PROXY",
):
env.pop(key, None)
env["NO_PROXY"] = "*"
env["no_proxy"] = "*"
env.setdefault("SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", "600")
env.setdefault("SGLANG_DISAGGREGATION_WAITING_TIMEOUT", "60")
env.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1")
if topology.force_rdma:
env["MOONCAKE_PROTOCOL"] = "rdma"
env["MC_MS_AUTO_DISC"] = "0"
if topology.ib_device:
env["MOONCAKE_DEVICE"] = topology.ib_device
repo_root = Path(__file__).resolve().parents[2]
python_paths = [
str(repo_root / "src"),
str(repo_root / "third_party" / "sglang" / "python"),
]
existing_pythonpath = env.get("PYTHONPATH")
if existing_pythonpath:
python_paths.append(existing_pythonpath)
env["PYTHONPATH"] = os.pathsep.join(python_paths)
return env
def _wait_for_ready_endpoint(url: str, *, timeout_s: float) -> None:
start = time.perf_counter()
last_error: str | None = None
with httpx.Client(timeout=5.0, trust_env=False) as client:
while time.perf_counter() - start < timeout_s:
try:
response = client.get(url)
if response.status_code == 200:
return
last_error = f"status={response.status_code}"
except Exception as exc: # pragma: no cover
last_error = f"{type(exc).__name__}: {exc}"
time.sleep(1.0)
raise TimeoutError(f"Timed out waiting for {url} ({last_error})")