compare: add multi-candidate runner
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -2,6 +2,8 @@
|
||||
.aituner-smoke/
|
||||
.aituner-decode/
|
||||
.aituner-tight/
|
||||
.aituner-prefill/
|
||||
.aituner-compare/
|
||||
.env
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
{
|
||||
"compare_id": "dash1-qwen235b-prefill-thinking-7day-baseline-vs-0323-vs-0327",
|
||||
"study_spec_path": "dash0_qwen235b_prefill_thinking_run3_ttft_tight_0323.json",
|
||||
"output_root": "../../.aituner-compare/dash1-qwen235b-prefill-thinking-7day-baseline-vs-0323-vs-0327",
|
||||
"window_ids": [
|
||||
"thinking_w20260321_1000",
|
||||
"thinking_w20260322_1000",
|
||||
"thinking_w20260323_1000",
|
||||
"thinking_w20260324_1000",
|
||||
"thinking_w20260325_1000",
|
||||
"thinking_w20260326_1000",
|
||||
"thinking_w20260327_1000"
|
||||
],
|
||||
"candidates": [
|
||||
{
|
||||
"name": "baseline",
|
||||
"phase": 1,
|
||||
"config_patch": {
|
||||
"env_patch": {},
|
||||
"flag_patch": {}
|
||||
},
|
||||
"runtime": {
|
||||
"cuda_visible_devices": "0,1,2,3",
|
||||
"port": 18141
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "tuned_0323",
|
||||
"phase": 1,
|
||||
"trial_ref": {
|
||||
"study_root": "../../.aituner-prefill/dash0-qwen235b-prefill-thinking-run3-ttft-tight-0323-topology",
|
||||
"trial_id": "trial-0006"
|
||||
},
|
||||
"runtime": {
|
||||
"cuda_visible_devices": "4,5,6,7",
|
||||
"port": 18142
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "tuned_0327",
|
||||
"phase": 2,
|
||||
"trial_ref": {
|
||||
"study_root": "../../.aituner-prefill/dash0-qwen235b-prefill-thinking-run2-ttft-tight-topology",
|
||||
"trial_id": "trial-0012"
|
||||
},
|
||||
"runtime": {
|
||||
"cuda_visible_devices": "0,1,2,3,4,5,6,7",
|
||||
"port": 18143
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
563
scripts/run_multi_compare.py
Normal file
563
scripts/run_multi_compare.py
Normal file
@@ -0,0 +1,563 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from aituner.spec import (
|
||||
CompareCandidateSpec,
|
||||
ConfigPatch,
|
||||
SpecError,
|
||||
TrialSpec,
|
||||
load_study_spec,
|
||||
load_structured_file,
|
||||
to_jsonable,
|
||||
)
|
||||
from aituner.store import StudyStore
|
||||
from aituner.worker import run_trial
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuntimeOverride:
|
||||
cuda_visible_devices: str
|
||||
port: int
|
||||
host: str = "127.0.0.1"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any], *, context: str) -> "RuntimeOverride":
|
||||
cuda_visible_devices = str(data.get("cuda_visible_devices") or "").strip()
|
||||
if not cuda_visible_devices:
|
||||
raise SpecError(f"{context}.cuda_visible_devices must be a non-empty string.")
|
||||
port_value = data.get("port")
|
||||
if isinstance(port_value, bool) or not isinstance(port_value, int):
|
||||
raise SpecError(f"{context}.port must be an integer.")
|
||||
host = str(data.get("host") or "127.0.0.1").strip()
|
||||
if not host:
|
||||
raise SpecError(f"{context}.host must be a non-empty string.")
|
||||
return cls(
|
||||
cuda_visible_devices=cuda_visible_devices,
|
||||
port=port_value,
|
||||
host=host,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiCompareCandidate:
|
||||
name: str
|
||||
phase: int
|
||||
candidate: CompareCandidateSpec
|
||||
runtime: RuntimeOverride
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any], *, context: str) -> "MultiCompareCandidate":
|
||||
name = str(data.get("name") or "").strip()
|
||||
if not name:
|
||||
raise SpecError(f"{context}.name must be a non-empty string.")
|
||||
phase_value = data.get("phase", 1)
|
||||
if isinstance(phase_value, bool) or not isinstance(phase_value, int) or phase_value < 1:
|
||||
raise SpecError(f"{context}.phase must be a positive integer.")
|
||||
candidate = CompareCandidateSpec.from_dict(data, context=context)
|
||||
runtime = RuntimeOverride.from_dict(
|
||||
dict(data.get("runtime") or {}),
|
||||
context=f"{context}.runtime",
|
||||
)
|
||||
return cls(name=name, phase=phase_value, candidate=candidate, runtime=runtime)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiCompareSpec:
|
||||
compare_id: str
|
||||
study_spec_path: str
|
||||
output_root: str | None
|
||||
window_ids: list[str]
|
||||
candidates: list[MultiCompareCandidate]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MultiCompareSpec":
|
||||
compare_id = str(data.get("compare_id") or "").strip()
|
||||
if not compare_id:
|
||||
raise SpecError("compare_id must be a non-empty string.")
|
||||
study_spec_path = str(data.get("study_spec_path") or "").strip()
|
||||
if not study_spec_path:
|
||||
raise SpecError("study_spec_path must be a non-empty string.")
|
||||
raw_window_ids = data.get("window_ids")
|
||||
if not isinstance(raw_window_ids, list) or not raw_window_ids:
|
||||
raise SpecError("window_ids must be a non-empty list.")
|
||||
window_ids = [str(item).strip() for item in raw_window_ids if str(item).strip()]
|
||||
if not window_ids:
|
||||
raise SpecError("window_ids must contain at least one non-empty string.")
|
||||
raw_candidates = data.get("candidates")
|
||||
if not isinstance(raw_candidates, list) or not raw_candidates:
|
||||
raise SpecError("candidates must be a non-empty list.")
|
||||
candidates = [
|
||||
MultiCompareCandidate.from_dict(dict(item), context=f"candidates[{idx}]")
|
||||
for idx, item in enumerate(raw_candidates)
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
if len(candidates) != len(raw_candidates):
|
||||
raise SpecError("Every candidates entry must be an object.")
|
||||
names = [item.name for item in candidates]
|
||||
if len(names) != len(set(names)):
|
||||
raise SpecError("candidates names must be unique.")
|
||||
return cls(
|
||||
compare_id=compare_id,
|
||||
study_spec_path=study_spec_path,
|
||||
output_root=str(data.get("output_root") or "").strip() or None,
|
||||
window_ids=window_ids,
|
||||
candidates=candidates,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_path(raw_path: str, *, base_dir: Path) -> Path:
|
||||
path = Path(raw_path)
|
||||
if not path.is_absolute():
|
||||
path = (base_dir / path).resolve()
|
||||
return path
|
||||
|
||||
|
||||
def _load_windows_payload(study: Any, *, study_spec_path: Path) -> list[dict[str, Any]]:
|
||||
windows_path = Path(study.trace.windows_path)
|
||||
if not windows_path.is_absolute():
|
||||
windows_path = (study_spec_path.parent / windows_path).resolve()
|
||||
payload = json.loads(windows_path.read_text(encoding="utf-8"))
|
||||
raw_windows = payload.get("windows") if isinstance(payload, dict) else payload
|
||||
if not isinstance(raw_windows, list):
|
||||
raise SpecError(f"windows payload must contain a list: {windows_path}")
|
||||
return [
|
||||
{str(key): value for key, value in item.items()}
|
||||
for item in raw_windows
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
|
||||
|
||||
def _select_windows(spec: MultiCompareSpec, *, study: Any, study_spec_path: Path) -> list[dict[str, Any]]:
|
||||
windows = _load_windows_payload(study, study_spec_path=study_spec_path)
|
||||
indexed = {str(item.get("window_id") or "").strip(): item for item in windows}
|
||||
selected: list[dict[str, Any]] = []
|
||||
for window_id in spec.window_ids:
|
||||
item = indexed.get(window_id)
|
||||
if item is None:
|
||||
raise SpecError(f"window_id not found in windows payload: {window_id}")
|
||||
selected.append(item)
|
||||
return selected
|
||||
|
||||
|
||||
def _load_config_patch(
|
||||
candidate: MultiCompareCandidate,
|
||||
*,
|
||||
spec_path: Path,
|
||||
) -> tuple[ConfigPatch, dict[str, Any]]:
|
||||
if candidate.candidate.config_patch is not None:
|
||||
config_patch = candidate.candidate.config_patch
|
||||
return config_patch, {
|
||||
"kind": "config_patch",
|
||||
"config_patch": {
|
||||
"env_patch": dict(config_patch.env_patch),
|
||||
"flag_patch": dict(config_patch.flag_patch),
|
||||
},
|
||||
}
|
||||
assert candidate.candidate.trial_ref is not None
|
||||
study_root = _resolve_path(candidate.candidate.trial_ref.study_root, base_dir=spec_path.parent)
|
||||
trial_spec_path = study_root / "trials" / candidate.candidate.trial_ref.trial_id / "trial_spec.json"
|
||||
if not trial_spec_path.exists():
|
||||
raise SpecError(f"trial_ref target not found: {trial_spec_path}")
|
||||
payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
|
||||
config_patch = ConfigPatch.from_dict(payload.get("config_patch") or {})
|
||||
return config_patch, {
|
||||
"kind": "trial_ref",
|
||||
"study_root": str(study_root),
|
||||
"trial_id": candidate.candidate.trial_ref.trial_id,
|
||||
"config_patch": {
|
||||
"env_patch": dict(config_patch.env_patch),
|
||||
"flag_patch": dict(config_patch.flag_patch),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _parse_int_like(value: Any, *, default: int = 1) -> int:
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
raise SpecError("Topology values must be integers.")
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
if isinstance(value, float) and value.is_integer():
|
||||
return int(value)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return int(value.strip())
|
||||
raise SpecError(f"Unable to parse integer topology value: {value!r}")
|
||||
|
||||
|
||||
def _parallel_size_for_candidate(*, study: Any, patch: ConfigPatch) -> int:
|
||||
flags = dict(study.engine.base_flags)
|
||||
flags.update(patch.flag_patch)
|
||||
tp = _parse_int_like(flags.get("tensor-parallel-size"), default=1)
|
||||
dp = _parse_int_like(flags.get("data-parallel-size"), default=1)
|
||||
return tp * dp
|
||||
|
||||
|
||||
def _trial_snapshot(trial: TrialSpec) -> dict[str, Any]:
|
||||
return to_jsonable(trial)
|
||||
|
||||
|
||||
def _study_snapshot(study: Any) -> dict[str, Any]:
|
||||
return to_jsonable(study)
|
||||
|
||||
|
||||
def _run_candidate_for_window(
|
||||
*,
|
||||
compare_id: str,
|
||||
compare_root: Path,
|
||||
study: Any,
|
||||
study_spec_path: Path,
|
||||
window_id: str,
|
||||
candidate: MultiCompareCandidate,
|
||||
config_patch: ConfigPatch,
|
||||
source: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
run_root = compare_root / "runs" / window_id / candidate.name
|
||||
run_root.mkdir(parents=True, exist_ok=True)
|
||||
result_path = run_root / "result.json"
|
||||
if result_path.exists():
|
||||
result = json.loads(result_path.read_text(encoding="utf-8"))
|
||||
parallel_size = _parallel_size_for_candidate(study=study, patch=config_patch)
|
||||
best_rate = result.get("best_request_rate")
|
||||
best_rate_per_gpu = (
|
||||
float(best_rate) / float(parallel_size)
|
||||
if isinstance(best_rate, (int, float)) and parallel_size > 0
|
||||
else None
|
||||
)
|
||||
return {
|
||||
"candidate": candidate.name,
|
||||
"source": source,
|
||||
"parallel_size": parallel_size,
|
||||
"runtime": {
|
||||
"cuda_visible_devices": candidate.runtime.cuda_visible_devices,
|
||||
"port": candidate.runtime.port,
|
||||
"host": candidate.runtime.host,
|
||||
},
|
||||
"config_patch": {
|
||||
"env_patch": dict(config_patch.env_patch),
|
||||
"flag_patch": dict(config_patch.flag_patch),
|
||||
},
|
||||
"status": result.get("status"),
|
||||
"best_sampling_u": result.get("best_sampling_u"),
|
||||
"best_request_rate": best_rate,
|
||||
"best_request_rate_per_gpu": best_rate_per_gpu,
|
||||
"best_pass_rate": result.get("best_pass_rate"),
|
||||
"best_request_count": result.get("best_request_count"),
|
||||
"failure_stage": result.get("failure_stage"),
|
||||
"failure_reason": result.get("failure_reason"),
|
||||
"artifact_dir": str(run_root),
|
||||
"result_path": str(result_path),
|
||||
"probe_log_path": str(run_root / "probe_history.json"),
|
||||
"engine_log_path": str(run_root / "engine.log"),
|
||||
"resumed": True,
|
||||
}
|
||||
|
||||
engine_envs = dict(study.engine.base_envs)
|
||||
engine_envs["CUDA_VISIBLE_DEVICES"] = candidate.runtime.cuda_visible_devices
|
||||
engine_flags = dict(study.engine.base_flags)
|
||||
engine_flags["port"] = candidate.runtime.port
|
||||
runtime_study = replace(
|
||||
study,
|
||||
trace=replace(study.trace, window_id=window_id),
|
||||
engine=replace(
|
||||
study.engine,
|
||||
host=candidate.runtime.host,
|
||||
port=candidate.runtime.port,
|
||||
base_envs=engine_envs,
|
||||
base_flags=engine_flags,
|
||||
),
|
||||
)
|
||||
actual_study_path = run_root / "study_spec.json"
|
||||
source_path = run_root / "study_spec.source"
|
||||
actual_study_path.write_text(
|
||||
json.dumps(_study_snapshot(runtime_study), ensure_ascii=False, indent=2) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
source_path.write_text(str(actual_study_path) + "\n", encoding="utf-8")
|
||||
trial = TrialSpec(
|
||||
study_id=compare_id,
|
||||
trial_id=candidate.name,
|
||||
config_patch=config_patch,
|
||||
search=runtime_study.search,
|
||||
study_spec_path=str(source_path),
|
||||
artifact_dir=str(run_root),
|
||||
probe_log_path=str(run_root / "probe_history.json"),
|
||||
engine_log_path=str(run_root / "engine.log"),
|
||||
result_path=str(result_path),
|
||||
)
|
||||
StudyStore.write_json(run_root / "trial_spec.json", _trial_snapshot(trial))
|
||||
result = run_trial(run_root / "trial_spec.json")
|
||||
parallel_size = _parallel_size_for_candidate(study=runtime_study, patch=config_patch)
|
||||
best_rate = result.get("best_request_rate")
|
||||
best_rate_per_gpu = (
|
||||
float(best_rate) / float(parallel_size)
|
||||
if isinstance(best_rate, (int, float)) and parallel_size > 0
|
||||
else None
|
||||
)
|
||||
return {
|
||||
"candidate": candidate.name,
|
||||
"source": source,
|
||||
"parallel_size": parallel_size,
|
||||
"runtime": {
|
||||
"cuda_visible_devices": candidate.runtime.cuda_visible_devices,
|
||||
"port": candidate.runtime.port,
|
||||
"host": candidate.runtime.host,
|
||||
},
|
||||
"config_patch": {
|
||||
"env_patch": dict(config_patch.env_patch),
|
||||
"flag_patch": dict(config_patch.flag_patch),
|
||||
},
|
||||
"status": result.get("status"),
|
||||
"best_sampling_u": result.get("best_sampling_u"),
|
||||
"best_request_rate": best_rate,
|
||||
"best_request_rate_per_gpu": best_rate_per_gpu,
|
||||
"best_pass_rate": result.get("best_pass_rate"),
|
||||
"best_request_count": result.get("best_request_count"),
|
||||
"failure_stage": result.get("failure_stage"),
|
||||
"failure_reason": result.get("failure_reason"),
|
||||
"artifact_dir": str(run_root),
|
||||
"result_path": str(result_path),
|
||||
"probe_log_path": str(run_root / "probe_history.json"),
|
||||
"engine_log_path": str(run_root / "engine.log"),
|
||||
"resumed": False,
|
||||
}
|
||||
|
||||
|
||||
def _winner(candidates: dict[str, dict[str, Any]]) -> str:
|
||||
scored = [
|
||||
(name, float(result["best_request_rate_per_gpu"]))
|
||||
for name, result in candidates.items()
|
||||
if isinstance(result.get("best_request_rate_per_gpu"), (int, float))
|
||||
]
|
||||
if not scored:
|
||||
return "incomparable"
|
||||
scored.sort(key=lambda item: item[1], reverse=True)
|
||||
if len(scored) > 1 and scored[0][1] == scored[1][1]:
|
||||
return "tie"
|
||||
return scored[0][0]
|
||||
|
||||
|
||||
def _aggregate(rows: list[dict[str, Any]], candidates: list[MultiCompareCandidate]) -> dict[str, Any]:
|
||||
candidate_names = [item.name for item in candidates]
|
||||
wins = {name: 0 for name in candidate_names}
|
||||
wins["tie"] = 0
|
||||
wins["incomparable"] = 0
|
||||
means: dict[str, dict[str, Any]] = {}
|
||||
for name in candidate_names:
|
||||
rates = [
|
||||
float(row["candidates"][name]["best_request_rate"])
|
||||
for row in rows
|
||||
if isinstance(row["candidates"][name].get("best_request_rate"), (int, float))
|
||||
]
|
||||
rates_per_gpu = [
|
||||
float(row["candidates"][name]["best_request_rate_per_gpu"])
|
||||
for row in rows
|
||||
if isinstance(row["candidates"][name].get("best_request_rate_per_gpu"), (int, float))
|
||||
]
|
||||
pass_rates = [
|
||||
float(row["candidates"][name]["best_pass_rate"])
|
||||
for row in rows
|
||||
if isinstance(row["candidates"][name].get("best_pass_rate"), (int, float))
|
||||
]
|
||||
means[name] = {
|
||||
"mean_request_rate": (sum(rates) / len(rates)) if rates else None,
|
||||
"mean_request_rate_per_gpu": (sum(rates_per_gpu) / len(rates_per_gpu))
|
||||
if rates_per_gpu
|
||||
else None,
|
||||
"mean_pass_rate": (sum(pass_rates) / len(pass_rates)) if pass_rates else None,
|
||||
}
|
||||
for row in rows:
|
||||
wins[row["winner"]] = wins.get(row["winner"], 0) + 1
|
||||
return {
|
||||
"window_count": len(rows),
|
||||
"wins": wins,
|
||||
"candidates": means,
|
||||
}
|
||||
|
||||
|
||||
def _render_report(summary: dict[str, Any], candidates: list[MultiCompareCandidate]) -> str:
|
||||
candidate_names = [item.name for item in candidates]
|
||||
lines = [
|
||||
f"# {summary['compare_id']}",
|
||||
"",
|
||||
"## Setup",
|
||||
"",
|
||||
f"- Study spec: `{summary['study_spec_path']}`",
|
||||
f"- Compare root: `{summary['compare_root']}`",
|
||||
f"- Windows: `{len(summary['windows'])}`",
|
||||
"",
|
||||
"## Candidates",
|
||||
"",
|
||||
]
|
||||
for item in candidates:
|
||||
lines.append(
|
||||
f"- `{item.name}`: phase=`{item.phase}`, gpus=`{item.runtime.cuda_visible_devices}`, port=`{item.runtime.port}`"
|
||||
)
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"## Aggregate",
|
||||
"",
|
||||
f"- Wins: `{json.dumps(summary['aggregate']['wins'], ensure_ascii=False)}`",
|
||||
]
|
||||
)
|
||||
for name in candidate_names:
|
||||
aggregate = summary["aggregate"]["candidates"][name]
|
||||
lines.append(
|
||||
f"- `{name}` mean req/s=`{aggregate['mean_request_rate']}`, mean req/s/gpu=`{aggregate['mean_request_rate_per_gpu']}`, mean pass_rate=`{aggregate['mean_pass_rate']}`"
|
||||
)
|
||||
header = ["Window", "Date"]
|
||||
for name in candidate_names:
|
||||
header.extend([f"{name} req/s", f"{name} req/s/gpu"])
|
||||
header.append("Winner")
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"## Per Window",
|
||||
"",
|
||||
"| " + " | ".join(header) + " |",
|
||||
"| " + " | ".join(["---"] * len(header)) + " |",
|
||||
]
|
||||
)
|
||||
for row in summary["windows"]:
|
||||
cells = [f"`{row['window_id']}`", f"`{row.get('date') or ''}`"]
|
||||
for name in candidate_names:
|
||||
candidate = row["candidates"][name]
|
||||
cells.append(f"`{candidate.get('best_request_rate')}`")
|
||||
cells.append(f"`{candidate.get('best_request_rate_per_gpu')}`")
|
||||
cells.append(f"`{row['winner']}`")
|
||||
lines.append("| " + " | ".join(cells) + " |")
|
||||
lines.append("")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def run_multi_compare(spec_path: Path) -> dict[str, Any]:
|
||||
spec_path = spec_path.resolve()
|
||||
spec = MultiCompareSpec.from_dict(dict(load_structured_file(spec_path)))
|
||||
study_spec_path = _resolve_path(spec.study_spec_path, base_dir=spec_path.parent)
|
||||
study = load_study_spec(study_spec_path)
|
||||
compare_root = (
|
||||
_resolve_path(spec.output_root, base_dir=spec_path.parent)
|
||||
if spec.output_root
|
||||
else (Path(".aituner-compare") / spec.compare_id).resolve()
|
||||
)
|
||||
compare_root.mkdir(parents=True, exist_ok=True)
|
||||
windows = _select_windows(spec, study=study, study_spec_path=study_spec_path)
|
||||
candidate_payloads = []
|
||||
resolved_candidates: dict[str, tuple[MultiCompareCandidate, ConfigPatch, dict[str, Any]]] = {}
|
||||
for candidate in spec.candidates:
|
||||
config_patch, source = _load_config_patch(candidate, spec_path=spec_path)
|
||||
resolved_candidates[candidate.name] = (candidate, config_patch, source)
|
||||
candidate_payloads.append(
|
||||
{
|
||||
"name": candidate.name,
|
||||
"phase": candidate.phase,
|
||||
"runtime": {
|
||||
"cuda_visible_devices": candidate.runtime.cuda_visible_devices,
|
||||
"port": candidate.runtime.port,
|
||||
"host": candidate.runtime.host,
|
||||
},
|
||||
"source": source,
|
||||
}
|
||||
)
|
||||
snapshot = {
|
||||
"compare_id": spec.compare_id,
|
||||
"study_spec_path": str(study_spec_path),
|
||||
"window_ids": spec.window_ids,
|
||||
"candidates": candidate_payloads,
|
||||
}
|
||||
StudyStore.write_json(compare_root / "compare_spec.snapshot.json", snapshot)
|
||||
|
||||
phases = sorted({item.phase for item in spec.candidates})
|
||||
per_window: list[dict[str, Any]] = []
|
||||
for window in windows:
|
||||
window_id = str(window["window_id"])
|
||||
row = {
|
||||
"window_id": window_id,
|
||||
"trace_type": window.get("trace_type"),
|
||||
"date": window.get("date"),
|
||||
"slot_token": window.get("slot_token"),
|
||||
"slot_label": window.get("slot_label"),
|
||||
"window_start": window.get("window_start"),
|
||||
"window_end": window.get("window_end"),
|
||||
"candidates": {},
|
||||
}
|
||||
for phase in phases:
|
||||
phase_candidates = [item for item in spec.candidates if item.phase == phase]
|
||||
with ThreadPoolExecutor(max_workers=len(phase_candidates)) as executor:
|
||||
future_map = {
|
||||
executor.submit(
|
||||
_run_candidate_for_window,
|
||||
compare_id=spec.compare_id,
|
||||
compare_root=compare_root,
|
||||
study=study,
|
||||
study_spec_path=study_spec_path,
|
||||
window_id=window_id,
|
||||
candidate=item,
|
||||
config_patch=resolved_candidates[item.name][1],
|
||||
source=resolved_candidates[item.name][2],
|
||||
): item.name
|
||||
for item in phase_candidates
|
||||
}
|
||||
for future in as_completed(future_map):
|
||||
result = future.result()
|
||||
row["candidates"][result["candidate"]] = result
|
||||
row["winner"] = _winner(row["candidates"])
|
||||
per_window.append(row)
|
||||
partial_summary = {
|
||||
"compare_id": spec.compare_id,
|
||||
"study_spec_path": str(study_spec_path),
|
||||
"compare_root": str(compare_root),
|
||||
"windows": per_window,
|
||||
"aggregate": _aggregate(per_window, spec.candidates),
|
||||
}
|
||||
StudyStore.write_json(compare_root / "summary.json", partial_summary)
|
||||
(compare_root / "report.md").write_text(
|
||||
_render_report(partial_summary, spec.candidates),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
summary = {
|
||||
"compare_id": spec.compare_id,
|
||||
"study_spec_path": str(study_spec_path),
|
||||
"compare_root": str(compare_root),
|
||||
"windows": per_window,
|
||||
"aggregate": _aggregate(per_window, spec.candidates),
|
||||
}
|
||||
StudyStore.write_json(compare_root / "summary.json", summary)
|
||||
(compare_root / "report.md").write_text(
|
||||
_render_report(summary, spec.candidates),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return summary
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Run a multi-candidate compare over trace windows.")
|
||||
parser.add_argument("--spec", required=True)
|
||||
args = parser.parse_args()
|
||||
summary = run_multi_compare(Path(args.spec))
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"compare_id": summary["compare_id"],
|
||||
"compare_root": summary["compare_root"],
|
||||
"window_count": summary["aggregate"]["window_count"],
|
||||
"wins": summary["aggregate"]["wins"],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user