from __future__ import annotations import shlex import sys from dataclasses import dataclass from agentic_pd_hybrid.topology import SingleNodeTopology, WorkerSpec @dataclass(frozen=True) class LaunchPlan: prefill_commands: tuple[tuple[str, ...], ...] decode_commands: tuple[tuple[str, ...], ...] direct_commands: tuple[tuple[str, ...], ...] router_command: tuple[str, ...] | None def render(self) -> str: sections: list[str] = [] for idx, command in enumerate(self.prefill_commands): sections.append(_render_named_command(f"prefill-{idx}", command)) for idx, command in enumerate(self.decode_commands): sections.append(_render_named_command(f"decode-{idx}", command)) for idx, command in enumerate(self.direct_commands): sections.append(_render_named_command(f"direct-{idx}", command)) if self.router_command is not None: sections.append(_render_named_command("router", self.router_command)) return "\n\n".join(sections) def build_launch_plan( topology: SingleNodeTopology, *, prefill_policy: str = "round_robin", decode_policy: str = "manual", include_router: bool = True, ) -> LaunchPlan: return LaunchPlan( prefill_commands=tuple( _build_server_command(topology, worker) for worker in topology.prefill_workers ), decode_commands=tuple( _build_server_command(topology, worker) for worker in topology.decode_workers ), direct_commands=tuple( _build_server_command(topology, worker) for worker in topology.direct_workers ), router_command=( _build_router_command( topology, prefill_policy=prefill_policy, decode_policy=decode_policy, ) if include_router and topology.prefill_workers and topology.decode_workers else None ), ) def _build_server_command( topology: SingleNodeTopology, worker: WorkerSpec, ) -> tuple[str, ...]: command = [ sys.executable, "-B", "-u", "-m", "sglang.launch_server", "--model-path", topology.model_path, "--host", worker.host, "--port", str(worker.port), "--base-gpu-id", str(worker.gpu_id), "--disaggregation-mode", _disaggregation_mode_for(worker), "--disaggregation-transfer-backend", topology.transfer_backend, ] if worker.tp_size > 1: command.extend(["--tp-size", str(worker.tp_size)]) if topology.trust_remote_code: command.append("--trust-remote-code") command.append("--enable-cache-report") if worker.bootstrap_port is not None: command.extend( ["--disaggregation-bootstrap-port", str(worker.bootstrap_port)] ) if topology.ib_device: command.extend(["--disaggregation-ib-device", topology.ib_device]) command.extend(topology.extra_server_args) if worker.role == "prefill": command.extend(topology.prefill_extra_server_args) elif worker.role == "decode": command.extend(topology.decode_extra_server_args) else: command.extend(topology.direct_extra_server_args) return tuple(command) def _build_router_command( topology: SingleNodeTopology, *, prefill_policy: str, decode_policy: str, ) -> tuple[str, ...]: command: list[str] = [ sys.executable, "-B", "-u", "-m", "agentic_pd_hybrid.pd_router", "--host", topology.router_host, "--port", str(topology.router_port), "--prefill-policy", prefill_policy, "--decode-policy", decode_policy, ] for worker in topology.prefill_workers: command.extend( ["--prefill", worker.url, str(worker.bootstrap_port or topology.router_port)] ) for worker in topology.decode_workers: command.extend(["--decode", worker.url]) return tuple(command) def _render_named_command(name: str, command: tuple[str, ...]) -> str: return f"# {name}\n" + " ".join(shlex.quote(part) for part in command) def _disaggregation_mode_for(worker: WorkerSpec) -> str: if worker.role == "direct": return "null" return worker.role