#!/usr/bin/env python3 from __future__ import annotations import argparse import json from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, replace from pathlib import Path from typing import Any from aituner.spec import ( CompareCandidateSpec, ConfigPatch, SpecError, TrialSpec, load_study_spec, load_structured_file, to_jsonable, ) from aituner.store import StudyStore from aituner.worker import run_trial @dataclass(frozen=True) class RuntimeOverride: cuda_visible_devices: str port: int host: str = "127.0.0.1" @classmethod def from_dict(cls, data: dict[str, Any], *, context: str) -> "RuntimeOverride": cuda_visible_devices = str(data.get("cuda_visible_devices") or "").strip() if not cuda_visible_devices: raise SpecError(f"{context}.cuda_visible_devices must be a non-empty string.") port_value = data.get("port") if isinstance(port_value, bool) or not isinstance(port_value, int): raise SpecError(f"{context}.port must be an integer.") host = str(data.get("host") or "127.0.0.1").strip() if not host: raise SpecError(f"{context}.host must be a non-empty string.") return cls( cuda_visible_devices=cuda_visible_devices, port=port_value, host=host, ) @dataclass(frozen=True) class MultiCompareCandidate: name: str phase: int candidate: CompareCandidateSpec runtime: RuntimeOverride @classmethod def from_dict(cls, data: dict[str, Any], *, context: str) -> "MultiCompareCandidate": name = str(data.get("name") or "").strip() if not name: raise SpecError(f"{context}.name must be a non-empty string.") phase_value = data.get("phase", 1) if isinstance(phase_value, bool) or not isinstance(phase_value, int) or phase_value < 1: raise SpecError(f"{context}.phase must be a positive integer.") candidate = CompareCandidateSpec.from_dict(data, context=context) runtime = RuntimeOverride.from_dict( dict(data.get("runtime") or {}), context=f"{context}.runtime", ) return cls(name=name, phase=phase_value, candidate=candidate, runtime=runtime) @dataclass(frozen=True) class MultiCompareSpec: compare_id: str study_spec_path: str output_root: str | None window_ids: list[str] candidates: list[MultiCompareCandidate] @classmethod def from_dict(cls, data: dict[str, Any]) -> "MultiCompareSpec": compare_id = str(data.get("compare_id") or "").strip() if not compare_id: raise SpecError("compare_id must be a non-empty string.") study_spec_path = str(data.get("study_spec_path") or "").strip() if not study_spec_path: raise SpecError("study_spec_path must be a non-empty string.") raw_window_ids = data.get("window_ids") if not isinstance(raw_window_ids, list) or not raw_window_ids: raise SpecError("window_ids must be a non-empty list.") window_ids = [str(item).strip() for item in raw_window_ids if str(item).strip()] if not window_ids: raise SpecError("window_ids must contain at least one non-empty string.") raw_candidates = data.get("candidates") if not isinstance(raw_candidates, list) or not raw_candidates: raise SpecError("candidates must be a non-empty list.") candidates = [ MultiCompareCandidate.from_dict(dict(item), context=f"candidates[{idx}]") for idx, item in enumerate(raw_candidates) if isinstance(item, dict) ] if len(candidates) != len(raw_candidates): raise SpecError("Every candidates entry must be an object.") names = [item.name for item in candidates] if len(names) != len(set(names)): raise SpecError("candidates names must be unique.") return cls( compare_id=compare_id, study_spec_path=study_spec_path, output_root=str(data.get("output_root") or "").strip() or None, window_ids=window_ids, candidates=candidates, ) def _resolve_path(raw_path: str, *, base_dir: Path) -> Path: path = Path(raw_path) if not path.is_absolute(): path = (base_dir / path).resolve() return path def _load_windows_payload(study: Any, *, study_spec_path: Path) -> list[dict[str, Any]]: windows_path = Path(study.trace.windows_path) if not windows_path.is_absolute(): windows_path = (study_spec_path.parent / windows_path).resolve() payload = json.loads(windows_path.read_text(encoding="utf-8")) raw_windows = payload.get("windows") if isinstance(payload, dict) else payload if not isinstance(raw_windows, list): raise SpecError(f"windows payload must contain a list: {windows_path}") return [ {str(key): value for key, value in item.items()} for item in raw_windows if isinstance(item, dict) ] def _select_windows(spec: MultiCompareSpec, *, study: Any, study_spec_path: Path) -> list[dict[str, Any]]: windows = _load_windows_payload(study, study_spec_path=study_spec_path) indexed = {str(item.get("window_id") or "").strip(): item for item in windows} selected: list[dict[str, Any]] = [] for window_id in spec.window_ids: item = indexed.get(window_id) if item is None: raise SpecError(f"window_id not found in windows payload: {window_id}") selected.append(item) return selected def _load_config_patch( candidate: MultiCompareCandidate, *, spec_path: Path, ) -> tuple[ConfigPatch, dict[str, Any]]: if candidate.candidate.config_patch is not None: config_patch = candidate.candidate.config_patch return config_patch, { "kind": "config_patch", "config_patch": { "env_patch": dict(config_patch.env_patch), "flag_patch": dict(config_patch.flag_patch), }, } assert candidate.candidate.trial_ref is not None study_root = _resolve_path(candidate.candidate.trial_ref.study_root, base_dir=spec_path.parent) trial_spec_path = study_root / "trials" / candidate.candidate.trial_ref.trial_id / "trial_spec.json" if not trial_spec_path.exists(): raise SpecError(f"trial_ref target not found: {trial_spec_path}") payload = json.loads(trial_spec_path.read_text(encoding="utf-8")) config_patch = ConfigPatch.from_dict(payload.get("config_patch") or {}) return config_patch, { "kind": "trial_ref", "study_root": str(study_root), "trial_id": candidate.candidate.trial_ref.trial_id, "config_patch": { "env_patch": dict(config_patch.env_patch), "flag_patch": dict(config_patch.flag_patch), }, } def _parse_int_like(value: Any, *, default: int = 1) -> int: if value is None: return default if isinstance(value, bool): raise SpecError("Topology values must be integers.") if isinstance(value, int): return value if isinstance(value, float) and value.is_integer(): return int(value) if isinstance(value, str) and value.strip(): return int(value.strip()) raise SpecError(f"Unable to parse integer topology value: {value!r}") def _parallel_size_for_candidate(*, study: Any, patch: ConfigPatch) -> int: flags = dict(study.engine.base_flags) flags.update(patch.flag_patch) tp = _parse_int_like(flags.get("tensor-parallel-size"), default=1) dp = _parse_int_like(flags.get("data-parallel-size"), default=1) return tp * dp def _trial_snapshot(trial: TrialSpec) -> dict[str, Any]: return to_jsonable(trial) def _study_snapshot(study: Any) -> dict[str, Any]: return to_jsonable(study) def _run_candidate_for_window( *, compare_id: str, compare_root: Path, study: Any, study_spec_path: Path, window_id: str, candidate: MultiCompareCandidate, config_patch: ConfigPatch, source: dict[str, Any], ) -> dict[str, Any]: run_root = compare_root / "runs" / window_id / candidate.name run_root.mkdir(parents=True, exist_ok=True) result_path = run_root / "result.json" if result_path.exists(): result = json.loads(result_path.read_text(encoding="utf-8")) parallel_size = _parallel_size_for_candidate(study=study, patch=config_patch) best_rate = result.get("best_request_rate") best_rate_per_gpu = ( float(best_rate) / float(parallel_size) if isinstance(best_rate, (int, float)) and parallel_size > 0 else None ) return { "candidate": candidate.name, "source": source, "parallel_size": parallel_size, "runtime": { "cuda_visible_devices": candidate.runtime.cuda_visible_devices, "port": candidate.runtime.port, "host": candidate.runtime.host, }, "config_patch": { "env_patch": dict(config_patch.env_patch), "flag_patch": dict(config_patch.flag_patch), }, "status": result.get("status"), "best_sampling_u": result.get("best_sampling_u"), "best_request_rate": best_rate, "best_request_rate_per_gpu": best_rate_per_gpu, "best_pass_rate": result.get("best_pass_rate"), "best_request_count": result.get("best_request_count"), "failure_stage": result.get("failure_stage"), "failure_reason": result.get("failure_reason"), "artifact_dir": str(run_root), "result_path": str(result_path), "probe_log_path": str(run_root / "probe_history.json"), "engine_log_path": str(run_root / "engine.log"), "resumed": True, } engine_envs = dict(study.engine.base_envs) engine_envs["CUDA_VISIBLE_DEVICES"] = candidate.runtime.cuda_visible_devices engine_flags = dict(study.engine.base_flags) engine_flags["port"] = candidate.runtime.port runtime_study = replace( study, trace=replace(study.trace, window_id=window_id), engine=replace( study.engine, host=candidate.runtime.host, port=candidate.runtime.port, base_envs=engine_envs, base_flags=engine_flags, ), ) actual_study_path = run_root / "study_spec.json" source_path = run_root / "study_spec.source" actual_study_path.write_text( json.dumps(_study_snapshot(runtime_study), ensure_ascii=False, indent=2) + "\n", encoding="utf-8", ) source_path.write_text(str(actual_study_path) + "\n", encoding="utf-8") trial = TrialSpec( study_id=compare_id, trial_id=candidate.name, config_patch=config_patch, search=runtime_study.search, study_spec_path=str(source_path), artifact_dir=str(run_root), probe_log_path=str(run_root / "probe_history.json"), engine_log_path=str(run_root / "engine.log"), result_path=str(result_path), ) StudyStore.write_json(run_root / "trial_spec.json", _trial_snapshot(trial)) result = run_trial(run_root / "trial_spec.json") parallel_size = _parallel_size_for_candidate(study=runtime_study, patch=config_patch) best_rate = result.get("best_request_rate") best_rate_per_gpu = ( float(best_rate) / float(parallel_size) if isinstance(best_rate, (int, float)) and parallel_size > 0 else None ) return { "candidate": candidate.name, "source": source, "parallel_size": parallel_size, "runtime": { "cuda_visible_devices": candidate.runtime.cuda_visible_devices, "port": candidate.runtime.port, "host": candidate.runtime.host, }, "config_patch": { "env_patch": dict(config_patch.env_patch), "flag_patch": dict(config_patch.flag_patch), }, "status": result.get("status"), "best_sampling_u": result.get("best_sampling_u"), "best_request_rate": best_rate, "best_request_rate_per_gpu": best_rate_per_gpu, "best_pass_rate": result.get("best_pass_rate"), "best_request_count": result.get("best_request_count"), "failure_stage": result.get("failure_stage"), "failure_reason": result.get("failure_reason"), "artifact_dir": str(run_root), "result_path": str(result_path), "probe_log_path": str(run_root / "probe_history.json"), "engine_log_path": str(run_root / "engine.log"), "resumed": False, } def _winner(candidates: dict[str, dict[str, Any]]) -> str: scored = [ (name, float(result["best_request_rate_per_gpu"])) for name, result in candidates.items() if isinstance(result.get("best_request_rate_per_gpu"), (int, float)) ] if not scored: return "incomparable" scored.sort(key=lambda item: item[1], reverse=True) if len(scored) > 1 and scored[0][1] == scored[1][1]: return "tie" return scored[0][0] def _aggregate(rows: list[dict[str, Any]], candidates: list[MultiCompareCandidate]) -> dict[str, Any]: candidate_names = [item.name for item in candidates] wins = {name: 0 for name in candidate_names} wins["tie"] = 0 wins["incomparable"] = 0 means: dict[str, dict[str, Any]] = {} for name in candidate_names: rates = [ float(row["candidates"][name]["best_request_rate"]) for row in rows if isinstance(row["candidates"][name].get("best_request_rate"), (int, float)) ] rates_per_gpu = [ float(row["candidates"][name]["best_request_rate_per_gpu"]) for row in rows if isinstance(row["candidates"][name].get("best_request_rate_per_gpu"), (int, float)) ] pass_rates = [ float(row["candidates"][name]["best_pass_rate"]) for row in rows if isinstance(row["candidates"][name].get("best_pass_rate"), (int, float)) ] means[name] = { "mean_request_rate": (sum(rates) / len(rates)) if rates else None, "mean_request_rate_per_gpu": (sum(rates_per_gpu) / len(rates_per_gpu)) if rates_per_gpu else None, "mean_pass_rate": (sum(pass_rates) / len(pass_rates)) if pass_rates else None, **_candidate_result_counts(rows, name), } for row in rows: wins[row["winner"]] = wins.get(row["winner"], 0) + 1 return { "window_count": len(rows), "wins": wins, "candidates": means, } def _candidate_result_counts(rows: list[dict[str, Any]], name: str) -> dict[str, int]: counts = { "completed_window_count": 0, "failed_window_count": 0, "no_feasible_window_count": 0, } for row in rows: result = row.get("candidates", {}).get(name) if not isinstance(result, dict): continue status = str(result.get("status") or "") if status == "completed": counts["completed_window_count"] += 1 elif status == "failed": counts["failed_window_count"] += 1 if not isinstance(result.get("best_request_rate_per_gpu"), (int, float)): counts["no_feasible_window_count"] += 1 return counts def _render_report(summary: dict[str, Any], candidates: list[MultiCompareCandidate]) -> str: candidate_names = [item.name for item in candidates] lines = [ f"# {summary['compare_id']}", "", "## Setup", "", f"- Study spec: `{summary['study_spec_path']}`", f"- Compare root: `{summary['compare_root']}`", f"- Windows: `{len(summary['windows'])}`", "", "## Candidates", "", ] for item in candidates: lines.append( f"- `{item.name}`: phase=`{item.phase}`, gpus=`{item.runtime.cuda_visible_devices}`, port=`{item.runtime.port}`" ) lines.extend( [ "", "## Aggregate", "", f"- Wins: `{json.dumps(summary['aggregate']['wins'], ensure_ascii=False)}`", ] ) for name in candidate_names: aggregate = summary["aggregate"]["candidates"][name] lines.append( f"- `{name}` mean req/s=`{aggregate['mean_request_rate']}`, mean req/s/gpu=`{aggregate['mean_request_rate_per_gpu']}`, mean pass_rate=`{aggregate['mean_pass_rate']}`" ) lines.append( f" completed/failed/no-feasible windows=`{aggregate['completed_window_count']}`/`{aggregate['failed_window_count']}`/`{aggregate['no_feasible_window_count']}`" ) header = ["Window", "Date"] for name in candidate_names: header.extend([f"{name} req/s", f"{name} req/s/gpu"]) header.append("Winner") lines.extend( [ "", "## Per Window", "", "| " + " | ".join(header) + " |", "| " + " | ".join(["---"] * len(header)) + " |", ] ) for row in summary["windows"]: cells = [f"`{row['window_id']}`", f"`{row.get('date') or ''}`"] for name in candidate_names: candidate = row["candidates"][name] cells.append(f"`{candidate.get('best_request_rate')}`") cells.append(f"`{candidate.get('best_request_rate_per_gpu')}`") cells.append(f"`{row['winner']}`") lines.append("| " + " | ".join(cells) + " |") lines.append("") return "\n".join(lines) def run_multi_compare(spec_path: Path) -> dict[str, Any]: spec_path = spec_path.resolve() spec = MultiCompareSpec.from_dict(dict(load_structured_file(spec_path))) study_spec_path = _resolve_path(spec.study_spec_path, base_dir=spec_path.parent) study = load_study_spec(study_spec_path) compare_root = ( _resolve_path(spec.output_root, base_dir=spec_path.parent) if spec.output_root else (Path(".aituner-compare") / spec.compare_id).resolve() ) compare_root.mkdir(parents=True, exist_ok=True) windows = _select_windows(spec, study=study, study_spec_path=study_spec_path) candidate_payloads = [] resolved_candidates: dict[str, tuple[MultiCompareCandidate, ConfigPatch, dict[str, Any]]] = {} for candidate in spec.candidates: config_patch, source = _load_config_patch(candidate, spec_path=spec_path) resolved_candidates[candidate.name] = (candidate, config_patch, source) candidate_payloads.append( { "name": candidate.name, "phase": candidate.phase, "runtime": { "cuda_visible_devices": candidate.runtime.cuda_visible_devices, "port": candidate.runtime.port, "host": candidate.runtime.host, }, "source": source, } ) snapshot = { "compare_id": spec.compare_id, "study_spec_path": str(study_spec_path), "window_ids": spec.window_ids, "candidates": candidate_payloads, } StudyStore.write_json(compare_root / "compare_spec.snapshot.json", snapshot) phases = sorted({item.phase for item in spec.candidates}) per_window: list[dict[str, Any]] = [] for window in windows: window_id = str(window["window_id"]) row = { "window_id": window_id, "trace_type": window.get("trace_type"), "date": window.get("date"), "slot_token": window.get("slot_token"), "slot_label": window.get("slot_label"), "window_start": window.get("window_start"), "window_end": window.get("window_end"), "candidates": {}, } for phase in phases: phase_candidates = [item for item in spec.candidates if item.phase == phase] with ThreadPoolExecutor(max_workers=len(phase_candidates)) as executor: future_map = { executor.submit( _run_candidate_for_window, compare_id=spec.compare_id, compare_root=compare_root, study=study, study_spec_path=study_spec_path, window_id=window_id, candidate=item, config_patch=resolved_candidates[item.name][1], source=resolved_candidates[item.name][2], ): item.name for item in phase_candidates } for future in as_completed(future_map): result = future.result() row["candidates"][result["candidate"]] = result row["winner"] = _winner(row["candidates"]) per_window.append(row) partial_summary = { "compare_id": spec.compare_id, "study_spec_path": str(study_spec_path), "compare_root": str(compare_root), "windows": per_window, "aggregate": _aggregate(per_window, spec.candidates), } StudyStore.write_json(compare_root / "summary.json", partial_summary) (compare_root / "report.md").write_text( _render_report(partial_summary, spec.candidates), encoding="utf-8", ) summary = { "compare_id": spec.compare_id, "study_spec_path": str(study_spec_path), "compare_root": str(compare_root), "windows": per_window, "aggregate": _aggregate(per_window, spec.candidates), } StudyStore.write_json(compare_root / "summary.json", summary) (compare_root / "report.md").write_text( _render_report(summary, spec.candidates), encoding="utf-8", ) return summary def main() -> int: parser = argparse.ArgumentParser(description="Run a multi-candidate compare over trace windows.") parser.add_argument("--spec", required=True) args = parser.parse_args() summary = run_multi_compare(Path(args.spec)) print( json.dumps( { "compare_id": summary["compare_id"], "compare_root": summary["compare_root"], "window_count": summary["aggregate"]["window_count"], "wins": summary["aggregate"]["wins"], }, ensure_ascii=False, indent=2, ) ) return 0 if __name__ == "__main__": raise SystemExit(main())