Files
aituner/scripts/run_multi_compare.py

588 lines
22 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, replace
from pathlib import Path
from typing import Any
from aituner.spec import (
CompareCandidateSpec,
ConfigPatch,
SpecError,
TrialSpec,
load_study_spec,
load_structured_file,
to_jsonable,
)
from aituner.store import StudyStore
from aituner.worker import run_trial
@dataclass(frozen=True)
class RuntimeOverride:
cuda_visible_devices: str
port: int
host: str = "127.0.0.1"
@classmethod
def from_dict(cls, data: dict[str, Any], *, context: str) -> "RuntimeOverride":
cuda_visible_devices = str(data.get("cuda_visible_devices") or "").strip()
if not cuda_visible_devices:
raise SpecError(f"{context}.cuda_visible_devices must be a non-empty string.")
port_value = data.get("port")
if isinstance(port_value, bool) or not isinstance(port_value, int):
raise SpecError(f"{context}.port must be an integer.")
host = str(data.get("host") or "127.0.0.1").strip()
if not host:
raise SpecError(f"{context}.host must be a non-empty string.")
return cls(
cuda_visible_devices=cuda_visible_devices,
port=port_value,
host=host,
)
@dataclass(frozen=True)
class MultiCompareCandidate:
name: str
phase: int
candidate: CompareCandidateSpec
runtime: RuntimeOverride
@classmethod
def from_dict(cls, data: dict[str, Any], *, context: str) -> "MultiCompareCandidate":
name = str(data.get("name") or "").strip()
if not name:
raise SpecError(f"{context}.name must be a non-empty string.")
phase_value = data.get("phase", 1)
if isinstance(phase_value, bool) or not isinstance(phase_value, int) or phase_value < 1:
raise SpecError(f"{context}.phase must be a positive integer.")
candidate = CompareCandidateSpec.from_dict(data, context=context)
runtime = RuntimeOverride.from_dict(
dict(data.get("runtime") or {}),
context=f"{context}.runtime",
)
return cls(name=name, phase=phase_value, candidate=candidate, runtime=runtime)
@dataclass(frozen=True)
class MultiCompareSpec:
compare_id: str
study_spec_path: str
output_root: str | None
window_ids: list[str]
candidates: list[MultiCompareCandidate]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MultiCompareSpec":
compare_id = str(data.get("compare_id") or "").strip()
if not compare_id:
raise SpecError("compare_id must be a non-empty string.")
study_spec_path = str(data.get("study_spec_path") or "").strip()
if not study_spec_path:
raise SpecError("study_spec_path must be a non-empty string.")
raw_window_ids = data.get("window_ids")
if not isinstance(raw_window_ids, list) or not raw_window_ids:
raise SpecError("window_ids must be a non-empty list.")
window_ids = [str(item).strip() for item in raw_window_ids if str(item).strip()]
if not window_ids:
raise SpecError("window_ids must contain at least one non-empty string.")
raw_candidates = data.get("candidates")
if not isinstance(raw_candidates, list) or not raw_candidates:
raise SpecError("candidates must be a non-empty list.")
candidates = [
MultiCompareCandidate.from_dict(dict(item), context=f"candidates[{idx}]")
for idx, item in enumerate(raw_candidates)
if isinstance(item, dict)
]
if len(candidates) != len(raw_candidates):
raise SpecError("Every candidates entry must be an object.")
names = [item.name for item in candidates]
if len(names) != len(set(names)):
raise SpecError("candidates names must be unique.")
return cls(
compare_id=compare_id,
study_spec_path=study_spec_path,
output_root=str(data.get("output_root") or "").strip() or None,
window_ids=window_ids,
candidates=candidates,
)
def _resolve_path(raw_path: str, *, base_dir: Path) -> Path:
path = Path(raw_path)
if not path.is_absolute():
path = (base_dir / path).resolve()
return path
def _load_windows_payload(study: Any, *, study_spec_path: Path) -> list[dict[str, Any]]:
windows_path = Path(study.trace.windows_path)
if not windows_path.is_absolute():
windows_path = (study_spec_path.parent / windows_path).resolve()
payload = json.loads(windows_path.read_text(encoding="utf-8"))
raw_windows = payload.get("windows") if isinstance(payload, dict) else payload
if not isinstance(raw_windows, list):
raise SpecError(f"windows payload must contain a list: {windows_path}")
return [
{str(key): value for key, value in item.items()}
for item in raw_windows
if isinstance(item, dict)
]
def _select_windows(spec: MultiCompareSpec, *, study: Any, study_spec_path: Path) -> list[dict[str, Any]]:
windows = _load_windows_payload(study, study_spec_path=study_spec_path)
indexed = {str(item.get("window_id") or "").strip(): item for item in windows}
selected: list[dict[str, Any]] = []
for window_id in spec.window_ids:
item = indexed.get(window_id)
if item is None:
raise SpecError(f"window_id not found in windows payload: {window_id}")
selected.append(item)
return selected
def _load_config_patch(
candidate: MultiCompareCandidate,
*,
spec_path: Path,
) -> tuple[ConfigPatch, dict[str, Any]]:
if candidate.candidate.config_patch is not None:
config_patch = candidate.candidate.config_patch
return config_patch, {
"kind": "config_patch",
"config_patch": {
"env_patch": dict(config_patch.env_patch),
"flag_patch": dict(config_patch.flag_patch),
},
}
assert candidate.candidate.trial_ref is not None
study_root = _resolve_path(candidate.candidate.trial_ref.study_root, base_dir=spec_path.parent)
trial_spec_path = study_root / "trials" / candidate.candidate.trial_ref.trial_id / "trial_spec.json"
if not trial_spec_path.exists():
raise SpecError(f"trial_ref target not found: {trial_spec_path}")
payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
config_patch = ConfigPatch.from_dict(payload.get("config_patch") or {})
return config_patch, {
"kind": "trial_ref",
"study_root": str(study_root),
"trial_id": candidate.candidate.trial_ref.trial_id,
"config_patch": {
"env_patch": dict(config_patch.env_patch),
"flag_patch": dict(config_patch.flag_patch),
},
}
def _parse_int_like(value: Any, *, default: int = 1) -> int:
if value is None:
return default
if isinstance(value, bool):
raise SpecError("Topology values must be integers.")
if isinstance(value, int):
return value
if isinstance(value, float) and value.is_integer():
return int(value)
if isinstance(value, str) and value.strip():
return int(value.strip())
raise SpecError(f"Unable to parse integer topology value: {value!r}")
def _parallel_size_for_candidate(*, study: Any, patch: ConfigPatch) -> int:
flags = dict(study.engine.base_flags)
flags.update(patch.flag_patch)
tp = _parse_int_like(flags.get("tensor-parallel-size"), default=1)
dp = _parse_int_like(flags.get("data-parallel-size"), default=1)
return tp * dp
def _trial_snapshot(trial: TrialSpec) -> dict[str, Any]:
return to_jsonable(trial)
def _study_snapshot(study: Any) -> dict[str, Any]:
return to_jsonable(study)
def _run_candidate_for_window(
*,
compare_id: str,
compare_root: Path,
study: Any,
study_spec_path: Path,
window_id: str,
candidate: MultiCompareCandidate,
config_patch: ConfigPatch,
source: dict[str, Any],
) -> dict[str, Any]:
run_root = compare_root / "runs" / window_id / candidate.name
run_root.mkdir(parents=True, exist_ok=True)
result_path = run_root / "result.json"
if result_path.exists():
result = json.loads(result_path.read_text(encoding="utf-8"))
parallel_size = _parallel_size_for_candidate(study=study, patch=config_patch)
best_rate = result.get("best_request_rate")
best_rate_per_gpu = (
float(best_rate) / float(parallel_size)
if isinstance(best_rate, (int, float)) and parallel_size > 0
else None
)
return {
"candidate": candidate.name,
"source": source,
"parallel_size": parallel_size,
"runtime": {
"cuda_visible_devices": candidate.runtime.cuda_visible_devices,
"port": candidate.runtime.port,
"host": candidate.runtime.host,
},
"config_patch": {
"env_patch": dict(config_patch.env_patch),
"flag_patch": dict(config_patch.flag_patch),
},
"status": result.get("status"),
"best_sampling_u": result.get("best_sampling_u"),
"best_request_rate": best_rate,
"best_request_rate_per_gpu": best_rate_per_gpu,
"best_pass_rate": result.get("best_pass_rate"),
"best_request_count": result.get("best_request_count"),
"failure_stage": result.get("failure_stage"),
"failure_reason": result.get("failure_reason"),
"artifact_dir": str(run_root),
"result_path": str(result_path),
"probe_log_path": str(run_root / "probe_history.json"),
"engine_log_path": str(run_root / "engine.log"),
"resumed": True,
}
engine_envs = dict(study.engine.base_envs)
engine_envs["CUDA_VISIBLE_DEVICES"] = candidate.runtime.cuda_visible_devices
engine_flags = dict(study.engine.base_flags)
engine_flags["port"] = candidate.runtime.port
runtime_study = replace(
study,
trace=replace(study.trace, window_id=window_id),
engine=replace(
study.engine,
host=candidate.runtime.host,
port=candidate.runtime.port,
base_envs=engine_envs,
base_flags=engine_flags,
),
)
actual_study_path = run_root / "study_spec.json"
source_path = run_root / "study_spec.source"
actual_study_path.write_text(
json.dumps(_study_snapshot(runtime_study), ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
source_path.write_text(str(actual_study_path) + "\n", encoding="utf-8")
trial = TrialSpec(
study_id=compare_id,
trial_id=candidate.name,
config_patch=config_patch,
search=runtime_study.search,
study_spec_path=str(source_path),
artifact_dir=str(run_root),
probe_log_path=str(run_root / "probe_history.json"),
engine_log_path=str(run_root / "engine.log"),
result_path=str(result_path),
)
StudyStore.write_json(run_root / "trial_spec.json", _trial_snapshot(trial))
result = run_trial(run_root / "trial_spec.json")
parallel_size = _parallel_size_for_candidate(study=runtime_study, patch=config_patch)
best_rate = result.get("best_request_rate")
best_rate_per_gpu = (
float(best_rate) / float(parallel_size)
if isinstance(best_rate, (int, float)) and parallel_size > 0
else None
)
return {
"candidate": candidate.name,
"source": source,
"parallel_size": parallel_size,
"runtime": {
"cuda_visible_devices": candidate.runtime.cuda_visible_devices,
"port": candidate.runtime.port,
"host": candidate.runtime.host,
},
"config_patch": {
"env_patch": dict(config_patch.env_patch),
"flag_patch": dict(config_patch.flag_patch),
},
"status": result.get("status"),
"best_sampling_u": result.get("best_sampling_u"),
"best_request_rate": best_rate,
"best_request_rate_per_gpu": best_rate_per_gpu,
"best_pass_rate": result.get("best_pass_rate"),
"best_request_count": result.get("best_request_count"),
"failure_stage": result.get("failure_stage"),
"failure_reason": result.get("failure_reason"),
"artifact_dir": str(run_root),
"result_path": str(result_path),
"probe_log_path": str(run_root / "probe_history.json"),
"engine_log_path": str(run_root / "engine.log"),
"resumed": False,
}
def _winner(candidates: dict[str, dict[str, Any]]) -> str:
scored = [
(name, float(result["best_request_rate_per_gpu"]))
for name, result in candidates.items()
if isinstance(result.get("best_request_rate_per_gpu"), (int, float))
]
if not scored:
return "incomparable"
scored.sort(key=lambda item: item[1], reverse=True)
if len(scored) > 1 and scored[0][1] == scored[1][1]:
return "tie"
return scored[0][0]
def _aggregate(rows: list[dict[str, Any]], candidates: list[MultiCompareCandidate]) -> dict[str, Any]:
candidate_names = [item.name for item in candidates]
wins = {name: 0 for name in candidate_names}
wins["tie"] = 0
wins["incomparable"] = 0
means: dict[str, dict[str, Any]] = {}
for name in candidate_names:
rates = [
float(row["candidates"][name]["best_request_rate"])
for row in rows
if isinstance(row["candidates"][name].get("best_request_rate"), (int, float))
]
rates_per_gpu = [
float(row["candidates"][name]["best_request_rate_per_gpu"])
for row in rows
if isinstance(row["candidates"][name].get("best_request_rate_per_gpu"), (int, float))
]
pass_rates = [
float(row["candidates"][name]["best_pass_rate"])
for row in rows
if isinstance(row["candidates"][name].get("best_pass_rate"), (int, float))
]
means[name] = {
"mean_request_rate": (sum(rates) / len(rates)) if rates else None,
"mean_request_rate_per_gpu": (sum(rates_per_gpu) / len(rates_per_gpu))
if rates_per_gpu
else None,
"mean_pass_rate": (sum(pass_rates) / len(pass_rates)) if pass_rates else None,
**_candidate_result_counts(rows, name),
}
for row in rows:
wins[row["winner"]] = wins.get(row["winner"], 0) + 1
return {
"window_count": len(rows),
"wins": wins,
"candidates": means,
}
def _candidate_result_counts(rows: list[dict[str, Any]], name: str) -> dict[str, int]:
counts = {
"completed_window_count": 0,
"failed_window_count": 0,
"no_feasible_window_count": 0,
}
for row in rows:
result = row.get("candidates", {}).get(name)
if not isinstance(result, dict):
continue
status = str(result.get("status") or "")
if status == "completed":
counts["completed_window_count"] += 1
elif status == "failed":
counts["failed_window_count"] += 1
if not isinstance(result.get("best_request_rate_per_gpu"), (int, float)):
counts["no_feasible_window_count"] += 1
return counts
def _render_report(summary: dict[str, Any], candidates: list[MultiCompareCandidate]) -> str:
candidate_names = [item.name for item in candidates]
lines = [
f"# {summary['compare_id']}",
"",
"## Setup",
"",
f"- Study spec: `{summary['study_spec_path']}`",
f"- Compare root: `{summary['compare_root']}`",
f"- Windows: `{len(summary['windows'])}`",
"",
"## Candidates",
"",
]
for item in candidates:
lines.append(
f"- `{item.name}`: phase=`{item.phase}`, gpus=`{item.runtime.cuda_visible_devices}`, port=`{item.runtime.port}`"
)
lines.extend(
[
"",
"## Aggregate",
"",
f"- Wins: `{json.dumps(summary['aggregate']['wins'], ensure_ascii=False)}`",
]
)
for name in candidate_names:
aggregate = summary["aggregate"]["candidates"][name]
lines.append(
f"- `{name}` mean req/s=`{aggregate['mean_request_rate']}`, mean req/s/gpu=`{aggregate['mean_request_rate_per_gpu']}`, mean pass_rate=`{aggregate['mean_pass_rate']}`"
)
lines.append(
f" completed/failed/no-feasible windows=`{aggregate['completed_window_count']}`/`{aggregate['failed_window_count']}`/`{aggregate['no_feasible_window_count']}`"
)
header = ["Window", "Date"]
for name in candidate_names:
header.extend([f"{name} req/s", f"{name} req/s/gpu"])
header.append("Winner")
lines.extend(
[
"",
"## Per Window",
"",
"| " + " | ".join(header) + " |",
"| " + " | ".join(["---"] * len(header)) + " |",
]
)
for row in summary["windows"]:
cells = [f"`{row['window_id']}`", f"`{row.get('date') or ''}`"]
for name in candidate_names:
candidate = row["candidates"][name]
cells.append(f"`{candidate.get('best_request_rate')}`")
cells.append(f"`{candidate.get('best_request_rate_per_gpu')}`")
cells.append(f"`{row['winner']}`")
lines.append("| " + " | ".join(cells) + " |")
lines.append("")
return "\n".join(lines)
def run_multi_compare(spec_path: Path) -> dict[str, Any]:
spec_path = spec_path.resolve()
spec = MultiCompareSpec.from_dict(dict(load_structured_file(spec_path)))
study_spec_path = _resolve_path(spec.study_spec_path, base_dir=spec_path.parent)
study = load_study_spec(study_spec_path)
compare_root = (
_resolve_path(spec.output_root, base_dir=spec_path.parent)
if spec.output_root
else (Path(".aituner-compare") / spec.compare_id).resolve()
)
compare_root.mkdir(parents=True, exist_ok=True)
windows = _select_windows(spec, study=study, study_spec_path=study_spec_path)
candidate_payloads = []
resolved_candidates: dict[str, tuple[MultiCompareCandidate, ConfigPatch, dict[str, Any]]] = {}
for candidate in spec.candidates:
config_patch, source = _load_config_patch(candidate, spec_path=spec_path)
resolved_candidates[candidate.name] = (candidate, config_patch, source)
candidate_payloads.append(
{
"name": candidate.name,
"phase": candidate.phase,
"runtime": {
"cuda_visible_devices": candidate.runtime.cuda_visible_devices,
"port": candidate.runtime.port,
"host": candidate.runtime.host,
},
"source": source,
}
)
snapshot = {
"compare_id": spec.compare_id,
"study_spec_path": str(study_spec_path),
"window_ids": spec.window_ids,
"candidates": candidate_payloads,
}
StudyStore.write_json(compare_root / "compare_spec.snapshot.json", snapshot)
phases = sorted({item.phase for item in spec.candidates})
per_window: list[dict[str, Any]] = []
for window in windows:
window_id = str(window["window_id"])
row = {
"window_id": window_id,
"trace_type": window.get("trace_type"),
"date": window.get("date"),
"slot_token": window.get("slot_token"),
"slot_label": window.get("slot_label"),
"window_start": window.get("window_start"),
"window_end": window.get("window_end"),
"candidates": {},
}
for phase in phases:
phase_candidates = [item for item in spec.candidates if item.phase == phase]
with ThreadPoolExecutor(max_workers=len(phase_candidates)) as executor:
future_map = {
executor.submit(
_run_candidate_for_window,
compare_id=spec.compare_id,
compare_root=compare_root,
study=study,
study_spec_path=study_spec_path,
window_id=window_id,
candidate=item,
config_patch=resolved_candidates[item.name][1],
source=resolved_candidates[item.name][2],
): item.name
for item in phase_candidates
}
for future in as_completed(future_map):
result = future.result()
row["candidates"][result["candidate"]] = result
row["winner"] = _winner(row["candidates"])
per_window.append(row)
partial_summary = {
"compare_id": spec.compare_id,
"study_spec_path": str(study_spec_path),
"compare_root": str(compare_root),
"windows": per_window,
"aggregate": _aggregate(per_window, spec.candidates),
}
StudyStore.write_json(compare_root / "summary.json", partial_summary)
(compare_root / "report.md").write_text(
_render_report(partial_summary, spec.candidates),
encoding="utf-8",
)
summary = {
"compare_id": spec.compare_id,
"study_spec_path": str(study_spec_path),
"compare_root": str(compare_root),
"windows": per_window,
"aggregate": _aggregate(per_window, spec.candidates),
}
StudyStore.write_json(compare_root / "summary.json", summary)
(compare_root / "report.md").write_text(
_render_report(summary, spec.candidates),
encoding="utf-8",
)
return summary
def main() -> int:
parser = argparse.ArgumentParser(description="Run a multi-candidate compare over trace windows.")
parser.add_argument("--spec", required=True)
args = parser.parse_args()
summary = run_multi_compare(Path(args.spec))
print(
json.dumps(
{
"compare_id": summary["compare_id"],
"compare_root": summary["compare_root"],
"window_count": summary["aggregate"]["window_count"],
"wins": summary["aggregate"]["wins"],
},
ensure_ascii=False,
indent=2,
)
)
return 0
if __name__ == "__main__":
raise SystemExit(main())