Files
replaysim/tools/run_frontier_sweep.py

535 lines
19 KiB
Python

#!/usr/bin/env python3
"""Run a small Frontier sweep from a ReplayServe JSON config."""
from __future__ import annotations
import argparse
import json
import os
import shutil
import subprocess
import sys
import time
from pathlib import Path
from typing import Any
REPLAYSERVE_ROOT = Path(__file__).resolve().parents[1]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run Frontier configs from JSON.")
parser.add_argument(
"--config",
type=Path,
default=REPLAYSERVE_ROOT / "configs" / "rs3_tiny_sweep.json",
help="Sweep JSON config.",
)
parser.add_argument("--suite-id", help="Override suite_id from the config.")
parser.add_argument(
"--run-root",
type=Path,
help="Override run root. Defaults to runs/<suite_id>.",
)
parser.add_argument(
"--only-config",
action="append",
default=[],
help="Run only a config id. Can be repeated.",
)
parser.add_argument(
"--only-fixture",
action="append",
default=[],
help="Run only a fixture. Can be repeated.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Write manifests and commands, but do not execute Frontier.",
)
parser.add_argument(
"--force",
action="store_true",
help="Replace existing run dirs selected by this invocation.",
)
return parser.parse_args()
def load_json(path: Path) -> dict[str, Any]:
with path.open("r", encoding="utf-8") as handle:
data = json.load(handle)
if not isinstance(data, dict):
raise ValueError(f"{path}: top-level JSON must be an object")
return data
def git_head(path: Path) -> str | None:
try:
result = subprocess.run(
["git", "-C", str(path), "rev-parse", "HEAD"],
check=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError:
return None
return result.stdout.strip()
def git_status(path: Path) -> str | None:
try:
result = subprocess.run(
["git", "-C", str(path), "status", "--short"],
check=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
except subprocess.CalledProcessError:
return None
return result.stdout
def shell_join(argv: list[str]) -> str:
import shlex
return " ".join(shlex.quote(part) for part in argv)
def merge_config(defaults: dict[str, Any], item: dict[str, Any]) -> dict[str, Any]:
merged = dict(defaults)
overrides = item.get("overrides", {})
if overrides is None:
overrides = {}
if not isinstance(overrides, dict):
raise ValueError(f"config {item.get('id')}: overrides must be an object")
merged.update(overrides)
if "max_num_seqs" in merged and "batch_size_cap" not in overrides:
merged["batch_size_cap"] = merged["max_num_seqs"]
return merged
def build_frontier_command(
*,
python_bin: str,
trace_file: Path,
metrics_root: Path,
run_id: str,
knobs: dict[str, Any],
) -> list[str]:
cmd = [
python_bin,
"-m",
"frontier.main",
"--simulation_mode",
str(knobs["simulation_mode"]),
"--sys_arch",
str(knobs["sys_arch"]),
"--cc_backend_config_type",
"analytical",
"--cluster_config_num_replicas",
str(knobs["num_replicas"]),
"--cluster_scheduler_config_type",
str(knobs["cluster_scheduler"]),
"--replica_config_model_name",
str(knobs["model_name"]),
"--replica_config_device",
str(knobs["device"]),
"--replica_config_network_device",
str(knobs["network_device"]),
"--replica_config_attn_tensor_parallel_size",
str(knobs["attn_tensor_parallel_size"]),
"--replica_config_attn_data_parallel_size",
str(knobs["attn_data_parallel_size"]),
"--replica_config_moe_tensor_parallel_size",
str(knobs["moe_tensor_parallel_size"]),
"--replica_config_moe_expert_parallel_size",
str(knobs["moe_expert_parallel_size"]),
"--replica_config_num_pipeline_stages",
str(knobs["num_pipeline_stages"]),
"--replica_scheduler_config_type",
str(knobs["replica_scheduler"]),
"--decode_cuda_graph_mode",
str(knobs.get("decode_cuda_graph_mode", "full_decode_only")),
"--vllm_v1_scheduler_config_batch_size_cap",
str(knobs["batch_size_cap"]),
"--vllm_v1_scheduler_config_max_tokens_in_batch",
str(knobs["max_tokens_in_batch"]),
"--vllm_v1_scheduler_config_long_prefill_token_threshold",
str(knobs["long_prefill_token_threshold"]),
"--vllm_v1_scheduler_config_block_size",
str(knobs["block_size"]),
"--vllm_v1_scheduler_config_num_blocks_mode",
str(knobs["num_blocks_mode"]),
"--vllm_v1_scheduler_config_gpu_memory_utilization",
str(knobs["gpu_memory_utilization"]),
"--vllm_v1_scheduler_config_non_kv_cache_overhead_bytes",
str(knobs["non_kv_cache_overhead_bytes"]),
"--request_generator_config_type",
"trace_replay",
"--trace_request_generator_config_trace_file",
str(trace_file),
"--trace_request_generator_config_max_tokens",
str(knobs["trace_max_tokens"]),
"--metrics_config_output_dir",
str(metrics_root),
"--metrics_config_run_id",
run_id,
"--metrics_config_write_metrics",
"--metrics_config_store_request_metrics",
"--metrics_config_store_batch_metrics",
"--metrics_config_store_token_completion_metrics",
"--metrics_config_store_utilization_metrics",
"--no-metrics_config_store_plots",
"--no-metrics_config_enable_chrome_trace",
"--no-metrics_config_write_json_trace",
"--no-metrics_config_store_frontier_stage_batch_ledger",
]
if bool(knobs.get("enable_dummy_mode", True)):
cmd.extend(
[
"--random_forrest_execution_time_predictor_config_enable_dummy_mode",
"--random_forrest_execution_time_predictor_config_dummy_execution_time_ms",
str(knobs["dummy_execution_time_ms"]),
]
)
else:
cmd.append("--no-random_forrest_execution_time_predictor_config_enable_dummy_mode")
profile_arg_names = {
"linear_op_input_file": "linear_op_input_file",
"atten_input_file": "atten_input_file",
"moe_input_file": "moe_input_file",
"linear_op_kernel_only_input_file": "linear_op_kernel_only_input_file",
"atten_kernel_only_input_file": "atten_kernel_only_input_file",
"moe_kernel_only_input_file": "moe_kernel_only_input_file",
}
for knob_name, cli_name in profile_arg_names.items():
value = knobs.get(knob_name)
if value:
cmd.extend(
[
f"--random_forrest_execution_time_predictor_config_{cli_name}",
str(value),
]
)
for knob_name in (
"prediction_max_prefill_chunk_size",
"prediction_max_batch_size",
"prediction_max_tokens_per_request",
):
value = knobs.get(knob_name)
if value is not None:
cmd.extend(
[
f"--random_forrest_execution_time_predictor_config_{knob_name}",
str(value),
]
)
if bool(knobs.get("no_cache", False)):
cmd.append("--random_forrest_execution_time_predictor_config_no_cache")
if bool(knobs.get("skip_cpu_overhead_modeling", True)):
cmd.append(
"--random_forrest_execution_time_predictor_config_skip_cpu_overhead_modeling"
)
if knobs.get("num_blocks") is not None:
cmd.extend(
[
"--vllm_v1_scheduler_config_num_blocks",
str(knobs["num_blocks"]),
]
)
if bool(knobs["enable_prefix_caching"]):
cmd.append("--vllm_v1_scheduler_config_enable_prefix_caching")
if bool(knobs["enable_chunked_prefill"]):
cmd.append("--vllm_v1_scheduler_config_enable_chunked_prefill")
return cmd
def write_text(path: Path, text: str) -> None:
path.write_text(text, encoding="utf-8")
def run_one(
*,
suite_id: str,
sim: str,
frontier_info: dict[str, Any],
frontier_root: Path,
fixture: str,
config_item: dict[str, Any],
knobs: dict[str, Any],
run_root: Path,
python_bin: str,
python_deps_dir: Path,
dry_run: bool,
force: bool,
) -> dict[str, Any]:
config_id = str(config_item["id"])
fixture_dir = REPLAYSERVE_ROOT / "traces" / "fixtures" / fixture
trace_file = fixture_dir / "frontier.csv"
sidecar_file = fixture_dir / "sidecar.jsonl"
if not trace_file.exists():
raise FileNotFoundError(f"missing trace file: {trace_file}")
if not sidecar_file.exists():
raise FileNotFoundError(f"missing sidecar file: {sidecar_file}")
run_dir = (run_root / sim / fixture / config_id).resolve()
metrics_root = (run_dir / "frontier_metrics").resolve()
if run_dir.exists():
if not force:
raise FileExistsError(f"run dir exists, use --force to replace: {run_dir}")
shutil.rmtree(run_dir)
run_dir.mkdir(parents=True)
metrics_root.mkdir(parents=True)
run_id = f"{suite_id}_{fixture}_{config_id}"
cmd = build_frontier_command(
python_bin=python_bin,
trace_file=trace_file,
metrics_root=metrics_root,
run_id=run_id,
knobs=knobs,
)
existing_pythonpath = os.environ.get("PYTHONPATH")
pythonpath_parts = []
if python_deps_dir.is_dir():
pythonpath_parts.append(str(python_deps_dir))
pythonpath_parts.append(str(frontier_root))
if existing_pythonpath:
pythonpath_parts.append(existing_pythonpath)
env = os.environ.copy()
env.update(
{
"PYTHONPATH": ":".join(pythonpath_parts),
"WANDB_DISABLED": "true",
"VIDUR_DISABLE_WANDB": "1",
"FRONTIER_LOG_LEVEL": env.get("FRONTIER_LOG_LEVEL", "info"),
"PYTHONDONTWRITEBYTECODE": "1",
}
)
frontier_head = git_head(frontier_root)
frontier_status = git_status(frontier_root)
manifest = {
"suite_id": suite_id,
"sim": sim,
"fixture": fixture,
"config_id": config_id,
"description": config_item.get("description", ""),
"run_dir": str(run_dir),
"metrics_root": str(metrics_root),
"run_id": run_id,
"frontier": {
**frontier_info,
"root": str(frontier_root),
"head": frontier_head,
"status_short": frontier_status,
},
"fixture_dir": str(fixture_dir),
"trace_file": str(trace_file),
"sidecar_file": str(sidecar_file),
"knobs": knobs,
"command": cmd,
}
with (run_dir / "run_manifest.json").open("w", encoding="utf-8") as handle:
json.dump(manifest, handle, indent=2, sort_keys=True)
handle.write("\n")
write_text(
run_dir / "command.txt",
"\n".join(
[
f"cd {frontier_root}",
f"export PYTHONPATH={env['PYTHONPATH']}",
f"export WANDB_DISABLED={env['WANDB_DISABLED']}",
f"export VIDUR_DISABLE_WANDB={env['VIDUR_DISABLE_WANDB']}",
f"export FRONTIER_LOG_LEVEL={env['FRONTIER_LOG_LEVEL']}",
f"export PYTHONDONTWRITEBYTECODE={env['PYTHONDONTWRITEBYTECODE']}",
f"command={shell_join(cmd)}",
"",
]
),
)
write_text(
run_dir / "env.txt",
"\n".join(
[
f"suite_id={suite_id}",
f"sim={sim}",
f"fixture={fixture}",
f"config_id={config_id}",
f"replayserve_root={REPLAYSERVE_ROOT}",
f"frontier_root={frontier_root}",
f"frontier_head={frontier_head}",
f"python_deps_dir={python_deps_dir}",
f"trace_file={trace_file}",
f"sidecar_file={sidecar_file}",
f"run_dir={run_dir}",
f"metrics_root={metrics_root}",
f"run_id={run_id}",
"",
]
),
)
if dry_run:
write_text(run_dir / "exit_code.txt", "0\n")
status = {
"status": "dry_run",
"exit_code": 0,
"runtime_seconds": 0,
"postprocess_exit_code": None,
}
with (run_dir / "run_status.json").open("w", encoding="utf-8") as handle:
json.dump(status, handle, indent=2, sort_keys=True)
handle.write("\n")
return status
start_epoch = int(time.time())
write_text(run_dir / "start_epoch.txt", f"{start_epoch}\n")
with (run_dir / "stdout.log").open("w", encoding="utf-8") as stdout, (
run_dir / "stderr.log"
).open("w", encoding="utf-8") as stderr:
proc = subprocess.run(cmd, cwd=frontier_root, env=env, stdout=stdout, stderr=stderr)
end_epoch = int(time.time())
runtime_seconds = end_epoch - start_epoch
write_text(run_dir / "end_epoch.txt", f"{end_epoch}\n")
write_text(run_dir / "exit_code.txt", f"{proc.returncode}\n")
write_text(run_dir / "runtime_seconds.txt", f"{runtime_seconds}\n")
postprocess_exit_code: int | None = None
if proc.returncode == 0:
postprocess_cmd = [
python_bin,
str(REPLAYSERVE_ROOT / "tools" / "postprocess_frontier_smoke.py"),
"--run-dir",
str(run_dir),
"--fixture-dir",
str(fixture_dir),
]
with (run_dir / "postprocess.stdout.log").open("w", encoding="utf-8") as stdout, (
run_dir / "postprocess.stderr.log"
).open("w", encoding="utf-8") as stderr:
post = subprocess.run(
postprocess_cmd,
cwd=REPLAYSERVE_ROOT,
env={**env, "PYTHONPATH": env["PYTHONPATH"]},
stdout=stdout,
stderr=stderr,
)
postprocess_exit_code = post.returncode
status_name = "pass" if proc.returncode == 0 and postprocess_exit_code in (0, None) else "fail"
if proc.returncode == 0 and postprocess_exit_code not in (0, None):
status_name = "postprocess_fail"
if status_name == "pass":
summary_path = run_dir / "postprocess_summary.json"
if summary_path.exists():
try:
summary = load_json(summary_path)
completion = summary.get("completion", {})
if isinstance(completion, dict) and not completion.get("is_complete", True):
status_name = "incomplete"
except Exception:
status_name = "postprocess_fail"
status = {
"status": status_name,
"exit_code": proc.returncode,
"runtime_seconds": runtime_seconds,
"postprocess_exit_code": postprocess_exit_code,
}
with (run_dir / "run_status.json").open("w", encoding="utf-8") as handle:
json.dump(status, handle, indent=2, sort_keys=True)
handle.write("\n")
return status
def main() -> int:
args = parse_args()
config_path = args.config.resolve()
config = load_json(config_path)
suite_id = args.suite_id or str(config.get("suite_id") or "rs3_sweep")
run_root = args.run_root or (REPLAYSERVE_ROOT / "runs" / suite_id)
sim = str(config.get("sim") or "frontier")
frontier_info = config.get("frontier", {})
if not isinstance(frontier_info, dict):
raise ValueError("frontier must be an object")
frontier_root = Path(str(frontier_info.get("root") or "/tmp/toc-llm-sim-research/Frontier"))
if not frontier_root.is_dir():
raise FileNotFoundError(f"Frontier root does not exist: {frontier_root}")
fixtures = [str(value) for value in config.get("fixtures", [])]
if args.only_fixture:
selected = set(args.only_fixture)
fixtures = [value for value in fixtures if value in selected]
if not fixtures:
raise ValueError("no fixtures selected")
defaults = config.get("defaults", {})
if not isinstance(defaults, dict):
raise ValueError("defaults must be an object")
config_items = config.get("configs", [])
if not isinstance(config_items, list) or not config_items:
raise ValueError("configs must be a non-empty list")
if args.only_config:
selected_configs = set(args.only_config)
config_items = [
item
for item in config_items
if isinstance(item, dict) and str(item.get("id")) in selected_configs
]
if not config_items:
raise ValueError("no configs selected")
if (REPLAYSERVE_ROOT / ".venv" / "bin" / "python").is_file():
python_bin = str(REPLAYSERVE_ROOT / ".venv" / "bin" / "python")
else:
python_bin = os.environ.get("PYTHON_BIN", sys.executable or "python3")
python_deps_dir = Path(
os.environ.get("PYTHON_DEPS_DIR", str(REPLAYSERVE_ROOT / ".deps" / "python"))
)
results: list[dict[str, Any]] = []
for fixture in fixtures:
for item in config_items:
if not isinstance(item, dict):
raise ValueError("each configs entry must be an object")
if "id" not in item:
raise ValueError("each configs entry needs id")
knobs = merge_config(defaults, item)
status = run_one(
suite_id=suite_id,
sim=sim,
frontier_info=frontier_info,
frontier_root=frontier_root,
fixture=fixture,
config_item=item,
knobs=knobs,
run_root=run_root,
python_bin=python_bin,
python_deps_dir=python_deps_dir,
dry_run=args.dry_run,
force=args.force,
)
results.append(
{
"fixture": fixture,
"config_id": item["id"],
**status,
}
)
print(
f"{fixture}/{item['id']}: {status['status']} "
f"exit={status['exit_code']} runtime={status['runtime_seconds']}s"
)
failures = [row for row in results if row["status"] not in {"pass", "dry_run"}]
return 1 if failures else 0
if __name__ == "__main__":
raise SystemExit(main())