Show effective flags in ablation trajectory
This commit is contained in:
@@ -10,22 +10,37 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
def topo(patch):
|
TOPOLOGY_KEYS = (
|
||||||
|
("tensor-parallel-size", "TP"),
|
||||||
|
("data-parallel-size", "DP"),
|
||||||
|
("expert-parallel-size", "EP"),
|
||||||
|
)
|
||||||
|
|
||||||
|
RUNTIME_KEYS = (
|
||||||
|
"gpu-memory-utilization",
|
||||||
|
"enable-chunked-prefill",
|
||||||
|
"max-num-batched-tokens",
|
||||||
|
"max-num-seqs",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def topo(patch, base_flags=None):
|
||||||
fp = (patch or {}).get("flag_patch", {}) or {}
|
fp = (patch or {}).get("flag_patch", {}) or {}
|
||||||
ep = (patch or {}).get("env_patch", {}) or {}
|
ep = (patch or {}).get("env_patch", {}) or {}
|
||||||
|
effective = dict(base_flags or {})
|
||||||
|
effective.update(fp)
|
||||||
parts = []
|
parts = []
|
||||||
for k, label in (
|
for k, label in TOPOLOGY_KEYS:
|
||||||
("tensor-parallel-size", "TP"),
|
if k in effective:
|
||||||
("data-parallel-size", "DP"),
|
parts.append(f"{label}{effective[k]}")
|
||||||
("expert-parallel-size", "EP"),
|
runtime = {k: effective[k] for k in RUNTIME_KEYS if k in effective}
|
||||||
):
|
runtime.update(
|
||||||
if k in fp:
|
{
|
||||||
parts.append(f"{label}{fp[k]}")
|
k: v
|
||||||
runtime = {
|
for k, v in fp.items()
|
||||||
k: v
|
if k not in {key for key, _ in TOPOLOGY_KEYS} and k not in runtime
|
||||||
for k, v in fp.items()
|
}
|
||||||
if k not in ("tensor-parallel-size", "data-parallel-size", "expert-parallel-size")
|
)
|
||||||
}
|
|
||||||
runtime.update({f"env:{k}": v for k, v in ep.items()})
|
runtime.update({f"env:{k}": v for k, v in ep.items()})
|
||||||
base = "+".join(parts) if parts else "baseline-topo"
|
base = "+".join(parts) if parts else "baseline-topo"
|
||||||
if runtime:
|
if runtime:
|
||||||
@@ -36,6 +51,11 @@ def topo(patch):
|
|||||||
def main():
|
def main():
|
||||||
store = Path(sys.argv[1])
|
store = Path(sys.argv[1])
|
||||||
state = json.load(open(store / "state.json"))
|
state = json.load(open(store / "state.json"))
|
||||||
|
snapshot_path = store / "study_spec.snapshot.json"
|
||||||
|
base_flags = {}
|
||||||
|
if snapshot_path.exists():
|
||||||
|
snapshot = json.load(open(snapshot_path))
|
||||||
|
base_flags = ((snapshot.get("engine") or {}).get("base_flags") or {})
|
||||||
print(f"study_id: {state.get('study_id')}")
|
print(f"study_id: {state.get('study_id')}")
|
||||||
print(f"best_trial: {state.get('best_trial_id')} best_per_gpu: {state.get('best_request_rate_per_gpu')}")
|
print(f"best_trial: {state.get('best_trial_id')} best_per_gpu: {state.get('best_request_rate_per_gpu')}")
|
||||||
print(f"stop_reason: {state.get('tuning_stop_reason')!r}")
|
print(f"stop_reason: {state.get('tuning_stop_reason')!r}")
|
||||||
@@ -53,7 +73,7 @@ def main():
|
|||||||
pgs = f"{pg:.4f}" if isinstance(pg, (int, float)) else str(pg)
|
pgs = f"{pg:.4f}" if isinstance(pg, (int, float)) else str(pg)
|
||||||
incs = f"{incumbent:.4f}" if isinstance(incumbent, (int, float)) else str(incumbent)
|
incs = f"{incumbent:.4f}" if isinstance(incumbent, (int, float)) else str(incumbent)
|
||||||
print(
|
print(
|
||||||
f"{i:<5}{t.get('trial_id',''):<11}{str(t.get('status','')):<14}{pgs:<10}{incs:<11}{topo(t.get('config_patch'))}"
|
f"{i:<5}{t.get('trial_id',''):<11}{str(t.get('status','')):<14}{pgs:<10}{incs:<11}{topo(t.get('config_patch'), base_flags)}"
|
||||||
)
|
)
|
||||||
# also dump proposals dir to see what was *proposed* (incl. vetoed/failed)
|
# also dump proposals dir to see what was *proposed* (incl. vetoed/failed)
|
||||||
pdir = store / "proposals"
|
pdir = store / "proposals"
|
||||||
@@ -64,7 +84,7 @@ def main():
|
|||||||
pr = json.load(open(p))
|
pr = json.load(open(p))
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
print(f" {p.stem}: should_stop={pr.get('should_stop')} | {topo(pr.get('config_patch'))}")
|
print(f" {p.stem}: should_stop={pr.get('should_stop')} | {topo(pr.get('config_patch'), base_flags)}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user