141 lines
4.3 KiB
Python
141 lines
4.3 KiB
Python
from __future__ import annotations
|
|
|
|
import shlex
|
|
import sys
|
|
from dataclasses import dataclass
|
|
|
|
from agentic_pd_hybrid.topology import SingleNodeTopology, WorkerSpec
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LaunchPlan:
|
|
prefill_commands: tuple[tuple[str, ...], ...]
|
|
decode_commands: tuple[tuple[str, ...], ...]
|
|
direct_commands: tuple[tuple[str, ...], ...]
|
|
router_command: tuple[str, ...] | None
|
|
|
|
def render(self) -> str:
|
|
sections: list[str] = []
|
|
for idx, command in enumerate(self.prefill_commands):
|
|
sections.append(_render_named_command(f"prefill-{idx}", command))
|
|
for idx, command in enumerate(self.decode_commands):
|
|
sections.append(_render_named_command(f"decode-{idx}", command))
|
|
for idx, command in enumerate(self.direct_commands):
|
|
sections.append(_render_named_command(f"direct-{idx}", command))
|
|
if self.router_command is not None:
|
|
sections.append(_render_named_command("router", self.router_command))
|
|
return "\n\n".join(sections)
|
|
|
|
|
|
def build_launch_plan(
|
|
topology: SingleNodeTopology,
|
|
*,
|
|
prefill_policy: str = "round_robin",
|
|
decode_policy: str = "manual",
|
|
include_router: bool = True,
|
|
) -> LaunchPlan:
|
|
return LaunchPlan(
|
|
prefill_commands=tuple(
|
|
_build_server_command(topology, worker) for worker in topology.prefill_workers
|
|
),
|
|
decode_commands=tuple(
|
|
_build_server_command(topology, worker) for worker in topology.decode_workers
|
|
),
|
|
direct_commands=tuple(
|
|
_build_server_command(topology, worker) for worker in topology.direct_workers
|
|
),
|
|
router_command=(
|
|
_build_router_command(
|
|
topology,
|
|
prefill_policy=prefill_policy,
|
|
decode_policy=decode_policy,
|
|
)
|
|
if include_router and topology.prefill_workers and topology.decode_workers
|
|
else None
|
|
),
|
|
)
|
|
|
|
|
|
def _build_server_command(
|
|
topology: SingleNodeTopology,
|
|
worker: WorkerSpec,
|
|
) -> tuple[str, ...]:
|
|
command = [
|
|
sys.executable,
|
|
"-B",
|
|
"-u",
|
|
"-m",
|
|
"sglang.launch_server",
|
|
"--model-path",
|
|
topology.model_path,
|
|
"--host",
|
|
worker.host,
|
|
"--port",
|
|
str(worker.port),
|
|
"--base-gpu-id",
|
|
str(worker.gpu_id),
|
|
"--disaggregation-mode",
|
|
_disaggregation_mode_for(worker),
|
|
"--disaggregation-transfer-backend",
|
|
topology.transfer_backend,
|
|
]
|
|
if worker.tp_size > 1:
|
|
command.extend(["--tp-size", str(worker.tp_size)])
|
|
if topology.trust_remote_code:
|
|
command.append("--trust-remote-code")
|
|
command.append("--enable-cache-report")
|
|
if worker.bootstrap_port is not None:
|
|
command.extend(
|
|
["--disaggregation-bootstrap-port", str(worker.bootstrap_port)]
|
|
)
|
|
if topology.ib_device:
|
|
command.extend(["--disaggregation-ib-device", topology.ib_device])
|
|
command.extend(topology.extra_server_args)
|
|
if worker.role == "prefill":
|
|
command.extend(topology.prefill_extra_server_args)
|
|
elif worker.role == "decode":
|
|
command.extend(topology.decode_extra_server_args)
|
|
else:
|
|
command.extend(topology.direct_extra_server_args)
|
|
return tuple(command)
|
|
|
|
|
|
def _build_router_command(
|
|
topology: SingleNodeTopology,
|
|
*,
|
|
prefill_policy: str,
|
|
decode_policy: str,
|
|
) -> tuple[str, ...]:
|
|
command: list[str] = [
|
|
sys.executable,
|
|
"-B",
|
|
"-u",
|
|
"-m",
|
|
"agentic_pd_hybrid.pd_router",
|
|
"--host",
|
|
topology.router_host,
|
|
"--port",
|
|
str(topology.router_port),
|
|
"--prefill-policy",
|
|
prefill_policy,
|
|
"--decode-policy",
|
|
decode_policy,
|
|
]
|
|
for worker in topology.prefill_workers:
|
|
command.extend(
|
|
["--prefill", worker.url, str(worker.bootstrap_port or topology.router_port)]
|
|
)
|
|
for worker in topology.decode_workers:
|
|
command.extend(["--decode", worker.url])
|
|
return tuple(command)
|
|
|
|
|
|
def _render_named_command(name: str, command: tuple[str, ...]) -> str:
|
|
return f"# {name}\n" + " ".join(shlex.quote(part) for part in command)
|
|
|
|
|
|
def _disaggregation_mode_for(worker: WorkerSpec) -> str:
|
|
if worker.role == "direct":
|
|
return "null"
|
|
return worker.role
|