commit cdcca1d9d7c72edc98e82ab04f4f5fcdefb3d8b5 Author: gahow Date: Sat Apr 4 21:26:37 2026 +0800 Initial AITuner study orchestrator diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..81a4b55 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.aituner/ +__pycache__/ +*.pyc +infra/gpu_fleet/config/fleet.toml +infra/gpu_fleet/config/jobs.toml diff --git a/configs/examples/capability.example.json b/configs/examples/capability.example.json new file mode 100644 index 0000000..769e931 --- /dev/null +++ b/configs/examples/capability.example.json @@ -0,0 +1,14 @@ +{ + "prefill_service_by_bucket": { + "4k": { + "tp4_ms": 320, + "tp8_ms": 240 + } + }, + "queueing_knee_by_bucket": { + "4k": { + "tp4_tok_s_per_gpu": 1000, + "tp8_tok_s_per_gpu": 1100 + } + } +} diff --git a/configs/examples/study.example.json b/configs/examples/study.example.json new file mode 100644 index 0000000..2b0b243 --- /dev/null +++ b/configs/examples/study.example.json @@ -0,0 +1,96 @@ +{ + "study_id": "example-chat-window", + "hardware": { + "gpu_count": 8, + "gpu_model": "H20", + "host_candidates": ["dash0", "dash1"] + }, + "model": { + "model_id": "qwen3-30b", + "served_model_name": "Qwen/Qwen3-30B-A3B-Instruct-2507" + }, + "engine": { + "engine_name": "vllm", + "engine_version": "0.x", + "exec_path": "/usr/local/bin/vllm", + "cwd": ".", + "host": "127.0.0.1", + "port": 8000, + "healthcheck_path": "/v1/models", + "ready_timeout_s": 600, + "request_timeout_s": 600, + "launch_args": [ + "serve", + "/path/to/model" + ], + "base_envs": {}, + "base_flags": { + "host": "127.0.0.1", + "port": 8000, + "served-model-name": "Qwen/Qwen3-30B-A3B-Instruct-2507" + }, + "tunable_envs": [ + "VLLM_ATTENTION_BACKEND", + "CUDA_GRAPH_MAX_BATCH_SIZE" + ], + "tunable_flags": [ + "tensor-parallel-size", + "data-parallel-size", + "pipeline-parallel-size", + "max-num-seqs", + "max-num-batched-tokens", + "gpu-memory-utilization", + "enable-prefix-caching", + "block-size" + ], + "python_executable": "python3" + }, + "trace": { + "windows_path": "trace_windows/windows.json", + "window_id": "chat_w_example_peak_0001", + "u_field": "sampling_u", + "timestamp_field": "timestamp", + "max_concurrency": 64 + }, + "slo": { + "target_pass_rate": 0.95, + "ttft_rule": { + "kind": "step_ms", + "buckets": [ + { + "max_input_tokens": 4096, + "threshold_ms": 2000 + }, + { + "max_input_tokens": 16384, + "threshold_ms": 4000 + }, + { + "threshold_ms": 8000 + } + ] + }, + "tpot_rule": { + "kind": "fixed_ms", + "threshold_ms": 120 + } + }, + "search": { + "low": 0.0, + "high": 1.0, + "tolerance": 0.01, + "max_probes": 8, + "sample_seed": 20260325 + }, + "llm": { + "system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target.", + "max_history_trials": 8, + "endpoint": { + "base_url": "https://example-openai-compatible-endpoint", + "model": "gpt-4.1-mini", + "api_key_env": "OPENAI_API_KEY", + "timeout_s": 120 + } + }, + "capability_profile_path": "capability.example.json" +} diff --git a/configs/examples/trace_windows/traces/chat_w_example_peak_0001.jsonl b/configs/examples/trace_windows/traces/chat_w_example_peak_0001.jsonl new file mode 100644 index 0000000..7730547 --- /dev/null +++ b/configs/examples/trace_windows/traces/chat_w_example_peak_0001.jsonl @@ -0,0 +1,3 @@ +{"request_id":"example-1","timestamp":0.0,"sampling_u":0.10,"messages":[{"role":"user","content":"hello"}],"input_length":512,"output_length":16} +{"request_id":"example-2","timestamp":1.0,"sampling_u":0.50,"messages":[{"role":"user","content":"summarize this file"}],"input_length":2048,"output_length":64} +{"request_id":"example-3","timestamp":2.5,"sampling_u":0.90,"messages":[{"role":"user","content":"write a longer answer"}],"input_length":8192,"output_length":128} diff --git a/configs/examples/trace_windows/windows.json b/configs/examples/trace_windows/windows.json new file mode 100644 index 0000000..44398c7 --- /dev/null +++ b/configs/examples/trace_windows/windows.json @@ -0,0 +1,15 @@ +{ + "sample_seed": 20260325, + "u_field": "sampling_u", + "window_duration_seconds": 10.0, + "windows": [ + { + "window_id": "chat_w_example_peak_0001", + "trace_type": "chat", + "trace_file": "traces/chat_w_example_peak_0001.jsonl", + "window_start": 0.0, + "window_end": 10.0, + "num_requests": 3 + } + ] +} diff --git a/infra/gpu_fleet/config/fleet.example.toml b/infra/gpu_fleet/config/fleet.example.toml new file mode 100644 index 0000000..ad41108 --- /dev/null +++ b/infra/gpu_fleet/config/fleet.example.toml @@ -0,0 +1,59 @@ +version = 1 + +[paths] +state_dir = ".aituner/gpu_fleet/state" +artifacts_dir = ".aituner/gpu_fleet/artifacts" + +[ssh] +connect_timeout_sec = 10 + +[scheduler] +gpu_free_memory_mb = 1024 +gpu_free_utilization_pct = 10 +prefer_pack = true + +[sync] +mode = "rsync" +local_path = "." +exclude = [ + ".git/", + ".venv/", + ".aituner/", + "__pycache__/", + "*.pyc", +] + +[[hosts]] +name = "dash0" +ssh_alias = "dash0" +enabled = true +sync_remote_path = "~/workspace/aituner" +fleet_root = "~/.aituner_gpu_fleet" + +[[hosts]] +name = "dash1" +ssh_alias = "dash1" +enabled = true +sync_remote_path = "~/workspace/aituner" +fleet_root = "~/.aituner_gpu_fleet" + +[[hosts]] +name = "dash2" +ssh_alias = "dash2" +enabled = true +sync_remote_path = "~/workspace/aituner" +fleet_root = "~/.aituner_gpu_fleet" + +[[hosts]] +name = "dash3" +ssh_alias = "dash3" +enabled = true +sync_remote_path = "~/aituner" +fleet_root = "~/.aituner_gpu_fleet" + +[[hosts]] +name = "dash5" +ssh_alias = "dash5" +enabled = true +sync_remote_path = "~/workspace/aituner" +fleet_root = "~/.aituner_gpu_fleet" diff --git a/infra/gpu_fleet/config/jobs.example.toml b/infra/gpu_fleet/config/jobs.example.toml new file mode 100644 index 0000000..589ab66 --- /dev/null +++ b/infra/gpu_fleet/config/jobs.example.toml @@ -0,0 +1,27 @@ + # This file is an append-only queue source for the monitor. + # Each job name must stay unique and immutable once appended. +version = 1 + +[[jobs]] +name = "smoke-train-h20-1gpu" +gpus = 1 +gpu_model = "H20" +hosts = ["dash0", "dash1", "dash2"] +command = "python train.py --config configs/smoke.toml" +artifacts = ["outputs/smoke-train-h20-1gpu"] +env = { WANDB_MODE = "offline" } + +[[jobs]] +name = "eval-5090-4gpu" +gpus = 4 +gpu_model = "5090" +hosts = ["dash5"] +command = "python eval.py --config configs/eval.toml" +artifacts = ["outputs/eval-5090-4gpu", "logs/eval-5090-4gpu.log"] + +[[jobs]] +name = "special-dash3-run" +gpus = 2 +hosts = ["dash3"] +command = "python benchmark.py --suite long-context" +artifacts = ["outputs/special-dash3-run"] diff --git a/infra/gpu_fleet/config/ssh_aliases.example.txt b/infra/gpu_fleet/config/ssh_aliases.example.txt new file mode 100644 index 0000000..9c851ad --- /dev/null +++ b/infra/gpu_fleet/config/ssh_aliases.example.txt @@ -0,0 +1,8 @@ +# One SSH alias per line. +# Lines starting with "#" are ignored. +dash0 +dash1 +dash2 +dash3 +dash5 + diff --git a/infra/gpu_fleet/gpu_fleet.py b/infra/gpu_fleet/gpu_fleet.py new file mode 100755 index 0000000..854acc3 --- /dev/null +++ b/infra/gpu_fleet/gpu_fleet.py @@ -0,0 +1,1132 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import csv +import fcntl +import hashlib +import json +import shlex +import subprocess +import sys +import textwrap +import time +import tomllib +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + + +SCRIPT_PATH = Path(__file__).resolve() +REPO_ROOT = SCRIPT_PATH.parents[2] +DEFAULT_CONFIG = SCRIPT_PATH.parent / "config" / "fleet.toml" +DEFAULT_STATE_REL = Path(".aituner/gpu_fleet/state") +DEFAULT_ARTIFACT_REL = Path(".aituner/gpu_fleet/artifacts") +PROBE_CACHE_NAME = "probe_cache.json" +QUEUE_STATE_NAME = "jobs_state.json" +MONITOR_LOCK_NAME = "monitor.lock" + + +class FleetError(RuntimeError): + pass + + +@dataclass +class HostSpec: + name: str + ssh_alias: str + enabled: bool = True + sync_remote_path: str = "~/workspace/aituner" + fleet_root: str = "~/.aituner_gpu_fleet" + + +@dataclass +class SyncSpec: + mode: str = "rsync" + local_path: Path = REPO_ROOT + exclude: list[str] = field(default_factory=list) + + +@dataclass +class SchedulerSpec: + gpu_free_memory_mb: int = 1024 + gpu_free_utilization_pct: int = 10 + prefer_pack: bool = True + + +@dataclass +class FleetConfig: + config_path: Path + project_root: Path + state_dir: Path + artifacts_dir: Path + ssh_timeout_sec: int + scheduler: SchedulerSpec + sync: SyncSpec + hosts: dict[str, HostSpec] + + +@dataclass +class JobSpec: + name: str + command: str + gpus: int + gpu_model: str | None = None + hosts: list[str] = field(default_factory=list) + artifacts: list[str] = field(default_factory=list) + env: dict[str, str] = field(default_factory=dict) + + +def utc_now() -> str: + return datetime.now(timezone.utc).isoformat(timespec="seconds") + + +def relative_to_root(root: Path, value: str | None, default: Path) -> Path: + if not value: + return default + candidate = Path(value).expanduser() + if candidate.is_absolute(): + return candidate + return (root / candidate).resolve() + + +def load_toml(path: Path) -> dict[str, Any]: + with path.open("rb") as fh: + return tomllib.load(fh) + + +def ensure_dir(path: Path) -> Path: + path.mkdir(parents=True, exist_ok=True) + return path + + +def dump_json(path: Path, payload: Any) -> None: + ensure_dir(path.parent) + path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + +def load_json(path: Path) -> Any: + return json.loads(path.read_text(encoding="utf-8")) + + +def render_command(template: str, **kwargs: str) -> str: + try: + return template.format(**kwargs) + except KeyError as exc: + raise FleetError(f"template is missing placeholder {exc!s}: {template}") from exc + + +def run_local( + argv: list[str], + *, + cwd: Path | None = None, + capture_output: bool = True, + check: bool = True, +) -> subprocess.CompletedProcess[str]: + completed = subprocess.run( + argv, + cwd=str(cwd) if cwd else None, + check=False, + text=True, + capture_output=capture_output, + ) + if check and completed.returncode != 0: + stderr = completed.stderr.strip() + stdout = completed.stdout.strip() + detail = stderr or stdout or f"exit code {completed.returncode}" + raise FleetError(f"command failed: {' '.join(shlex.quote(x) for x in argv)}\n{detail}") + return completed + + +def ssh_base(host: HostSpec, timeout_sec: int) -> list[str]: + return [ + "ssh", + "-o", + f"ConnectTimeout={timeout_sec}", + host.ssh_alias, + ] + + +def run_ssh( + config: FleetConfig, + host: HostSpec, + remote_command: str, + *, + capture_output: bool = True, + check: bool = True, +) -> subprocess.CompletedProcess[str]: + argv = ssh_base(host, config.ssh_timeout_sec) + [remote_command] + return run_local(argv, cwd=config.project_root, capture_output=capture_output, check=check) + + +def rsync_push(config: FleetConfig, host: HostSpec) -> None: + ensure_remote_dir(config, host, host.sync_remote_path) + argv = ["rsync", "-az"] + for pattern in config.sync.exclude: + argv.extend(["--exclude", pattern]) + local_src = str(config.sync.local_path.resolve()) + "/" + remote_dst = f"{host.ssh_alias}:{host.sync_remote_path.rstrip('/')}/" + argv.extend([local_src, remote_dst]) + run_local(argv, cwd=config.project_root, capture_output=True, check=True) + + +def rsync_pull(config: FleetConfig, host: HostSpec, remote_path: str, local_path: Path) -> None: + ensure_dir(local_path.parent) + argv = [ + "rsync", + "-az", + f"{host.ssh_alias}:{remote_path}", + str(local_path), + ] + run_local(argv, cwd=config.project_root, capture_output=True, check=True) + + +def ensure_remote_dir(config: FleetConfig, host: HostSpec, remote_path: str) -> None: + run_ssh(config, host, f"mkdir -p {shlex.quote(remote_path)}", capture_output=True, check=True) + + +def load_config(path: Path) -> FleetConfig: + raw = load_toml(path) + if int(raw.get("version", 1)) != 1: + raise FleetError("only config version 1 is supported") + + project_root = REPO_ROOT + paths_raw = raw.get("paths", {}) + state_dir = relative_to_root(project_root, paths_raw.get("state_dir"), project_root / DEFAULT_STATE_REL) + artifacts_dir = relative_to_root( + project_root, + paths_raw.get("artifacts_dir"), + project_root / DEFAULT_ARTIFACT_REL, + ) + + sync_raw = raw.get("sync", {}) + sync = SyncSpec( + mode=str(sync_raw.get("mode", "rsync")), + local_path=relative_to_root(project_root, sync_raw.get("local_path"), project_root), + exclude=[str(item) for item in sync_raw.get("exclude", [])], + ) + if sync.mode != "rsync": + raise FleetError(f"unsupported sync.mode: {sync.mode}") + + scheduler_raw = raw.get("scheduler", {}) + scheduler = SchedulerSpec( + gpu_free_memory_mb=int(scheduler_raw.get("gpu_free_memory_mb", 1024)), + gpu_free_utilization_pct=int(scheduler_raw.get("gpu_free_utilization_pct", 10)), + prefer_pack=bool(scheduler_raw.get("prefer_pack", True)), + ) + + ssh_raw = raw.get("ssh", {}) + hosts: dict[str, HostSpec] = {} + for raw_host in raw.get("hosts", []): + name = str(raw_host.get("name", "")).strip() + ssh_alias = str(raw_host.get("ssh_alias", "")).strip() + if not name or not ssh_alias: + raise FleetError("every host must define both name and ssh_alias") + host = HostSpec( + name=name, + ssh_alias=ssh_alias, + enabled=bool(raw_host.get("enabled", True)), + sync_remote_path=str(raw_host.get("sync_remote_path", "~/workspace/aituner")), + fleet_root=str(raw_host.get("fleet_root", "~/.aituner_gpu_fleet")), + ) + hosts[name] = host + + if not hosts: + raise FleetError("no hosts are configured") + + return FleetConfig( + config_path=path, + project_root=project_root, + state_dir=state_dir, + artifacts_dir=artifacts_dir, + ssh_timeout_sec=int(ssh_raw.get("connect_timeout_sec", 10)), + scheduler=scheduler, + sync=sync, + hosts=hosts, + ) + + +def load_jobs(path: Path) -> list[JobSpec]: + raw = load_toml(path) + if int(raw.get("version", 1)) != 1: + raise FleetError("only jobs version 1 is supported") + jobs: list[JobSpec] = [] + seen_names: set[str] = set() + for item in raw.get("jobs", []): + name = str(item.get("name", "")).strip() + if not name: + raise FleetError("every job must define name") + if name in seen_names: + raise FleetError(f"duplicated job name: {name}") + command = str(item.get("command", "")).strip() + if not command: + raise FleetError(f"job {name} must define command") + gpus = int(item.get("gpus", 0)) + if gpus <= 0: + raise FleetError(f"job {name} must request a positive gpu count") + jobs.append( + JobSpec( + name=name, + command=command, + gpus=gpus, + gpu_model=str(item["gpu_model"]) if "gpu_model" in item else None, + hosts=[str(entry) for entry in item.get("hosts", [])], + artifacts=[str(entry) for entry in item.get("artifacts", [])], + env={str(k): str(v) for k, v in item.get("env", {}).items()}, + ) + ) + seen_names.add(name) + if not jobs: + raise FleetError("no jobs are defined") + return jobs + + +def validate_jobs_against_config(config: FleetConfig, jobs: list[JobSpec]) -> None: + for job in jobs: + if job.hosts: + unknown_hosts = sorted( + requested for requested in job.hosts if requested not in config.hosts and requested not in {host.ssh_alias for host in config.hosts.values()} + ) + if unknown_hosts: + raise FleetError(f"job {job.name} references unknown hosts: {', '.join(unknown_hosts)}") + + candidate_hosts = [ + host for host in config.hosts.values() + if matches_host_constraint(host, job.hosts) + ] + if not candidate_hosts: + raise FleetError(f"job {job.name} has no candidate hosts after host filtering") + + +def sync_jobs_to_queue_state(queue_state: dict[str, Any], jobs: list[JobSpec]) -> None: + now = utc_now() + state_jobs = queue_state.setdefault("jobs", {}) + for job in jobs: + signature = job_signature(job) + record = state_jobs.get(job.name) + payload = { + "name": job.name, + "signature": signature, + "command": job.command, + "gpus": job.gpus, + "gpu_model": job.gpu_model, + "hosts": job.hosts, + "artifacts": job.artifacts, + "env": job.env, + "last_seen_at": now, + } + if record is None: + state_jobs[job.name] = { + **payload, + "status": "pending", + "created_at": now, + "submitted_at": None, + "completed_at": None, + "run_id": None, + "attempts": 0, + } + continue + + if record.get("signature") != signature and record.get("status") != "pending": + raise FleetError(f"job {job.name} changed after it entered the queue; use a new job name instead") + + record.update(payload) + if record.get("status") == "unknown": + record["status"] = "pending" + + +def pending_jobs_from_queue(queue_state: dict[str, Any], jobs: list[JobSpec]) -> list[JobSpec]: + state_jobs = queue_state.get("jobs", {}) + pending: list[JobSpec] = [] + for job in jobs: + record = state_jobs.get(job.name, {}) + if record.get("status", "pending") == "pending": + pending.append(job) + return pending + + +def state_subdir(config: FleetConfig, *parts: str) -> Path: + return ensure_dir(config.state_dir.joinpath(*parts)) + + +def run_state_dir(config: FleetConfig) -> Path: + return state_subdir(config, "runs") + + +def probe_cache_path(config: FleetConfig) -> Path: + return state_subdir(config, "probe") / PROBE_CACHE_NAME + + +def queue_state_path(config: FleetConfig) -> Path: + return state_subdir(config, "queue") / QUEUE_STATE_NAME + + +def monitor_lock_path(config: FleetConfig) -> Path: + return state_subdir(config, "queue") / MONITOR_LOCK_NAME + + +def load_queue_state(config: FleetConfig) -> dict[str, Any]: + path = queue_state_path(config) + if path.exists(): + payload = load_json(path) + payload.setdefault("jobs", {}) + return payload + return { + "generated_at": utc_now(), + "updated_at": utc_now(), + "jobs": {}, + } + + +def save_queue_state(config: FleetConfig, payload: dict[str, Any]) -> None: + payload["updated_at"] = utc_now() + dump_json(queue_state_path(config), payload) + + +def job_signature(job: JobSpec) -> str: + payload = { + "name": job.name, + "command": job.command, + "gpus": job.gpus, + "gpu_model": job.gpu_model, + "hosts": job.hosts, + "artifacts": job.artifacts, + "env": job.env, + } + raw = json.dumps(payload, sort_keys=True, ensure_ascii=True) + return hashlib.sha256(raw.encode("utf-8")).hexdigest() + + +def parse_gpu_query(output: str) -> list[dict[str, Any]]: + rows: list[dict[str, Any]] = [] + reader = csv.reader(line for line in output.splitlines() if line.strip()) + for row in reader: + if len(row) != 4: + continue + index, name, memory_total, memory_used = [cell.strip() for cell in row] + rows.append( + { + "index": int(index), + "name": name, + "memory_total_mb": int(memory_total), + "memory_used_mb": int(memory_used), + } + ) + return rows + + +def parse_gpu_status(output: str) -> list[dict[str, Any]]: + rows: list[dict[str, Any]] = [] + reader = csv.reader(line for line in output.splitlines() if line.strip()) + for row in reader: + if len(row) != 5: + continue + index, name, memory_total, memory_used, util = [cell.strip() for cell in row] + rows.append( + { + "index": int(index), + "name": name, + "memory_total_mb": int(memory_total), + "memory_used_mb": int(memory_used), + "utilization_gpu_pct": int(util), + } + ) + return rows + + +def probe_single_host(config: FleetConfig, host: HostSpec) -> dict[str, Any]: + gpu_cmd = textwrap.dedent( + """ + bash -lc ' + set -euo pipefail + nvidia-smi --query-gpu=index,name,memory.total,memory.used --format=csv,noheader,nounits + ' + """ + ).strip() + completed = run_ssh(config, host, gpu_cmd, capture_output=True, check=False) + payload: dict[str, Any] = { + "host": host.name, + "ssh_alias": host.ssh_alias, + "enabled": host.enabled, + "probed_at": utc_now(), + } + if completed.returncode != 0: + payload["status"] = "unreachable" + payload["error"] = (completed.stderr or completed.stdout).strip() + return payload + + gpus = parse_gpu_query(completed.stdout) + names = sorted({gpu["name"] for gpu in gpus}) + payload["status"] = "ok" + payload["gpus"] = gpus + payload["gpu_count"] = len(gpus) + payload["gpu_models"] = names + payload["gpu_summary"] = ", ".join(f"{sum(1 for gpu in gpus if gpu['name'] == name)}x {name}" for name in names) + return payload + + +def load_probe_cache(config: FleetConfig) -> dict[str, Any]: + path = probe_cache_path(config) + if not path.exists(): + raise FleetError( + f"probe cache is missing: {path}. Run 'python3 infra/gpu_fleet/gpu_fleet.py probe --config {config.config_path}' first." + ) + return load_json(path) + + +def query_host_gpu_status(config: FleetConfig, host: HostSpec) -> list[dict[str, Any]]: + cmd = textwrap.dedent( + """ + bash -lc ' + set -euo pipefail + nvidia-smi --query-gpu=index,name,memory.total,memory.used,utilization.gpu --format=csv,noheader,nounits + ' + """ + ).strip() + completed = run_ssh(config, host, cmd, capture_output=True, check=True) + return parse_gpu_status(completed.stdout) + + +def free_gpu_indices(config: FleetConfig, host: HostSpec) -> list[int]: + status = query_host_gpu_status(config, host) + free: list[int] = [] + for gpu in status: + if gpu["memory_used_mb"] <= config.scheduler.gpu_free_memory_mb and gpu["utilization_gpu_pct"] <= config.scheduler.gpu_free_utilization_pct: + free.append(gpu["index"]) + return free + + +def matches_gpu_model(host_probe: dict[str, Any], host: HostSpec, requested: str | None) -> bool: + if not requested: + return True + needle = requested.lower() + for name in host_probe.get("gpu_models", []): + if needle in name.lower(): + return True + summary = str(host_probe.get("gpu_summary", "")) + return needle in summary.lower() + + +def matches_host_constraint(host: HostSpec, requested_hosts: list[str]) -> bool: + if not requested_hosts: + return True + requested = {item.strip() for item in requested_hosts if item.strip()} + return host.name in requested or host.ssh_alias in requested + + +def build_payload(job: JobSpec, host: HostSpec, run_id: str, remote_run_dir: str, gpu_ids: list[int]) -> str: + env_exports = "\n".join( + f"export {key}={shlex.quote(value)}" + for key, value in sorted(job.env.items()) + ) + log_file = f"{remote_run_dir}/stdout.log" + exit_file = f"{remote_run_dir}/exit_code" + started_file = f"{remote_run_dir}/started_at" + finished_file = f"{remote_run_dir}/finished_at" + payload = textwrap.dedent( + f""" + set -euo pipefail + mkdir -p {shlex.quote(remote_run_dir)} + exec > >(tee -a {shlex.quote(log_file)}) 2>&1 + printf '%s\\n' "$(date -Iseconds)" > {shlex.quote(started_file)} + on_exit() {{ + status=$? + printf '%s\\n' "$status" > {shlex.quote(exit_file)} + printf '%s\\n' "$(date -Iseconds)" > {shlex.quote(finished_file)} + }} + trap on_exit EXIT + cd {shlex.quote(host.sync_remote_path)} + export CUDA_VISIBLE_DEVICES={shlex.quote(','.join(str(idx) for idx in gpu_ids))} + export AITUNER_RUN_ID={shlex.quote(run_id)} + export AITUNER_REMOTE_RUN_DIR={shlex.quote(remote_run_dir)} + export AITUNER_GPU_COUNT={shlex.quote(str(job.gpus))} + {env_exports} + {job.command} + """ + ).strip() + return payload + + +def write_run_manifest(config: FleetConfig, manifest: dict[str, Any]) -> Path: + path = run_state_dir(config) / f"{manifest['run_id']}.json" + dump_json(path, manifest) + return path + + +def run_manifest_path(config: FleetConfig, run_id: str) -> Path: + return run_state_dir(config) / f"{run_id}.json" + + +def load_run_manifests(config: FleetConfig) -> list[dict[str, Any]]: + runs_dir = run_state_dir(config) + manifests: list[dict[str, Any]] = [] + for path in sorted(runs_dir.glob("*.json")): + manifests.append(load_json(path)) + return manifests + + +def launch_job(config: FleetConfig, host: HostSpec, job: JobSpec, gpu_ids: list[int]) -> dict[str, Any]: + sync_remote_path = host.sync_remote_path.rstrip("/") + fleet_root = host.fleet_root.rstrip("/") + run_id = f"{job.name}-{datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S%fZ')}" + session_name = f"aituner_{run_id.replace('.', '_')}" + remote_run_dir = f"{fleet_root}/runs/{run_id}" + payload = build_payload(job, host, run_id, remote_run_dir, gpu_ids) + remote_launch = ( + f"mkdir -p {shlex.quote(remote_run_dir)} " + f"&& tmux new-session -d -s {shlex.quote(session_name)} bash -lc {shlex.quote(payload)}" + ) + run_ssh(config, host, remote_launch, capture_output=True, check=True) + manifest = { + "run_id": run_id, + "job_name": job.name, + "host": host.name, + "ssh_alias": host.ssh_alias, + "session_name": session_name, + "gpu_ids": gpu_ids, + "gpu_model": job.gpu_model, + "gpus": job.gpus, + "command": job.command, + "env": job.env, + "artifacts": job.artifacts, + "remote_sync_path": sync_remote_path, + "remote_run_dir": remote_run_dir, + "status": "running", + "submitted_at": utc_now(), + "harvested_at": None, + } + write_run_manifest(config, manifest) + return manifest + + +def refresh_manifest_status(config: FleetConfig, manifest: dict[str, Any]) -> dict[str, Any]: + host = config.hosts[manifest["host"]] + remote_cmd = textwrap.dedent( + f""" + bash -lc ' + if [ -f {shlex.quote(manifest["remote_run_dir"] + "/exit_code")} ]; then + code=$(cat {shlex.quote(manifest["remote_run_dir"] + "/exit_code")}) + printf "done:%s\\n" "$code" + elif tmux has-session -t {shlex.quote(manifest["session_name"])} 2>/dev/null; then + printf "running\\n" + else + printf "unknown\\n" + fi + ' + """ + ).strip() + completed = run_ssh(config, host, remote_cmd, capture_output=True, check=False) + status_line = (completed.stdout or completed.stderr).strip() + new_manifest = dict(manifest) + if status_line.startswith("done:"): + code = int(status_line.split(":", 1)[1]) + new_manifest["status"] = "completed" if code == 0 else "failed" + new_manifest["exit_code"] = code + elif status_line == "running": + new_manifest["status"] = "running" + else: + new_manifest["status"] = "unknown" + write_run_manifest(config, new_manifest) + return new_manifest + + +def harvest_run(config: FleetConfig, manifest: dict[str, Any]) -> dict[str, Any]: + host = config.hosts[manifest["host"]] + refreshed = refresh_manifest_status(config, manifest) + if refreshed["status"] not in {"completed", "failed"}: + return refreshed + + local_base = ensure_dir(config.artifacts_dir / refreshed["run_id"]) + remote_run_dir = refreshed["remote_run_dir"].rstrip("/") + rsync_pull(config, host, f"{remote_run_dir}/", local_base / "remote_run") + + for artifact in refreshed.get("artifacts", []): + artifact_remote = f"{refreshed['remote_sync_path'].rstrip('/')}/{artifact}" + target = local_base / "artifacts" / Path(artifact.lstrip("/")) + check = run_ssh( + config, + host, + f"test -e {shlex.quote(artifact_remote)}", + capture_output=True, + check=False, + ) + if check.returncode == 0: + rsync_pull(config, host, artifact_remote, target) + + refreshed["harvested_at"] = utc_now() + write_run_manifest(config, refreshed) + return refreshed + + +def reconcile_queue_state_with_runs(config: FleetConfig, queue_state: dict[str, Any]) -> None: + state_jobs = queue_state.get("jobs", {}) + for name, record in state_jobs.items(): + if record.get("status") != "running": + continue + run_id = record.get("run_id") + if not run_id: + record["status"] = "unknown" + record["updated_at"] = utc_now() + continue + manifest_path = run_manifest_path(config, run_id) + if not manifest_path.exists(): + record["status"] = "unknown" + record["updated_at"] = utc_now() + continue + + manifest = load_json(manifest_path) + refreshed = refresh_manifest_status(config, manifest) + record["updated_at"] = utc_now() + if refreshed["status"] == "running": + continue + if refreshed["status"] in {"completed", "failed"}: + harvested = harvest_run(config, refreshed) + record["status"] = harvested["status"] + record["completed_at"] = utc_now() + record["exit_code"] = harvested.get("exit_code") + record["harvested_at"] = harvested.get("harvested_at") + continue + record["status"] = "unknown" + + +def print_probe_table(probe: dict[str, Any]) -> None: + for name, payload in sorted(probe["hosts"].items()): + if payload["status"] != "ok": + print(f"{name:12} status={payload['status']} error={payload.get('error', '')}") + continue + print(f"{name:12} status=ok gpus={payload['gpu_summary']}") + + +class MonitorLock: + def __init__(self, path: Path): + self.path = path + self.handle: Any | None = None + + def __enter__(self) -> "MonitorLock": + ensure_dir(self.path.parent) + self.handle = self.path.open("w", encoding="utf-8") + try: + fcntl.flock(self.handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError as exc: + raise FleetError(f"monitor lock is busy: {self.path}") from exc + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + if self.handle is not None: + fcntl.flock(self.handle.fileno(), fcntl.LOCK_UN) + self.handle.close() + + +def cmd_validate(args: argparse.Namespace) -> int: + config = load_config(Path(args.config).resolve()) + print(f"config ok: {config.config_path}") + if args.jobs: + jobs = load_jobs(Path(args.jobs).resolve()) + validate_jobs_against_config(config, jobs) + print(f"jobs ok: {len(jobs)} definitions") + return 0 + + +def cmd_bootstrap_hosts(args: argparse.Namespace) -> int: + aliases_path = Path(args.aliases_file).expanduser().resolve() + aliases = [ + line.strip() + for line in aliases_path.read_text(encoding="utf-8").splitlines() + if line.strip() and not line.lstrip().startswith("#") + ] + if not aliases: + raise FleetError(f"no aliases found in {aliases_path}") + sections = [ + 'version = 1', + '', + '[paths]', + 'state_dir = ".aituner/gpu_fleet/state"', + 'artifacts_dir = ".aituner/gpu_fleet/artifacts"', + '', + '[sync]', + 'mode = "rsync"', + 'local_path = "."', + '', + ] + for alias in aliases: + sections.extend( + [ + '[[hosts]]', + f'name = "{alias}"', + f'ssh_alias = "{alias}"', + 'enabled = true', + 'sync_remote_path = "~/workspace/aituner"', + 'fleet_root = "~/.aituner_gpu_fleet"', + '', + ] + ) + payload = "\n".join(sections) + if args.output: + output = Path(args.output).expanduser().resolve() + ensure_dir(output.parent) + output.write_text(payload + "\n", encoding="utf-8") + print(f"wrote skeleton config: {output}") + else: + print(payload) + return 0 + + +def cmd_probe(args: argparse.Namespace) -> int: + config = load_config(Path(args.config).resolve()) + results = { + "generated_at": utc_now(), + "config_path": str(config.config_path), + "hosts": {}, + } + for host in config.hosts.values(): + if args.host and host.name != args.host: + continue + if not host.enabled and not args.include_disabled: + continue + results["hosts"][host.name] = probe_single_host(config, host) + dump_json(probe_cache_path(config), results) + print_probe_table(results) + return 0 + + +def snapshot_free_gpus(config: FleetConfig, probe: dict[str, Any]) -> dict[str, list[int]]: + availability: dict[str, list[int]] = {} + for host in config.hosts.values(): + host_probe = probe["hosts"].get(host.name) + if not host.enabled or not host_probe or host_probe.get("status") != "ok": + continue + availability[host.name] = free_gpu_indices(config, host) + return availability + + +def select_host_for_job( + config: FleetConfig, + probe: dict[str, Any], + availability: dict[str, list[int]], + job: JobSpec, +) -> tuple[HostSpec, list[int]] | None: + candidates: list[tuple[int, int, str, HostSpec, list[int]]] = [] + for host in config.hosts.values(): + if not host.enabled: + continue + host_probe = probe["hosts"].get(host.name) + if not host_probe or host_probe.get("status") != "ok": + continue + if not matches_host_constraint(host, job.hosts): + continue + if not matches_gpu_model(host_probe, host, job.gpu_model): + continue + free = availability.get(host.name, []) + if len(free) < job.gpus: + continue + chosen = sorted(free)[: job.gpus] + leftover = len(free) - job.gpus + exact_model = 0 if job.gpu_model and matches_gpu_model(host_probe, host, job.gpu_model) else 1 + candidates.append((leftover if config.scheduler.prefer_pack else -len(free), exact_model, host.name, host, chosen)) + if not candidates: + return None + candidates.sort(key=lambda item: (item[0], item[1], item[2])) + _, _, _, host, gpu_ids = candidates[0] + return host, gpu_ids + + +def dispatch_jobs( + config: FleetConfig, + jobs: list[JobSpec], + *, + dry_run: bool, + max_launch: int | None = None, +) -> tuple[list[dict[str, Any]], list[str]]: + probe = load_probe_cache(config) + availability = snapshot_free_gpus(config, probe) + manifests: list[dict[str, Any]] = [] + skipped: list[str] = [] + synced_hosts: set[str] = set() + launch_budget = max_launch if max_launch is not None else len(jobs) + + for job in jobs: + if len(manifests) >= launch_budget: + skipped.append(f"{job.name}: launch budget exhausted") + continue + selected = select_host_for_job(config, probe, availability, job) + if not selected: + skipped.append(f"{job.name}: no host has matching host/model/free GPUs") + continue + host, gpu_ids = selected + availability[host.name] = [gpu for gpu in availability.get(host.name, []) if gpu not in set(gpu_ids)] + if dry_run: + manifests.append( + { + "job_name": job.name, + "host": host.name, + "gpu_ids": gpu_ids, + "dry_run": True, + } + ) + continue + if host.name not in synced_hosts: + rsync_push(config, host) + synced_hosts.add(host.name) + manifest = launch_job(config, host, job, gpu_ids) + manifests.append(manifest) + return manifests, skipped + + +def cmd_dispatch(args: argparse.Namespace) -> int: + config = load_config(Path(args.config).resolve()) + jobs = load_jobs(Path(args.jobs).resolve()) + validate_jobs_against_config(config, jobs) + manifests, skipped = dispatch_jobs( + config, + jobs, + dry_run=args.dry_run, + max_launch=args.max_launch, + ) + + for manifest in manifests: + if manifest.get("dry_run"): + print(f"[dry-run] {manifest['job_name']} -> {manifest['host']} gpus={manifest['gpu_ids']}") + continue + print( + f"launched {manifest['run_id']} host={manifest['host']} gpus={manifest['gpu_ids']}" + ) + for line in skipped: + print(f"skipped {line}") + return 0 + + +def print_queue_summary(queue_state: dict[str, Any]) -> None: + counts: dict[str, int] = {} + for record in queue_state.get("jobs", {}).values(): + status = str(record.get("status", "pending")) + counts[status] = counts.get(status, 0) + 1 + summary = " ".join(f"{key}={counts[key]}" for key in sorted(counts)) + print(summary or "pending=0") + + +def print_queue_jobs(queue_state: dict[str, Any]) -> None: + for name, record in sorted(queue_state.get("jobs", {}).items()): + host = record.get("host") or "-" + run_id = record.get("run_id") or "-" + print( + f"{name:32} status={record.get('status','pending'):10} gpus={record.get('gpus','-')} host={host:8} run_id={run_id}" + ) + + +def monitor_cycle( + config: FleetConfig, + jobs_path: Path, + *, + dry_run: bool, + max_launch_per_cycle: int | None, + verbose: bool, +) -> tuple[int, int]: + queue_state = load_queue_state(config) + reconcile_queue_state_with_runs(config, queue_state) + + jobs: list[JobSpec] = [] + try: + jobs = load_jobs(jobs_path) + validate_jobs_against_config(config, jobs) + sync_jobs_to_queue_state(queue_state, jobs) + except Exception as exc: + if verbose: + print(f"queue read warning: {exc}") + save_queue_state(config, queue_state) + return 0, 0 + + pending = pending_jobs_from_queue(queue_state, jobs) + try: + manifests, skipped = dispatch_jobs( + config, + pending, + dry_run=dry_run, + max_launch=max_launch_per_cycle, + ) + except Exception as exc: + if verbose: + print(f"dispatch warning: {exc}") + save_queue_state(config, queue_state) + return 0, 0 + + now = utc_now() + for manifest in manifests: + record = queue_state["jobs"][manifest["job_name"]] + if manifest.get("dry_run"): + continue + record["status"] = "running" + record["submitted_at"] = now + record["updated_at"] = now + record["run_id"] = manifest["run_id"] + record["host"] = manifest["host"] + record["gpu_ids"] = manifest["gpu_ids"] + record["attempts"] = int(record.get("attempts", 0)) + 1 + + save_queue_state(config, queue_state) + + if verbose: + for manifest in manifests: + if manifest.get("dry_run"): + print(f"[dry-run] {manifest['job_name']} -> {manifest['host']} gpus={manifest['gpu_ids']}") + else: + print(f"launched {manifest['run_id']} host={manifest['host']} gpus={manifest['gpu_ids']}") + for line in skipped: + print(f"deferred {line}") + print_queue_summary(queue_state) + return len(manifests), len(skipped) + + +def cmd_monitor(args: argparse.Namespace) -> int: + config = load_config(Path(args.config).resolve()) + jobs_path = Path(args.jobs).resolve() + with MonitorLock(monitor_lock_path(config)): + if args.once: + monitor_cycle( + config, + jobs_path, + dry_run=args.dry_run, + max_launch_per_cycle=args.max_launch_per_cycle, + verbose=True, + ) + return 0 + + print(f"monitoring {jobs_path} every {args.interval_sec}s") + try: + while True: + monitor_cycle( + config, + jobs_path, + dry_run=args.dry_run, + max_launch_per_cycle=args.max_launch_per_cycle, + verbose=not args.quiet, + ) + time.sleep(args.interval_sec) + except KeyboardInterrupt: + print("monitor stopped") + return 0 + + +def cmd_queue_status(args: argparse.Namespace) -> int: + config = load_config(Path(args.config).resolve()) + queue_state = load_queue_state(config) + print_queue_summary(queue_state) + print_queue_jobs(queue_state) + return 0 + + +def cmd_status(args: argparse.Namespace) -> int: + config = load_config(Path(args.config).resolve()) + manifests = load_run_manifests(config) + if not manifests: + print("no runs recorded") + return 0 + for manifest in manifests: + refreshed = refresh_manifest_status(config, manifest) + exit_code = refreshed.get("exit_code") + suffix = f" exit_code={exit_code}" if exit_code is not None else "" + print( + f"{refreshed['run_id']} job={refreshed['job_name']} host={refreshed['host']} status={refreshed['status']}{suffix}" + ) + return 0 + + +def cmd_harvest(args: argparse.Namespace) -> int: + config = load_config(Path(args.config).resolve()) + manifests = load_run_manifests(config) + if args.run_id: + manifests = [manifest for manifest in manifests if manifest["run_id"] == args.run_id] + harvested_any = False + for manifest in manifests: + refreshed = harvest_run(config, manifest) + if refreshed.get("harvested_at"): + harvested_any = True + print( + f"harvested {refreshed['run_id']} status={refreshed['status']} into {config.artifacts_dir / refreshed['run_id']}" + ) + else: + print(f"pending {refreshed['run_id']} status={refreshed['status']}") + if not harvested_any and not manifests: + print("no matching runs") + return 0 + + +def cmd_show_hosts(args: argparse.Namespace) -> int: + config = load_config(Path(args.config).resolve()) + probe = load_probe_cache(config) + for host in config.hosts.values(): + host_probe = probe["hosts"].get(host.name, {}) + summary = host_probe.get("gpu_summary", "unprobed") + print(f"{host.name:12} alias={host.ssh_alias:12} gpus={summary:20}") + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Dispatch heterogeneous GPU jobs over SSH/tmux.") + subparsers = parser.add_subparsers(dest="command", required=True) + + validate = subparsers.add_parser("validate", help="Validate fleet and optional job configs.") + validate.add_argument("--config", default=str(DEFAULT_CONFIG)) + validate.add_argument("--jobs") + validate.set_defaults(func=cmd_validate) + + bootstrap = subparsers.add_parser("bootstrap-hosts", help="Generate a fleet config skeleton from a plain alias list.") + bootstrap.add_argument("--aliases-file", required=True) + bootstrap.add_argument("--output") + bootstrap.set_defaults(func=cmd_bootstrap_hosts) + + probe = subparsers.add_parser("probe", help="Probe remote hosts for GPU information.") + probe.add_argument("--config", default=str(DEFAULT_CONFIG)) + probe.add_argument("--host") + probe.add_argument("--include-disabled", action="store_true") + probe.set_defaults(func=cmd_probe) + + dispatch = subparsers.add_parser("dispatch", help="Dispatch a batch of jobs to matching hosts.") + dispatch.add_argument("--config", default=str(DEFAULT_CONFIG)) + dispatch.add_argument("--jobs", required=True) + dispatch.add_argument("--dry-run", action="store_true") + dispatch.add_argument("--max-launch", type=int) + dispatch.set_defaults(func=cmd_dispatch) + + monitor = subparsers.add_parser("monitor", help="Continuously consume jobs.toml as a local queue.") + monitor.add_argument("--config", default=str(DEFAULT_CONFIG)) + monitor.add_argument("--jobs", required=True) + monitor.add_argument("--interval-sec", type=int, default=15) + monitor.add_argument("--max-launch-per-cycle", type=int) + monitor.add_argument("--dry-run", action="store_true") + monitor.add_argument("--once", action="store_true") + monitor.add_argument("--quiet", action="store_true") + monitor.set_defaults(func=cmd_monitor) + + status = subparsers.add_parser("status", help="Refresh and print known run status.") + status.add_argument("--config", default=str(DEFAULT_CONFIG)) + status.set_defaults(func=cmd_status) + + harvest = subparsers.add_parser("harvest", help="Pull finished run logs and artifacts back to local state.") + harvest.add_argument("--config", default=str(DEFAULT_CONFIG)) + harvest.add_argument("--run-id") + harvest.set_defaults(func=cmd_harvest) + + show_hosts = subparsers.add_parser("show-hosts", help="Show configured hosts together with the latest probe summary.") + show_hosts.add_argument("--config", default=str(DEFAULT_CONFIG)) + show_hosts.set_defaults(func=cmd_show_hosts) + + queue_status = subparsers.add_parser("queue-status", help="Show durable queue state tracked by the monitor.") + queue_status.add_argument("--config", default=str(DEFAULT_CONFIG)) + queue_status.set_defaults(func=cmd_queue_status) + + return parser + + +def main(argv: list[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + try: + return args.func(args) + except FleetError as exc: + print(f"error: {exc}", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..557483e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,19 @@ +[build-system] +requires = ["setuptools>=68"] +build-backend = "setuptools.build_meta" + +[project] +name = "aituner" +version = "0.1.0" +description = "AITuner study orchestrator for OpenAI-compatible serving engines" +requires-python = ">=3.11" +dependencies = [] + +[project.scripts] +aituner = "aituner.cli:main" + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/src/aituner/__init__.py b/src/aituner/__init__.py new file mode 100644 index 0000000..8fcf64f --- /dev/null +++ b/src/aituner/__init__.py @@ -0,0 +1,5 @@ +"""AITuner package.""" + +__all__ = [ + "cli", +] diff --git a/src/aituner/cli.py b/src/aituner/cli.py new file mode 100644 index 0000000..bb4831d --- /dev/null +++ b/src/aituner/cli.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +from .job import append_job, build_trial_job +from .llm import build_prompt, call_llm_for_proposal, load_capability_profile, parse_proposal_text +from .spec import Proposal, SpecError, load_study_spec +from .store import StudyStore +from .trace import load_trace_requests, summarize_window +from .worker import run_trial + + +def _study_source_path(study_root: Path) -> Path: + return Path((study_root / "study_spec.source").read_text(encoding="utf-8").strip()) + + +def cmd_study_init(args: argparse.Namespace) -> int: + spec_path = Path(args.spec).resolve() + study = load_study_spec(spec_path) + store = StudyStore(Path(args.store_root) if args.store_root else None) + root = store.init_study(spec_path=spec_path, study=study) + print(root) + return 0 + + +def cmd_study_prompt(args: argparse.Namespace) -> int: + store = StudyStore(Path(args.store_root) if args.store_root else None) + study_root = Path(args.study_root).resolve() + source_path = _study_source_path(study_root) + study = load_study_spec(source_path) + state = store.load_state(study.study_id) + capability_profile = load_capability_profile(study, study_spec_path=source_path) + window, requests = load_trace_requests(study, study_spec_path=source_path) + prompt = build_prompt( + study=study, + window_summary=summarize_window(requests, window), + state=state, + capability_profile=capability_profile, + ) + prompt_name = args.prompt_name or f"prompt-{state.next_trial_index:04d}" + path = store.write_prompt(study.study_id, prompt_name, prompt) + print(path) + return 0 + + +def cmd_study_llm_propose(args: argparse.Namespace) -> int: + store = StudyStore(Path(args.store_root) if args.store_root else None) + study_root = Path(args.study_root).resolve() + source_path = _study_source_path(study_root) + study = load_study_spec(source_path) + state = store.load_state(study.study_id) + capability_profile = load_capability_profile(study, study_spec_path=source_path) + window, requests = load_trace_requests(study, study_spec_path=source_path) + prompt = build_prompt( + study=study, + window_summary=summarize_window(requests, window), + state=state, + capability_profile=capability_profile, + ) + proposal_text = call_llm_for_proposal(policy=study.llm, prompt=prompt) + proposal = parse_proposal_text(proposal_text, study) + name = args.proposal_name or f"proposal-{state.next_trial_index:04d}" + path = store.write_proposal(study.study_id, name, proposal) + print(path) + return 0 + + +def cmd_study_register_proposal(args: argparse.Namespace) -> int: + store = StudyStore(Path(args.store_root) if args.store_root else None) + study_root = Path(args.study_root).resolve() + source_path = _study_source_path(study_root) + study = load_study_spec(source_path) + proposal = parse_proposal_text(Path(args.proposal_file).read_text(encoding="utf-8"), study) + name = args.proposal_name or Path(args.proposal_file).stem + path = store.write_proposal(study.study_id, name, proposal) + print(path) + return 0 + + +def cmd_study_emit_job(args: argparse.Namespace) -> int: + store = StudyStore(Path(args.store_root) if args.store_root else None) + study_root = Path(args.study_root).resolve() + source_path = _study_source_path(study_root) + study = load_study_spec(source_path) + state = store.load_state(study.study_id) + proposal_text = Path(args.proposal_file).read_text(encoding="utf-8") + proposal = parse_proposal_text(proposal_text, study) + trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal) + repo_root = Path(__file__).resolve().parents[2] + job = build_trial_job(study=study, trial=trial, repo_root=repo_root) + append_job(Path(args.jobs_file).resolve(), job) + print(trial.trial_id) + return 0 + + +def cmd_study_ingest(args: argparse.Namespace) -> int: + store = StudyStore(Path(args.store_root) if args.store_root else None) + study_root = Path(args.study_root).resolve() + study_id = study_root.name + state = store.ingest_trial_results(study_id) + print(json.dumps({"best_trial_id": state.best_trial_id, "best_request_rate": state.best_request_rate})) + return 0 + + +def cmd_worker_run_trial(args: argparse.Namespace) -> int: + result = run_trial(Path(args.trial_spec).resolve()) + print(json.dumps(result)) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="AITuner CLI") + subparsers = parser.add_subparsers(dest="command", required=True) + + study = subparsers.add_parser("study") + study_sub = study.add_subparsers(dest="study_command", required=True) + + init = study_sub.add_parser("init") + init.add_argument("--spec", required=True) + init.add_argument("--store-root") + init.set_defaults(func=cmd_study_init) + + prompt = study_sub.add_parser("prompt") + prompt.add_argument("--study-root", required=True) + prompt.add_argument("--store-root") + prompt.add_argument("--prompt-name") + prompt.set_defaults(func=cmd_study_prompt) + + llm_propose = study_sub.add_parser("llm-propose") + llm_propose.add_argument("--study-root", required=True) + llm_propose.add_argument("--store-root") + llm_propose.add_argument("--proposal-name") + llm_propose.set_defaults(func=cmd_study_llm_propose) + + register = study_sub.add_parser("register-proposal") + register.add_argument("--study-root", required=True) + register.add_argument("--store-root") + register.add_argument("--proposal-file", required=True) + register.add_argument("--proposal-name") + register.set_defaults(func=cmd_study_register_proposal) + + emit = study_sub.add_parser("emit-job") + emit.add_argument("--study-root", required=True) + emit.add_argument("--store-root") + emit.add_argument("--proposal-file", required=True) + emit.add_argument("--jobs-file", required=True) + emit.set_defaults(func=cmd_study_emit_job) + + ingest = study_sub.add_parser("ingest") + ingest.add_argument("--study-root", required=True) + ingest.add_argument("--store-root") + ingest.set_defaults(func=cmd_study_ingest) + + worker = subparsers.add_parser("worker") + worker_sub = worker.add_subparsers(dest="worker_command", required=True) + run = worker_sub.add_parser("run-trial") + run.add_argument("--trial-spec", required=True) + run.set_defaults(func=cmd_worker_run_trial) + + return parser + + +def main(argv: list[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + try: + return int(args.func(args)) + except SpecError as exc: + print(str(exc), file=sys.stderr) + return 2 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/aituner/engine.py b/src/aituner/engine.py new file mode 100644 index 0000000..aa18895 --- /dev/null +++ b/src/aituner/engine.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import os +import shlex +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .spec import ConfigPatch, EngineLaunchSpec + + +@dataclass(frozen=True) +class LaunchRecipe: + argv: list[str] + env: dict[str, str] + cwd: str | None + base_url: str + healthcheck_path: str + ready_timeout_s: float + request_timeout_s: float + + +def _normalize_flag_name(name: str) -> str: + return str(name).strip().replace("_", "-") + + +def _serialize_flag_parts(name: str, value: Any) -> list[str]: + flag = f"--{_normalize_flag_name(name)}" + if value is None: + return [] + if isinstance(value, bool): + return [flag] if value else [f"--no-{_normalize_flag_name(name)}"] + if isinstance(value, list): + parts: list[str] = [] + for item in value: + parts.extend([flag, str(item)]) + return parts + return [flag, str(value)] + + +def build_launch_recipe(spec: EngineLaunchSpec, patch: ConfigPatch) -> LaunchRecipe: + env = dict(os.environ) + env.update(spec.base_envs) + env.update(patch.env_patch) + flags = dict(spec.base_flags) + flags.update(patch.flag_patch) + argv = [spec.exec_path, *spec.launch_args] + for key, value in flags.items(): + argv.extend(_serialize_flag_parts(key, value)) + cwd = None + if spec.cwd: + cwd = str(Path(spec.cwd).expanduser()) + return LaunchRecipe( + argv=argv, + env={str(key): str(value) for key, value in env.items()}, + cwd=cwd, + base_url=spec.base_url, + healthcheck_path=spec.healthcheck_path, + ready_timeout_s=spec.ready_timeout_s, + request_timeout_s=spec.request_timeout_s, + ) + + +def shell_join(argv: list[str]) -> str: + return " ".join(shlex.quote(item) for item in argv) diff --git a/src/aituner/http_client.py b/src/aituner/http_client.py new file mode 100644 index 0000000..85dd108 --- /dev/null +++ b/src/aituner/http_client.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import json +import os +import time +import urllib.error +import urllib.request +from dataclasses import dataclass +from typing import Any, Iterable + + +class HttpClientError(RuntimeError): + """Raised for HTTP client failures.""" + + +def _auth_headers(api_key_env: str | None) -> dict[str, str]: + headers = {"Content-Type": "application/json"} + if api_key_env: + api_key = os.environ.get(api_key_env) + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers + + +def wait_for_server(base_url: str, path: str, timeout_s: float) -> None: + deadline = time.monotonic() + timeout_s + url = f"{base_url.rstrip('/')}{path}" + last_error = "server_not_ready" + while time.monotonic() < deadline: + try: + request = urllib.request.Request(url=url, headers=_auth_headers(None), method="GET") + with urllib.request.urlopen(request, timeout=5) as response: + if 200 <= response.status < 500: + return + except Exception as exc: # noqa: BLE001 + last_error = str(exc) + time.sleep(1.0) + raise HttpClientError(f"Timed out waiting for {url}: {last_error}") + + +def chat_completion( + *, + base_url: str, + api_key_env: str | None, + model: str, + messages: list[dict[str, Any]], + timeout_s: float, + system_prompt: str = "", +) -> dict[str, Any]: + payload: dict[str, Any] = {"model": model, "messages": messages} + if system_prompt: + payload["messages"] = [{"role": "system", "content": system_prompt}, *messages] + data = json.dumps(payload).encode("utf-8") + request = urllib.request.Request( + url=f"{base_url.rstrip('/')}/v1/chat/completions", + headers=_auth_headers(api_key_env), + data=data, + method="POST", + ) + try: + with urllib.request.urlopen(request, timeout=timeout_s) as response: + return json.loads(response.read().decode("utf-8")) + except urllib.error.HTTPError as exc: + detail = exc.read().decode("utf-8", errors="replace") + raise HttpClientError(f"chat_completion failed: {exc.code} {detail}") from exc + + +@dataclass(frozen=True) +class StreamMetrics: + ttft_ms: float | None + tpot_ms: float | None + completion_tokens: int | None + + +def stream_chat_completion( + *, + base_url: str, + body: dict[str, Any], + timeout_s: float, +) -> StreamMetrics: + data = json.dumps(body).encode("utf-8") + request = urllib.request.Request( + url=f"{base_url.rstrip('/')}/v1/chat/completions", + headers=_auth_headers(None), + data=data, + method="POST", + ) + start = time.monotonic() + first_token_at: float | None = None + last_token_at: float | None = None + chunk_token_count = 0 + completion_tokens: int | None = None + try: + with urllib.request.urlopen(request, timeout=timeout_s) as response: + for raw in _iter_sse_lines(response): + if raw == "[DONE]": + break + payload = json.loads(raw) + if not isinstance(payload, dict): + continue + usage = payload.get("usage") + if isinstance(usage, dict): + comp = usage.get("completion_tokens") + if isinstance(comp, int) and comp >= 0: + completion_tokens = comp + choices = payload.get("choices") + if not isinstance(choices, list) or not choices: + continue + delta = choices[0].get("delta", {}) + if not isinstance(delta, dict): + continue + content = delta.get("content") + if isinstance(content, str) and content: + now = time.monotonic() + if first_token_at is None: + first_token_at = now + last_token_at = now + chunk_token_count += 1 + except urllib.error.HTTPError as exc: + detail = exc.read().decode("utf-8", errors="replace") + raise HttpClientError(f"stream_chat_completion failed: {exc.code} {detail}") from exc + ttft_ms = None if first_token_at is None else (first_token_at - start) * 1000.0 + used_tokens = completion_tokens if completion_tokens is not None else chunk_token_count + if ( + first_token_at is None + or last_token_at is None + or used_tokens is None + or used_tokens <= 1 + ): + tpot_ms = None + else: + tpot_ms = ((last_token_at - first_token_at) / max(used_tokens - 1, 1)) * 1000.0 + return StreamMetrics( + ttft_ms=ttft_ms, + tpot_ms=tpot_ms, + completion_tokens=used_tokens if used_tokens > 0 else None, + ) + + +def _iter_sse_lines(response: Any) -> Iterable[str]: + for raw in response: + line = raw.decode("utf-8", errors="replace").strip() + if not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if payload: + yield payload diff --git a/src/aituner/job.py b/src/aituner/job.py new file mode 100644 index 0000000..d2e3af7 --- /dev/null +++ b/src/aituner/job.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .spec import StudySpec, TrialSpec + + +@dataclass(frozen=True) +class InfraJob: + name: str + gpus: int + gpu_model: str | None + hosts: list[str] + command: str + artifacts: list[str] + env: dict[str, str] + + +def _toml_scalar(value: Any) -> str: + if isinstance(value, bool): + return "true" if value else "false" + if isinstance(value, int): + return str(value) + text = str(value).replace("\\", "\\\\").replace('"', '\\"') + return f'"{text}"' + + +def _toml_list(values: list[Any]) -> str: + return "[" + ", ".join(_toml_scalar(item) for item in values) + "]" + + +def _toml_inline_table(mapping: dict[str, str]) -> str: + parts = [f"{key} = {_toml_scalar(value)}" for key, value in sorted(mapping.items())] + return "{ " + ", ".join(parts) + " }" + + +def build_trial_job(*, study: StudySpec, trial: TrialSpec, repo_root: Path) -> InfraJob: + trial_path = Path(trial.artifact_dir) / "trial_spec.json" + rel_trial_path = trial_path.resolve().relative_to(repo_root.resolve()) + rel_trial_dir = Path(trial.artifact_dir).resolve().relative_to(repo_root.resolve()) + command = ( + f"{study.engine.python_executable} -m aituner.cli worker run-trial " + f"--trial-spec {rel_trial_path}" + ) + env = {"PYTHONPATH": "src"} + return InfraJob( + name=f"{study.study_id}-{trial.trial_id}", + gpus=study.hardware.gpu_count, + gpu_model=study.hardware.gpu_model, + hosts=list(study.hardware.host_candidates), + command=command, + artifacts=[str(rel_trial_dir)], + env=env, + ) + + +def append_job(jobs_path: Path, job: InfraJob) -> None: + jobs_path.parent.mkdir(parents=True, exist_ok=True) + with jobs_path.open("a", encoding="utf-8") as handle: + if jobs_path.stat().st_size == 0: + handle.write("version = 1\n") + handle.write("\n[[jobs]]\n") + handle.write(f"name = {_toml_scalar(job.name)}\n") + handle.write(f"gpus = {job.gpus}\n") + if job.gpu_model: + handle.write(f"gpu_model = {_toml_scalar(job.gpu_model)}\n") + if job.hosts: + handle.write(f"hosts = {_toml_list(job.hosts)}\n") + handle.write(f"command = {_toml_scalar(job.command)}\n") + if job.artifacts: + handle.write(f"artifacts = {_toml_list(job.artifacts)}\n") + if job.env: + handle.write(f"env = {_toml_inline_table(job.env)}\n") diff --git a/src/aituner/llm.py b/src/aituner/llm.py new file mode 100644 index 0000000..7664092 --- /dev/null +++ b/src/aituner/llm.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from .http_client import chat_completion +from .spec import LLMPolicySpec, Proposal, SpecError, StudySpec, StudyState + + +def build_prompt( + *, + study: StudySpec, + window_summary: dict[str, Any], + state: StudyState, + capability_profile: dict[str, Any] | None, +) -> str: + history = [] + for trial in state.trials[-study.llm.max_history_trials :]: + history.append( + { + "trial_id": trial.trial_id, + "status": trial.status, + "best_sampling_u": trial.best_sampling_u, + "best_request_rate": trial.best_request_rate, + "best_pass_rate": trial.best_pass_rate, + "diagnosis": trial.diagnosis, + } + ) + sections = [ + "You are tuning an OpenAI-compatible serving engine.", + "Return exactly one JSON object with keys: observation, diagnosis, config_patch, expected_effects, why_not_previous_failures.", + "config_patch must contain env_patch and flag_patch.", + "Only use allowed tunable env keys and allowed tunable flag keys.", + "", + "Study stack:", + json.dumps( + { + "study_id": study.study_id, + "hardware": { + "gpu_count": study.hardware.gpu_count, + "gpu_model": study.hardware.gpu_model, + }, + "model": { + "model_id": study.model.model_id, + "served_model_name": study.model.served_model_name, + }, + "engine": { + "engine_name": study.engine.engine_name, + "engine_version": study.engine.engine_version, + "base_flags": study.engine.base_flags, + "base_envs": study.engine.base_envs, + "allowed_flag_keys": study.engine.tunable_flags, + "allowed_env_keys": study.engine.tunable_envs, + }, + }, + ensure_ascii=False, + indent=2, + ), + "", + "Window summary:", + json.dumps(window_summary, ensure_ascii=False, indent=2), + "", + "SLO:", + json.dumps( + { + "target_pass_rate": study.slo.target_pass_rate, + "ttft_rule": study.slo.ttft_rule, + "tpot_rule": study.slo.tpot_rule, + }, + default=lambda value: value.__dict__, + ensure_ascii=False, + indent=2, + ), + "", + "Capability profile:", + json.dumps(capability_profile or {}, ensure_ascii=False, indent=2), + "", + "Trial history:", + json.dumps(history, ensure_ascii=False, indent=2), + "", + "The proposal should improve the maximum feasible sampling_u under the 95%+ SLO target.", + ] + return "\n".join(sections) + + +def load_capability_profile(study: StudySpec, *, study_spec_path: Path) -> dict[str, Any] | None: + if not study.capability_profile_path: + return None + path = Path(study.capability_profile_path) + if not path.is_absolute(): + path = (study_spec_path.parent / path).resolve() + return json.loads(path.read_text(encoding="utf-8")) + + +def validate_proposal(proposal: Proposal, study: StudySpec) -> Proposal: + unknown_envs = sorted(set(proposal.config_patch.env_patch) - set(study.engine.tunable_envs)) + unknown_flags = sorted( + set(proposal.config_patch.flag_patch) - set(study.engine.tunable_flags) + ) + if unknown_envs: + raise SpecError(f"Proposal uses unsupported env keys: {', '.join(unknown_envs)}") + if unknown_flags: + raise SpecError(f"Proposal uses unsupported flag keys: {', '.join(unknown_flags)}") + return proposal + + +def parse_proposal_text(text: str, study: StudySpec) -> Proposal: + payload = json.loads(text) + proposal = Proposal.from_dict(payload) + return validate_proposal(proposal, study) + + +def call_llm_for_proposal( + *, + policy: LLMPolicySpec, + prompt: str, +) -> str: + if policy.endpoint is None: + raise RuntimeError("study.llm.endpoint is not configured") + response = chat_completion( + base_url=policy.endpoint.base_url, + api_key_env=policy.endpoint.api_key_env, + model=policy.endpoint.model, + messages=[{"role": "user", "content": prompt}], + timeout_s=policy.endpoint.timeout_s, + system_prompt=policy.system_prompt, + ) + choices = response.get("choices") + if not isinstance(choices, list) or not choices: + raise RuntimeError("LLM response does not contain choices") + message = choices[0].get("message", {}) + if not isinstance(message, dict): + raise RuntimeError("LLM response does not contain a valid message") + content = message.get("content") + if isinstance(content, str): + return content + if isinstance(content, list): + return "".join( + item.get("text", "") + for item in content + if isinstance(item, dict) and isinstance(item.get("text"), str) + ) + raise RuntimeError("LLM response content is empty") diff --git a/src/aituner/search.py b/src/aituner/search.py new file mode 100644 index 0000000..5f8e5cb --- /dev/null +++ b/src/aituner/search.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Generic, TypeVar + + +T = TypeVar("T") + + +@dataclass(frozen=True) +class ThresholdProbe(Generic[T]): + threshold: float + feasible: bool + payload: T + + +@dataclass(frozen=True) +class ThresholdSearchResult(Generic[T]): + best_threshold: float + best_feasible_payload: T | None + probes: list[ThresholdProbe[T]] + + +def binary_search_max_feasible( + *, + low: float, + high: float, + tolerance: float, + max_probes: int, + evaluator: Callable[[float], ThresholdProbe[T]], +) -> ThresholdSearchResult[T]: + probes: list[ThresholdProbe[T]] = [] + cache: dict[float, ThresholdProbe[T]] = {} + best_threshold = low + best_payload: T | None = None + cur_low = low + cur_high = high + for _ in range(max_probes): + if cur_high - cur_low <= tolerance: + break + threshold = round((cur_low + cur_high) / 2.0, 12) + probe = cache.get(threshold) + if probe is None: + probe = evaluator(threshold) + cache[threshold] = probe + probes.append(probe) + if probe.feasible: + if threshold >= best_threshold: + best_threshold = threshold + best_payload = probe.payload + cur_low = threshold + else: + cur_high = threshold + return ThresholdSearchResult( + best_threshold=best_threshold, + best_feasible_payload=best_payload, + probes=probes, + ) diff --git a/src/aituner/slo.py b/src/aituner/slo.py new file mode 100644 index 0000000..0689172 --- /dev/null +++ b/src/aituner/slo.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from .spec import SloSpec, ThresholdRule + + +@dataclass(frozen=True) +class RequestOutcome: + request_id: str + success: bool + ttft_ms: float | None + tpot_ms: float | None + prompt_tokens: int | None + completion_tokens: int | None + error: str = "" + + +@dataclass(frozen=True) +class RequestEvaluation: + request_id: str + passed: bool + reasons: list[str] + + +def _rule_threshold_ms(rule: ThresholdRule, prompt_tokens: int | None) -> float: + if rule.kind == "fixed_ms": + assert rule.threshold_ms is not None + return rule.threshold_ms + if rule.kind != "step_ms": + raise ValueError(f"Unsupported threshold rule: {rule.kind}") + prompt = float(prompt_tokens or 0) + for bucket in rule.buckets: + ceiling = bucket.get("max_input_tokens") + if ceiling is None or prompt <= ceiling: + return float(bucket["threshold_ms"]) + return float(rule.buckets[-1]["threshold_ms"]) + + +def evaluate_request(outcome: RequestOutcome, slo: SloSpec) -> RequestEvaluation: + reasons: list[str] = [] + if not outcome.success: + reasons.append(outcome.error or "request_failed") + return RequestEvaluation(request_id=outcome.request_id, passed=False, reasons=reasons) + if slo.ttft_rule is not None: + if outcome.ttft_ms is None: + reasons.append("ttft_missing") + else: + threshold = _rule_threshold_ms(slo.ttft_rule, outcome.prompt_tokens) + if outcome.ttft_ms > threshold: + reasons.append(f"ttft_ms>{threshold}") + if slo.tpot_rule is not None: + if outcome.tpot_ms is None: + reasons.append("tpot_missing") + else: + threshold = _rule_threshold_ms(slo.tpot_rule, outcome.prompt_tokens) + if outcome.tpot_ms > threshold: + reasons.append(f"tpot_ms>{threshold}") + return RequestEvaluation( + request_id=outcome.request_id, + passed=not reasons, + reasons=reasons, + ) + + +def summarize_evaluations( + outcomes: list[RequestOutcome], slo: SloSpec +) -> tuple[list[RequestEvaluation], dict[str, Any]]: + evaluations = [evaluate_request(item, slo) for item in outcomes] + total = len(evaluations) + passed = sum(1 for item in evaluations if item.passed) + pass_rate = (passed / total) if total else 0.0 + return evaluations, { + "request_count": total, + "slo_pass_count": passed, + "slo_pass_rate": pass_rate, + "target_pass_rate": slo.target_pass_rate, + "feasible": pass_rate >= slo.target_pass_rate, + } diff --git a/src/aituner/spec.py b/src/aituner/spec.py new file mode 100644 index 0000000..3b7fbca --- /dev/null +++ b/src/aituner/spec.py @@ -0,0 +1,440 @@ +from __future__ import annotations + +import json +import tomllib +from dataclasses import asdict, dataclass, field, is_dataclass +from pathlib import Path +from typing import Any, Mapping + + +class SpecError(ValueError): + """Raised when a structured spec is invalid.""" + + +def _require_mapping(value: Any, *, context: str) -> Mapping[str, Any]: + if not isinstance(value, Mapping): + raise SpecError(f"{context} must be an object.") + return value + + +def _require_str(value: Any, *, context: str) -> str: + if not isinstance(value, str) or not value.strip(): + raise SpecError(f"{context} must be a non-empty string.") + return value.strip() + + +def _require_float(value: Any, *, context: str) -> float: + if isinstance(value, bool) or not isinstance(value, (int, float)): + raise SpecError(f"{context} must be numeric.") + return float(value) + + +def _require_int(value: Any, *, context: str) -> int: + if isinstance(value, bool) or not isinstance(value, int): + raise SpecError(f"{context} must be an integer.") + return value + + +def _coerce_str_map(value: Any, *, context: str) -> dict[str, str]: + mapping = _require_mapping(value or {}, context=context) + return {str(key): str(item) for key, item in mapping.items()} + + +def _coerce_any_map(value: Any, *, context: str) -> dict[str, Any]: + mapping = _require_mapping(value or {}, context=context) + return {str(key): item for key, item in mapping.items()} + + +def _coerce_str_list(value: Any, *, context: str) -> list[str]: + if value is None: + return [] + if not isinstance(value, list): + raise SpecError(f"{context} must be a list.") + result: list[str] = [] + for item in value: + result.append(_require_str(item, context=context)) + return result + + +@dataclass(frozen=True) +class HardwareSpec: + gpu_count: int + gpu_model: str | None = None + host_candidates: list[str] = field(default_factory=list) + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "HardwareSpec": + return cls( + gpu_count=_require_int(data.get("gpu_count"), context="hardware.gpu_count"), + gpu_model=str(data["gpu_model"]).strip() if data.get("gpu_model") else None, + host_candidates=_coerce_str_list( + data.get("host_candidates"), context="hardware.host_candidates" + ), + ) + + +@dataclass(frozen=True) +class ModelSpec: + model_id: str + served_model_name: str + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "ModelSpec": + return cls( + model_id=_require_str(data.get("model_id"), context="model.model_id"), + served_model_name=_require_str( + data.get("served_model_name"), context="model.served_model_name" + ), + ) + + +@dataclass(frozen=True) +class EngineLaunchSpec: + engine_name: str + engine_version: str | None + exec_path: str + cwd: str | None + host: str + port: int + ready_timeout_s: float + request_timeout_s: float + healthcheck_path: str + launch_args: list[str] + base_envs: dict[str, str] + base_flags: dict[str, Any] + tunable_envs: list[str] + tunable_flags: list[str] + python_executable: str = "python3" + + @property + def base_url(self) -> str: + return f"http://{self.host}:{self.port}" + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "EngineLaunchSpec": + return cls( + engine_name=_require_str(data.get("engine_name"), context="engine.engine_name"), + engine_version=str(data["engine_version"]).strip() + if data.get("engine_version") + else None, + exec_path=_require_str(data.get("exec_path"), context="engine.exec_path"), + cwd=str(data["cwd"]).strip() if data.get("cwd") else None, + host=str(data.get("host") or "127.0.0.1").strip(), + port=_require_int(data.get("port", 8000), context="engine.port"), + ready_timeout_s=_require_float( + data.get("ready_timeout_s", 600.0), context="engine.ready_timeout_s" + ), + request_timeout_s=_require_float( + data.get("request_timeout_s", 600.0), + context="engine.request_timeout_s", + ), + healthcheck_path=str(data.get("healthcheck_path") or "/v1/models").strip(), + launch_args=_coerce_str_list(data.get("launch_args"), context="engine.launch_args"), + base_envs=_coerce_str_map(data.get("base_envs"), context="engine.base_envs"), + base_flags=_coerce_any_map(data.get("base_flags"), context="engine.base_flags"), + tunable_envs=_coerce_str_list( + data.get("tunable_envs"), context="engine.tunable_envs" + ), + tunable_flags=_coerce_str_list( + data.get("tunable_flags"), context="engine.tunable_flags" + ), + python_executable=str(data.get("python_executable") or "python3").strip(), + ) + + +@dataclass(frozen=True) +class TraceSpec: + windows_path: str + window_id: str + trace_file_override: str | None + u_field: str + timestamp_field: str + max_concurrency: int + max_requests_per_probe: int | None = None + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec": + max_requests = data.get("max_requests_per_probe") + return cls( + windows_path=_require_str(data.get("windows_path"), context="trace.windows_path"), + window_id=_require_str(data.get("window_id"), context="trace.window_id"), + trace_file_override=str(data["trace_file_override"]).strip() + if data.get("trace_file_override") + else None, + u_field=str(data.get("u_field") or "sampling_u").strip(), + timestamp_field=str(data.get("timestamp_field") or "timestamp").strip(), + max_concurrency=_require_int( + data.get("max_concurrency", 64), context="trace.max_concurrency" + ), + max_requests_per_probe=int(max_requests) if max_requests is not None else None, + ) + + +@dataclass(frozen=True) +class ThresholdRule: + kind: str + threshold_ms: float | None = None + buckets: list[dict[str, float]] = field(default_factory=list) + + @classmethod + def from_dict(cls, data: Mapping[str, Any], *, context: str) -> "ThresholdRule": + kind = _require_str(data.get("kind"), context=f"{context}.kind") + if kind == "fixed_ms": + return cls( + kind=kind, + threshold_ms=_require_float( + data.get("threshold_ms"), context=f"{context}.threshold_ms" + ), + ) + if kind == "step_ms": + raw = data.get("buckets") + if not isinstance(raw, list) or not raw: + raise SpecError(f"{context}.buckets must be a non-empty list.") + buckets: list[dict[str, float]] = [] + for idx, item in enumerate(raw): + mapping = _require_mapping(item, context=f"{context}.buckets[{idx}]") + bucket: dict[str, float] = { + "threshold_ms": _require_float( + mapping.get("threshold_ms"), + context=f"{context}.buckets[{idx}].threshold_ms", + ) + } + if "max_input_tokens" in mapping and mapping["max_input_tokens"] is not None: + bucket["max_input_tokens"] = _require_float( + mapping["max_input_tokens"], + context=f"{context}.buckets[{idx}].max_input_tokens", + ) + buckets.append(bucket) + return cls(kind=kind, buckets=buckets) + raise SpecError(f"Unsupported threshold rule kind: {kind}") + + +@dataclass(frozen=True) +class SloSpec: + target_pass_rate: float + ttft_rule: ThresholdRule | None + tpot_rule: ThresholdRule | None + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "SloSpec": + ttft_rule = ( + ThresholdRule.from_dict( + _require_mapping(data["ttft_rule"], context="slo.ttft_rule"), + context="slo.ttft_rule", + ) + if data.get("ttft_rule") + else None + ) + tpot_rule = ( + ThresholdRule.from_dict( + _require_mapping(data["tpot_rule"], context="slo.tpot_rule"), + context="slo.tpot_rule", + ) + if data.get("tpot_rule") + else None + ) + return cls( + target_pass_rate=_require_float( + data.get("target_pass_rate", 0.95), context="slo.target_pass_rate" + ), + ttft_rule=ttft_rule, + tpot_rule=tpot_rule, + ) + + +@dataclass(frozen=True) +class SamplingSearchSpec: + low: float + high: float + tolerance: float + max_probes: int + sample_seed: int + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "SamplingSearchSpec": + return cls( + low=_require_float(data.get("low", 0.0), context="search.low"), + high=_require_float(data.get("high", 1.0), context="search.high"), + tolerance=_require_float( + data.get("tolerance", 0.01), context="search.tolerance" + ), + max_probes=_require_int(data.get("max_probes", 8), context="search.max_probes"), + sample_seed=_require_int( + data.get("sample_seed", 20260325), context="search.sample_seed" + ), + ) + + +@dataclass(frozen=True) +class LLMEndpointSpec: + base_url: str + model: str + api_key_env: str = "OPENAI_API_KEY" + timeout_s: float = 120.0 + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "LLMEndpointSpec": + return cls( + base_url=_require_str(data.get("base_url"), context="llm.endpoint.base_url"), + model=_require_str(data.get("model"), context="llm.endpoint.model"), + api_key_env=str(data.get("api_key_env") or "OPENAI_API_KEY").strip(), + timeout_s=_require_float( + data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s" + ), + ) + + +@dataclass(frozen=True) +class LLMPolicySpec: + endpoint: LLMEndpointSpec | None + system_prompt: str + max_history_trials: int + + @classmethod + def from_dict(cls, data: Mapping[str, Any] | None) -> "LLMPolicySpec": + payload = _require_mapping(data or {}, context="llm") + endpoint = ( + LLMEndpointSpec.from_dict( + _require_mapping(payload["endpoint"], context="llm.endpoint") + ) + if payload.get("endpoint") + else None + ) + return cls( + endpoint=endpoint, + system_prompt=str(payload.get("system_prompt") or "").strip(), + max_history_trials=_require_int( + payload.get("max_history_trials", 8), context="llm.max_history_trials" + ), + ) + + +@dataclass(frozen=True) +class StudySpec: + study_id: str + hardware: HardwareSpec + model: ModelSpec + engine: EngineLaunchSpec + trace: TraceSpec + slo: SloSpec + search: SamplingSearchSpec + llm: LLMPolicySpec + capability_profile_path: str | None = None + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "StudySpec": + return cls( + study_id=_require_str(data.get("study_id"), context="study_id"), + hardware=HardwareSpec.from_dict( + _require_mapping(data.get("hardware"), context="hardware") + ), + model=ModelSpec.from_dict(_require_mapping(data.get("model"), context="model")), + engine=EngineLaunchSpec.from_dict( + _require_mapping(data.get("engine"), context="engine") + ), + trace=TraceSpec.from_dict(_require_mapping(data.get("trace"), context="trace")), + slo=SloSpec.from_dict(_require_mapping(data.get("slo"), context="slo")), + search=SamplingSearchSpec.from_dict( + _require_mapping(data.get("search"), context="search") + ), + llm=LLMPolicySpec.from_dict(data.get("llm")), + capability_profile_path=str(data["capability_profile_path"]).strip() + if data.get("capability_profile_path") + else None, + ) + + +@dataclass(frozen=True) +class ConfigPatch: + env_patch: dict[str, str] = field(default_factory=dict) + flag_patch: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "ConfigPatch": + return cls( + env_patch=_coerce_str_map(data.get("env_patch"), context="config_patch.env_patch"), + flag_patch=_coerce_any_map( + data.get("flag_patch"), context="config_patch.flag_patch" + ), + ) + + +@dataclass(frozen=True) +class Proposal: + observation: str + diagnosis: str + config_patch: ConfigPatch + expected_effects: list[str] + why_not_previous_failures: str = "" + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "Proposal": + return cls( + observation=_require_str(data.get("observation"), context="proposal.observation"), + diagnosis=_require_str(data.get("diagnosis"), context="proposal.diagnosis"), + config_patch=ConfigPatch.from_dict( + _require_mapping(data.get("config_patch"), context="proposal.config_patch") + ), + expected_effects=_coerce_str_list( + data.get("expected_effects"), context="proposal.expected_effects" + ), + why_not_previous_failures=str(data.get("why_not_previous_failures") or "").strip(), + ) + + +@dataclass(frozen=True) +class TrialSpec: + study_id: str + trial_id: str + config_patch: ConfigPatch + search: SamplingSearchSpec + study_spec_path: str + artifact_dir: str + probe_log_path: str + engine_log_path: str + result_path: str + + +@dataclass +class TrialSummary: + trial_id: str + status: str + best_sampling_u: float | None = None + best_request_rate: float | None = None + best_pass_rate: float | None = None + result_path: str | None = None + diagnosis: str = "" + + +@dataclass +class StudyState: + study_id: str + best_trial_id: str | None = None + best_request_rate: float | None = None + next_trial_index: int = 1 + trials: list[TrialSummary] = field(default_factory=list) + + +def to_jsonable(value: Any) -> Any: + if is_dataclass(value): + return {key: to_jsonable(item) for key, item in asdict(value).items()} + if isinstance(value, dict): + return {str(key): to_jsonable(item) for key, item in value.items()} + if isinstance(value, list): + return [to_jsonable(item) for item in value] + return value + + +def load_structured_file(path: Path) -> Mapping[str, Any]: + suffix = path.suffix.lower() + if suffix == ".json": + payload = json.loads(path.read_text(encoding="utf-8")) + elif suffix in {".toml", ".tml"}: + payload = tomllib.loads(path.read_text(encoding="utf-8")) + else: + raise SpecError(f"Unsupported spec file type: {path}") + return _require_mapping(payload, context=str(path)) + + +def load_study_spec(path: Path) -> StudySpec: + return StudySpec.from_dict(load_structured_file(path)) diff --git a/src/aituner/store.py b/src/aituner/store.py new file mode 100644 index 0000000..247c472 --- /dev/null +++ b/src/aituner/store.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import json +from dataclasses import replace +from pathlib import Path +from typing import Any + +from .spec import Proposal, StudySpec, StudyState, TrialSpec, TrialSummary, to_jsonable + + +class StudyStore: + def __init__(self, root: Path | None = None): + base = root or Path(".aituner") / "studies" + self.root = base.resolve() + + def study_root(self, study_id: str) -> Path: + return self.root / study_id + + def init_study(self, *, spec_path: Path, study: StudySpec) -> Path: + root = self.study_root(study.study_id) + for rel in ("prompts", "proposals", "trials", "results"): + (root / rel).mkdir(parents=True, exist_ok=True) + (root / "study_spec.source").write_text(str(spec_path.resolve()) + "\n", encoding="utf-8") + self.write_json(root / "study_spec.snapshot.json", to_jsonable(study)) + if not (root / "state.json").exists(): + self.write_json(root / "state.json", to_jsonable(StudyState(study_id=study.study_id))) + return root + + def load_state(self, study_id: str) -> StudyState: + payload = json.loads((self.study_root(study_id) / "state.json").read_text(encoding="utf-8")) + trials = [TrialSummary(**item) for item in payload.get("trials", [])] + return StudyState( + study_id=str(payload["study_id"]), + best_trial_id=payload.get("best_trial_id"), + best_request_rate=payload.get("best_request_rate"), + next_trial_index=int(payload.get("next_trial_index", 1)), + trials=trials, + ) + + def save_state(self, state: StudyState) -> None: + self.write_json(self.study_root(state.study_id) / "state.json", to_jsonable(state)) + + def write_prompt(self, study_id: str, prompt_name: str, prompt_text: str) -> Path: + path = self.study_root(study_id) / "prompts" / f"{prompt_name}.txt" + path.write_text(prompt_text, encoding="utf-8") + return path + + def write_proposal(self, study_id: str, proposal_name: str, proposal: Proposal) -> Path: + path = self.study_root(study_id) / "proposals" / f"{proposal_name}.json" + self.write_json(path, to_jsonable(proposal)) + return path + + def materialize_trial( + self, + *, + study: StudySpec, + state: StudyState, + proposal: Proposal, + ) -> tuple[TrialSpec, StudyState]: + trial_id = f"trial-{state.next_trial_index:04d}" + trial_root = self.study_root(study.study_id) / "trials" / trial_id + trial_root.mkdir(parents=True, exist_ok=True) + spec = TrialSpec( + study_id=study.study_id, + trial_id=trial_id, + config_patch=proposal.config_patch, + search=study.search, + study_spec_path=str((self.study_root(study.study_id) / "study_spec.source").resolve()), + artifact_dir=str(trial_root), + probe_log_path=str(trial_root / "probe_history.json"), + engine_log_path=str(trial_root / "engine.log"), + result_path=str(trial_root / "result.json"), + ) + self.write_json(trial_root / "trial_spec.json", to_jsonable(spec)) + next_state = replace(state, next_trial_index=state.next_trial_index + 1) + next_state.trials.append( + TrialSummary(trial_id=trial_id, status="queued", diagnosis=proposal.diagnosis) + ) + self.save_state(next_state) + return spec, next_state + + def ingest_trial_results(self, study_id: str) -> StudyState: + state = self.load_state(study_id) + by_id = {item.trial_id: item for item in state.trials} + trials_dir = self.study_root(study_id) / "trials" + best_trial_id = state.best_trial_id + best_rate = state.best_request_rate + for trial_dir in sorted(trials_dir.glob("trial-*")): + result_path = trial_dir / "result.json" + if not result_path.exists(): + continue + payload = json.loads(result_path.read_text(encoding="utf-8")) + trial_id = str(payload["trial_id"]) + summary = by_id.get(trial_id) + if summary is None: + summary = TrialSummary(trial_id=trial_id, status="unknown") + state.trials.append(summary) + by_id[trial_id] = summary + summary.status = str(payload.get("status") or "completed") + summary.best_sampling_u = payload.get("best_sampling_u") + summary.best_request_rate = payload.get("best_request_rate") + summary.best_pass_rate = payload.get("best_pass_rate") + summary.result_path = str(result_path) + if ( + isinstance(summary.best_request_rate, (int, float)) + and (best_rate is None or summary.best_request_rate > best_rate) + ): + best_rate = float(summary.best_request_rate) + best_trial_id = trial_id + state.best_request_rate = best_rate + state.best_trial_id = best_trial_id + self.save_state(state) + return state + + @staticmethod + def write_json(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") diff --git a/src/aituner/trace.py b/src/aituner/trace.py new file mode 100644 index 0000000..8eec760 --- /dev/null +++ b/src/aituner/trace.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import json +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Mapping + +from .spec import StudySpec + + +class TraceError(ValueError): + """Raised when trace assets are invalid.""" + + +def _percentile(values: list[float], p: float) -> float: + if not values: + return 0.0 + ordered = sorted(values) + idx = min(len(ordered) - 1, max(0, math.ceil((p / 100.0) * len(ordered)) - 1)) + return ordered[idx] + + +@dataclass(frozen=True) +class WindowRecord: + window_id: str + trace_path: Path + trace_type: str + window_start: float + window_end: float + source_payload: dict[str, Any] + + +@dataclass(frozen=True) +class TraceRequest: + row_id: str + arrival_s: float + sampling_u: float + body: dict[str, Any] + prompt_tokens_hint: int | None + completion_tokens_hint: int | None + + +def resolve_window_record(study: StudySpec, *, study_spec_path: Path) -> WindowRecord: + windows_path = Path(study.trace.windows_path) + if not windows_path.is_absolute(): + windows_path = (study_spec_path.parent / windows_path).resolve() + payload = json.loads(windows_path.read_text(encoding="utf-8")) + windows = payload["windows"] if isinstance(payload, Mapping) and "windows" in payload else payload + if not isinstance(windows, list): + raise TraceError(f"windows payload must contain a list: {windows_path}") + for item in windows: + if not isinstance(item, Mapping): + continue + if str(item.get("window_id") or "").strip() != study.trace.window_id: + continue + trace_file = study.trace.trace_file_override or str(item.get("trace_file") or "").strip() + if not trace_file: + raise TraceError(f"window {study.trace.window_id} does not define trace_file") + trace_path = Path(trace_file) + if not trace_path.is_absolute(): + trace_path = (windows_path.parent / trace_path).resolve() + return WindowRecord( + window_id=study.trace.window_id, + trace_path=trace_path, + trace_type=str(item.get("trace_type") or "chat").strip(), + window_start=float(item.get("window_start") or 0.0), + window_end=float(item.get("window_end") or 0.0), + source_payload={str(key): value for key, value in item.items()}, + ) + raise TraceError(f"window_id not found: {study.trace.window_id}") + + +def _coerce_messages(row: Mapping[str, Any]) -> list[dict[str, Any]]: + messages = row.get("messages") + if isinstance(messages, list) and messages: + return [dict(item) for item in messages if isinstance(item, Mapping)] + prompt = row.get("prompt") or row.get("input") or row.get("text") + if isinstance(prompt, str) and prompt.strip(): + return [{"role": "user", "content": prompt}] + raise TraceError("trace row is missing chat messages/prompt text") + + +def _coerce_completion_tokens(row: Mapping[str, Any]) -> int | None: + for key in ("max_completion_tokens", "max_tokens", "output_length", "completion_tokens"): + value = row.get(key) + if isinstance(value, bool): + continue + if isinstance(value, int) and value >= 0: + return value + if isinstance(value, float) and value >= 0: + return int(value) + return None + + +def _coerce_prompt_tokens(row: Mapping[str, Any]) -> int | None: + for key in ("input_length", "prompt_length", "prompt_len", "input_tokens"): + value = row.get(key) + if isinstance(value, bool): + continue + if isinstance(value, int) and value >= 0: + return value + if isinstance(value, float) and value >= 0: + return int(value) + return None + + +def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[WindowRecord, list[TraceRequest]]: + window = resolve_window_record(study, study_spec_path=study_spec_path) + requests: list[TraceRequest] = [] + with window.trace_path.open("r", encoding="utf-8") as handle: + for idx, raw in enumerate(handle): + if not raw.strip(): + continue + row = json.loads(raw) + if not isinstance(row, Mapping): + continue + timestamp = row.get(study.trace.timestamp_field) + if timestamp is None: + timestamp = row.get("arrival_time", row.get("timestamp")) + if isinstance(timestamp, bool) or not isinstance(timestamp, (int, float)): + raise TraceError(f"trace row {idx} is missing numeric timestamp") + sampling_u = row.get(study.trace.u_field, 1.0) + if isinstance(sampling_u, bool) or not isinstance(sampling_u, (int, float)): + raise TraceError(f"trace row {idx} is missing numeric {study.trace.u_field}") + body: dict[str, Any] = { + "model": study.model.served_model_name, + "messages": _coerce_messages(row), + "stream": True, + "stream_options": {"include_usage": True}, + } + completion_tokens = _coerce_completion_tokens(row) + if completion_tokens is not None: + body["max_tokens"] = completion_tokens + temperature = row.get("temperature") + if isinstance(temperature, (int, float)) and not isinstance(temperature, bool): + body["temperature"] = temperature + requests.append( + TraceRequest( + row_id=str(row.get("request_id") or row.get("id") or idx), + arrival_s=float(timestamp), + sampling_u=float(sampling_u), + body=body, + prompt_tokens_hint=_coerce_prompt_tokens(row), + completion_tokens_hint=completion_tokens, + ) + ) + requests.sort(key=lambda item: item.arrival_s) + if study.trace.max_requests_per_probe is not None: + requests = requests[: study.trace.max_requests_per_probe] + return window, requests + + +def summarize_window(requests: list[TraceRequest], window: WindowRecord) -> dict[str, Any]: + prompt_tokens = [float(item.prompt_tokens_hint or 0) for item in requests] + completion_tokens = [float(item.completion_tokens_hint or 0) for item in requests] + duration = max(window.window_end - window.window_start, 0.0) or ( + requests[-1].arrival_s - requests[0].arrival_s if len(requests) >= 2 else 0.0 + ) + qps = (len(requests) / duration) if duration > 0 else 0.0 + return { + "window_id": window.window_id, + "trace_path": str(window.trace_path), + "trace_type": window.trace_type, + "request_count": len(requests), + "duration_s": duration, + "request_rate": qps, + "prompt_tokens_p50": _percentile(prompt_tokens, 50.0), + "prompt_tokens_p95": _percentile(prompt_tokens, 95.0), + "completion_tokens_p50": _percentile(completion_tokens, 50.0), + "completion_tokens_p95": _percentile(completion_tokens, 95.0), + } + + +def select_requests_for_threshold( + requests: list[TraceRequest], *, threshold: float +) -> list[TraceRequest]: + return [item for item in requests if item.sampling_u <= threshold] diff --git a/src/aituner/worker.py b/src/aituner/worker.py new file mode 100644 index 0000000..d02470f --- /dev/null +++ b/src/aituner/worker.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +import json +import subprocess +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .engine import build_launch_recipe +from .http_client import HttpClientError, stream_chat_completion, wait_for_server +from .search import ThresholdProbe, binary_search_max_feasible +from .slo import RequestOutcome, summarize_evaluations +from .spec import ConfigPatch, SamplingSearchSpec, TrialSpec, load_study_spec +from .trace import TraceRequest, load_trace_requests, select_requests_for_threshold + + +@dataclass(frozen=True) +class ProbePayload: + threshold: float + request_count: int + pass_rate: float + request_rate: float + feasible: bool + outcomes: list[dict[str, Any]] + +def _trial_spec_from_json(path: Path) -> TrialSpec: + payload = json.loads(path.read_text(encoding="utf-8")) + return TrialSpec( + study_id=str(payload["study_id"]), + trial_id=str(payload["trial_id"]), + config_patch=ConfigPatch.from_dict(payload["config_patch"]), + search=SamplingSearchSpec.from_dict(payload["search"]), + study_spec_path=str(payload["study_spec_path"]), + artifact_dir=str(payload["artifact_dir"]), + probe_log_path=str(payload["probe_log_path"]), + engine_log_path=str(payload["engine_log_path"]), + result_path=str(payload["result_path"]), + ) + + +def _run_one_request( + request: TraceRequest, + *, + base_url: str, + timeout_s: float, +) -> RequestOutcome: + try: + metrics = stream_chat_completion(base_url=base_url, body=request.body, timeout_s=timeout_s) + return RequestOutcome( + request_id=request.row_id, + success=True, + ttft_ms=metrics.ttft_ms, + tpot_ms=metrics.tpot_ms, + prompt_tokens=request.prompt_tokens_hint, + completion_tokens=metrics.completion_tokens or request.completion_tokens_hint, + ) + except HttpClientError as exc: + return RequestOutcome( + request_id=request.row_id, + success=False, + ttft_ms=None, + tpot_ms=None, + prompt_tokens=request.prompt_tokens_hint, + completion_tokens=request.completion_tokens_hint, + error=str(exc), + ) + + +def _replay_requests( + requests: list[TraceRequest], + *, + base_url: str, + timeout_s: float, + max_concurrency: int, +) -> list[RequestOutcome]: + outcomes_by_id: dict[str, RequestOutcome] = {} + lock = threading.Lock() + start = time.monotonic() + with ThreadPoolExecutor(max_workers=max_concurrency) as pool: + futures = [] + for request in requests: + delay = max(0.0, request.arrival_s) + now = time.monotonic() + sleep_for = (start + delay) - now + if sleep_for > 0: + time.sleep(sleep_for) + futures.append( + pool.submit( + _run_one_request, + request, + base_url=base_url, + timeout_s=timeout_s, + ) + ) + for future in as_completed(futures): + outcome = future.result() + with lock: + outcomes_by_id[outcome.request_id] = outcome + return [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id] + + +def run_trial(trial_spec_path: Path) -> dict[str, Any]: + from .store import StudyStore + + trial = _trial_spec_from_json(trial_spec_path) + study_spec_path = Path(Path(trial.study_spec_path).read_text(encoding="utf-8").strip()) + study = load_study_spec(study_spec_path) + window, requests = load_trace_requests(study, study_spec_path=study_spec_path) + recipe = build_launch_recipe(study.engine, trial.config_patch) + artifact_dir = Path(trial.artifact_dir) + artifact_dir.mkdir(parents=True, exist_ok=True) + engine_log_path = Path(trial.engine_log_path) + with engine_log_path.open("w", encoding="utf-8") as engine_log: + process = subprocess.Popen( # noqa: S603 + recipe.argv, + cwd=recipe.cwd, + env=recipe.env, + stdout=engine_log, + stderr=subprocess.STDOUT, + text=True, + ) + try: + wait_for_server(recipe.base_url, recipe.healthcheck_path, recipe.ready_timeout_s) + probe_history: list[dict[str, Any]] = [] + + def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]: + selected = select_requests_for_threshold(requests, threshold=threshold) + outcomes = _replay_requests( + selected, + base_url=recipe.base_url, + timeout_s=recipe.request_timeout_s, + max_concurrency=study.trace.max_concurrency, + ) + evaluations, summary = summarize_evaluations(outcomes, study.slo) + request_rate = ( + len(selected) / max(window.window_end - window.window_start, 1e-9) + if selected + else 0.0 + ) + payload = ProbePayload( + threshold=threshold, + request_count=len(selected), + pass_rate=float(summary["slo_pass_rate"]), + request_rate=request_rate, + feasible=bool(summary["feasible"]), + outcomes=[ + { + "request_id": outcome.request_id, + "success": outcome.success, + "ttft_ms": outcome.ttft_ms, + "tpot_ms": outcome.tpot_ms, + "prompt_tokens": outcome.prompt_tokens, + "completion_tokens": outcome.completion_tokens, + "evaluation": evaluation.passed, + "reasons": evaluation.reasons, + } + for outcome, evaluation in zip(outcomes, evaluations) + ], + ) + probe_record = { + "threshold": threshold, + "request_count": payload.request_count, + "pass_rate": payload.pass_rate, + "request_rate": payload.request_rate, + "feasible": payload.feasible, + } + probe_history.append(probe_record) + StudyStore.write_json(Path(trial.probe_log_path), probe_history) + return ThresholdProbe( + threshold=threshold, + feasible=payload.feasible, + payload=payload, + ) + + search = binary_search_max_feasible( + low=trial.search.low, + high=trial.search.high, + tolerance=trial.search.tolerance, + max_probes=trial.search.max_probes, + evaluator=evaluator, + ) + best = search.best_feasible_payload + result = { + "study_id": trial.study_id, + "trial_id": trial.trial_id, + "status": "completed", + "best_sampling_u": search.best_threshold if best is not None else None, + "best_request_rate": best.request_rate if best is not None else None, + "best_pass_rate": best.pass_rate if best is not None else None, + "best_request_count": best.request_count if best is not None else None, + "probes": [ + { + "threshold": probe.threshold, + "feasible": probe.feasible, + "payload": { + "request_count": probe.payload.request_count, + "pass_rate": probe.payload.pass_rate, + "request_rate": probe.payload.request_rate, + }, + } + for probe in search.probes + ], + } + StudyStore.write_json(Path(trial.result_path), result) + return result + finally: + process.terminate() + try: + process.wait(timeout=30) + except subprocess.TimeoutExpired: + process.kill() + process.wait(timeout=30) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1dfc780 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +import sys +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +SRC = ROOT / "src" +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py new file mode 100644 index 0000000..9734ee5 --- /dev/null +++ b/tests/test_core_flow.py @@ -0,0 +1,267 @@ +from __future__ import annotations + +import json +import tempfile +import unittest +from pathlib import Path + +from aituner.job import append_job, build_trial_job +from aituner.llm import build_prompt, parse_proposal_text +from aituner.search import ThresholdProbe, binary_search_max_feasible +from aituner.slo import RequestOutcome, summarize_evaluations +from aituner.spec import Proposal, load_study_spec +from aituner.store import StudyStore +from aituner.trace import load_trace_requests, summarize_window + + +def _write_study_assets(tmp_path: Path) -> Path: + trace_dir = tmp_path / "trace_windows" / "traces" + trace_dir.mkdir(parents=True) + trace_path = trace_dir / "chat_w1.jsonl" + rows = [ + { + "request_id": "r1", + "timestamp": 0.0, + "sampling_u": 0.10, + "messages": [{"role": "user", "content": "hello"}], + "input_length": 1000, + "output_length": 16 + }, + { + "request_id": "r2", + "timestamp": 1.0, + "sampling_u": 0.50, + "messages": [{"role": "user", "content": "world"}], + "input_length": 5000, + "output_length": 32 + }, + { + "request_id": "r3", + "timestamp": 2.0, + "sampling_u": 0.90, + "messages": [{"role": "user", "content": "!"}], + "input_length": 20000, + "output_length": 64 + } + ] + with trace_path.open("w", encoding="utf-8") as handle: + for row in rows: + handle.write(json.dumps(row) + "\n") + + windows_path = tmp_path / "trace_windows" / "windows.json" + windows_payload = { + "u_field": "sampling_u", + "windows": [ + { + "window_id": "chat_w1", + "trace_type": "chat", + "trace_file": "traces/chat_w1.jsonl", + "window_start": 0.0, + "window_end": 10.0 + } + ] + } + windows_path.write_text(json.dumps(windows_payload), encoding="utf-8") + + capability_path = tmp_path / "capability.json" + capability_path.write_text( + json.dumps({"prefill_service_by_bucket": {"4k": {"tp4_ms": 320, "tp8_ms": 240}}}), + encoding="utf-8", + ) + + study_path = tmp_path / "study.json" + study_payload = { + "study_id": "study-1", + "hardware": {"gpu_count": 8, "gpu_model": "H20", "host_candidates": ["dash0"]}, + "model": { + "model_id": "qwen", + "served_model_name": "Qwen/Qwen3-30B-A3B-Instruct-2507" + }, + "engine": { + "engine_name": "vllm", + "engine_version": "0.1", + "exec_path": "/usr/local/bin/vllm", + "cwd": str(tmp_path), + "host": "127.0.0.1", + "port": 8000, + "healthcheck_path": "/v1/models", + "ready_timeout_s": 30, + "request_timeout_s": 30, + "launch_args": ["serve", "/models/qwen"], + "base_envs": {"BASE_ENV": "1"}, + "base_flags": {"host": "127.0.0.1", "port": 8000}, + "tunable_envs": ["VLLM_ATTENTION_BACKEND"], + "tunable_flags": ["tensor-parallel-size", "max-num-seqs"], + "python_executable": "python3" + }, + "trace": { + "windows_path": str(windows_path), + "window_id": "chat_w1", + "u_field": "sampling_u", + "timestamp_field": "timestamp", + "max_concurrency": 4 + }, + "slo": { + "target_pass_rate": 0.95, + "ttft_rule": { + "kind": "step_ms", + "buckets": [ + {"max_input_tokens": 4096, "threshold_ms": 2000}, + {"max_input_tokens": 16384, "threshold_ms": 5000}, + {"threshold_ms": 9000} + ] + }, + "tpot_rule": {"kind": "fixed_ms", "threshold_ms": 120} + }, + "search": { + "low": 0.0, + "high": 1.0, + "tolerance": 0.01, + "max_probes": 8, + "sample_seed": 20260325 + }, + "llm": {"system_prompt": "Tune it.", "max_history_trials": 8}, + "capability_profile_path": str(capability_path) + } + study_path.write_text(json.dumps(study_payload), encoding="utf-8") + return study_path + + +class CoreFlowTests(unittest.TestCase): + def test_trace_and_prompt_flow(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + study_path = _write_study_assets(tmp_path) + study = load_study_spec(study_path) + store = StudyStore(tmp_path / ".aituner" / "studies") + study_root = store.init_study(spec_path=study_path, study=study) + state = store.load_state(study.study_id) + + window, requests = load_trace_requests(study, study_spec_path=study_path) + summary = summarize_window(requests, window) + self.assertEqual(summary["request_count"], 3) + self.assertEqual(summary["request_rate"], 0.3) + + prompt = build_prompt( + study=study, + window_summary=summary, + state=state, + capability_profile={"queueing_knee_by_bucket": {"4k": 1000}}, + ) + self.assertIn("allowed_flag_keys", prompt) + self.assertIn("study-1", prompt) + self.assertIn("queueing_knee_by_bucket", prompt) + self.assertTrue(study_root.exists()) + + def test_slo_evaluation_step_and_fixed_rules(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + study = load_study_spec(_write_study_assets(Path(tmp))) + outcomes = [ + RequestOutcome( + request_id="r1", + success=True, + ttft_ms=1000, + tpot_ms=100, + prompt_tokens=1000, + completion_tokens=16, + ), + RequestOutcome( + request_id="r2", + success=True, + ttft_ms=6000, + tpot_ms=100, + prompt_tokens=5000, + completion_tokens=16, + ), + ] + evaluations, summary = summarize_evaluations(outcomes, study.slo) + self.assertTrue(evaluations[0].passed) + self.assertFalse(evaluations[1].passed) + self.assertEqual(summary["slo_pass_rate"], 0.5) + self.assertFalse(summary["feasible"]) + + def test_binary_search_max_feasible(self) -> None: + result = binary_search_max_feasible( + low=0.0, + high=1.0, + tolerance=0.01, + max_probes=8, + evaluator=lambda threshold: ThresholdProbe( + threshold=threshold, + feasible=threshold <= 0.625, + payload={"threshold": threshold}, + ), + ) + self.assertLessEqual(result.best_threshold, 0.625) + self.assertGreaterEqual(result.best_threshold, 0.5) + self.assertIsNotNone(result.best_feasible_payload) + + def test_proposal_validation_and_job_emission(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + study_path = _write_study_assets(tmp_path) + study = load_study_spec(study_path) + store = StudyStore(tmp_path / ".aituner" / "studies") + store.init_study(spec_path=study_path, study=study) + state = store.load_state(study.study_id) + + proposal_text = json.dumps( + { + "observation": "Current TTFT fails before TPOT.", + "diagnosis": "Prefill pressure dominates.", + "config_patch": { + "env_patch": {"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, + "flag_patch": {"tensor-parallel-size": 4, "max-num-seqs": 64} + }, + "expected_effects": ["lower TTFT", "raise feasible sampling_u"], + "why_not_previous_failures": "Avoids changing unsupported envs." + } + ) + proposal = parse_proposal_text(proposal_text, study) + trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal) + + job = build_trial_job(study=study, trial=trial, repo_root=tmp_path) + jobs_path = tmp_path / "jobs.toml" + append_job(jobs_path, job) + rendered = jobs_path.read_text(encoding="utf-8") + self.assertIn('name = "study-1-trial-0001"', rendered) + self.assertIn('command = "python3 -m aituner.cli worker run-trial', rendered) + self.assertIn('PYTHONPATH = "src"', rendered) + + def test_ingest_trial_results_updates_best(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + study_path = _write_study_assets(tmp_path) + study = load_study_spec(study_path) + store = StudyStore(tmp_path / ".aituner" / "studies") + store.init_study(spec_path=study_path, study=study) + state = store.load_state(study.study_id) + proposal = Proposal.from_dict( + { + "observation": "Obs", + "diagnosis": "Diag", + "config_patch": {"env_patch": {}, "flag_patch": {"tensor-parallel-size": 4}}, + "expected_effects": ["raise rate"] + } + ) + trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal) + Path(trial.result_path).write_text( + json.dumps( + { + "study_id": study.study_id, + "trial_id": trial.trial_id, + "status": "completed", + "best_sampling_u": 0.75, + "best_request_rate": 12.5, + "best_pass_rate": 0.97 + } + ), + encoding="utf-8", + ) + next_state = store.ingest_trial_results(study.study_id) + self.assertEqual(next_state.best_trial_id, trial.trial_id) + self.assertEqual(next_state.best_request_rate, 12.5) + + +if __name__ == "__main__": + unittest.main()