225 lines
6.8 KiB
Python
225 lines
6.8 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import signal
|
|
import subprocess
|
|
import time
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
|
|
from agentic_pd_hybrid.launcher import build_launch_plan
|
|
from agentic_pd_hybrid.topology import SingleNodeTopology
|
|
|
|
|
|
@dataclass
|
|
class ManagedProcess:
|
|
name: str
|
|
command: tuple[str, ...]
|
|
process: subprocess.Popen[bytes]
|
|
log_path: Path
|
|
|
|
|
|
@dataclass
|
|
class ManagedPdStack:
|
|
topology: SingleNodeTopology
|
|
run_dir: Path
|
|
prefill_processes: list[ManagedProcess]
|
|
decode_processes: list[ManagedProcess]
|
|
direct_processes: list[ManagedProcess]
|
|
router_process: ManagedProcess | None
|
|
|
|
@property
|
|
def router_url(self) -> str:
|
|
return self.topology.router_url
|
|
|
|
def stop(self) -> None:
|
|
processes = (
|
|
([self.router_process] if self.router_process is not None else [])
|
|
+ self.direct_processes
|
|
+ self.decode_processes
|
|
+ self.prefill_processes
|
|
)
|
|
for managed in processes:
|
|
if managed.process.poll() is None:
|
|
os.killpg(os.getpgid(managed.process.pid), signal.SIGTERM)
|
|
deadline = time.time() + 20
|
|
for managed in processes:
|
|
if managed.process.poll() is not None:
|
|
continue
|
|
remaining = max(0.0, deadline - time.time())
|
|
try:
|
|
managed.process.wait(timeout=remaining)
|
|
except subprocess.TimeoutExpired:
|
|
if managed.process.poll() is None:
|
|
os.killpg(os.getpgid(managed.process.pid), signal.SIGKILL)
|
|
managed.process.wait(timeout=5)
|
|
|
|
|
|
def launch_pd_stack(
|
|
*,
|
|
topology: SingleNodeTopology,
|
|
run_dir: Path,
|
|
prefill_policy: str,
|
|
decode_policy: str,
|
|
timeout_s: float = 1200.0,
|
|
router_request_timeout_s: float | None = None,
|
|
include_router: bool = True,
|
|
) -> ManagedPdStack:
|
|
run_dir.mkdir(parents=True, exist_ok=True)
|
|
logs_dir = run_dir / "logs"
|
|
logs_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
plan = build_launch_plan(
|
|
topology,
|
|
prefill_policy=prefill_policy,
|
|
decode_policy=decode_policy,
|
|
include_router=include_router,
|
|
router_request_timeout_s=router_request_timeout_s,
|
|
)
|
|
|
|
prefill_processes = [
|
|
_spawn_process(
|
|
name=f"prefill-{idx}",
|
|
command=command,
|
|
log_path=logs_dir / f"prefill-{idx}.log",
|
|
topology=topology,
|
|
)
|
|
for idx, command in enumerate(plan.prefill_commands)
|
|
]
|
|
decode_processes = [
|
|
_spawn_process(
|
|
name=f"decode-{idx}",
|
|
command=command,
|
|
log_path=logs_dir / f"decode-{idx}.log",
|
|
topology=topology,
|
|
)
|
|
for idx, command in enumerate(plan.decode_commands)
|
|
]
|
|
direct_processes = [
|
|
_spawn_process(
|
|
name=f"direct-{idx}",
|
|
command=command,
|
|
log_path=logs_dir / f"direct-{idx}.log",
|
|
topology=topology,
|
|
)
|
|
for idx, command in enumerate(plan.direct_commands)
|
|
]
|
|
|
|
router_process: ManagedProcess | None = None
|
|
try:
|
|
for worker in topology.prefill_workers:
|
|
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
|
|
for worker in topology.decode_workers:
|
|
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
|
|
for worker in topology.direct_workers:
|
|
_wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s)
|
|
|
|
if plan.router_command is not None:
|
|
router_process = _spawn_process(
|
|
name="router",
|
|
command=plan.router_command,
|
|
log_path=logs_dir / "router.log",
|
|
topology=topology,
|
|
)
|
|
_wait_for_ready_endpoint(f"{topology.router_url}/health", timeout_s=timeout_s)
|
|
except Exception:
|
|
stack = ManagedPdStack(
|
|
topology=topology,
|
|
run_dir=run_dir,
|
|
prefill_processes=prefill_processes,
|
|
decode_processes=decode_processes,
|
|
direct_processes=direct_processes,
|
|
router_process=router_process,
|
|
)
|
|
stack.stop()
|
|
raise
|
|
|
|
return ManagedPdStack(
|
|
topology=topology,
|
|
run_dir=run_dir,
|
|
prefill_processes=prefill_processes,
|
|
decode_processes=decode_processes,
|
|
direct_processes=direct_processes,
|
|
router_process=router_process,
|
|
)
|
|
|
|
|
|
def _spawn_process(
|
|
*,
|
|
name: str,
|
|
command: tuple[str, ...],
|
|
log_path: Path,
|
|
topology: SingleNodeTopology,
|
|
) -> ManagedProcess:
|
|
log_handle = log_path.open("wb")
|
|
env = _build_process_env(topology)
|
|
process = subprocess.Popen(
|
|
command,
|
|
stdout=log_handle,
|
|
stderr=subprocess.STDOUT,
|
|
env=env,
|
|
preexec_fn=os.setsid,
|
|
)
|
|
return ManagedProcess(
|
|
name=name,
|
|
command=command,
|
|
process=process,
|
|
log_path=log_path,
|
|
)
|
|
|
|
|
|
def _build_process_env(topology: SingleNodeTopology) -> dict[str, str]:
|
|
env = os.environ.copy()
|
|
env["PYTHONDONTWRITEBYTECODE"] = "1"
|
|
env["PYTHONUNBUFFERED"] = "1"
|
|
|
|
# SGLang's PD bootstrap path uses `requests`; force localhost traffic to stay local.
|
|
for key in (
|
|
"http_proxy",
|
|
"https_proxy",
|
|
"all_proxy",
|
|
"HTTP_PROXY",
|
|
"HTTPS_PROXY",
|
|
"ALL_PROXY",
|
|
):
|
|
env.pop(key, None)
|
|
env["NO_PROXY"] = "*"
|
|
env["no_proxy"] = "*"
|
|
env.setdefault("SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", "600")
|
|
env.setdefault("SGLANG_DISAGGREGATION_WAITING_TIMEOUT", "600")
|
|
env.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1")
|
|
if topology.force_rdma:
|
|
env["MOONCAKE_PROTOCOL"] = "rdma"
|
|
env["MC_MS_AUTO_DISC"] = "0"
|
|
if topology.ib_device:
|
|
env["MOONCAKE_DEVICE"] = topology.ib_device
|
|
|
|
repo_root = Path(__file__).resolve().parents[2]
|
|
python_paths = [
|
|
str(repo_root / "src"),
|
|
str(repo_root / "third_party" / "sglang" / "python"),
|
|
]
|
|
existing_pythonpath = env.get("PYTHONPATH")
|
|
if existing_pythonpath:
|
|
python_paths.append(existing_pythonpath)
|
|
env["PYTHONPATH"] = os.pathsep.join(python_paths)
|
|
return env
|
|
|
|
|
|
def _wait_for_ready_endpoint(url: str, *, timeout_s: float) -> None:
|
|
start = time.perf_counter()
|
|
last_error: str | None = None
|
|
with httpx.Client(timeout=5.0, trust_env=False) as client:
|
|
while time.perf_counter() - start < timeout_s:
|
|
try:
|
|
response = client.get(url)
|
|
if response.status_code == 200:
|
|
return
|
|
last_error = f"status={response.status_code}"
|
|
except Exception as exc: # pragma: no cover
|
|
last_error = f"{type(exc).__name__}: {exc}"
|
|
time.sleep(1.0)
|
|
raise TimeoutError(f"Timed out waiting for {url} ({last_error})")
|