Files
agentic-pd-hybrid/src/agentic_pd_hybrid/launcher.py

141 lines
4.3 KiB
Python

from __future__ import annotations
import shlex
import sys
from dataclasses import dataclass
from agentic_pd_hybrid.topology import SingleNodeTopology, WorkerSpec
@dataclass(frozen=True)
class LaunchPlan:
prefill_commands: tuple[tuple[str, ...], ...]
decode_commands: tuple[tuple[str, ...], ...]
direct_commands: tuple[tuple[str, ...], ...]
router_command: tuple[str, ...] | None
def render(self) -> str:
sections: list[str] = []
for idx, command in enumerate(self.prefill_commands):
sections.append(_render_named_command(f"prefill-{idx}", command))
for idx, command in enumerate(self.decode_commands):
sections.append(_render_named_command(f"decode-{idx}", command))
for idx, command in enumerate(self.direct_commands):
sections.append(_render_named_command(f"direct-{idx}", command))
if self.router_command is not None:
sections.append(_render_named_command("router", self.router_command))
return "\n\n".join(sections)
def build_launch_plan(
topology: SingleNodeTopology,
*,
prefill_policy: str = "round_robin",
decode_policy: str = "manual",
include_router: bool = True,
) -> LaunchPlan:
return LaunchPlan(
prefill_commands=tuple(
_build_server_command(topology, worker) for worker in topology.prefill_workers
),
decode_commands=tuple(
_build_server_command(topology, worker) for worker in topology.decode_workers
),
direct_commands=tuple(
_build_server_command(topology, worker) for worker in topology.direct_workers
),
router_command=(
_build_router_command(
topology,
prefill_policy=prefill_policy,
decode_policy=decode_policy,
)
if include_router and topology.prefill_workers and topology.decode_workers
else None
),
)
def _build_server_command(
topology: SingleNodeTopology,
worker: WorkerSpec,
) -> tuple[str, ...]:
command = [
sys.executable,
"-B",
"-u",
"-m",
"sglang.launch_server",
"--model-path",
topology.model_path,
"--host",
worker.host,
"--port",
str(worker.port),
"--base-gpu-id",
str(worker.gpu_id),
"--disaggregation-mode",
_disaggregation_mode_for(worker),
"--disaggregation-transfer-backend",
topology.transfer_backend,
]
if worker.tp_size > 1:
command.extend(["--tp-size", str(worker.tp_size)])
if topology.trust_remote_code:
command.append("--trust-remote-code")
command.append("--enable-cache-report")
if worker.bootstrap_port is not None:
command.extend(
["--disaggregation-bootstrap-port", str(worker.bootstrap_port)]
)
if topology.ib_device:
command.extend(["--disaggregation-ib-device", topology.ib_device])
command.extend(topology.extra_server_args)
if worker.role == "prefill":
command.extend(topology.prefill_extra_server_args)
elif worker.role == "decode":
command.extend(topology.decode_extra_server_args)
else:
command.extend(topology.direct_extra_server_args)
return tuple(command)
def _build_router_command(
topology: SingleNodeTopology,
*,
prefill_policy: str,
decode_policy: str,
) -> tuple[str, ...]:
command: list[str] = [
sys.executable,
"-B",
"-u",
"-m",
"agentic_pd_hybrid.pd_router",
"--host",
topology.router_host,
"--port",
str(topology.router_port),
"--prefill-policy",
prefill_policy,
"--decode-policy",
decode_policy,
]
for worker in topology.prefill_workers:
command.extend(
["--prefill", worker.url, str(worker.bootstrap_port or topology.router_port)]
)
for worker in topology.decode_workers:
command.extend(["--decode", worker.url])
return tuple(command)
def _render_named_command(name: str, command: tuple[str, ...]) -> str:
return f"# {name}\n" + " ".join(shlex.quote(part) for part in command)
def _disaggregation_mode_for(worker: WorkerSpec) -> str:
if worker.role == "direct":
return "null"
return worker.role