diff --git a/scripts/ablation_trajectory.py b/scripts/ablation_trajectory.py index 8e76f89..0442d32 100644 --- a/scripts/ablation_trajectory.py +++ b/scripts/ablation_trajectory.py @@ -10,22 +10,37 @@ import sys from pathlib import Path -def topo(patch): +TOPOLOGY_KEYS = ( + ("tensor-parallel-size", "TP"), + ("data-parallel-size", "DP"), + ("expert-parallel-size", "EP"), +) + +RUNTIME_KEYS = ( + "gpu-memory-utilization", + "enable-chunked-prefill", + "max-num-batched-tokens", + "max-num-seqs", +) + + +def topo(patch, base_flags=None): fp = (patch or {}).get("flag_patch", {}) or {} ep = (patch or {}).get("env_patch", {}) or {} + effective = dict(base_flags or {}) + effective.update(fp) parts = [] - for k, label in ( - ("tensor-parallel-size", "TP"), - ("data-parallel-size", "DP"), - ("expert-parallel-size", "EP"), - ): - if k in fp: - parts.append(f"{label}{fp[k]}") - runtime = { - k: v - for k, v in fp.items() - if k not in ("tensor-parallel-size", "data-parallel-size", "expert-parallel-size") - } + for k, label in TOPOLOGY_KEYS: + if k in effective: + parts.append(f"{label}{effective[k]}") + runtime = {k: effective[k] for k in RUNTIME_KEYS if k in effective} + runtime.update( + { + k: v + for k, v in fp.items() + if k not in {key for key, _ in TOPOLOGY_KEYS} and k not in runtime + } + ) runtime.update({f"env:{k}": v for k, v in ep.items()}) base = "+".join(parts) if parts else "baseline-topo" if runtime: @@ -36,6 +51,11 @@ def topo(patch): def main(): store = Path(sys.argv[1]) state = json.load(open(store / "state.json")) + snapshot_path = store / "study_spec.snapshot.json" + base_flags = {} + if snapshot_path.exists(): + snapshot = json.load(open(snapshot_path)) + base_flags = ((snapshot.get("engine") or {}).get("base_flags") or {}) print(f"study_id: {state.get('study_id')}") print(f"best_trial: {state.get('best_trial_id')} best_per_gpu: {state.get('best_request_rate_per_gpu')}") print(f"stop_reason: {state.get('tuning_stop_reason')!r}") @@ -53,7 +73,7 @@ def main(): pgs = f"{pg:.4f}" if isinstance(pg, (int, float)) else str(pg) incs = f"{incumbent:.4f}" if isinstance(incumbent, (int, float)) else str(incumbent) print( - f"{i:<5}{t.get('trial_id',''):<11}{str(t.get('status','')):<14}{pgs:<10}{incs:<11}{topo(t.get('config_patch'))}" + f"{i:<5}{t.get('trial_id',''):<11}{str(t.get('status','')):<14}{pgs:<10}{incs:<11}{topo(t.get('config_patch'), base_flags)}" ) # also dump proposals dir to see what was *proposed* (incl. vetoed/failed) pdir = store / "proposals" @@ -64,7 +84,7 @@ def main(): pr = json.load(open(p)) except Exception: continue - print(f" {p.stem}: should_stop={pr.get('should_stop')} | {topo(pr.get('config_patch'))}") + print(f" {p.stem}: should_stop={pr.get('should_stop')} | {topo(pr.get('config_patch'), base_flags)}") if __name__ == "__main__":