Files
aituner/infra/gpu_fleet/gpu_fleet.py
2026-04-04 21:26:37 +08:00

1133 lines
38 KiB
Python
Executable File

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import csv
import fcntl
import hashlib
import json
import shlex
import subprocess
import sys
import textwrap
import time
import tomllib
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
SCRIPT_PATH = Path(__file__).resolve()
REPO_ROOT = SCRIPT_PATH.parents[2]
DEFAULT_CONFIG = SCRIPT_PATH.parent / "config" / "fleet.toml"
DEFAULT_STATE_REL = Path(".aituner/gpu_fleet/state")
DEFAULT_ARTIFACT_REL = Path(".aituner/gpu_fleet/artifacts")
PROBE_CACHE_NAME = "probe_cache.json"
QUEUE_STATE_NAME = "jobs_state.json"
MONITOR_LOCK_NAME = "monitor.lock"
class FleetError(RuntimeError):
pass
@dataclass
class HostSpec:
name: str
ssh_alias: str
enabled: bool = True
sync_remote_path: str = "~/workspace/aituner"
fleet_root: str = "~/.aituner_gpu_fleet"
@dataclass
class SyncSpec:
mode: str = "rsync"
local_path: Path = REPO_ROOT
exclude: list[str] = field(default_factory=list)
@dataclass
class SchedulerSpec:
gpu_free_memory_mb: int = 1024
gpu_free_utilization_pct: int = 10
prefer_pack: bool = True
@dataclass
class FleetConfig:
config_path: Path
project_root: Path
state_dir: Path
artifacts_dir: Path
ssh_timeout_sec: int
scheduler: SchedulerSpec
sync: SyncSpec
hosts: dict[str, HostSpec]
@dataclass
class JobSpec:
name: str
command: str
gpus: int
gpu_model: str | None = None
hosts: list[str] = field(default_factory=list)
artifacts: list[str] = field(default_factory=list)
env: dict[str, str] = field(default_factory=dict)
def utc_now() -> str:
return datetime.now(timezone.utc).isoformat(timespec="seconds")
def relative_to_root(root: Path, value: str | None, default: Path) -> Path:
if not value:
return default
candidate = Path(value).expanduser()
if candidate.is_absolute():
return candidate
return (root / candidate).resolve()
def load_toml(path: Path) -> dict[str, Any]:
with path.open("rb") as fh:
return tomllib.load(fh)
def ensure_dir(path: Path) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path
def dump_json(path: Path, payload: Any) -> None:
ensure_dir(path.parent)
path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
def load_json(path: Path) -> Any:
return json.loads(path.read_text(encoding="utf-8"))
def render_command(template: str, **kwargs: str) -> str:
try:
return template.format(**kwargs)
except KeyError as exc:
raise FleetError(f"template is missing placeholder {exc!s}: {template}") from exc
def run_local(
argv: list[str],
*,
cwd: Path | None = None,
capture_output: bool = True,
check: bool = True,
) -> subprocess.CompletedProcess[str]:
completed = subprocess.run(
argv,
cwd=str(cwd) if cwd else None,
check=False,
text=True,
capture_output=capture_output,
)
if check and completed.returncode != 0:
stderr = completed.stderr.strip()
stdout = completed.stdout.strip()
detail = stderr or stdout or f"exit code {completed.returncode}"
raise FleetError(f"command failed: {' '.join(shlex.quote(x) for x in argv)}\n{detail}")
return completed
def ssh_base(host: HostSpec, timeout_sec: int) -> list[str]:
return [
"ssh",
"-o",
f"ConnectTimeout={timeout_sec}",
host.ssh_alias,
]
def run_ssh(
config: FleetConfig,
host: HostSpec,
remote_command: str,
*,
capture_output: bool = True,
check: bool = True,
) -> subprocess.CompletedProcess[str]:
argv = ssh_base(host, config.ssh_timeout_sec) + [remote_command]
return run_local(argv, cwd=config.project_root, capture_output=capture_output, check=check)
def rsync_push(config: FleetConfig, host: HostSpec) -> None:
ensure_remote_dir(config, host, host.sync_remote_path)
argv = ["rsync", "-az"]
for pattern in config.sync.exclude:
argv.extend(["--exclude", pattern])
local_src = str(config.sync.local_path.resolve()) + "/"
remote_dst = f"{host.ssh_alias}:{host.sync_remote_path.rstrip('/')}/"
argv.extend([local_src, remote_dst])
run_local(argv, cwd=config.project_root, capture_output=True, check=True)
def rsync_pull(config: FleetConfig, host: HostSpec, remote_path: str, local_path: Path) -> None:
ensure_dir(local_path.parent)
argv = [
"rsync",
"-az",
f"{host.ssh_alias}:{remote_path}",
str(local_path),
]
run_local(argv, cwd=config.project_root, capture_output=True, check=True)
def ensure_remote_dir(config: FleetConfig, host: HostSpec, remote_path: str) -> None:
run_ssh(config, host, f"mkdir -p {shlex.quote(remote_path)}", capture_output=True, check=True)
def load_config(path: Path) -> FleetConfig:
raw = load_toml(path)
if int(raw.get("version", 1)) != 1:
raise FleetError("only config version 1 is supported")
project_root = REPO_ROOT
paths_raw = raw.get("paths", {})
state_dir = relative_to_root(project_root, paths_raw.get("state_dir"), project_root / DEFAULT_STATE_REL)
artifacts_dir = relative_to_root(
project_root,
paths_raw.get("artifacts_dir"),
project_root / DEFAULT_ARTIFACT_REL,
)
sync_raw = raw.get("sync", {})
sync = SyncSpec(
mode=str(sync_raw.get("mode", "rsync")),
local_path=relative_to_root(project_root, sync_raw.get("local_path"), project_root),
exclude=[str(item) for item in sync_raw.get("exclude", [])],
)
if sync.mode != "rsync":
raise FleetError(f"unsupported sync.mode: {sync.mode}")
scheduler_raw = raw.get("scheduler", {})
scheduler = SchedulerSpec(
gpu_free_memory_mb=int(scheduler_raw.get("gpu_free_memory_mb", 1024)),
gpu_free_utilization_pct=int(scheduler_raw.get("gpu_free_utilization_pct", 10)),
prefer_pack=bool(scheduler_raw.get("prefer_pack", True)),
)
ssh_raw = raw.get("ssh", {})
hosts: dict[str, HostSpec] = {}
for raw_host in raw.get("hosts", []):
name = str(raw_host.get("name", "")).strip()
ssh_alias = str(raw_host.get("ssh_alias", "")).strip()
if not name or not ssh_alias:
raise FleetError("every host must define both name and ssh_alias")
host = HostSpec(
name=name,
ssh_alias=ssh_alias,
enabled=bool(raw_host.get("enabled", True)),
sync_remote_path=str(raw_host.get("sync_remote_path", "~/workspace/aituner")),
fleet_root=str(raw_host.get("fleet_root", "~/.aituner_gpu_fleet")),
)
hosts[name] = host
if not hosts:
raise FleetError("no hosts are configured")
return FleetConfig(
config_path=path,
project_root=project_root,
state_dir=state_dir,
artifacts_dir=artifacts_dir,
ssh_timeout_sec=int(ssh_raw.get("connect_timeout_sec", 10)),
scheduler=scheduler,
sync=sync,
hosts=hosts,
)
def load_jobs(path: Path) -> list[JobSpec]:
raw = load_toml(path)
if int(raw.get("version", 1)) != 1:
raise FleetError("only jobs version 1 is supported")
jobs: list[JobSpec] = []
seen_names: set[str] = set()
for item in raw.get("jobs", []):
name = str(item.get("name", "")).strip()
if not name:
raise FleetError("every job must define name")
if name in seen_names:
raise FleetError(f"duplicated job name: {name}")
command = str(item.get("command", "")).strip()
if not command:
raise FleetError(f"job {name} must define command")
gpus = int(item.get("gpus", 0))
if gpus <= 0:
raise FleetError(f"job {name} must request a positive gpu count")
jobs.append(
JobSpec(
name=name,
command=command,
gpus=gpus,
gpu_model=str(item["gpu_model"]) if "gpu_model" in item else None,
hosts=[str(entry) for entry in item.get("hosts", [])],
artifacts=[str(entry) for entry in item.get("artifacts", [])],
env={str(k): str(v) for k, v in item.get("env", {}).items()},
)
)
seen_names.add(name)
if not jobs:
raise FleetError("no jobs are defined")
return jobs
def validate_jobs_against_config(config: FleetConfig, jobs: list[JobSpec]) -> None:
for job in jobs:
if job.hosts:
unknown_hosts = sorted(
requested for requested in job.hosts if requested not in config.hosts and requested not in {host.ssh_alias for host in config.hosts.values()}
)
if unknown_hosts:
raise FleetError(f"job {job.name} references unknown hosts: {', '.join(unknown_hosts)}")
candidate_hosts = [
host for host in config.hosts.values()
if matches_host_constraint(host, job.hosts)
]
if not candidate_hosts:
raise FleetError(f"job {job.name} has no candidate hosts after host filtering")
def sync_jobs_to_queue_state(queue_state: dict[str, Any], jobs: list[JobSpec]) -> None:
now = utc_now()
state_jobs = queue_state.setdefault("jobs", {})
for job in jobs:
signature = job_signature(job)
record = state_jobs.get(job.name)
payload = {
"name": job.name,
"signature": signature,
"command": job.command,
"gpus": job.gpus,
"gpu_model": job.gpu_model,
"hosts": job.hosts,
"artifacts": job.artifacts,
"env": job.env,
"last_seen_at": now,
}
if record is None:
state_jobs[job.name] = {
**payload,
"status": "pending",
"created_at": now,
"submitted_at": None,
"completed_at": None,
"run_id": None,
"attempts": 0,
}
continue
if record.get("signature") != signature and record.get("status") != "pending":
raise FleetError(f"job {job.name} changed after it entered the queue; use a new job name instead")
record.update(payload)
if record.get("status") == "unknown":
record["status"] = "pending"
def pending_jobs_from_queue(queue_state: dict[str, Any], jobs: list[JobSpec]) -> list[JobSpec]:
state_jobs = queue_state.get("jobs", {})
pending: list[JobSpec] = []
for job in jobs:
record = state_jobs.get(job.name, {})
if record.get("status", "pending") == "pending":
pending.append(job)
return pending
def state_subdir(config: FleetConfig, *parts: str) -> Path:
return ensure_dir(config.state_dir.joinpath(*parts))
def run_state_dir(config: FleetConfig) -> Path:
return state_subdir(config, "runs")
def probe_cache_path(config: FleetConfig) -> Path:
return state_subdir(config, "probe") / PROBE_CACHE_NAME
def queue_state_path(config: FleetConfig) -> Path:
return state_subdir(config, "queue") / QUEUE_STATE_NAME
def monitor_lock_path(config: FleetConfig) -> Path:
return state_subdir(config, "queue") / MONITOR_LOCK_NAME
def load_queue_state(config: FleetConfig) -> dict[str, Any]:
path = queue_state_path(config)
if path.exists():
payload = load_json(path)
payload.setdefault("jobs", {})
return payload
return {
"generated_at": utc_now(),
"updated_at": utc_now(),
"jobs": {},
}
def save_queue_state(config: FleetConfig, payload: dict[str, Any]) -> None:
payload["updated_at"] = utc_now()
dump_json(queue_state_path(config), payload)
def job_signature(job: JobSpec) -> str:
payload = {
"name": job.name,
"command": job.command,
"gpus": job.gpus,
"gpu_model": job.gpu_model,
"hosts": job.hosts,
"artifacts": job.artifacts,
"env": job.env,
}
raw = json.dumps(payload, sort_keys=True, ensure_ascii=True)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def parse_gpu_query(output: str) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
reader = csv.reader(line for line in output.splitlines() if line.strip())
for row in reader:
if len(row) != 4:
continue
index, name, memory_total, memory_used = [cell.strip() for cell in row]
rows.append(
{
"index": int(index),
"name": name,
"memory_total_mb": int(memory_total),
"memory_used_mb": int(memory_used),
}
)
return rows
def parse_gpu_status(output: str) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
reader = csv.reader(line for line in output.splitlines() if line.strip())
for row in reader:
if len(row) != 5:
continue
index, name, memory_total, memory_used, util = [cell.strip() for cell in row]
rows.append(
{
"index": int(index),
"name": name,
"memory_total_mb": int(memory_total),
"memory_used_mb": int(memory_used),
"utilization_gpu_pct": int(util),
}
)
return rows
def probe_single_host(config: FleetConfig, host: HostSpec) -> dict[str, Any]:
gpu_cmd = textwrap.dedent(
"""
bash -lc '
set -euo pipefail
nvidia-smi --query-gpu=index,name,memory.total,memory.used --format=csv,noheader,nounits
'
"""
).strip()
completed = run_ssh(config, host, gpu_cmd, capture_output=True, check=False)
payload: dict[str, Any] = {
"host": host.name,
"ssh_alias": host.ssh_alias,
"enabled": host.enabled,
"probed_at": utc_now(),
}
if completed.returncode != 0:
payload["status"] = "unreachable"
payload["error"] = (completed.stderr or completed.stdout).strip()
return payload
gpus = parse_gpu_query(completed.stdout)
names = sorted({gpu["name"] for gpu in gpus})
payload["status"] = "ok"
payload["gpus"] = gpus
payload["gpu_count"] = len(gpus)
payload["gpu_models"] = names
payload["gpu_summary"] = ", ".join(f"{sum(1 for gpu in gpus if gpu['name'] == name)}x {name}" for name in names)
return payload
def load_probe_cache(config: FleetConfig) -> dict[str, Any]:
path = probe_cache_path(config)
if not path.exists():
raise FleetError(
f"probe cache is missing: {path}. Run 'python3 infra/gpu_fleet/gpu_fleet.py probe --config {config.config_path}' first."
)
return load_json(path)
def query_host_gpu_status(config: FleetConfig, host: HostSpec) -> list[dict[str, Any]]:
cmd = textwrap.dedent(
"""
bash -lc '
set -euo pipefail
nvidia-smi --query-gpu=index,name,memory.total,memory.used,utilization.gpu --format=csv,noheader,nounits
'
"""
).strip()
completed = run_ssh(config, host, cmd, capture_output=True, check=True)
return parse_gpu_status(completed.stdout)
def free_gpu_indices(config: FleetConfig, host: HostSpec) -> list[int]:
status = query_host_gpu_status(config, host)
free: list[int] = []
for gpu in status:
if gpu["memory_used_mb"] <= config.scheduler.gpu_free_memory_mb and gpu["utilization_gpu_pct"] <= config.scheduler.gpu_free_utilization_pct:
free.append(gpu["index"])
return free
def matches_gpu_model(host_probe: dict[str, Any], host: HostSpec, requested: str | None) -> bool:
if not requested:
return True
needle = requested.lower()
for name in host_probe.get("gpu_models", []):
if needle in name.lower():
return True
summary = str(host_probe.get("gpu_summary", ""))
return needle in summary.lower()
def matches_host_constraint(host: HostSpec, requested_hosts: list[str]) -> bool:
if not requested_hosts:
return True
requested = {item.strip() for item in requested_hosts if item.strip()}
return host.name in requested or host.ssh_alias in requested
def build_payload(job: JobSpec, host: HostSpec, run_id: str, remote_run_dir: str, gpu_ids: list[int]) -> str:
env_exports = "\n".join(
f"export {key}={shlex.quote(value)}"
for key, value in sorted(job.env.items())
)
log_file = f"{remote_run_dir}/stdout.log"
exit_file = f"{remote_run_dir}/exit_code"
started_file = f"{remote_run_dir}/started_at"
finished_file = f"{remote_run_dir}/finished_at"
payload = textwrap.dedent(
f"""
set -euo pipefail
mkdir -p {shlex.quote(remote_run_dir)}
exec > >(tee -a {shlex.quote(log_file)}) 2>&1
printf '%s\\n' "$(date -Iseconds)" > {shlex.quote(started_file)}
on_exit() {{
status=$?
printf '%s\\n' "$status" > {shlex.quote(exit_file)}
printf '%s\\n' "$(date -Iseconds)" > {shlex.quote(finished_file)}
}}
trap on_exit EXIT
cd {shlex.quote(host.sync_remote_path)}
export CUDA_VISIBLE_DEVICES={shlex.quote(','.join(str(idx) for idx in gpu_ids))}
export AITUNER_RUN_ID={shlex.quote(run_id)}
export AITUNER_REMOTE_RUN_DIR={shlex.quote(remote_run_dir)}
export AITUNER_GPU_COUNT={shlex.quote(str(job.gpus))}
{env_exports}
{job.command}
"""
).strip()
return payload
def write_run_manifest(config: FleetConfig, manifest: dict[str, Any]) -> Path:
path = run_state_dir(config) / f"{manifest['run_id']}.json"
dump_json(path, manifest)
return path
def run_manifest_path(config: FleetConfig, run_id: str) -> Path:
return run_state_dir(config) / f"{run_id}.json"
def load_run_manifests(config: FleetConfig) -> list[dict[str, Any]]:
runs_dir = run_state_dir(config)
manifests: list[dict[str, Any]] = []
for path in sorted(runs_dir.glob("*.json")):
manifests.append(load_json(path))
return manifests
def launch_job(config: FleetConfig, host: HostSpec, job: JobSpec, gpu_ids: list[int]) -> dict[str, Any]:
sync_remote_path = host.sync_remote_path.rstrip("/")
fleet_root = host.fleet_root.rstrip("/")
run_id = f"{job.name}-{datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%fZ')}"
session_name = f"aituner_{run_id.replace('.', '_')}"
remote_run_dir = f"{fleet_root}/runs/{run_id}"
payload = build_payload(job, host, run_id, remote_run_dir, gpu_ids)
remote_launch = (
f"mkdir -p {shlex.quote(remote_run_dir)} "
f"&& tmux new-session -d -s {shlex.quote(session_name)} bash -lc {shlex.quote(payload)}"
)
run_ssh(config, host, remote_launch, capture_output=True, check=True)
manifest = {
"run_id": run_id,
"job_name": job.name,
"host": host.name,
"ssh_alias": host.ssh_alias,
"session_name": session_name,
"gpu_ids": gpu_ids,
"gpu_model": job.gpu_model,
"gpus": job.gpus,
"command": job.command,
"env": job.env,
"artifacts": job.artifacts,
"remote_sync_path": sync_remote_path,
"remote_run_dir": remote_run_dir,
"status": "running",
"submitted_at": utc_now(),
"harvested_at": None,
}
write_run_manifest(config, manifest)
return manifest
def refresh_manifest_status(config: FleetConfig, manifest: dict[str, Any]) -> dict[str, Any]:
host = config.hosts[manifest["host"]]
remote_cmd = textwrap.dedent(
f"""
bash -lc '
if [ -f {shlex.quote(manifest["remote_run_dir"] + "/exit_code")} ]; then
code=$(cat {shlex.quote(manifest["remote_run_dir"] + "/exit_code")})
printf "done:%s\\n" "$code"
elif tmux has-session -t {shlex.quote(manifest["session_name"])} 2>/dev/null; then
printf "running\\n"
else
printf "unknown\\n"
fi
'
"""
).strip()
completed = run_ssh(config, host, remote_cmd, capture_output=True, check=False)
status_line = (completed.stdout or completed.stderr).strip()
new_manifest = dict(manifest)
if status_line.startswith("done:"):
code = int(status_line.split(":", 1)[1])
new_manifest["status"] = "completed" if code == 0 else "failed"
new_manifest["exit_code"] = code
elif status_line == "running":
new_manifest["status"] = "running"
else:
new_manifest["status"] = "unknown"
write_run_manifest(config, new_manifest)
return new_manifest
def harvest_run(config: FleetConfig, manifest: dict[str, Any]) -> dict[str, Any]:
host = config.hosts[manifest["host"]]
refreshed = refresh_manifest_status(config, manifest)
if refreshed["status"] not in {"completed", "failed"}:
return refreshed
local_base = ensure_dir(config.artifacts_dir / refreshed["run_id"])
remote_run_dir = refreshed["remote_run_dir"].rstrip("/")
rsync_pull(config, host, f"{remote_run_dir}/", local_base / "remote_run")
for artifact in refreshed.get("artifacts", []):
artifact_remote = f"{refreshed['remote_sync_path'].rstrip('/')}/{artifact}"
target = local_base / "artifacts" / Path(artifact.lstrip("/"))
check = run_ssh(
config,
host,
f"test -e {shlex.quote(artifact_remote)}",
capture_output=True,
check=False,
)
if check.returncode == 0:
rsync_pull(config, host, artifact_remote, target)
refreshed["harvested_at"] = utc_now()
write_run_manifest(config, refreshed)
return refreshed
def reconcile_queue_state_with_runs(config: FleetConfig, queue_state: dict[str, Any]) -> None:
state_jobs = queue_state.get("jobs", {})
for name, record in state_jobs.items():
if record.get("status") != "running":
continue
run_id = record.get("run_id")
if not run_id:
record["status"] = "unknown"
record["updated_at"] = utc_now()
continue
manifest_path = run_manifest_path(config, run_id)
if not manifest_path.exists():
record["status"] = "unknown"
record["updated_at"] = utc_now()
continue
manifest = load_json(manifest_path)
refreshed = refresh_manifest_status(config, manifest)
record["updated_at"] = utc_now()
if refreshed["status"] == "running":
continue
if refreshed["status"] in {"completed", "failed"}:
harvested = harvest_run(config, refreshed)
record["status"] = harvested["status"]
record["completed_at"] = utc_now()
record["exit_code"] = harvested.get("exit_code")
record["harvested_at"] = harvested.get("harvested_at")
continue
record["status"] = "unknown"
def print_probe_table(probe: dict[str, Any]) -> None:
for name, payload in sorted(probe["hosts"].items()):
if payload["status"] != "ok":
print(f"{name:12} status={payload['status']} error={payload.get('error', '')}")
continue
print(f"{name:12} status=ok gpus={payload['gpu_summary']}")
class MonitorLock:
def __init__(self, path: Path):
self.path = path
self.handle: Any | None = None
def __enter__(self) -> "MonitorLock":
ensure_dir(self.path.parent)
self.handle = self.path.open("w", encoding="utf-8")
try:
fcntl.flock(self.handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
except BlockingIOError as exc:
raise FleetError(f"monitor lock is busy: {self.path}") from exc
return self
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
if self.handle is not None:
fcntl.flock(self.handle.fileno(), fcntl.LOCK_UN)
self.handle.close()
def cmd_validate(args: argparse.Namespace) -> int:
config = load_config(Path(args.config).resolve())
print(f"config ok: {config.config_path}")
if args.jobs:
jobs = load_jobs(Path(args.jobs).resolve())
validate_jobs_against_config(config, jobs)
print(f"jobs ok: {len(jobs)} definitions")
return 0
def cmd_bootstrap_hosts(args: argparse.Namespace) -> int:
aliases_path = Path(args.aliases_file).expanduser().resolve()
aliases = [
line.strip()
for line in aliases_path.read_text(encoding="utf-8").splitlines()
if line.strip() and not line.lstrip().startswith("#")
]
if not aliases:
raise FleetError(f"no aliases found in {aliases_path}")
sections = [
'version = 1',
'',
'[paths]',
'state_dir = ".aituner/gpu_fleet/state"',
'artifacts_dir = ".aituner/gpu_fleet/artifacts"',
'',
'[sync]',
'mode = "rsync"',
'local_path = "."',
'',
]
for alias in aliases:
sections.extend(
[
'[[hosts]]',
f'name = "{alias}"',
f'ssh_alias = "{alias}"',
'enabled = true',
'sync_remote_path = "~/workspace/aituner"',
'fleet_root = "~/.aituner_gpu_fleet"',
'',
]
)
payload = "\n".join(sections)
if args.output:
output = Path(args.output).expanduser().resolve()
ensure_dir(output.parent)
output.write_text(payload + "\n", encoding="utf-8")
print(f"wrote skeleton config: {output}")
else:
print(payload)
return 0
def cmd_probe(args: argparse.Namespace) -> int:
config = load_config(Path(args.config).resolve())
results = {
"generated_at": utc_now(),
"config_path": str(config.config_path),
"hosts": {},
}
for host in config.hosts.values():
if args.host and host.name != args.host:
continue
if not host.enabled and not args.include_disabled:
continue
results["hosts"][host.name] = probe_single_host(config, host)
dump_json(probe_cache_path(config), results)
print_probe_table(results)
return 0
def snapshot_free_gpus(config: FleetConfig, probe: dict[str, Any]) -> dict[str, list[int]]:
availability: dict[str, list[int]] = {}
for host in config.hosts.values():
host_probe = probe["hosts"].get(host.name)
if not host.enabled or not host_probe or host_probe.get("status") != "ok":
continue
availability[host.name] = free_gpu_indices(config, host)
return availability
def select_host_for_job(
config: FleetConfig,
probe: dict[str, Any],
availability: dict[str, list[int]],
job: JobSpec,
) -> tuple[HostSpec, list[int]] | None:
candidates: list[tuple[int, int, str, HostSpec, list[int]]] = []
for host in config.hosts.values():
if not host.enabled:
continue
host_probe = probe["hosts"].get(host.name)
if not host_probe or host_probe.get("status") != "ok":
continue
if not matches_host_constraint(host, job.hosts):
continue
if not matches_gpu_model(host_probe, host, job.gpu_model):
continue
free = availability.get(host.name, [])
if len(free) < job.gpus:
continue
chosen = sorted(free)[: job.gpus]
leftover = len(free) - job.gpus
exact_model = 0 if job.gpu_model and matches_gpu_model(host_probe, host, job.gpu_model) else 1
candidates.append((leftover if config.scheduler.prefer_pack else -len(free), exact_model, host.name, host, chosen))
if not candidates:
return None
candidates.sort(key=lambda item: (item[0], item[1], item[2]))
_, _, _, host, gpu_ids = candidates[0]
return host, gpu_ids
def dispatch_jobs(
config: FleetConfig,
jobs: list[JobSpec],
*,
dry_run: bool,
max_launch: int | None = None,
) -> tuple[list[dict[str, Any]], list[str]]:
probe = load_probe_cache(config)
availability = snapshot_free_gpus(config, probe)
manifests: list[dict[str, Any]] = []
skipped: list[str] = []
synced_hosts: set[str] = set()
launch_budget = max_launch if max_launch is not None else len(jobs)
for job in jobs:
if len(manifests) >= launch_budget:
skipped.append(f"{job.name}: launch budget exhausted")
continue
selected = select_host_for_job(config, probe, availability, job)
if not selected:
skipped.append(f"{job.name}: no host has matching host/model/free GPUs")
continue
host, gpu_ids = selected
availability[host.name] = [gpu for gpu in availability.get(host.name, []) if gpu not in set(gpu_ids)]
if dry_run:
manifests.append(
{
"job_name": job.name,
"host": host.name,
"gpu_ids": gpu_ids,
"dry_run": True,
}
)
continue
if host.name not in synced_hosts:
rsync_push(config, host)
synced_hosts.add(host.name)
manifest = launch_job(config, host, job, gpu_ids)
manifests.append(manifest)
return manifests, skipped
def cmd_dispatch(args: argparse.Namespace) -> int:
config = load_config(Path(args.config).resolve())
jobs = load_jobs(Path(args.jobs).resolve())
validate_jobs_against_config(config, jobs)
manifests, skipped = dispatch_jobs(
config,
jobs,
dry_run=args.dry_run,
max_launch=args.max_launch,
)
for manifest in manifests:
if manifest.get("dry_run"):
print(f"[dry-run] {manifest['job_name']} -> {manifest['host']} gpus={manifest['gpu_ids']}")
continue
print(
f"launched {manifest['run_id']} host={manifest['host']} gpus={manifest['gpu_ids']}"
)
for line in skipped:
print(f"skipped {line}")
return 0
def print_queue_summary(queue_state: dict[str, Any]) -> None:
counts: dict[str, int] = {}
for record in queue_state.get("jobs", {}).values():
status = str(record.get("status", "pending"))
counts[status] = counts.get(status, 0) + 1
summary = " ".join(f"{key}={counts[key]}" for key in sorted(counts))
print(summary or "pending=0")
def print_queue_jobs(queue_state: dict[str, Any]) -> None:
for name, record in sorted(queue_state.get("jobs", {}).items()):
host = record.get("host") or "-"
run_id = record.get("run_id") or "-"
print(
f"{name:32} status={record.get('status','pending'):10} gpus={record.get('gpus','-')} host={host:8} run_id={run_id}"
)
def monitor_cycle(
config: FleetConfig,
jobs_path: Path,
*,
dry_run: bool,
max_launch_per_cycle: int | None,
verbose: bool,
) -> tuple[int, int]:
queue_state = load_queue_state(config)
reconcile_queue_state_with_runs(config, queue_state)
jobs: list[JobSpec] = []
try:
jobs = load_jobs(jobs_path)
validate_jobs_against_config(config, jobs)
sync_jobs_to_queue_state(queue_state, jobs)
except Exception as exc:
if verbose:
print(f"queue read warning: {exc}")
save_queue_state(config, queue_state)
return 0, 0
pending = pending_jobs_from_queue(queue_state, jobs)
try:
manifests, skipped = dispatch_jobs(
config,
pending,
dry_run=dry_run,
max_launch=max_launch_per_cycle,
)
except Exception as exc:
if verbose:
print(f"dispatch warning: {exc}")
save_queue_state(config, queue_state)
return 0, 0
now = utc_now()
for manifest in manifests:
record = queue_state["jobs"][manifest["job_name"]]
if manifest.get("dry_run"):
continue
record["status"] = "running"
record["submitted_at"] = now
record["updated_at"] = now
record["run_id"] = manifest["run_id"]
record["host"] = manifest["host"]
record["gpu_ids"] = manifest["gpu_ids"]
record["attempts"] = int(record.get("attempts", 0)) + 1
save_queue_state(config, queue_state)
if verbose:
for manifest in manifests:
if manifest.get("dry_run"):
print(f"[dry-run] {manifest['job_name']} -> {manifest['host']} gpus={manifest['gpu_ids']}")
else:
print(f"launched {manifest['run_id']} host={manifest['host']} gpus={manifest['gpu_ids']}")
for line in skipped:
print(f"deferred {line}")
print_queue_summary(queue_state)
return len(manifests), len(skipped)
def cmd_monitor(args: argparse.Namespace) -> int:
config = load_config(Path(args.config).resolve())
jobs_path = Path(args.jobs).resolve()
with MonitorLock(monitor_lock_path(config)):
if args.once:
monitor_cycle(
config,
jobs_path,
dry_run=args.dry_run,
max_launch_per_cycle=args.max_launch_per_cycle,
verbose=True,
)
return 0
print(f"monitoring {jobs_path} every {args.interval_sec}s")
try:
while True:
monitor_cycle(
config,
jobs_path,
dry_run=args.dry_run,
max_launch_per_cycle=args.max_launch_per_cycle,
verbose=not args.quiet,
)
time.sleep(args.interval_sec)
except KeyboardInterrupt:
print("monitor stopped")
return 0
def cmd_queue_status(args: argparse.Namespace) -> int:
config = load_config(Path(args.config).resolve())
queue_state = load_queue_state(config)
print_queue_summary(queue_state)
print_queue_jobs(queue_state)
return 0
def cmd_status(args: argparse.Namespace) -> int:
config = load_config(Path(args.config).resolve())
manifests = load_run_manifests(config)
if not manifests:
print("no runs recorded")
return 0
for manifest in manifests:
refreshed = refresh_manifest_status(config, manifest)
exit_code = refreshed.get("exit_code")
suffix = f" exit_code={exit_code}" if exit_code is not None else ""
print(
f"{refreshed['run_id']} job={refreshed['job_name']} host={refreshed['host']} status={refreshed['status']}{suffix}"
)
return 0
def cmd_harvest(args: argparse.Namespace) -> int:
config = load_config(Path(args.config).resolve())
manifests = load_run_manifests(config)
if args.run_id:
manifests = [manifest for manifest in manifests if manifest["run_id"] == args.run_id]
harvested_any = False
for manifest in manifests:
refreshed = harvest_run(config, manifest)
if refreshed.get("harvested_at"):
harvested_any = True
print(
f"harvested {refreshed['run_id']} status={refreshed['status']} into {config.artifacts_dir / refreshed['run_id']}"
)
else:
print(f"pending {refreshed['run_id']} status={refreshed['status']}")
if not harvested_any and not manifests:
print("no matching runs")
return 0
def cmd_show_hosts(args: argparse.Namespace) -> int:
config = load_config(Path(args.config).resolve())
probe = load_probe_cache(config)
for host in config.hosts.values():
host_probe = probe["hosts"].get(host.name, {})
summary = host_probe.get("gpu_summary", "unprobed")
print(f"{host.name:12} alias={host.ssh_alias:12} gpus={summary:20}")
return 0
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Dispatch heterogeneous GPU jobs over SSH/tmux.")
subparsers = parser.add_subparsers(dest="command", required=True)
validate = subparsers.add_parser("validate", help="Validate fleet and optional job configs.")
validate.add_argument("--config", default=str(DEFAULT_CONFIG))
validate.add_argument("--jobs")
validate.set_defaults(func=cmd_validate)
bootstrap = subparsers.add_parser("bootstrap-hosts", help="Generate a fleet config skeleton from a plain alias list.")
bootstrap.add_argument("--aliases-file", required=True)
bootstrap.add_argument("--output")
bootstrap.set_defaults(func=cmd_bootstrap_hosts)
probe = subparsers.add_parser("probe", help="Probe remote hosts for GPU information.")
probe.add_argument("--config", default=str(DEFAULT_CONFIG))
probe.add_argument("--host")
probe.add_argument("--include-disabled", action="store_true")
probe.set_defaults(func=cmd_probe)
dispatch = subparsers.add_parser("dispatch", help="Dispatch a batch of jobs to matching hosts.")
dispatch.add_argument("--config", default=str(DEFAULT_CONFIG))
dispatch.add_argument("--jobs", required=True)
dispatch.add_argument("--dry-run", action="store_true")
dispatch.add_argument("--max-launch", type=int)
dispatch.set_defaults(func=cmd_dispatch)
monitor = subparsers.add_parser("monitor", help="Continuously consume jobs.toml as a local queue.")
monitor.add_argument("--config", default=str(DEFAULT_CONFIG))
monitor.add_argument("--jobs", required=True)
monitor.add_argument("--interval-sec", type=int, default=15)
monitor.add_argument("--max-launch-per-cycle", type=int)
monitor.add_argument("--dry-run", action="store_true")
monitor.add_argument("--once", action="store_true")
monitor.add_argument("--quiet", action="store_true")
monitor.set_defaults(func=cmd_monitor)
status = subparsers.add_parser("status", help="Refresh and print known run status.")
status.add_argument("--config", default=str(DEFAULT_CONFIG))
status.set_defaults(func=cmd_status)
harvest = subparsers.add_parser("harvest", help="Pull finished run logs and artifacts back to local state.")
harvest.add_argument("--config", default=str(DEFAULT_CONFIG))
harvest.add_argument("--run-id")
harvest.set_defaults(func=cmd_harvest)
show_hosts = subparsers.add_parser("show-hosts", help="Show configured hosts together with the latest probe summary.")
show_hosts.add_argument("--config", default=str(DEFAULT_CONFIG))
show_hosts.set_defaults(func=cmd_show_hosts)
queue_status = subparsers.add_parser("queue-status", help="Show durable queue state tracked by the monitor.")
queue_status.add_argument("--config", default=str(DEFAULT_CONFIG))
queue_status.set_defaults(func=cmd_queue_status)
return parser
def main(argv: list[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
try:
return args.func(args)
except FleetError as exc:
print(f"error: {exc}", file=sys.stderr)
return 1
if __name__ == "__main__":
raise SystemExit(main())