Files
aituner/scripts/ablation_trajectory.py

92 lines
3.2 KiB
Python

#!/usr/bin/env python3
"""Extract a per-iteration trajectory table from an ablation study store.
Usage: python3 ablation_trajectory.py <study_store_dir>
Prints iter, proposal source/name, config_patch summary, per_gpu, status,
and the running incumbent per_gpu. Read-only.
"""
import json
import sys
from pathlib import Path
TOPOLOGY_KEYS = (
("tensor-parallel-size", "TP"),
("data-parallel-size", "DP"),
("expert-parallel-size", "EP"),
)
RUNTIME_KEYS = (
"gpu-memory-utilization",
"enable-chunked-prefill",
"max-num-batched-tokens",
"max-num-seqs",
)
def topo(patch, base_flags=None):
fp = (patch or {}).get("flag_patch", {}) or {}
ep = (patch or {}).get("env_patch", {}) or {}
effective = dict(base_flags or {})
effective.update(fp)
parts = []
for k, label in TOPOLOGY_KEYS:
if k in effective:
parts.append(f"{label}{effective[k]}")
runtime = {k: effective[k] for k in RUNTIME_KEYS if k in effective}
runtime.update(
{
k: v
for k, v in fp.items()
if k not in {key for key, _ in TOPOLOGY_KEYS} and k not in runtime
}
)
runtime.update({f"env:{k}": v for k, v in ep.items()})
base = "+".join(parts) if parts else "baseline-topo"
if runtime:
base += " | " + ", ".join(f"{k}={v}" for k, v in runtime.items())
return base
def main():
store = Path(sys.argv[1])
state = json.load(open(store / "state.json"))
snapshot_path = store / "study_spec.snapshot.json"
base_flags = {}
if snapshot_path.exists():
snapshot = json.load(open(snapshot_path))
base_flags = ((snapshot.get("engine") or {}).get("base_flags") or {})
print(f"study_id: {state.get('study_id')}")
print(f"best_trial: {state.get('best_trial_id')} best_per_gpu: {state.get('best_request_rate_per_gpu')}")
print(f"stop_reason: {state.get('tuning_stop_reason')!r}")
print(f"stop_diagnosis: {state.get('tuning_stop_diagnosis')!r}")
print(f"stop_details: {json.dumps(state.get('tuning_stop_details'), ensure_ascii=False)}")
print()
incumbent = None
hdr = f"{'iter':<5}{'trial':<11}{'status':<14}{'per_gpu':<10}{'incumbent':<11}config"
print(hdr)
print("-" * len(hdr))
for i, t in enumerate(state.get("trials", []), 1):
pg = t.get("best_request_rate_per_gpu")
if pg is not None and (incumbent is None or pg > incumbent):
incumbent = pg
pgs = f"{pg:.4f}" if isinstance(pg, (int, float)) else str(pg)
incs = f"{incumbent:.4f}" if isinstance(incumbent, (int, float)) else str(incumbent)
print(
f"{i:<5}{t.get('trial_id',''):<11}{str(t.get('status','')):<14}{pgs:<10}{incs:<11}{topo(t.get('config_patch'), base_flags)}"
)
# also dump proposals dir to see what was *proposed* (incl. vetoed/failed)
pdir = store / "proposals"
if pdir.exists():
print("\n-- proposal files (chronological) --")
for p in sorted(pdir.glob("*.json")):
try:
pr = json.load(open(p))
except Exception:
continue
print(f" {p.stem}: should_stop={pr.get('should_stop')} | {topo(pr.get('config_patch'), base_flags)}")
if __name__ == "__main__":
main()