#!/usr/bin/env python3 """Run a small Frontier sweep from a ReplayServe JSON config.""" from __future__ import annotations import argparse import json import os import shutil import subprocess import sys import time from pathlib import Path from typing import Any REPLAYSERVE_ROOT = Path(__file__).resolve().parents[1] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run Frontier configs from JSON.") parser.add_argument( "--config", type=Path, default=REPLAYSERVE_ROOT / "configs" / "rs3_tiny_sweep.json", help="Sweep JSON config.", ) parser.add_argument("--suite-id", help="Override suite_id from the config.") parser.add_argument( "--run-root", type=Path, help="Override run root. Defaults to runs/.", ) parser.add_argument( "--only-config", action="append", default=[], help="Run only a config id. Can be repeated.", ) parser.add_argument( "--only-fixture", action="append", default=[], help="Run only a fixture. Can be repeated.", ) parser.add_argument( "--dry-run", action="store_true", help="Write manifests and commands, but do not execute Frontier.", ) parser.add_argument( "--force", action="store_true", help="Replace existing run dirs selected by this invocation.", ) return parser.parse_args() def load_json(path: Path) -> dict[str, Any]: with path.open("r", encoding="utf-8") as handle: data = json.load(handle) if not isinstance(data, dict): raise ValueError(f"{path}: top-level JSON must be an object") return data def git_head(path: Path) -> str | None: try: result = subprocess.run( ["git", "-C", str(path), "rev-parse", "HEAD"], check=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) except subprocess.CalledProcessError: return None return result.stdout.strip() def git_status(path: Path) -> str | None: try: result = subprocess.run( ["git", "-C", str(path), "status", "--short"], check=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) except subprocess.CalledProcessError: return None return result.stdout def shell_join(argv: list[str]) -> str: import shlex return " ".join(shlex.quote(part) for part in argv) def merge_config(defaults: dict[str, Any], item: dict[str, Any]) -> dict[str, Any]: merged = dict(defaults) overrides = item.get("overrides", {}) if overrides is None: overrides = {} if not isinstance(overrides, dict): raise ValueError(f"config {item.get('id')}: overrides must be an object") merged.update(overrides) if "max_num_seqs" in merged and "batch_size_cap" not in overrides: merged["batch_size_cap"] = merged["max_num_seqs"] return merged def build_frontier_command( *, python_bin: str, trace_file: Path, metrics_root: Path, run_id: str, knobs: dict[str, Any], ) -> list[str]: cmd = [ python_bin, "-m", "frontier.main", "--simulation_mode", str(knobs["simulation_mode"]), "--sys_arch", str(knobs["sys_arch"]), "--cc_backend_config_type", "analytical", "--cluster_config_num_replicas", str(knobs["num_replicas"]), "--cluster_scheduler_config_type", str(knobs["cluster_scheduler"]), "--replica_config_model_name", str(knobs["model_name"]), "--replica_config_device", str(knobs["device"]), "--replica_config_network_device", str(knobs["network_device"]), "--replica_config_attn_tensor_parallel_size", str(knobs["attn_tensor_parallel_size"]), "--replica_config_attn_data_parallel_size", str(knobs["attn_data_parallel_size"]), "--replica_config_moe_tensor_parallel_size", str(knobs["moe_tensor_parallel_size"]), "--replica_config_moe_expert_parallel_size", str(knobs["moe_expert_parallel_size"]), "--replica_config_num_pipeline_stages", str(knobs["num_pipeline_stages"]), "--replica_scheduler_config_type", str(knobs["replica_scheduler"]), "--decode_cuda_graph_mode", str(knobs.get("decode_cuda_graph_mode", "full_decode_only")), "--vllm_v1_scheduler_config_batch_size_cap", str(knobs["batch_size_cap"]), "--vllm_v1_scheduler_config_max_tokens_in_batch", str(knobs["max_tokens_in_batch"]), "--vllm_v1_scheduler_config_long_prefill_token_threshold", str(knobs["long_prefill_token_threshold"]), "--vllm_v1_scheduler_config_block_size", str(knobs["block_size"]), "--vllm_v1_scheduler_config_num_blocks_mode", str(knobs["num_blocks_mode"]), "--vllm_v1_scheduler_config_gpu_memory_utilization", str(knobs["gpu_memory_utilization"]), "--vllm_v1_scheduler_config_non_kv_cache_overhead_bytes", str(knobs["non_kv_cache_overhead_bytes"]), "--request_generator_config_type", "trace_replay", "--trace_request_generator_config_trace_file", str(trace_file), "--trace_request_generator_config_max_tokens", str(knobs["trace_max_tokens"]), "--metrics_config_output_dir", str(metrics_root), "--metrics_config_run_id", run_id, "--metrics_config_write_metrics", "--metrics_config_store_request_metrics", "--metrics_config_store_batch_metrics", "--metrics_config_store_token_completion_metrics", "--metrics_config_store_utilization_metrics", "--no-metrics_config_store_plots", "--no-metrics_config_enable_chrome_trace", "--no-metrics_config_write_json_trace", "--no-metrics_config_store_frontier_stage_batch_ledger", ] if bool(knobs.get("enable_dummy_mode", True)): cmd.extend( [ "--random_forrest_execution_time_predictor_config_enable_dummy_mode", "--random_forrest_execution_time_predictor_config_dummy_execution_time_ms", str(knobs["dummy_execution_time_ms"]), ] ) else: cmd.append("--no-random_forrest_execution_time_predictor_config_enable_dummy_mode") profile_arg_names = { "linear_op_input_file": "linear_op_input_file", "atten_input_file": "atten_input_file", "moe_input_file": "moe_input_file", "linear_op_kernel_only_input_file": "linear_op_kernel_only_input_file", "atten_kernel_only_input_file": "atten_kernel_only_input_file", "moe_kernel_only_input_file": "moe_kernel_only_input_file", } for knob_name, cli_name in profile_arg_names.items(): value = knobs.get(knob_name) if value: cmd.extend( [ f"--random_forrest_execution_time_predictor_config_{cli_name}", str(value), ] ) for knob_name in ( "prediction_max_prefill_chunk_size", "prediction_max_batch_size", "prediction_max_tokens_per_request", ): value = knobs.get(knob_name) if value is not None: cmd.extend( [ f"--random_forrest_execution_time_predictor_config_{knob_name}", str(value), ] ) if bool(knobs.get("no_cache", False)): cmd.append("--random_forrest_execution_time_predictor_config_no_cache") if bool(knobs.get("skip_cpu_overhead_modeling", True)): cmd.append( "--random_forrest_execution_time_predictor_config_skip_cpu_overhead_modeling" ) if knobs.get("num_blocks") is not None: cmd.extend( [ "--vllm_v1_scheduler_config_num_blocks", str(knobs["num_blocks"]), ] ) if bool(knobs["enable_prefix_caching"]): cmd.append("--vllm_v1_scheduler_config_enable_prefix_caching") if bool(knobs["enable_chunked_prefill"]): cmd.append("--vllm_v1_scheduler_config_enable_chunked_prefill") return cmd def write_text(path: Path, text: str) -> None: path.write_text(text, encoding="utf-8") def run_one( *, suite_id: str, sim: str, frontier_info: dict[str, Any], frontier_root: Path, fixture: str, config_item: dict[str, Any], knobs: dict[str, Any], run_root: Path, python_bin: str, python_deps_dir: Path, dry_run: bool, force: bool, ) -> dict[str, Any]: config_id = str(config_item["id"]) fixture_dir = REPLAYSERVE_ROOT / "traces" / "fixtures" / fixture trace_file = fixture_dir / "frontier.csv" sidecar_file = fixture_dir / "sidecar.jsonl" if not trace_file.exists(): raise FileNotFoundError(f"missing trace file: {trace_file}") if not sidecar_file.exists(): raise FileNotFoundError(f"missing sidecar file: {sidecar_file}") run_dir = (run_root / sim / fixture / config_id).resolve() metrics_root = (run_dir / "frontier_metrics").resolve() if run_dir.exists(): if not force: raise FileExistsError(f"run dir exists, use --force to replace: {run_dir}") shutil.rmtree(run_dir) run_dir.mkdir(parents=True) metrics_root.mkdir(parents=True) run_id = f"{suite_id}_{fixture}_{config_id}" cmd = build_frontier_command( python_bin=python_bin, trace_file=trace_file, metrics_root=metrics_root, run_id=run_id, knobs=knobs, ) existing_pythonpath = os.environ.get("PYTHONPATH") pythonpath_parts = [] if python_deps_dir.is_dir(): pythonpath_parts.append(str(python_deps_dir)) pythonpath_parts.append(str(frontier_root)) if existing_pythonpath: pythonpath_parts.append(existing_pythonpath) env = os.environ.copy() env.update( { "PYTHONPATH": ":".join(pythonpath_parts), "WANDB_DISABLED": "true", "VIDUR_DISABLE_WANDB": "1", "FRONTIER_LOG_LEVEL": env.get("FRONTIER_LOG_LEVEL", "info"), "PYTHONDONTWRITEBYTECODE": "1", } ) frontier_head = git_head(frontier_root) frontier_status = git_status(frontier_root) manifest = { "suite_id": suite_id, "sim": sim, "fixture": fixture, "config_id": config_id, "description": config_item.get("description", ""), "run_dir": str(run_dir), "metrics_root": str(metrics_root), "run_id": run_id, "frontier": { **frontier_info, "root": str(frontier_root), "head": frontier_head, "status_short": frontier_status, }, "fixture_dir": str(fixture_dir), "trace_file": str(trace_file), "sidecar_file": str(sidecar_file), "knobs": knobs, "command": cmd, } with (run_dir / "run_manifest.json").open("w", encoding="utf-8") as handle: json.dump(manifest, handle, indent=2, sort_keys=True) handle.write("\n") write_text( run_dir / "command.txt", "\n".join( [ f"cd {frontier_root}", f"export PYTHONPATH={env['PYTHONPATH']}", f"export WANDB_DISABLED={env['WANDB_DISABLED']}", f"export VIDUR_DISABLE_WANDB={env['VIDUR_DISABLE_WANDB']}", f"export FRONTIER_LOG_LEVEL={env['FRONTIER_LOG_LEVEL']}", f"export PYTHONDONTWRITEBYTECODE={env['PYTHONDONTWRITEBYTECODE']}", f"command={shell_join(cmd)}", "", ] ), ) write_text( run_dir / "env.txt", "\n".join( [ f"suite_id={suite_id}", f"sim={sim}", f"fixture={fixture}", f"config_id={config_id}", f"replayserve_root={REPLAYSERVE_ROOT}", f"frontier_root={frontier_root}", f"frontier_head={frontier_head}", f"python_deps_dir={python_deps_dir}", f"trace_file={trace_file}", f"sidecar_file={sidecar_file}", f"run_dir={run_dir}", f"metrics_root={metrics_root}", f"run_id={run_id}", "", ] ), ) if dry_run: write_text(run_dir / "exit_code.txt", "0\n") status = { "status": "dry_run", "exit_code": 0, "runtime_seconds": 0, "postprocess_exit_code": None, } with (run_dir / "run_status.json").open("w", encoding="utf-8") as handle: json.dump(status, handle, indent=2, sort_keys=True) handle.write("\n") return status start_epoch = int(time.time()) write_text(run_dir / "start_epoch.txt", f"{start_epoch}\n") with (run_dir / "stdout.log").open("w", encoding="utf-8") as stdout, ( run_dir / "stderr.log" ).open("w", encoding="utf-8") as stderr: proc = subprocess.run(cmd, cwd=frontier_root, env=env, stdout=stdout, stderr=stderr) end_epoch = int(time.time()) runtime_seconds = end_epoch - start_epoch write_text(run_dir / "end_epoch.txt", f"{end_epoch}\n") write_text(run_dir / "exit_code.txt", f"{proc.returncode}\n") write_text(run_dir / "runtime_seconds.txt", f"{runtime_seconds}\n") postprocess_exit_code: int | None = None if proc.returncode == 0: postprocess_cmd = [ python_bin, str(REPLAYSERVE_ROOT / "tools" / "postprocess_frontier_smoke.py"), "--run-dir", str(run_dir), "--fixture-dir", str(fixture_dir), ] with (run_dir / "postprocess.stdout.log").open("w", encoding="utf-8") as stdout, ( run_dir / "postprocess.stderr.log" ).open("w", encoding="utf-8") as stderr: post = subprocess.run( postprocess_cmd, cwd=REPLAYSERVE_ROOT, env={**env, "PYTHONPATH": env["PYTHONPATH"]}, stdout=stdout, stderr=stderr, ) postprocess_exit_code = post.returncode status_name = "pass" if proc.returncode == 0 and postprocess_exit_code in (0, None) else "fail" if proc.returncode == 0 and postprocess_exit_code not in (0, None): status_name = "postprocess_fail" if status_name == "pass": summary_path = run_dir / "postprocess_summary.json" if summary_path.exists(): try: summary = load_json(summary_path) completion = summary.get("completion", {}) if isinstance(completion, dict) and not completion.get("is_complete", True): status_name = "incomplete" except Exception: status_name = "postprocess_fail" status = { "status": status_name, "exit_code": proc.returncode, "runtime_seconds": runtime_seconds, "postprocess_exit_code": postprocess_exit_code, } with (run_dir / "run_status.json").open("w", encoding="utf-8") as handle: json.dump(status, handle, indent=2, sort_keys=True) handle.write("\n") return status def main() -> int: args = parse_args() config_path = args.config.resolve() config = load_json(config_path) suite_id = args.suite_id or str(config.get("suite_id") or "rs3_sweep") run_root = args.run_root or (REPLAYSERVE_ROOT / "runs" / suite_id) sim = str(config.get("sim") or "frontier") frontier_info = config.get("frontier", {}) if not isinstance(frontier_info, dict): raise ValueError("frontier must be an object") frontier_root = Path(str(frontier_info.get("root") or "/tmp/toc-llm-sim-research/Frontier")) if not frontier_root.is_dir(): raise FileNotFoundError(f"Frontier root does not exist: {frontier_root}") fixtures = [str(value) for value in config.get("fixtures", [])] if args.only_fixture: selected = set(args.only_fixture) fixtures = [value for value in fixtures if value in selected] if not fixtures: raise ValueError("no fixtures selected") defaults = config.get("defaults", {}) if not isinstance(defaults, dict): raise ValueError("defaults must be an object") config_items = config.get("configs", []) if not isinstance(config_items, list) or not config_items: raise ValueError("configs must be a non-empty list") if args.only_config: selected_configs = set(args.only_config) config_items = [ item for item in config_items if isinstance(item, dict) and str(item.get("id")) in selected_configs ] if not config_items: raise ValueError("no configs selected") if (REPLAYSERVE_ROOT / ".venv" / "bin" / "python").is_file(): python_bin = str(REPLAYSERVE_ROOT / ".venv" / "bin" / "python") else: python_bin = os.environ.get("PYTHON_BIN", sys.executable or "python3") python_deps_dir = Path( os.environ.get("PYTHON_DEPS_DIR", str(REPLAYSERVE_ROOT / ".deps" / "python")) ) results: list[dict[str, Any]] = [] for fixture in fixtures: for item in config_items: if not isinstance(item, dict): raise ValueError("each configs entry must be an object") if "id" not in item: raise ValueError("each configs entry needs id") knobs = merge_config(defaults, item) status = run_one( suite_id=suite_id, sim=sim, frontier_info=frontier_info, frontier_root=frontier_root, fixture=fixture, config_item=item, knobs=knobs, run_root=run_root, python_bin=python_bin, python_deps_dir=python_deps_dir, dry_run=args.dry_run, force=args.force, ) results.append( { "fixture": fixture, "config_id": item["id"], **status, } ) print( f"{fixture}/{item['id']}: {status['status']} " f"exit={status['exit_code']} runtime={status['runtime_seconds']}s" ) failures = [row for row in results if row["status"] not in {"pass", "dry_run"}] return 1 if failures else 0 if __name__ == "__main__": raise SystemExit(main())