from __future__ import annotations import os import signal import subprocess import time from dataclasses import dataclass from pathlib import Path import httpx from agentic_pd_hybrid.launcher import build_launch_plan from agentic_pd_hybrid.topology import SingleNodeTopology @dataclass class ManagedProcess: name: str command: tuple[str, ...] process: subprocess.Popen[bytes] log_path: Path @dataclass class ManagedPdStack: topology: SingleNodeTopology run_dir: Path prefill_processes: list[ManagedProcess] decode_processes: list[ManagedProcess] direct_processes: list[ManagedProcess] router_process: ManagedProcess | None @property def router_url(self) -> str: return self.topology.router_url def stop(self) -> None: processes = ( ([self.router_process] if self.router_process is not None else []) + self.direct_processes + self.decode_processes + self.prefill_processes ) for managed in processes: if managed.process.poll() is None: os.killpg(os.getpgid(managed.process.pid), signal.SIGTERM) deadline = time.time() + 20 for managed in processes: if managed.process.poll() is not None: continue remaining = max(0.0, deadline - time.time()) try: managed.process.wait(timeout=remaining) except subprocess.TimeoutExpired: if managed.process.poll() is None: os.killpg(os.getpgid(managed.process.pid), signal.SIGKILL) managed.process.wait(timeout=5) def launch_pd_stack( *, topology: SingleNodeTopology, run_dir: Path, prefill_policy: str, decode_policy: str, timeout_s: float = 1200.0, include_router: bool = True, ) -> ManagedPdStack: run_dir.mkdir(parents=True, exist_ok=True) logs_dir = run_dir / "logs" logs_dir.mkdir(parents=True, exist_ok=True) plan = build_launch_plan( topology, prefill_policy=prefill_policy, decode_policy=decode_policy, include_router=include_router, ) prefill_processes = [ _spawn_process( name=f"prefill-{idx}", command=command, log_path=logs_dir / f"prefill-{idx}.log", topology=topology, ) for idx, command in enumerate(plan.prefill_commands) ] decode_processes = [ _spawn_process( name=f"decode-{idx}", command=command, log_path=logs_dir / f"decode-{idx}.log", topology=topology, ) for idx, command in enumerate(plan.decode_commands) ] direct_processes = [ _spawn_process( name=f"direct-{idx}", command=command, log_path=logs_dir / f"direct-{idx}.log", topology=topology, ) for idx, command in enumerate(plan.direct_commands) ] router_process: ManagedProcess | None = None try: for worker in topology.prefill_workers: _wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s) for worker in topology.decode_workers: _wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s) for worker in topology.direct_workers: _wait_for_ready_endpoint(f"{worker.url}/v1/models", timeout_s=timeout_s) if plan.router_command is not None: router_process = _spawn_process( name="router", command=plan.router_command, log_path=logs_dir / "router.log", topology=topology, ) _wait_for_ready_endpoint(f"{topology.router_url}/health", timeout_s=timeout_s) except Exception: stack = ManagedPdStack( topology=topology, run_dir=run_dir, prefill_processes=prefill_processes, decode_processes=decode_processes, direct_processes=direct_processes, router_process=router_process, ) stack.stop() raise return ManagedPdStack( topology=topology, run_dir=run_dir, prefill_processes=prefill_processes, decode_processes=decode_processes, direct_processes=direct_processes, router_process=router_process, ) def _spawn_process( *, name: str, command: tuple[str, ...], log_path: Path, topology: SingleNodeTopology, ) -> ManagedProcess: log_handle = log_path.open("wb") env = _build_process_env(topology) process = subprocess.Popen( command, stdout=log_handle, stderr=subprocess.STDOUT, env=env, preexec_fn=os.setsid, ) return ManagedProcess( name=name, command=command, process=process, log_path=log_path, ) def _build_process_env(topology: SingleNodeTopology) -> dict[str, str]: env = os.environ.copy() env["PYTHONDONTWRITEBYTECODE"] = "1" env["PYTHONUNBUFFERED"] = "1" # SGLang's PD bootstrap path uses `requests`; force localhost traffic to stay local. for key in ( "http_proxy", "https_proxy", "all_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", ): env.pop(key, None) env["NO_PROXY"] = "*" env["no_proxy"] = "*" env.setdefault("SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", "600") env.setdefault("SGLANG_DISAGGREGATION_WAITING_TIMEOUT", "60") env.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1") if topology.force_rdma: env["MOONCAKE_PROTOCOL"] = "rdma" env["MC_MS_AUTO_DISC"] = "0" if topology.ib_device: env["MOONCAKE_DEVICE"] = topology.ib_device repo_root = Path(__file__).resolve().parents[2] python_paths = [ str(repo_root / "src"), str(repo_root / "third_party" / "sglang" / "python"), ] existing_pythonpath = env.get("PYTHONPATH") if existing_pythonpath: python_paths.append(existing_pythonpath) env["PYTHONPATH"] = os.pathsep.join(python_paths) return env def _wait_for_ready_endpoint(url: str, *, timeout_s: float) -> None: start = time.perf_counter() last_error: str | None = None with httpx.Client(timeout=5.0, trust_env=False) as client: while time.perf_counter() - start < timeout_s: try: response = client.get(url) if response.status_code == 200: return last_error = f"status={response.status_code}" except Exception as exc: # pragma: no cover last_error = f"{type(exc).__name__}: {exc}" time.sleep(1.0) raise TimeoutError(f"Timed out waiting for {url} ({last_error})")