#!/usr/bin/env python3 from __future__ import annotations import argparse import csv import fcntl import hashlib import json import shlex import subprocess import sys import textwrap import time import tomllib from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Any SCRIPT_PATH = Path(__file__).resolve() REPO_ROOT = SCRIPT_PATH.parents[2] DEFAULT_CONFIG = SCRIPT_PATH.parent / "config" / "fleet.toml" DEFAULT_STATE_REL = Path(".aituner/gpu_fleet/state") DEFAULT_ARTIFACT_REL = Path(".aituner/gpu_fleet/artifacts") PROBE_CACHE_NAME = "probe_cache.json" QUEUE_STATE_NAME = "jobs_state.json" MONITOR_LOCK_NAME = "monitor.lock" class FleetError(RuntimeError): pass @dataclass class HostSpec: name: str ssh_alias: str enabled: bool = True sync_remote_path: str = "~/workspace/aituner" fleet_root: str = "~/.aituner_gpu_fleet" @dataclass class SyncSpec: mode: str = "rsync" local_path: Path = REPO_ROOT exclude: list[str] = field(default_factory=list) @dataclass class SchedulerSpec: gpu_free_memory_mb: int = 1024 gpu_free_utilization_pct: int = 10 prefer_pack: bool = True @dataclass class FleetConfig: config_path: Path project_root: Path state_dir: Path artifacts_dir: Path ssh_timeout_sec: int scheduler: SchedulerSpec sync: SyncSpec hosts: dict[str, HostSpec] @dataclass class JobSpec: name: str command: str gpus: int gpu_model: str | None = None hosts: list[str] = field(default_factory=list) artifacts: list[str] = field(default_factory=list) env: dict[str, str] = field(default_factory=dict) def utc_now() -> str: return datetime.now(timezone.utc).isoformat(timespec="seconds") def relative_to_root(root: Path, value: str | None, default: Path) -> Path: if not value: return default candidate = Path(value).expanduser() if candidate.is_absolute(): return candidate return (root / candidate).resolve() def load_toml(path: Path) -> dict[str, Any]: with path.open("rb") as fh: return tomllib.load(fh) def ensure_dir(path: Path) -> Path: path.mkdir(parents=True, exist_ok=True) return path def dump_json(path: Path, payload: Any) -> None: ensure_dir(path.parent) path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") def load_json(path: Path) -> Any: return json.loads(path.read_text(encoding="utf-8")) def render_command(template: str, **kwargs: str) -> str: try: return template.format(**kwargs) except KeyError as exc: raise FleetError(f"template is missing placeholder {exc!s}: {template}") from exc def run_local( argv: list[str], *, cwd: Path | None = None, capture_output: bool = True, check: bool = True, ) -> subprocess.CompletedProcess[str]: completed = subprocess.run( argv, cwd=str(cwd) if cwd else None, check=False, text=True, capture_output=capture_output, ) if check and completed.returncode != 0: stderr = completed.stderr.strip() stdout = completed.stdout.strip() detail = stderr or stdout or f"exit code {completed.returncode}" raise FleetError(f"command failed: {' '.join(shlex.quote(x) for x in argv)}\n{detail}") return completed def ssh_base(host: HostSpec, timeout_sec: int) -> list[str]: return [ "ssh", "-o", f"ConnectTimeout={timeout_sec}", host.ssh_alias, ] def run_ssh( config: FleetConfig, host: HostSpec, remote_command: str, *, capture_output: bool = True, check: bool = True, ) -> subprocess.CompletedProcess[str]: argv = ssh_base(host, config.ssh_timeout_sec) + [remote_command] return run_local(argv, cwd=config.project_root, capture_output=capture_output, check=check) def rsync_push(config: FleetConfig, host: HostSpec) -> None: ensure_remote_dir(config, host, host.sync_remote_path) argv = ["rsync", "-az"] for pattern in config.sync.exclude: argv.extend(["--exclude", pattern]) local_src = str(config.sync.local_path.resolve()) + "/" remote_dst = f"{host.ssh_alias}:{host.sync_remote_path.rstrip('/')}/" argv.extend([local_src, remote_dst]) run_local(argv, cwd=config.project_root, capture_output=True, check=True) def rsync_pull(config: FleetConfig, host: HostSpec, remote_path: str, local_path: Path) -> None: ensure_dir(local_path.parent) argv = [ "rsync", "-az", f"{host.ssh_alias}:{remote_path}", str(local_path), ] run_local(argv, cwd=config.project_root, capture_output=True, check=True) def ensure_remote_dir(config: FleetConfig, host: HostSpec, remote_path: str) -> None: run_ssh(config, host, f"mkdir -p {shlex.quote(remote_path)}", capture_output=True, check=True) def load_config(path: Path) -> FleetConfig: raw = load_toml(path) if int(raw.get("version", 1)) != 1: raise FleetError("only config version 1 is supported") project_root = REPO_ROOT paths_raw = raw.get("paths", {}) state_dir = relative_to_root(project_root, paths_raw.get("state_dir"), project_root / DEFAULT_STATE_REL) artifacts_dir = relative_to_root( project_root, paths_raw.get("artifacts_dir"), project_root / DEFAULT_ARTIFACT_REL, ) sync_raw = raw.get("sync", {}) sync = SyncSpec( mode=str(sync_raw.get("mode", "rsync")), local_path=relative_to_root(project_root, sync_raw.get("local_path"), project_root), exclude=[str(item) for item in sync_raw.get("exclude", [])], ) if sync.mode != "rsync": raise FleetError(f"unsupported sync.mode: {sync.mode}") scheduler_raw = raw.get("scheduler", {}) scheduler = SchedulerSpec( gpu_free_memory_mb=int(scheduler_raw.get("gpu_free_memory_mb", 1024)), gpu_free_utilization_pct=int(scheduler_raw.get("gpu_free_utilization_pct", 10)), prefer_pack=bool(scheduler_raw.get("prefer_pack", True)), ) ssh_raw = raw.get("ssh", {}) hosts: dict[str, HostSpec] = {} for raw_host in raw.get("hosts", []): name = str(raw_host.get("name", "")).strip() ssh_alias = str(raw_host.get("ssh_alias", "")).strip() if not name or not ssh_alias: raise FleetError("every host must define both name and ssh_alias") host = HostSpec( name=name, ssh_alias=ssh_alias, enabled=bool(raw_host.get("enabled", True)), sync_remote_path=str(raw_host.get("sync_remote_path", "~/workspace/aituner")), fleet_root=str(raw_host.get("fleet_root", "~/.aituner_gpu_fleet")), ) hosts[name] = host if not hosts: raise FleetError("no hosts are configured") return FleetConfig( config_path=path, project_root=project_root, state_dir=state_dir, artifacts_dir=artifacts_dir, ssh_timeout_sec=int(ssh_raw.get("connect_timeout_sec", 10)), scheduler=scheduler, sync=sync, hosts=hosts, ) def load_jobs(path: Path) -> list[JobSpec]: raw = load_toml(path) if int(raw.get("version", 1)) != 1: raise FleetError("only jobs version 1 is supported") jobs: list[JobSpec] = [] seen_names: set[str] = set() for item in raw.get("jobs", []): name = str(item.get("name", "")).strip() if not name: raise FleetError("every job must define name") if name in seen_names: raise FleetError(f"duplicated job name: {name}") command = str(item.get("command", "")).strip() if not command: raise FleetError(f"job {name} must define command") gpus = int(item.get("gpus", 0)) if gpus <= 0: raise FleetError(f"job {name} must request a positive gpu count") jobs.append( JobSpec( name=name, command=command, gpus=gpus, gpu_model=str(item["gpu_model"]) if "gpu_model" in item else None, hosts=[str(entry) for entry in item.get("hosts", [])], artifacts=[str(entry) for entry in item.get("artifacts", [])], env={str(k): str(v) for k, v in item.get("env", {}).items()}, ) ) seen_names.add(name) if not jobs: raise FleetError("no jobs are defined") return jobs def validate_jobs_against_config(config: FleetConfig, jobs: list[JobSpec]) -> None: for job in jobs: if job.hosts: unknown_hosts = sorted( requested for requested in job.hosts if requested not in config.hosts and requested not in {host.ssh_alias for host in config.hosts.values()} ) if unknown_hosts: raise FleetError(f"job {job.name} references unknown hosts: {', '.join(unknown_hosts)}") candidate_hosts = [ host for host in config.hosts.values() if matches_host_constraint(host, job.hosts) ] if not candidate_hosts: raise FleetError(f"job {job.name} has no candidate hosts after host filtering") def sync_jobs_to_queue_state(queue_state: dict[str, Any], jobs: list[JobSpec]) -> None: now = utc_now() state_jobs = queue_state.setdefault("jobs", {}) for job in jobs: signature = job_signature(job) record = state_jobs.get(job.name) payload = { "name": job.name, "signature": signature, "command": job.command, "gpus": job.gpus, "gpu_model": job.gpu_model, "hosts": job.hosts, "artifacts": job.artifacts, "env": job.env, "last_seen_at": now, } if record is None: state_jobs[job.name] = { **payload, "status": "pending", "created_at": now, "submitted_at": None, "completed_at": None, "run_id": None, "attempts": 0, } continue if record.get("signature") != signature and record.get("status") != "pending": raise FleetError(f"job {job.name} changed after it entered the queue; use a new job name instead") record.update(payload) if record.get("status") == "unknown": record["status"] = "pending" def pending_jobs_from_queue(queue_state: dict[str, Any], jobs: list[JobSpec]) -> list[JobSpec]: state_jobs = queue_state.get("jobs", {}) pending: list[JobSpec] = [] for job in jobs: record = state_jobs.get(job.name, {}) if record.get("status", "pending") == "pending": pending.append(job) return pending def state_subdir(config: FleetConfig, *parts: str) -> Path: return ensure_dir(config.state_dir.joinpath(*parts)) def run_state_dir(config: FleetConfig) -> Path: return state_subdir(config, "runs") def probe_cache_path(config: FleetConfig) -> Path: return state_subdir(config, "probe") / PROBE_CACHE_NAME def queue_state_path(config: FleetConfig) -> Path: return state_subdir(config, "queue") / QUEUE_STATE_NAME def monitor_lock_path(config: FleetConfig) -> Path: return state_subdir(config, "queue") / MONITOR_LOCK_NAME def load_queue_state(config: FleetConfig) -> dict[str, Any]: path = queue_state_path(config) if path.exists(): payload = load_json(path) payload.setdefault("jobs", {}) return payload return { "generated_at": utc_now(), "updated_at": utc_now(), "jobs": {}, } def save_queue_state(config: FleetConfig, payload: dict[str, Any]) -> None: payload["updated_at"] = utc_now() dump_json(queue_state_path(config), payload) def job_signature(job: JobSpec) -> str: payload = { "name": job.name, "command": job.command, "gpus": job.gpus, "gpu_model": job.gpu_model, "hosts": job.hosts, "artifacts": job.artifacts, "env": job.env, } raw = json.dumps(payload, sort_keys=True, ensure_ascii=True) return hashlib.sha256(raw.encode("utf-8")).hexdigest() def parse_gpu_query(output: str) -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] reader = csv.reader(line for line in output.splitlines() if line.strip()) for row in reader: if len(row) != 4: continue index, name, memory_total, memory_used = [cell.strip() for cell in row] rows.append( { "index": int(index), "name": name, "memory_total_mb": int(memory_total), "memory_used_mb": int(memory_used), } ) return rows def parse_gpu_status(output: str) -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] reader = csv.reader(line for line in output.splitlines() if line.strip()) for row in reader: if len(row) != 5: continue index, name, memory_total, memory_used, util = [cell.strip() for cell in row] rows.append( { "index": int(index), "name": name, "memory_total_mb": int(memory_total), "memory_used_mb": int(memory_used), "utilization_gpu_pct": int(util), } ) return rows def probe_single_host(config: FleetConfig, host: HostSpec) -> dict[str, Any]: gpu_cmd = textwrap.dedent( """ bash -lc ' set -euo pipefail nvidia-smi --query-gpu=index,name,memory.total,memory.used --format=csv,noheader,nounits ' """ ).strip() completed = run_ssh(config, host, gpu_cmd, capture_output=True, check=False) payload: dict[str, Any] = { "host": host.name, "ssh_alias": host.ssh_alias, "enabled": host.enabled, "probed_at": utc_now(), } if completed.returncode != 0: payload["status"] = "unreachable" payload["error"] = (completed.stderr or completed.stdout).strip() return payload gpus = parse_gpu_query(completed.stdout) names = sorted({gpu["name"] for gpu in gpus}) payload["status"] = "ok" payload["gpus"] = gpus payload["gpu_count"] = len(gpus) payload["gpu_models"] = names payload["gpu_summary"] = ", ".join(f"{sum(1 for gpu in gpus if gpu['name'] == name)}x {name}" for name in names) return payload def load_probe_cache(config: FleetConfig) -> dict[str, Any]: path = probe_cache_path(config) if not path.exists(): raise FleetError( f"probe cache is missing: {path}. Run 'python3 infra/gpu_fleet/gpu_fleet.py probe --config {config.config_path}' first." ) return load_json(path) def query_host_gpu_status(config: FleetConfig, host: HostSpec) -> list[dict[str, Any]]: cmd = textwrap.dedent( """ bash -lc ' set -euo pipefail nvidia-smi --query-gpu=index,name,memory.total,memory.used,utilization.gpu --format=csv,noheader,nounits ' """ ).strip() completed = run_ssh(config, host, cmd, capture_output=True, check=True) return parse_gpu_status(completed.stdout) def free_gpu_indices(config: FleetConfig, host: HostSpec) -> list[int]: status = query_host_gpu_status(config, host) free: list[int] = [] for gpu in status: if gpu["memory_used_mb"] <= config.scheduler.gpu_free_memory_mb and gpu["utilization_gpu_pct"] <= config.scheduler.gpu_free_utilization_pct: free.append(gpu["index"]) return free def matches_gpu_model(host_probe: dict[str, Any], host: HostSpec, requested: str | None) -> bool: if not requested: return True needle = requested.lower() for name in host_probe.get("gpu_models", []): if needle in name.lower(): return True summary = str(host_probe.get("gpu_summary", "")) return needle in summary.lower() def matches_host_constraint(host: HostSpec, requested_hosts: list[str]) -> bool: if not requested_hosts: return True requested = {item.strip() for item in requested_hosts if item.strip()} return host.name in requested or host.ssh_alias in requested def build_payload(job: JobSpec, host: HostSpec, run_id: str, remote_run_dir: str, gpu_ids: list[int]) -> str: env_exports = "\n".join( f"export {key}={shlex.quote(value)}" for key, value in sorted(job.env.items()) ) log_file = f"{remote_run_dir}/stdout.log" exit_file = f"{remote_run_dir}/exit_code" started_file = f"{remote_run_dir}/started_at" finished_file = f"{remote_run_dir}/finished_at" payload = textwrap.dedent( f""" set -euo pipefail mkdir -p {shlex.quote(remote_run_dir)} exec > >(tee -a {shlex.quote(log_file)}) 2>&1 printf '%s\\n' "$(date -Iseconds)" > {shlex.quote(started_file)} on_exit() {{ status=$? printf '%s\\n' "$status" > {shlex.quote(exit_file)} printf '%s\\n' "$(date -Iseconds)" > {shlex.quote(finished_file)} }} trap on_exit EXIT cd {shlex.quote(host.sync_remote_path)} export CUDA_VISIBLE_DEVICES={shlex.quote(','.join(str(idx) for idx in gpu_ids))} export AITUNER_RUN_ID={shlex.quote(run_id)} export AITUNER_REMOTE_RUN_DIR={shlex.quote(remote_run_dir)} export AITUNER_GPU_COUNT={shlex.quote(str(job.gpus))} {env_exports} {job.command} """ ).strip() return payload def write_run_manifest(config: FleetConfig, manifest: dict[str, Any]) -> Path: path = run_state_dir(config) / f"{manifest['run_id']}.json" dump_json(path, manifest) return path def run_manifest_path(config: FleetConfig, run_id: str) -> Path: return run_state_dir(config) / f"{run_id}.json" def load_run_manifests(config: FleetConfig) -> list[dict[str, Any]]: runs_dir = run_state_dir(config) manifests: list[dict[str, Any]] = [] for path in sorted(runs_dir.glob("*.json")): manifests.append(load_json(path)) return manifests def launch_job(config: FleetConfig, host: HostSpec, job: JobSpec, gpu_ids: list[int]) -> dict[str, Any]: sync_remote_path = host.sync_remote_path.rstrip("/") fleet_root = host.fleet_root.rstrip("/") run_id = f"{job.name}-{datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%fZ')}" session_name = f"aituner_{run_id.replace('.', '_')}" remote_run_dir = f"{fleet_root}/runs/{run_id}" payload = build_payload(job, host, run_id, remote_run_dir, gpu_ids) remote_launch = ( f"mkdir -p {shlex.quote(remote_run_dir)} " f"&& tmux new-session -d -s {shlex.quote(session_name)} bash -lc {shlex.quote(payload)}" ) run_ssh(config, host, remote_launch, capture_output=True, check=True) manifest = { "run_id": run_id, "job_name": job.name, "host": host.name, "ssh_alias": host.ssh_alias, "session_name": session_name, "gpu_ids": gpu_ids, "gpu_model": job.gpu_model, "gpus": job.gpus, "command": job.command, "env": job.env, "artifacts": job.artifacts, "remote_sync_path": sync_remote_path, "remote_run_dir": remote_run_dir, "status": "running", "submitted_at": utc_now(), "harvested_at": None, } write_run_manifest(config, manifest) return manifest def refresh_manifest_status(config: FleetConfig, manifest: dict[str, Any]) -> dict[str, Any]: host = config.hosts[manifest["host"]] remote_cmd = textwrap.dedent( f""" bash -lc ' if [ -f {shlex.quote(manifest["remote_run_dir"] + "/exit_code")} ]; then code=$(cat {shlex.quote(manifest["remote_run_dir"] + "/exit_code")}) printf "done:%s\\n" "$code" elif tmux has-session -t {shlex.quote(manifest["session_name"])} 2>/dev/null; then printf "running\\n" else printf "unknown\\n" fi ' """ ).strip() completed = run_ssh(config, host, remote_cmd, capture_output=True, check=False) status_line = (completed.stdout or completed.stderr).strip() new_manifest = dict(manifest) if status_line.startswith("done:"): code = int(status_line.split(":", 1)[1]) new_manifest["status"] = "completed" if code == 0 else "failed" new_manifest["exit_code"] = code elif status_line == "running": new_manifest["status"] = "running" else: new_manifest["status"] = "unknown" write_run_manifest(config, new_manifest) return new_manifest def harvest_run(config: FleetConfig, manifest: dict[str, Any]) -> dict[str, Any]: host = config.hosts[manifest["host"]] refreshed = refresh_manifest_status(config, manifest) if refreshed["status"] not in {"completed", "failed"}: return refreshed local_base = ensure_dir(config.artifacts_dir / refreshed["run_id"]) remote_run_dir = refreshed["remote_run_dir"].rstrip("/") rsync_pull(config, host, f"{remote_run_dir}/", local_base / "remote_run") for artifact in refreshed.get("artifacts", []): artifact_remote = f"{refreshed['remote_sync_path'].rstrip('/')}/{artifact}" target = local_base / "artifacts" / Path(artifact.lstrip("/")) check = run_ssh( config, host, f"test -e {shlex.quote(artifact_remote)}", capture_output=True, check=False, ) if check.returncode == 0: rsync_pull(config, host, artifact_remote, target) refreshed["harvested_at"] = utc_now() write_run_manifest(config, refreshed) return refreshed def reconcile_queue_state_with_runs(config: FleetConfig, queue_state: dict[str, Any]) -> None: state_jobs = queue_state.get("jobs", {}) for name, record in state_jobs.items(): if record.get("status") != "running": continue run_id = record.get("run_id") if not run_id: record["status"] = "unknown" record["updated_at"] = utc_now() continue manifest_path = run_manifest_path(config, run_id) if not manifest_path.exists(): record["status"] = "unknown" record["updated_at"] = utc_now() continue manifest = load_json(manifest_path) refreshed = refresh_manifest_status(config, manifest) record["updated_at"] = utc_now() if refreshed["status"] == "running": continue if refreshed["status"] in {"completed", "failed"}: harvested = harvest_run(config, refreshed) record["status"] = harvested["status"] record["completed_at"] = utc_now() record["exit_code"] = harvested.get("exit_code") record["harvested_at"] = harvested.get("harvested_at") continue record["status"] = "unknown" def print_probe_table(probe: dict[str, Any]) -> None: for name, payload in sorted(probe["hosts"].items()): if payload["status"] != "ok": print(f"{name:12} status={payload['status']} error={payload.get('error', '')}") continue print(f"{name:12} status=ok gpus={payload['gpu_summary']}") class MonitorLock: def __init__(self, path: Path): self.path = path self.handle: Any | None = None def __enter__(self) -> "MonitorLock": ensure_dir(self.path.parent) self.handle = self.path.open("w", encoding="utf-8") try: fcntl.flock(self.handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) except BlockingIOError as exc: raise FleetError(f"monitor lock is busy: {self.path}") from exc return self def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: if self.handle is not None: fcntl.flock(self.handle.fileno(), fcntl.LOCK_UN) self.handle.close() def cmd_validate(args: argparse.Namespace) -> int: config = load_config(Path(args.config).resolve()) print(f"config ok: {config.config_path}") if args.jobs: jobs = load_jobs(Path(args.jobs).resolve()) validate_jobs_against_config(config, jobs) print(f"jobs ok: {len(jobs)} definitions") return 0 def cmd_bootstrap_hosts(args: argparse.Namespace) -> int: aliases_path = Path(args.aliases_file).expanduser().resolve() aliases = [ line.strip() for line in aliases_path.read_text(encoding="utf-8").splitlines() if line.strip() and not line.lstrip().startswith("#") ] if not aliases: raise FleetError(f"no aliases found in {aliases_path}") sections = [ 'version = 1', '', '[paths]', 'state_dir = ".aituner/gpu_fleet/state"', 'artifacts_dir = ".aituner/gpu_fleet/artifacts"', '', '[sync]', 'mode = "rsync"', 'local_path = "."', '', ] for alias in aliases: sections.extend( [ '[[hosts]]', f'name = "{alias}"', f'ssh_alias = "{alias}"', 'enabled = true', 'sync_remote_path = "~/workspace/aituner"', 'fleet_root = "~/.aituner_gpu_fleet"', '', ] ) payload = "\n".join(sections) if args.output: output = Path(args.output).expanduser().resolve() ensure_dir(output.parent) output.write_text(payload + "\n", encoding="utf-8") print(f"wrote skeleton config: {output}") else: print(payload) return 0 def cmd_probe(args: argparse.Namespace) -> int: config = load_config(Path(args.config).resolve()) results = { "generated_at": utc_now(), "config_path": str(config.config_path), "hosts": {}, } for host in config.hosts.values(): if args.host and host.name != args.host: continue if not host.enabled and not args.include_disabled: continue results["hosts"][host.name] = probe_single_host(config, host) dump_json(probe_cache_path(config), results) print_probe_table(results) return 0 def snapshot_free_gpus(config: FleetConfig, probe: dict[str, Any]) -> dict[str, list[int]]: availability: dict[str, list[int]] = {} for host in config.hosts.values(): host_probe = probe["hosts"].get(host.name) if not host.enabled or not host_probe or host_probe.get("status") != "ok": continue availability[host.name] = free_gpu_indices(config, host) return availability def select_host_for_job( config: FleetConfig, probe: dict[str, Any], availability: dict[str, list[int]], job: JobSpec, ) -> tuple[HostSpec, list[int]] | None: candidates: list[tuple[int, int, str, HostSpec, list[int]]] = [] for host in config.hosts.values(): if not host.enabled: continue host_probe = probe["hosts"].get(host.name) if not host_probe or host_probe.get("status") != "ok": continue if not matches_host_constraint(host, job.hosts): continue if not matches_gpu_model(host_probe, host, job.gpu_model): continue free = availability.get(host.name, []) if len(free) < job.gpus: continue chosen = sorted(free)[: job.gpus] leftover = len(free) - job.gpus exact_model = 0 if job.gpu_model and matches_gpu_model(host_probe, host, job.gpu_model) else 1 candidates.append((leftover if config.scheduler.prefer_pack else -len(free), exact_model, host.name, host, chosen)) if not candidates: return None candidates.sort(key=lambda item: (item[0], item[1], item[2])) _, _, _, host, gpu_ids = candidates[0] return host, gpu_ids def dispatch_jobs( config: FleetConfig, jobs: list[JobSpec], *, dry_run: bool, max_launch: int | None = None, ) -> tuple[list[dict[str, Any]], list[str]]: probe = load_probe_cache(config) availability = snapshot_free_gpus(config, probe) manifests: list[dict[str, Any]] = [] skipped: list[str] = [] synced_hosts: set[str] = set() launch_budget = max_launch if max_launch is not None else len(jobs) for job in jobs: if len(manifests) >= launch_budget: skipped.append(f"{job.name}: launch budget exhausted") continue selected = select_host_for_job(config, probe, availability, job) if not selected: skipped.append(f"{job.name}: no host has matching host/model/free GPUs") continue host, gpu_ids = selected availability[host.name] = [gpu for gpu in availability.get(host.name, []) if gpu not in set(gpu_ids)] if dry_run: manifests.append( { "job_name": job.name, "host": host.name, "gpu_ids": gpu_ids, "dry_run": True, } ) continue if host.name not in synced_hosts: rsync_push(config, host) synced_hosts.add(host.name) manifest = launch_job(config, host, job, gpu_ids) manifests.append(manifest) return manifests, skipped def cmd_dispatch(args: argparse.Namespace) -> int: config = load_config(Path(args.config).resolve()) jobs = load_jobs(Path(args.jobs).resolve()) validate_jobs_against_config(config, jobs) manifests, skipped = dispatch_jobs( config, jobs, dry_run=args.dry_run, max_launch=args.max_launch, ) for manifest in manifests: if manifest.get("dry_run"): print(f"[dry-run] {manifest['job_name']} -> {manifest['host']} gpus={manifest['gpu_ids']}") continue print( f"launched {manifest['run_id']} host={manifest['host']} gpus={manifest['gpu_ids']}" ) for line in skipped: print(f"skipped {line}") return 0 def print_queue_summary(queue_state: dict[str, Any]) -> None: counts: dict[str, int] = {} for record in queue_state.get("jobs", {}).values(): status = str(record.get("status", "pending")) counts[status] = counts.get(status, 0) + 1 summary = " ".join(f"{key}={counts[key]}" for key in sorted(counts)) print(summary or "pending=0") def print_queue_jobs(queue_state: dict[str, Any]) -> None: for name, record in sorted(queue_state.get("jobs", {}).items()): host = record.get("host") or "-" run_id = record.get("run_id") or "-" print( f"{name:32} status={record.get('status','pending'):10} gpus={record.get('gpus','-')} host={host:8} run_id={run_id}" ) def monitor_cycle( config: FleetConfig, jobs_path: Path, *, dry_run: bool, max_launch_per_cycle: int | None, verbose: bool, ) -> tuple[int, int]: queue_state = load_queue_state(config) reconcile_queue_state_with_runs(config, queue_state) jobs: list[JobSpec] = [] try: jobs = load_jobs(jobs_path) validate_jobs_against_config(config, jobs) sync_jobs_to_queue_state(queue_state, jobs) except Exception as exc: if verbose: print(f"queue read warning: {exc}") save_queue_state(config, queue_state) return 0, 0 pending = pending_jobs_from_queue(queue_state, jobs) try: manifests, skipped = dispatch_jobs( config, pending, dry_run=dry_run, max_launch=max_launch_per_cycle, ) except Exception as exc: if verbose: print(f"dispatch warning: {exc}") save_queue_state(config, queue_state) return 0, 0 now = utc_now() for manifest in manifests: record = queue_state["jobs"][manifest["job_name"]] if manifest.get("dry_run"): continue record["status"] = "running" record["submitted_at"] = now record["updated_at"] = now record["run_id"] = manifest["run_id"] record["host"] = manifest["host"] record["gpu_ids"] = manifest["gpu_ids"] record["attempts"] = int(record.get("attempts", 0)) + 1 save_queue_state(config, queue_state) if verbose: for manifest in manifests: if manifest.get("dry_run"): print(f"[dry-run] {manifest['job_name']} -> {manifest['host']} gpus={manifest['gpu_ids']}") else: print(f"launched {manifest['run_id']} host={manifest['host']} gpus={manifest['gpu_ids']}") for line in skipped: print(f"deferred {line}") print_queue_summary(queue_state) return len(manifests), len(skipped) def cmd_monitor(args: argparse.Namespace) -> int: config = load_config(Path(args.config).resolve()) jobs_path = Path(args.jobs).resolve() with MonitorLock(monitor_lock_path(config)): if args.once: monitor_cycle( config, jobs_path, dry_run=args.dry_run, max_launch_per_cycle=args.max_launch_per_cycle, verbose=True, ) return 0 print(f"monitoring {jobs_path} every {args.interval_sec}s") try: while True: monitor_cycle( config, jobs_path, dry_run=args.dry_run, max_launch_per_cycle=args.max_launch_per_cycle, verbose=not args.quiet, ) time.sleep(args.interval_sec) except KeyboardInterrupt: print("monitor stopped") return 0 def cmd_queue_status(args: argparse.Namespace) -> int: config = load_config(Path(args.config).resolve()) queue_state = load_queue_state(config) print_queue_summary(queue_state) print_queue_jobs(queue_state) return 0 def cmd_status(args: argparse.Namespace) -> int: config = load_config(Path(args.config).resolve()) manifests = load_run_manifests(config) if not manifests: print("no runs recorded") return 0 for manifest in manifests: refreshed = refresh_manifest_status(config, manifest) exit_code = refreshed.get("exit_code") suffix = f" exit_code={exit_code}" if exit_code is not None else "" print( f"{refreshed['run_id']} job={refreshed['job_name']} host={refreshed['host']} status={refreshed['status']}{suffix}" ) return 0 def cmd_harvest(args: argparse.Namespace) -> int: config = load_config(Path(args.config).resolve()) manifests = load_run_manifests(config) if args.run_id: manifests = [manifest for manifest in manifests if manifest["run_id"] == args.run_id] harvested_any = False for manifest in manifests: refreshed = harvest_run(config, manifest) if refreshed.get("harvested_at"): harvested_any = True print( f"harvested {refreshed['run_id']} status={refreshed['status']} into {config.artifacts_dir / refreshed['run_id']}" ) else: print(f"pending {refreshed['run_id']} status={refreshed['status']}") if not harvested_any and not manifests: print("no matching runs") return 0 def cmd_show_hosts(args: argparse.Namespace) -> int: config = load_config(Path(args.config).resolve()) probe = load_probe_cache(config) for host in config.hosts.values(): host_probe = probe["hosts"].get(host.name, {}) summary = host_probe.get("gpu_summary", "unprobed") print(f"{host.name:12} alias={host.ssh_alias:12} gpus={summary:20}") return 0 def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Dispatch heterogeneous GPU jobs over SSH/tmux.") subparsers = parser.add_subparsers(dest="command", required=True) validate = subparsers.add_parser("validate", help="Validate fleet and optional job configs.") validate.add_argument("--config", default=str(DEFAULT_CONFIG)) validate.add_argument("--jobs") validate.set_defaults(func=cmd_validate) bootstrap = subparsers.add_parser("bootstrap-hosts", help="Generate a fleet config skeleton from a plain alias list.") bootstrap.add_argument("--aliases-file", required=True) bootstrap.add_argument("--output") bootstrap.set_defaults(func=cmd_bootstrap_hosts) probe = subparsers.add_parser("probe", help="Probe remote hosts for GPU information.") probe.add_argument("--config", default=str(DEFAULT_CONFIG)) probe.add_argument("--host") probe.add_argument("--include-disabled", action="store_true") probe.set_defaults(func=cmd_probe) dispatch = subparsers.add_parser("dispatch", help="Dispatch a batch of jobs to matching hosts.") dispatch.add_argument("--config", default=str(DEFAULT_CONFIG)) dispatch.add_argument("--jobs", required=True) dispatch.add_argument("--dry-run", action="store_true") dispatch.add_argument("--max-launch", type=int) dispatch.set_defaults(func=cmd_dispatch) monitor = subparsers.add_parser("monitor", help="Continuously consume jobs.toml as a local queue.") monitor.add_argument("--config", default=str(DEFAULT_CONFIG)) monitor.add_argument("--jobs", required=True) monitor.add_argument("--interval-sec", type=int, default=15) monitor.add_argument("--max-launch-per-cycle", type=int) monitor.add_argument("--dry-run", action="store_true") monitor.add_argument("--once", action="store_true") monitor.add_argument("--quiet", action="store_true") monitor.set_defaults(func=cmd_monitor) status = subparsers.add_parser("status", help="Refresh and print known run status.") status.add_argument("--config", default=str(DEFAULT_CONFIG)) status.set_defaults(func=cmd_status) harvest = subparsers.add_parser("harvest", help="Pull finished run logs and artifacts back to local state.") harvest.add_argument("--config", default=str(DEFAULT_CONFIG)) harvest.add_argument("--run-id") harvest.set_defaults(func=cmd_harvest) show_hosts = subparsers.add_parser("show-hosts", help="Show configured hosts together with the latest probe summary.") show_hosts.add_argument("--config", default=str(DEFAULT_CONFIG)) show_hosts.set_defaults(func=cmd_show_hosts) queue_status = subparsers.add_parser("queue-status", help="Show durable queue state tracked by the monitor.") queue_status.add_argument("--config", default=str(DEFAULT_CONFIG)) queue_status.set_defaults(func=cmd_queue_status) return parser def main(argv: list[str] | None = None) -> int: parser = build_parser() args = parser.parse_args(argv) try: return args.func(args) except FleetError as exc: print(f"error: {exc}", file=sys.stderr) return 1 if __name__ == "__main__": raise SystemExit(main())