1133 lines
38 KiB
Python
Executable File
1133 lines
38 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
import fcntl
|
|
import hashlib
|
|
import json
|
|
import shlex
|
|
import subprocess
|
|
import sys
|
|
import textwrap
|
|
import time
|
|
import tomllib
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
SCRIPT_PATH = Path(__file__).resolve()
|
|
REPO_ROOT = SCRIPT_PATH.parents[2]
|
|
DEFAULT_CONFIG = SCRIPT_PATH.parent / "config" / "fleet.toml"
|
|
DEFAULT_STATE_REL = Path(".aituner/gpu_fleet/state")
|
|
DEFAULT_ARTIFACT_REL = Path(".aituner/gpu_fleet/artifacts")
|
|
PROBE_CACHE_NAME = "probe_cache.json"
|
|
QUEUE_STATE_NAME = "jobs_state.json"
|
|
MONITOR_LOCK_NAME = "monitor.lock"
|
|
|
|
|
|
class FleetError(RuntimeError):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class HostSpec:
|
|
name: str
|
|
ssh_alias: str
|
|
enabled: bool = True
|
|
sync_remote_path: str = "~/workspace/aituner"
|
|
fleet_root: str = "~/.aituner_gpu_fleet"
|
|
|
|
|
|
@dataclass
|
|
class SyncSpec:
|
|
mode: str = "rsync"
|
|
local_path: Path = REPO_ROOT
|
|
exclude: list[str] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class SchedulerSpec:
|
|
gpu_free_memory_mb: int = 1024
|
|
gpu_free_utilization_pct: int = 10
|
|
prefer_pack: bool = True
|
|
|
|
|
|
@dataclass
|
|
class FleetConfig:
|
|
config_path: Path
|
|
project_root: Path
|
|
state_dir: Path
|
|
artifacts_dir: Path
|
|
ssh_timeout_sec: int
|
|
scheduler: SchedulerSpec
|
|
sync: SyncSpec
|
|
hosts: dict[str, HostSpec]
|
|
|
|
|
|
@dataclass
|
|
class JobSpec:
|
|
name: str
|
|
command: str
|
|
gpus: int
|
|
gpu_model: str | None = None
|
|
hosts: list[str] = field(default_factory=list)
|
|
artifacts: list[str] = field(default_factory=list)
|
|
env: dict[str, str] = field(default_factory=dict)
|
|
|
|
|
|
def utc_now() -> str:
|
|
return datetime.now(timezone.utc).isoformat(timespec="seconds")
|
|
|
|
|
|
def relative_to_root(root: Path, value: str | None, default: Path) -> Path:
|
|
if not value:
|
|
return default
|
|
candidate = Path(value).expanduser()
|
|
if candidate.is_absolute():
|
|
return candidate
|
|
return (root / candidate).resolve()
|
|
|
|
|
|
def load_toml(path: Path) -> dict[str, Any]:
|
|
with path.open("rb") as fh:
|
|
return tomllib.load(fh)
|
|
|
|
|
|
def ensure_dir(path: Path) -> Path:
|
|
path.mkdir(parents=True, exist_ok=True)
|
|
return path
|
|
|
|
|
|
def dump_json(path: Path, payload: Any) -> None:
|
|
ensure_dir(path.parent)
|
|
path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
|
|
|
|
|
def load_json(path: Path) -> Any:
|
|
return json.loads(path.read_text(encoding="utf-8"))
|
|
|
|
|
|
def render_command(template: str, **kwargs: str) -> str:
|
|
try:
|
|
return template.format(**kwargs)
|
|
except KeyError as exc:
|
|
raise FleetError(f"template is missing placeholder {exc!s}: {template}") from exc
|
|
|
|
|
|
def run_local(
|
|
argv: list[str],
|
|
*,
|
|
cwd: Path | None = None,
|
|
capture_output: bool = True,
|
|
check: bool = True,
|
|
) -> subprocess.CompletedProcess[str]:
|
|
completed = subprocess.run(
|
|
argv,
|
|
cwd=str(cwd) if cwd else None,
|
|
check=False,
|
|
text=True,
|
|
capture_output=capture_output,
|
|
)
|
|
if check and completed.returncode != 0:
|
|
stderr = completed.stderr.strip()
|
|
stdout = completed.stdout.strip()
|
|
detail = stderr or stdout or f"exit code {completed.returncode}"
|
|
raise FleetError(f"command failed: {' '.join(shlex.quote(x) for x in argv)}\n{detail}")
|
|
return completed
|
|
|
|
|
|
def ssh_base(host: HostSpec, timeout_sec: int) -> list[str]:
|
|
return [
|
|
"ssh",
|
|
"-o",
|
|
f"ConnectTimeout={timeout_sec}",
|
|
host.ssh_alias,
|
|
]
|
|
|
|
|
|
def run_ssh(
|
|
config: FleetConfig,
|
|
host: HostSpec,
|
|
remote_command: str,
|
|
*,
|
|
capture_output: bool = True,
|
|
check: bool = True,
|
|
) -> subprocess.CompletedProcess[str]:
|
|
argv = ssh_base(host, config.ssh_timeout_sec) + [remote_command]
|
|
return run_local(argv, cwd=config.project_root, capture_output=capture_output, check=check)
|
|
|
|
|
|
def rsync_push(config: FleetConfig, host: HostSpec) -> None:
|
|
ensure_remote_dir(config, host, host.sync_remote_path)
|
|
argv = ["rsync", "-az"]
|
|
for pattern in config.sync.exclude:
|
|
argv.extend(["--exclude", pattern])
|
|
local_src = str(config.sync.local_path.resolve()) + "/"
|
|
remote_dst = f"{host.ssh_alias}:{host.sync_remote_path.rstrip('/')}/"
|
|
argv.extend([local_src, remote_dst])
|
|
run_local(argv, cwd=config.project_root, capture_output=True, check=True)
|
|
|
|
|
|
def rsync_pull(config: FleetConfig, host: HostSpec, remote_path: str, local_path: Path) -> None:
|
|
ensure_dir(local_path.parent)
|
|
argv = [
|
|
"rsync",
|
|
"-az",
|
|
f"{host.ssh_alias}:{remote_path}",
|
|
str(local_path),
|
|
]
|
|
run_local(argv, cwd=config.project_root, capture_output=True, check=True)
|
|
|
|
|
|
def ensure_remote_dir(config: FleetConfig, host: HostSpec, remote_path: str) -> None:
|
|
run_ssh(config, host, f"mkdir -p {shlex.quote(remote_path)}", capture_output=True, check=True)
|
|
|
|
|
|
def load_config(path: Path) -> FleetConfig:
|
|
raw = load_toml(path)
|
|
if int(raw.get("version", 1)) != 1:
|
|
raise FleetError("only config version 1 is supported")
|
|
|
|
project_root = REPO_ROOT
|
|
paths_raw = raw.get("paths", {})
|
|
state_dir = relative_to_root(project_root, paths_raw.get("state_dir"), project_root / DEFAULT_STATE_REL)
|
|
artifacts_dir = relative_to_root(
|
|
project_root,
|
|
paths_raw.get("artifacts_dir"),
|
|
project_root / DEFAULT_ARTIFACT_REL,
|
|
)
|
|
|
|
sync_raw = raw.get("sync", {})
|
|
sync = SyncSpec(
|
|
mode=str(sync_raw.get("mode", "rsync")),
|
|
local_path=relative_to_root(project_root, sync_raw.get("local_path"), project_root),
|
|
exclude=[str(item) for item in sync_raw.get("exclude", [])],
|
|
)
|
|
if sync.mode != "rsync":
|
|
raise FleetError(f"unsupported sync.mode: {sync.mode}")
|
|
|
|
scheduler_raw = raw.get("scheduler", {})
|
|
scheduler = SchedulerSpec(
|
|
gpu_free_memory_mb=int(scheduler_raw.get("gpu_free_memory_mb", 1024)),
|
|
gpu_free_utilization_pct=int(scheduler_raw.get("gpu_free_utilization_pct", 10)),
|
|
prefer_pack=bool(scheduler_raw.get("prefer_pack", True)),
|
|
)
|
|
|
|
ssh_raw = raw.get("ssh", {})
|
|
hosts: dict[str, HostSpec] = {}
|
|
for raw_host in raw.get("hosts", []):
|
|
name = str(raw_host.get("name", "")).strip()
|
|
ssh_alias = str(raw_host.get("ssh_alias", "")).strip()
|
|
if not name or not ssh_alias:
|
|
raise FleetError("every host must define both name and ssh_alias")
|
|
host = HostSpec(
|
|
name=name,
|
|
ssh_alias=ssh_alias,
|
|
enabled=bool(raw_host.get("enabled", True)),
|
|
sync_remote_path=str(raw_host.get("sync_remote_path", "~/workspace/aituner")),
|
|
fleet_root=str(raw_host.get("fleet_root", "~/.aituner_gpu_fleet")),
|
|
)
|
|
hosts[name] = host
|
|
|
|
if not hosts:
|
|
raise FleetError("no hosts are configured")
|
|
|
|
return FleetConfig(
|
|
config_path=path,
|
|
project_root=project_root,
|
|
state_dir=state_dir,
|
|
artifacts_dir=artifacts_dir,
|
|
ssh_timeout_sec=int(ssh_raw.get("connect_timeout_sec", 10)),
|
|
scheduler=scheduler,
|
|
sync=sync,
|
|
hosts=hosts,
|
|
)
|
|
|
|
|
|
def load_jobs(path: Path) -> list[JobSpec]:
|
|
raw = load_toml(path)
|
|
if int(raw.get("version", 1)) != 1:
|
|
raise FleetError("only jobs version 1 is supported")
|
|
jobs: list[JobSpec] = []
|
|
seen_names: set[str] = set()
|
|
for item in raw.get("jobs", []):
|
|
name = str(item.get("name", "")).strip()
|
|
if not name:
|
|
raise FleetError("every job must define name")
|
|
if name in seen_names:
|
|
raise FleetError(f"duplicated job name: {name}")
|
|
command = str(item.get("command", "")).strip()
|
|
if not command:
|
|
raise FleetError(f"job {name} must define command")
|
|
gpus = int(item.get("gpus", 0))
|
|
if gpus <= 0:
|
|
raise FleetError(f"job {name} must request a positive gpu count")
|
|
jobs.append(
|
|
JobSpec(
|
|
name=name,
|
|
command=command,
|
|
gpus=gpus,
|
|
gpu_model=str(item["gpu_model"]) if "gpu_model" in item else None,
|
|
hosts=[str(entry) for entry in item.get("hosts", [])],
|
|
artifacts=[str(entry) for entry in item.get("artifacts", [])],
|
|
env={str(k): str(v) for k, v in item.get("env", {}).items()},
|
|
)
|
|
)
|
|
seen_names.add(name)
|
|
if not jobs:
|
|
raise FleetError("no jobs are defined")
|
|
return jobs
|
|
|
|
|
|
def validate_jobs_against_config(config: FleetConfig, jobs: list[JobSpec]) -> None:
|
|
for job in jobs:
|
|
if job.hosts:
|
|
unknown_hosts = sorted(
|
|
requested for requested in job.hosts if requested not in config.hosts and requested not in {host.ssh_alias for host in config.hosts.values()}
|
|
)
|
|
if unknown_hosts:
|
|
raise FleetError(f"job {job.name} references unknown hosts: {', '.join(unknown_hosts)}")
|
|
|
|
candidate_hosts = [
|
|
host for host in config.hosts.values()
|
|
if matches_host_constraint(host, job.hosts)
|
|
]
|
|
if not candidate_hosts:
|
|
raise FleetError(f"job {job.name} has no candidate hosts after host filtering")
|
|
|
|
|
|
def sync_jobs_to_queue_state(queue_state: dict[str, Any], jobs: list[JobSpec]) -> None:
|
|
now = utc_now()
|
|
state_jobs = queue_state.setdefault("jobs", {})
|
|
for job in jobs:
|
|
signature = job_signature(job)
|
|
record = state_jobs.get(job.name)
|
|
payload = {
|
|
"name": job.name,
|
|
"signature": signature,
|
|
"command": job.command,
|
|
"gpus": job.gpus,
|
|
"gpu_model": job.gpu_model,
|
|
"hosts": job.hosts,
|
|
"artifacts": job.artifacts,
|
|
"env": job.env,
|
|
"last_seen_at": now,
|
|
}
|
|
if record is None:
|
|
state_jobs[job.name] = {
|
|
**payload,
|
|
"status": "pending",
|
|
"created_at": now,
|
|
"submitted_at": None,
|
|
"completed_at": None,
|
|
"run_id": None,
|
|
"attempts": 0,
|
|
}
|
|
continue
|
|
|
|
if record.get("signature") != signature and record.get("status") != "pending":
|
|
raise FleetError(f"job {job.name} changed after it entered the queue; use a new job name instead")
|
|
|
|
record.update(payload)
|
|
if record.get("status") == "unknown":
|
|
record["status"] = "pending"
|
|
|
|
|
|
def pending_jobs_from_queue(queue_state: dict[str, Any], jobs: list[JobSpec]) -> list[JobSpec]:
|
|
state_jobs = queue_state.get("jobs", {})
|
|
pending: list[JobSpec] = []
|
|
for job in jobs:
|
|
record = state_jobs.get(job.name, {})
|
|
if record.get("status", "pending") == "pending":
|
|
pending.append(job)
|
|
return pending
|
|
|
|
|
|
def state_subdir(config: FleetConfig, *parts: str) -> Path:
|
|
return ensure_dir(config.state_dir.joinpath(*parts))
|
|
|
|
|
|
def run_state_dir(config: FleetConfig) -> Path:
|
|
return state_subdir(config, "runs")
|
|
|
|
|
|
def probe_cache_path(config: FleetConfig) -> Path:
|
|
return state_subdir(config, "probe") / PROBE_CACHE_NAME
|
|
|
|
|
|
def queue_state_path(config: FleetConfig) -> Path:
|
|
return state_subdir(config, "queue") / QUEUE_STATE_NAME
|
|
|
|
|
|
def monitor_lock_path(config: FleetConfig) -> Path:
|
|
return state_subdir(config, "queue") / MONITOR_LOCK_NAME
|
|
|
|
|
|
def load_queue_state(config: FleetConfig) -> dict[str, Any]:
|
|
path = queue_state_path(config)
|
|
if path.exists():
|
|
payload = load_json(path)
|
|
payload.setdefault("jobs", {})
|
|
return payload
|
|
return {
|
|
"generated_at": utc_now(),
|
|
"updated_at": utc_now(),
|
|
"jobs": {},
|
|
}
|
|
|
|
|
|
def save_queue_state(config: FleetConfig, payload: dict[str, Any]) -> None:
|
|
payload["updated_at"] = utc_now()
|
|
dump_json(queue_state_path(config), payload)
|
|
|
|
|
|
def job_signature(job: JobSpec) -> str:
|
|
payload = {
|
|
"name": job.name,
|
|
"command": job.command,
|
|
"gpus": job.gpus,
|
|
"gpu_model": job.gpu_model,
|
|
"hosts": job.hosts,
|
|
"artifacts": job.artifacts,
|
|
"env": job.env,
|
|
}
|
|
raw = json.dumps(payload, sort_keys=True, ensure_ascii=True)
|
|
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def parse_gpu_query(output: str) -> list[dict[str, Any]]:
|
|
rows: list[dict[str, Any]] = []
|
|
reader = csv.reader(line for line in output.splitlines() if line.strip())
|
|
for row in reader:
|
|
if len(row) != 4:
|
|
continue
|
|
index, name, memory_total, memory_used = [cell.strip() for cell in row]
|
|
rows.append(
|
|
{
|
|
"index": int(index),
|
|
"name": name,
|
|
"memory_total_mb": int(memory_total),
|
|
"memory_used_mb": int(memory_used),
|
|
}
|
|
)
|
|
return rows
|
|
|
|
|
|
def parse_gpu_status(output: str) -> list[dict[str, Any]]:
|
|
rows: list[dict[str, Any]] = []
|
|
reader = csv.reader(line for line in output.splitlines() if line.strip())
|
|
for row in reader:
|
|
if len(row) != 5:
|
|
continue
|
|
index, name, memory_total, memory_used, util = [cell.strip() for cell in row]
|
|
rows.append(
|
|
{
|
|
"index": int(index),
|
|
"name": name,
|
|
"memory_total_mb": int(memory_total),
|
|
"memory_used_mb": int(memory_used),
|
|
"utilization_gpu_pct": int(util),
|
|
}
|
|
)
|
|
return rows
|
|
|
|
|
|
def probe_single_host(config: FleetConfig, host: HostSpec) -> dict[str, Any]:
|
|
gpu_cmd = textwrap.dedent(
|
|
"""
|
|
bash -lc '
|
|
set -euo pipefail
|
|
nvidia-smi --query-gpu=index,name,memory.total,memory.used --format=csv,noheader,nounits
|
|
'
|
|
"""
|
|
).strip()
|
|
completed = run_ssh(config, host, gpu_cmd, capture_output=True, check=False)
|
|
payload: dict[str, Any] = {
|
|
"host": host.name,
|
|
"ssh_alias": host.ssh_alias,
|
|
"enabled": host.enabled,
|
|
"probed_at": utc_now(),
|
|
}
|
|
if completed.returncode != 0:
|
|
payload["status"] = "unreachable"
|
|
payload["error"] = (completed.stderr or completed.stdout).strip()
|
|
return payload
|
|
|
|
gpus = parse_gpu_query(completed.stdout)
|
|
names = sorted({gpu["name"] for gpu in gpus})
|
|
payload["status"] = "ok"
|
|
payload["gpus"] = gpus
|
|
payload["gpu_count"] = len(gpus)
|
|
payload["gpu_models"] = names
|
|
payload["gpu_summary"] = ", ".join(f"{sum(1 for gpu in gpus if gpu['name'] == name)}x {name}" for name in names)
|
|
return payload
|
|
|
|
|
|
def load_probe_cache(config: FleetConfig) -> dict[str, Any]:
|
|
path = probe_cache_path(config)
|
|
if not path.exists():
|
|
raise FleetError(
|
|
f"probe cache is missing: {path}. Run 'python3 infra/gpu_fleet/gpu_fleet.py probe --config {config.config_path}' first."
|
|
)
|
|
return load_json(path)
|
|
|
|
|
|
def query_host_gpu_status(config: FleetConfig, host: HostSpec) -> list[dict[str, Any]]:
|
|
cmd = textwrap.dedent(
|
|
"""
|
|
bash -lc '
|
|
set -euo pipefail
|
|
nvidia-smi --query-gpu=index,name,memory.total,memory.used,utilization.gpu --format=csv,noheader,nounits
|
|
'
|
|
"""
|
|
).strip()
|
|
completed = run_ssh(config, host, cmd, capture_output=True, check=True)
|
|
return parse_gpu_status(completed.stdout)
|
|
|
|
|
|
def free_gpu_indices(config: FleetConfig, host: HostSpec) -> list[int]:
|
|
status = query_host_gpu_status(config, host)
|
|
free: list[int] = []
|
|
for gpu in status:
|
|
if gpu["memory_used_mb"] <= config.scheduler.gpu_free_memory_mb and gpu["utilization_gpu_pct"] <= config.scheduler.gpu_free_utilization_pct:
|
|
free.append(gpu["index"])
|
|
return free
|
|
|
|
|
|
def matches_gpu_model(host_probe: dict[str, Any], host: HostSpec, requested: str | None) -> bool:
|
|
if not requested:
|
|
return True
|
|
needle = requested.lower()
|
|
for name in host_probe.get("gpu_models", []):
|
|
if needle in name.lower():
|
|
return True
|
|
summary = str(host_probe.get("gpu_summary", ""))
|
|
return needle in summary.lower()
|
|
|
|
|
|
def matches_host_constraint(host: HostSpec, requested_hosts: list[str]) -> bool:
|
|
if not requested_hosts:
|
|
return True
|
|
requested = {item.strip() for item in requested_hosts if item.strip()}
|
|
return host.name in requested or host.ssh_alias in requested
|
|
|
|
|
|
def build_payload(job: JobSpec, host: HostSpec, run_id: str, remote_run_dir: str, gpu_ids: list[int]) -> str:
|
|
env_exports = "\n".join(
|
|
f"export {key}={shlex.quote(value)}"
|
|
for key, value in sorted(job.env.items())
|
|
)
|
|
log_file = f"{remote_run_dir}/stdout.log"
|
|
exit_file = f"{remote_run_dir}/exit_code"
|
|
started_file = f"{remote_run_dir}/started_at"
|
|
finished_file = f"{remote_run_dir}/finished_at"
|
|
payload = textwrap.dedent(
|
|
f"""
|
|
set -euo pipefail
|
|
mkdir -p {shlex.quote(remote_run_dir)}
|
|
exec > >(tee -a {shlex.quote(log_file)}) 2>&1
|
|
printf '%s\\n' "$(date -Iseconds)" > {shlex.quote(started_file)}
|
|
on_exit() {{
|
|
status=$?
|
|
printf '%s\\n' "$status" > {shlex.quote(exit_file)}
|
|
printf '%s\\n' "$(date -Iseconds)" > {shlex.quote(finished_file)}
|
|
}}
|
|
trap on_exit EXIT
|
|
cd {shlex.quote(host.sync_remote_path)}
|
|
export CUDA_VISIBLE_DEVICES={shlex.quote(','.join(str(idx) for idx in gpu_ids))}
|
|
export AITUNER_RUN_ID={shlex.quote(run_id)}
|
|
export AITUNER_REMOTE_RUN_DIR={shlex.quote(remote_run_dir)}
|
|
export AITUNER_GPU_COUNT={shlex.quote(str(job.gpus))}
|
|
{env_exports}
|
|
{job.command}
|
|
"""
|
|
).strip()
|
|
return payload
|
|
|
|
|
|
def write_run_manifest(config: FleetConfig, manifest: dict[str, Any]) -> Path:
|
|
path = run_state_dir(config) / f"{manifest['run_id']}.json"
|
|
dump_json(path, manifest)
|
|
return path
|
|
|
|
|
|
def run_manifest_path(config: FleetConfig, run_id: str) -> Path:
|
|
return run_state_dir(config) / f"{run_id}.json"
|
|
|
|
|
|
def load_run_manifests(config: FleetConfig) -> list[dict[str, Any]]:
|
|
runs_dir = run_state_dir(config)
|
|
manifests: list[dict[str, Any]] = []
|
|
for path in sorted(runs_dir.glob("*.json")):
|
|
manifests.append(load_json(path))
|
|
return manifests
|
|
|
|
|
|
def launch_job(config: FleetConfig, host: HostSpec, job: JobSpec, gpu_ids: list[int]) -> dict[str, Any]:
|
|
sync_remote_path = host.sync_remote_path.rstrip("/")
|
|
fleet_root = host.fleet_root.rstrip("/")
|
|
run_id = f"{job.name}-{datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%fZ')}"
|
|
session_name = f"aituner_{run_id.replace('.', '_')}"
|
|
remote_run_dir = f"{fleet_root}/runs/{run_id}"
|
|
payload = build_payload(job, host, run_id, remote_run_dir, gpu_ids)
|
|
remote_launch = (
|
|
f"mkdir -p {shlex.quote(remote_run_dir)} "
|
|
f"&& tmux new-session -d -s {shlex.quote(session_name)} bash -lc {shlex.quote(payload)}"
|
|
)
|
|
run_ssh(config, host, remote_launch, capture_output=True, check=True)
|
|
manifest = {
|
|
"run_id": run_id,
|
|
"job_name": job.name,
|
|
"host": host.name,
|
|
"ssh_alias": host.ssh_alias,
|
|
"session_name": session_name,
|
|
"gpu_ids": gpu_ids,
|
|
"gpu_model": job.gpu_model,
|
|
"gpus": job.gpus,
|
|
"command": job.command,
|
|
"env": job.env,
|
|
"artifacts": job.artifacts,
|
|
"remote_sync_path": sync_remote_path,
|
|
"remote_run_dir": remote_run_dir,
|
|
"status": "running",
|
|
"submitted_at": utc_now(),
|
|
"harvested_at": None,
|
|
}
|
|
write_run_manifest(config, manifest)
|
|
return manifest
|
|
|
|
|
|
def refresh_manifest_status(config: FleetConfig, manifest: dict[str, Any]) -> dict[str, Any]:
|
|
host = config.hosts[manifest["host"]]
|
|
remote_cmd = textwrap.dedent(
|
|
f"""
|
|
bash -lc '
|
|
if [ -f {shlex.quote(manifest["remote_run_dir"] + "/exit_code")} ]; then
|
|
code=$(cat {shlex.quote(manifest["remote_run_dir"] + "/exit_code")})
|
|
printf "done:%s\\n" "$code"
|
|
elif tmux has-session -t {shlex.quote(manifest["session_name"])} 2>/dev/null; then
|
|
printf "running\\n"
|
|
else
|
|
printf "unknown\\n"
|
|
fi
|
|
'
|
|
"""
|
|
).strip()
|
|
completed = run_ssh(config, host, remote_cmd, capture_output=True, check=False)
|
|
status_line = (completed.stdout or completed.stderr).strip()
|
|
new_manifest = dict(manifest)
|
|
if status_line.startswith("done:"):
|
|
code = int(status_line.split(":", 1)[1])
|
|
new_manifest["status"] = "completed" if code == 0 else "failed"
|
|
new_manifest["exit_code"] = code
|
|
elif status_line == "running":
|
|
new_manifest["status"] = "running"
|
|
else:
|
|
new_manifest["status"] = "unknown"
|
|
write_run_manifest(config, new_manifest)
|
|
return new_manifest
|
|
|
|
|
|
def harvest_run(config: FleetConfig, manifest: dict[str, Any]) -> dict[str, Any]:
|
|
host = config.hosts[manifest["host"]]
|
|
refreshed = refresh_manifest_status(config, manifest)
|
|
if refreshed["status"] not in {"completed", "failed"}:
|
|
return refreshed
|
|
|
|
local_base = ensure_dir(config.artifacts_dir / refreshed["run_id"])
|
|
remote_run_dir = refreshed["remote_run_dir"].rstrip("/")
|
|
rsync_pull(config, host, f"{remote_run_dir}/", local_base / "remote_run")
|
|
|
|
for artifact in refreshed.get("artifacts", []):
|
|
artifact_remote = f"{refreshed['remote_sync_path'].rstrip('/')}/{artifact}"
|
|
target = local_base / "artifacts" / Path(artifact.lstrip("/"))
|
|
check = run_ssh(
|
|
config,
|
|
host,
|
|
f"test -e {shlex.quote(artifact_remote)}",
|
|
capture_output=True,
|
|
check=False,
|
|
)
|
|
if check.returncode == 0:
|
|
rsync_pull(config, host, artifact_remote, target)
|
|
|
|
refreshed["harvested_at"] = utc_now()
|
|
write_run_manifest(config, refreshed)
|
|
return refreshed
|
|
|
|
|
|
def reconcile_queue_state_with_runs(config: FleetConfig, queue_state: dict[str, Any]) -> None:
|
|
state_jobs = queue_state.get("jobs", {})
|
|
for name, record in state_jobs.items():
|
|
if record.get("status") != "running":
|
|
continue
|
|
run_id = record.get("run_id")
|
|
if not run_id:
|
|
record["status"] = "unknown"
|
|
record["updated_at"] = utc_now()
|
|
continue
|
|
manifest_path = run_manifest_path(config, run_id)
|
|
if not manifest_path.exists():
|
|
record["status"] = "unknown"
|
|
record["updated_at"] = utc_now()
|
|
continue
|
|
|
|
manifest = load_json(manifest_path)
|
|
refreshed = refresh_manifest_status(config, manifest)
|
|
record["updated_at"] = utc_now()
|
|
if refreshed["status"] == "running":
|
|
continue
|
|
if refreshed["status"] in {"completed", "failed"}:
|
|
harvested = harvest_run(config, refreshed)
|
|
record["status"] = harvested["status"]
|
|
record["completed_at"] = utc_now()
|
|
record["exit_code"] = harvested.get("exit_code")
|
|
record["harvested_at"] = harvested.get("harvested_at")
|
|
continue
|
|
record["status"] = "unknown"
|
|
|
|
|
|
def print_probe_table(probe: dict[str, Any]) -> None:
|
|
for name, payload in sorted(probe["hosts"].items()):
|
|
if payload["status"] != "ok":
|
|
print(f"{name:12} status={payload['status']} error={payload.get('error', '')}")
|
|
continue
|
|
print(f"{name:12} status=ok gpus={payload['gpu_summary']}")
|
|
|
|
|
|
class MonitorLock:
|
|
def __init__(self, path: Path):
|
|
self.path = path
|
|
self.handle: Any | None = None
|
|
|
|
def __enter__(self) -> "MonitorLock":
|
|
ensure_dir(self.path.parent)
|
|
self.handle = self.path.open("w", encoding="utf-8")
|
|
try:
|
|
fcntl.flock(self.handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
|
except BlockingIOError as exc:
|
|
raise FleetError(f"monitor lock is busy: {self.path}") from exc
|
|
return self
|
|
|
|
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
|
if self.handle is not None:
|
|
fcntl.flock(self.handle.fileno(), fcntl.LOCK_UN)
|
|
self.handle.close()
|
|
|
|
|
|
def cmd_validate(args: argparse.Namespace) -> int:
|
|
config = load_config(Path(args.config).resolve())
|
|
print(f"config ok: {config.config_path}")
|
|
if args.jobs:
|
|
jobs = load_jobs(Path(args.jobs).resolve())
|
|
validate_jobs_against_config(config, jobs)
|
|
print(f"jobs ok: {len(jobs)} definitions")
|
|
return 0
|
|
|
|
|
|
def cmd_bootstrap_hosts(args: argparse.Namespace) -> int:
|
|
aliases_path = Path(args.aliases_file).expanduser().resolve()
|
|
aliases = [
|
|
line.strip()
|
|
for line in aliases_path.read_text(encoding="utf-8").splitlines()
|
|
if line.strip() and not line.lstrip().startswith("#")
|
|
]
|
|
if not aliases:
|
|
raise FleetError(f"no aliases found in {aliases_path}")
|
|
sections = [
|
|
'version = 1',
|
|
'',
|
|
'[paths]',
|
|
'state_dir = ".aituner/gpu_fleet/state"',
|
|
'artifacts_dir = ".aituner/gpu_fleet/artifacts"',
|
|
'',
|
|
'[sync]',
|
|
'mode = "rsync"',
|
|
'local_path = "."',
|
|
'',
|
|
]
|
|
for alias in aliases:
|
|
sections.extend(
|
|
[
|
|
'[[hosts]]',
|
|
f'name = "{alias}"',
|
|
f'ssh_alias = "{alias}"',
|
|
'enabled = true',
|
|
'sync_remote_path = "~/workspace/aituner"',
|
|
'fleet_root = "~/.aituner_gpu_fleet"',
|
|
'',
|
|
]
|
|
)
|
|
payload = "\n".join(sections)
|
|
if args.output:
|
|
output = Path(args.output).expanduser().resolve()
|
|
ensure_dir(output.parent)
|
|
output.write_text(payload + "\n", encoding="utf-8")
|
|
print(f"wrote skeleton config: {output}")
|
|
else:
|
|
print(payload)
|
|
return 0
|
|
|
|
|
|
def cmd_probe(args: argparse.Namespace) -> int:
|
|
config = load_config(Path(args.config).resolve())
|
|
results = {
|
|
"generated_at": utc_now(),
|
|
"config_path": str(config.config_path),
|
|
"hosts": {},
|
|
}
|
|
for host in config.hosts.values():
|
|
if args.host and host.name != args.host:
|
|
continue
|
|
if not host.enabled and not args.include_disabled:
|
|
continue
|
|
results["hosts"][host.name] = probe_single_host(config, host)
|
|
dump_json(probe_cache_path(config), results)
|
|
print_probe_table(results)
|
|
return 0
|
|
|
|
|
|
def snapshot_free_gpus(config: FleetConfig, probe: dict[str, Any]) -> dict[str, list[int]]:
|
|
availability: dict[str, list[int]] = {}
|
|
for host in config.hosts.values():
|
|
host_probe = probe["hosts"].get(host.name)
|
|
if not host.enabled or not host_probe or host_probe.get("status") != "ok":
|
|
continue
|
|
availability[host.name] = free_gpu_indices(config, host)
|
|
return availability
|
|
|
|
|
|
def select_host_for_job(
|
|
config: FleetConfig,
|
|
probe: dict[str, Any],
|
|
availability: dict[str, list[int]],
|
|
job: JobSpec,
|
|
) -> tuple[HostSpec, list[int]] | None:
|
|
candidates: list[tuple[int, int, str, HostSpec, list[int]]] = []
|
|
for host in config.hosts.values():
|
|
if not host.enabled:
|
|
continue
|
|
host_probe = probe["hosts"].get(host.name)
|
|
if not host_probe or host_probe.get("status") != "ok":
|
|
continue
|
|
if not matches_host_constraint(host, job.hosts):
|
|
continue
|
|
if not matches_gpu_model(host_probe, host, job.gpu_model):
|
|
continue
|
|
free = availability.get(host.name, [])
|
|
if len(free) < job.gpus:
|
|
continue
|
|
chosen = sorted(free)[: job.gpus]
|
|
leftover = len(free) - job.gpus
|
|
exact_model = 0 if job.gpu_model and matches_gpu_model(host_probe, host, job.gpu_model) else 1
|
|
candidates.append((leftover if config.scheduler.prefer_pack else -len(free), exact_model, host.name, host, chosen))
|
|
if not candidates:
|
|
return None
|
|
candidates.sort(key=lambda item: (item[0], item[1], item[2]))
|
|
_, _, _, host, gpu_ids = candidates[0]
|
|
return host, gpu_ids
|
|
|
|
|
|
def dispatch_jobs(
|
|
config: FleetConfig,
|
|
jobs: list[JobSpec],
|
|
*,
|
|
dry_run: bool,
|
|
max_launch: int | None = None,
|
|
) -> tuple[list[dict[str, Any]], list[str]]:
|
|
probe = load_probe_cache(config)
|
|
availability = snapshot_free_gpus(config, probe)
|
|
manifests: list[dict[str, Any]] = []
|
|
skipped: list[str] = []
|
|
synced_hosts: set[str] = set()
|
|
launch_budget = max_launch if max_launch is not None else len(jobs)
|
|
|
|
for job in jobs:
|
|
if len(manifests) >= launch_budget:
|
|
skipped.append(f"{job.name}: launch budget exhausted")
|
|
continue
|
|
selected = select_host_for_job(config, probe, availability, job)
|
|
if not selected:
|
|
skipped.append(f"{job.name}: no host has matching host/model/free GPUs")
|
|
continue
|
|
host, gpu_ids = selected
|
|
availability[host.name] = [gpu for gpu in availability.get(host.name, []) if gpu not in set(gpu_ids)]
|
|
if dry_run:
|
|
manifests.append(
|
|
{
|
|
"job_name": job.name,
|
|
"host": host.name,
|
|
"gpu_ids": gpu_ids,
|
|
"dry_run": True,
|
|
}
|
|
)
|
|
continue
|
|
if host.name not in synced_hosts:
|
|
rsync_push(config, host)
|
|
synced_hosts.add(host.name)
|
|
manifest = launch_job(config, host, job, gpu_ids)
|
|
manifests.append(manifest)
|
|
return manifests, skipped
|
|
|
|
|
|
def cmd_dispatch(args: argparse.Namespace) -> int:
|
|
config = load_config(Path(args.config).resolve())
|
|
jobs = load_jobs(Path(args.jobs).resolve())
|
|
validate_jobs_against_config(config, jobs)
|
|
manifests, skipped = dispatch_jobs(
|
|
config,
|
|
jobs,
|
|
dry_run=args.dry_run,
|
|
max_launch=args.max_launch,
|
|
)
|
|
|
|
for manifest in manifests:
|
|
if manifest.get("dry_run"):
|
|
print(f"[dry-run] {manifest['job_name']} -> {manifest['host']} gpus={manifest['gpu_ids']}")
|
|
continue
|
|
print(
|
|
f"launched {manifest['run_id']} host={manifest['host']} gpus={manifest['gpu_ids']}"
|
|
)
|
|
for line in skipped:
|
|
print(f"skipped {line}")
|
|
return 0
|
|
|
|
|
|
def print_queue_summary(queue_state: dict[str, Any]) -> None:
|
|
counts: dict[str, int] = {}
|
|
for record in queue_state.get("jobs", {}).values():
|
|
status = str(record.get("status", "pending"))
|
|
counts[status] = counts.get(status, 0) + 1
|
|
summary = " ".join(f"{key}={counts[key]}" for key in sorted(counts))
|
|
print(summary or "pending=0")
|
|
|
|
|
|
def print_queue_jobs(queue_state: dict[str, Any]) -> None:
|
|
for name, record in sorted(queue_state.get("jobs", {}).items()):
|
|
host = record.get("host") or "-"
|
|
run_id = record.get("run_id") or "-"
|
|
print(
|
|
f"{name:32} status={record.get('status','pending'):10} gpus={record.get('gpus','-')} host={host:8} run_id={run_id}"
|
|
)
|
|
|
|
|
|
def monitor_cycle(
|
|
config: FleetConfig,
|
|
jobs_path: Path,
|
|
*,
|
|
dry_run: bool,
|
|
max_launch_per_cycle: int | None,
|
|
verbose: bool,
|
|
) -> tuple[int, int]:
|
|
queue_state = load_queue_state(config)
|
|
reconcile_queue_state_with_runs(config, queue_state)
|
|
|
|
jobs: list[JobSpec] = []
|
|
try:
|
|
jobs = load_jobs(jobs_path)
|
|
validate_jobs_against_config(config, jobs)
|
|
sync_jobs_to_queue_state(queue_state, jobs)
|
|
except Exception as exc:
|
|
if verbose:
|
|
print(f"queue read warning: {exc}")
|
|
save_queue_state(config, queue_state)
|
|
return 0, 0
|
|
|
|
pending = pending_jobs_from_queue(queue_state, jobs)
|
|
try:
|
|
manifests, skipped = dispatch_jobs(
|
|
config,
|
|
pending,
|
|
dry_run=dry_run,
|
|
max_launch=max_launch_per_cycle,
|
|
)
|
|
except Exception as exc:
|
|
if verbose:
|
|
print(f"dispatch warning: {exc}")
|
|
save_queue_state(config, queue_state)
|
|
return 0, 0
|
|
|
|
now = utc_now()
|
|
for manifest in manifests:
|
|
record = queue_state["jobs"][manifest["job_name"]]
|
|
if manifest.get("dry_run"):
|
|
continue
|
|
record["status"] = "running"
|
|
record["submitted_at"] = now
|
|
record["updated_at"] = now
|
|
record["run_id"] = manifest["run_id"]
|
|
record["host"] = manifest["host"]
|
|
record["gpu_ids"] = manifest["gpu_ids"]
|
|
record["attempts"] = int(record.get("attempts", 0)) + 1
|
|
|
|
save_queue_state(config, queue_state)
|
|
|
|
if verbose:
|
|
for manifest in manifests:
|
|
if manifest.get("dry_run"):
|
|
print(f"[dry-run] {manifest['job_name']} -> {manifest['host']} gpus={manifest['gpu_ids']}")
|
|
else:
|
|
print(f"launched {manifest['run_id']} host={manifest['host']} gpus={manifest['gpu_ids']}")
|
|
for line in skipped:
|
|
print(f"deferred {line}")
|
|
print_queue_summary(queue_state)
|
|
return len(manifests), len(skipped)
|
|
|
|
|
|
def cmd_monitor(args: argparse.Namespace) -> int:
|
|
config = load_config(Path(args.config).resolve())
|
|
jobs_path = Path(args.jobs).resolve()
|
|
with MonitorLock(monitor_lock_path(config)):
|
|
if args.once:
|
|
monitor_cycle(
|
|
config,
|
|
jobs_path,
|
|
dry_run=args.dry_run,
|
|
max_launch_per_cycle=args.max_launch_per_cycle,
|
|
verbose=True,
|
|
)
|
|
return 0
|
|
|
|
print(f"monitoring {jobs_path} every {args.interval_sec}s")
|
|
try:
|
|
while True:
|
|
monitor_cycle(
|
|
config,
|
|
jobs_path,
|
|
dry_run=args.dry_run,
|
|
max_launch_per_cycle=args.max_launch_per_cycle,
|
|
verbose=not args.quiet,
|
|
)
|
|
time.sleep(args.interval_sec)
|
|
except KeyboardInterrupt:
|
|
print("monitor stopped")
|
|
return 0
|
|
|
|
|
|
def cmd_queue_status(args: argparse.Namespace) -> int:
|
|
config = load_config(Path(args.config).resolve())
|
|
queue_state = load_queue_state(config)
|
|
print_queue_summary(queue_state)
|
|
print_queue_jobs(queue_state)
|
|
return 0
|
|
|
|
|
|
def cmd_status(args: argparse.Namespace) -> int:
|
|
config = load_config(Path(args.config).resolve())
|
|
manifests = load_run_manifests(config)
|
|
if not manifests:
|
|
print("no runs recorded")
|
|
return 0
|
|
for manifest in manifests:
|
|
refreshed = refresh_manifest_status(config, manifest)
|
|
exit_code = refreshed.get("exit_code")
|
|
suffix = f" exit_code={exit_code}" if exit_code is not None else ""
|
|
print(
|
|
f"{refreshed['run_id']} job={refreshed['job_name']} host={refreshed['host']} status={refreshed['status']}{suffix}"
|
|
)
|
|
return 0
|
|
|
|
|
|
def cmd_harvest(args: argparse.Namespace) -> int:
|
|
config = load_config(Path(args.config).resolve())
|
|
manifests = load_run_manifests(config)
|
|
if args.run_id:
|
|
manifests = [manifest for manifest in manifests if manifest["run_id"] == args.run_id]
|
|
harvested_any = False
|
|
for manifest in manifests:
|
|
refreshed = harvest_run(config, manifest)
|
|
if refreshed.get("harvested_at"):
|
|
harvested_any = True
|
|
print(
|
|
f"harvested {refreshed['run_id']} status={refreshed['status']} into {config.artifacts_dir / refreshed['run_id']}"
|
|
)
|
|
else:
|
|
print(f"pending {refreshed['run_id']} status={refreshed['status']}")
|
|
if not harvested_any and not manifests:
|
|
print("no matching runs")
|
|
return 0
|
|
|
|
|
|
def cmd_show_hosts(args: argparse.Namespace) -> int:
|
|
config = load_config(Path(args.config).resolve())
|
|
probe = load_probe_cache(config)
|
|
for host in config.hosts.values():
|
|
host_probe = probe["hosts"].get(host.name, {})
|
|
summary = host_probe.get("gpu_summary", "unprobed")
|
|
print(f"{host.name:12} alias={host.ssh_alias:12} gpus={summary:20}")
|
|
return 0
|
|
|
|
|
|
def build_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(description="Dispatch heterogeneous GPU jobs over SSH/tmux.")
|
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
|
|
validate = subparsers.add_parser("validate", help="Validate fleet and optional job configs.")
|
|
validate.add_argument("--config", default=str(DEFAULT_CONFIG))
|
|
validate.add_argument("--jobs")
|
|
validate.set_defaults(func=cmd_validate)
|
|
|
|
bootstrap = subparsers.add_parser("bootstrap-hosts", help="Generate a fleet config skeleton from a plain alias list.")
|
|
bootstrap.add_argument("--aliases-file", required=True)
|
|
bootstrap.add_argument("--output")
|
|
bootstrap.set_defaults(func=cmd_bootstrap_hosts)
|
|
|
|
probe = subparsers.add_parser("probe", help="Probe remote hosts for GPU information.")
|
|
probe.add_argument("--config", default=str(DEFAULT_CONFIG))
|
|
probe.add_argument("--host")
|
|
probe.add_argument("--include-disabled", action="store_true")
|
|
probe.set_defaults(func=cmd_probe)
|
|
|
|
dispatch = subparsers.add_parser("dispatch", help="Dispatch a batch of jobs to matching hosts.")
|
|
dispatch.add_argument("--config", default=str(DEFAULT_CONFIG))
|
|
dispatch.add_argument("--jobs", required=True)
|
|
dispatch.add_argument("--dry-run", action="store_true")
|
|
dispatch.add_argument("--max-launch", type=int)
|
|
dispatch.set_defaults(func=cmd_dispatch)
|
|
|
|
monitor = subparsers.add_parser("monitor", help="Continuously consume jobs.toml as a local queue.")
|
|
monitor.add_argument("--config", default=str(DEFAULT_CONFIG))
|
|
monitor.add_argument("--jobs", required=True)
|
|
monitor.add_argument("--interval-sec", type=int, default=15)
|
|
monitor.add_argument("--max-launch-per-cycle", type=int)
|
|
monitor.add_argument("--dry-run", action="store_true")
|
|
monitor.add_argument("--once", action="store_true")
|
|
monitor.add_argument("--quiet", action="store_true")
|
|
monitor.set_defaults(func=cmd_monitor)
|
|
|
|
status = subparsers.add_parser("status", help="Refresh and print known run status.")
|
|
status.add_argument("--config", default=str(DEFAULT_CONFIG))
|
|
status.set_defaults(func=cmd_status)
|
|
|
|
harvest = subparsers.add_parser("harvest", help="Pull finished run logs and artifacts back to local state.")
|
|
harvest.add_argument("--config", default=str(DEFAULT_CONFIG))
|
|
harvest.add_argument("--run-id")
|
|
harvest.set_defaults(func=cmd_harvest)
|
|
|
|
show_hosts = subparsers.add_parser("show-hosts", help="Show configured hosts together with the latest probe summary.")
|
|
show_hosts.add_argument("--config", default=str(DEFAULT_CONFIG))
|
|
show_hosts.set_defaults(func=cmd_show_hosts)
|
|
|
|
queue_status = subparsers.add_parser("queue-status", help="Show durable queue state tracked by the monitor.")
|
|
queue_status.add_argument("--config", default=str(DEFAULT_CONFIG))
|
|
queue_status.set_defaults(func=cmd_queue_status)
|
|
|
|
return parser
|
|
|
|
|
|
def main(argv: list[str] | None = None) -> int:
|
|
parser = build_parser()
|
|
args = parser.parse_args(argv)
|
|
try:
|
|
return args.func(args)
|
|
except FleetError as exc:
|
|
print(f"error: {exc}", file=sys.stderr)
|
|
return 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|