Files
aituner/scripts/run_baseline_then_llm.py

143 lines
5.3 KiB
Python

from __future__ import annotations
import argparse
import json
from pathlib import Path
from aituner.llm import (
build_prompt,
call_llm_for_proposal,
load_capability_profile,
parse_proposal_text,
)
from aituner.spec import load_study_spec
from aituner.store import StudyStore
from aituner.trace import load_trace_requests, summarize_window
from aituner.worker import run_trial
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Run one baseline trial followed by LLM-proposed trials."
)
parser.add_argument("--spec", required=True)
parser.add_argument("--store-root", required=True)
parser.add_argument("--baseline-proposal", required=True)
parser.add_argument("--total-trials", type=int, default=12)
return parser
def main() -> int:
args = build_parser().parse_args()
spec_path = Path(args.spec).resolve()
store_root = Path(args.store_root).resolve()
baseline_path = Path(args.baseline_proposal).resolve()
if args.total_trials <= 0:
raise SystemExit("--total-trials must be positive")
study = load_study_spec(spec_path)
store = StudyStore(store_root)
study_root = store.init_study(spec_path=spec_path, study=study)
capability_profile = load_capability_profile(study, study_spec_path=spec_path)
print(
json.dumps(
{
"event": "study_initialized",
"study_root": str(study_root),
"study_id": study.study_id,
"total_trials": args.total_trials,
},
ensure_ascii=False,
),
flush=True,
)
state = store.load_state(study.study_id)
baseline_text = baseline_path.read_text(encoding="utf-8")
baseline = parse_proposal_text(baseline_text, study)
baseline_name = baseline_path.stem
store.write_proposal(study.study_id, baseline_name, baseline)
trial, _ = store.materialize_trial(study=study, state=state, proposal=baseline)
result = run_trial(Path(trial.artifact_dir) / "trial_spec.json")
state = store.ingest_trial_results(study.study_id)
print(
json.dumps(
{
"event": "trial_completed",
"trial_id": trial.trial_id,
"source": baseline_name,
"status": result.get("status"),
"best_sampling_u": result.get("best_sampling_u"),
"best_request_rate": result.get("best_request_rate"),
"best_pass_rate": result.get("best_pass_rate"),
"state_best_trial_id": state.best_trial_id,
"state_best_request_rate": state.best_request_rate,
"state_best_request_rate_per_gpu": state.best_request_rate_per_gpu,
},
ensure_ascii=False,
),
flush=True,
)
remaining_trials = args.total_trials - 1
for _ in range(max(0, remaining_trials)):
state = store.load_state(study.study_id)
window, requests = load_trace_requests(study, study_spec_path=spec_path)
prompt = build_prompt(
study=study,
window_summary=summarize_window(requests, window),
state=state,
capability_profile=capability_profile,
)
prompt_name = f"prompt-{state.next_trial_index:04d}"
store.write_prompt(study.study_id, prompt_name, prompt)
proposal_text = call_llm_for_proposal(policy=study.llm, prompt=prompt)
proposal_name = f"proposal-{state.next_trial_index:04d}"
raw_proposal_path = store.study_root(study.study_id) / "proposals" / f"{proposal_name}.raw.txt"
raw_proposal_path.write_text(proposal_text, encoding="utf-8")
proposal = parse_proposal_text(proposal_text, study)
store.write_proposal(study.study_id, proposal_name, proposal)
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
result = run_trial(Path(trial.artifact_dir) / "trial_spec.json")
state = store.ingest_trial_results(study.study_id)
print(
json.dumps(
{
"event": "trial_completed",
"trial_id": trial.trial_id,
"source": proposal_name,
"status": result.get("status"),
"best_sampling_u": result.get("best_sampling_u"),
"best_request_rate": result.get("best_request_rate"),
"best_pass_rate": result.get("best_pass_rate"),
"state_best_trial_id": state.best_trial_id,
"state_best_request_rate": state.best_request_rate,
"state_best_request_rate_per_gpu": state.best_request_rate_per_gpu,
},
ensure_ascii=False,
),
flush=True,
)
final_state = store.load_state(study.study_id)
print(
json.dumps(
{
"event": "study_finished",
"study_root": str(study_root),
"best_trial_id": final_state.best_trial_id,
"best_request_rate": final_state.best_request_rate,
"best_request_rate_per_gpu": final_state.best_request_rate_per_gpu,
"trial_count": len(final_state.trials),
},
ensure_ascii=False,
),
flush=True,
)
return 0
if __name__ == "__main__":
raise SystemExit(main())