from __future__ import annotations import argparse import json from pathlib import Path from aituner.llm import ( build_prompt, call_llm_for_proposal, load_capability_profile, parse_proposal_text, ) from aituner.spec import load_study_spec from aituner.store import StudyStore from aituner.trace import load_trace_requests, summarize_window from aituner.worker import run_trial def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Run one baseline trial followed by LLM-proposed trials." ) parser.add_argument("--spec", required=True) parser.add_argument("--store-root", required=True) parser.add_argument("--baseline-proposal", required=True) parser.add_argument("--total-trials", type=int, default=12) return parser def main() -> int: args = build_parser().parse_args() spec_path = Path(args.spec).resolve() store_root = Path(args.store_root).resolve() baseline_path = Path(args.baseline_proposal).resolve() if args.total_trials <= 0: raise SystemExit("--total-trials must be positive") study = load_study_spec(spec_path) store = StudyStore(store_root) study_root = store.init_study(spec_path=spec_path, study=study) capability_profile = load_capability_profile(study, study_spec_path=spec_path) print( json.dumps( { "event": "study_initialized", "study_root": str(study_root), "study_id": study.study_id, "total_trials": args.total_trials, }, ensure_ascii=False, ), flush=True, ) state = store.load_state(study.study_id) baseline_text = baseline_path.read_text(encoding="utf-8") baseline = parse_proposal_text(baseline_text, study) baseline_name = baseline_path.stem store.write_proposal(study.study_id, baseline_name, baseline) trial, _ = store.materialize_trial(study=study, state=state, proposal=baseline) result = run_trial(Path(trial.artifact_dir) / "trial_spec.json") state = store.ingest_trial_results(study.study_id) print( json.dumps( { "event": "trial_completed", "trial_id": trial.trial_id, "source": baseline_name, "status": result.get("status"), "best_sampling_u": result.get("best_sampling_u"), "best_request_rate": result.get("best_request_rate"), "best_pass_rate": result.get("best_pass_rate"), "state_best_trial_id": state.best_trial_id, "state_best_request_rate": state.best_request_rate, "state_best_request_rate_per_gpu": state.best_request_rate_per_gpu, }, ensure_ascii=False, ), flush=True, ) remaining_trials = args.total_trials - 1 for _ in range(max(0, remaining_trials)): state = store.load_state(study.study_id) window, requests = load_trace_requests(study, study_spec_path=spec_path) prompt = build_prompt( study=study, window_summary=summarize_window(requests, window), state=state, capability_profile=capability_profile, ) prompt_name = f"prompt-{state.next_trial_index:04d}" store.write_prompt(study.study_id, prompt_name, prompt) proposal_text = call_llm_for_proposal(policy=study.llm, prompt=prompt) proposal_name = f"proposal-{state.next_trial_index:04d}" raw_proposal_path = store.study_root(study.study_id) / "proposals" / f"{proposal_name}.raw.txt" raw_proposal_path.write_text(proposal_text, encoding="utf-8") proposal = parse_proposal_text(proposal_text, study) store.write_proposal(study.study_id, proposal_name, proposal) trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal) result = run_trial(Path(trial.artifact_dir) / "trial_spec.json") state = store.ingest_trial_results(study.study_id) print( json.dumps( { "event": "trial_completed", "trial_id": trial.trial_id, "source": proposal_name, "status": result.get("status"), "best_sampling_u": result.get("best_sampling_u"), "best_request_rate": result.get("best_request_rate"), "best_pass_rate": result.get("best_pass_rate"), "state_best_trial_id": state.best_trial_id, "state_best_request_rate": state.best_request_rate, "state_best_request_rate_per_gpu": state.best_request_rate_per_gpu, }, ensure_ascii=False, ), flush=True, ) final_state = store.load_state(study.study_id) print( json.dumps( { "event": "study_finished", "study_root": str(study_root), "best_trial_id": final_state.best_trial_id, "best_request_rate": final_state.best_request_rate, "best_request_rate_per_gpu": final_state.best_request_rate_per_gpu, "trial_count": len(final_state.trials), }, ensure_ascii=False, ), flush=True, ) return 0 if __name__ == "__main__": raise SystemExit(main())