Add llm-first tuning proposal policy

This commit is contained in:
2026-06-27 12:21:51 +08:00
parent 9accf2575e
commit 7ad439730e
2 changed files with 104 additions and 1 deletions

View File

@@ -288,6 +288,7 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
capability_profile = load_capability_profile(study, study_spec_path=spec_path)
proposal_files = [Path(item).resolve() for item in (args.proposal_file or [])]
max_trials = args.max_trials or (len(proposal_files) if proposal_files else 2)
proposal_policy = args.proposal_policy
if max_trials <= 0:
raise SpecError("max_trials must be positive")
if proposal_files and max_trials > len(proposal_files):
@@ -387,7 +388,7 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
else:
guided_proposal = (
build_harness_guided_proposal(harness_context)
if harness_context is not None
if harness_context is not None and proposal_policy == "harness-first"
else None
)
if guided_proposal is not None:
@@ -782,6 +783,15 @@ def build_parser() -> argparse.ArgumentParser:
tune.add_argument("--store-root")
tune.add_argument("--proposal-file", action="append")
tune.add_argument("--max-trials", type=int)
tune.add_argument(
"--proposal-policy",
choices=("harness-first", "llm-first"),
default="harness-first",
help=(
"Choose whether deterministic harness proposals are tried before the LLM "
"or whether the LLM proposes directly from the harness prompt/context."
),
)
tune.add_argument(
"--skip-baseline",
action="store_true",

View File

@@ -6323,6 +6323,99 @@ class CoreFlowTests(unittest.TestCase):
)
self.assertTrue(state.tuning_stop_diagnosis)
def test_cli_tune_llm_first_skips_deterministic_harness_proposal(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
payload = json.loads(study_path.read_text(encoding="utf-8"))
payload["llm"]["endpoint"] = {
"provider": "custom",
"base_url": "http://llm.example/v1",
"wire_api": "chat.completions",
"model": "test-model",
"api_key_env": "OPENAI_API_KEY",
}
study_path.write_text(json.dumps(payload), encoding="utf-8")
study = load_study_spec(study_path)
store_root = tmp_path / "store"
store = StudyStore(store_root)
store.init_study(spec_path=study_path, study=study)
store.save_state(
StudyState(
study_id=study.study_id,
best_trial_id="trial-0001",
best_parallel_size=8,
best_sampling_u=0.25,
best_request_rate=1.0,
best_request_rate_per_gpu=0.125,
next_trial_index=2,
trials=[
TrialSummary(
trial_id="trial-0001",
status="completed",
parallel_size=8,
best_request_rate=1.0,
best_request_rate_per_gpu=0.125,
config_patch={"env_patch": {}, "flag_patch": {}},
)
],
)
)
llm_payload = json.dumps(
{
"observation": "Use harness evidence but let the LLM choose.",
"diagnosis": "Try higher admission concurrency.",
"config_patch": {"env_patch": {}, "flag_patch": {"max-num-seqs": 64}},
"expected_effects": ["measure admission concurrency"],
"why_not_previous_failures": "does not repeat a prior full config",
"should_stop": False,
}
)
def fake_run_trial(trial_spec_path: Path) -> dict[str, object]:
payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
trial_root = Path(payload["artifact_dir"])
result = {
"study_id": payload["study_id"],
"trial_id": payload["trial_id"],
"status": "completed",
"best_sampling_u": 0.5,
"best_request_rate": 2.0,
"best_pass_rate": 1.0,
"best_request_count": 2,
"probes": [],
}
(trial_root / "result.json").write_text(json.dumps(result), encoding="utf-8")
return result
with mock.patch("aituner.cli.call_llm_for_proposal", return_value=llm_payload) as llm_mock:
with mock.patch("aituner.cli.run_trial", side_effect=fake_run_trial):
exit_code = cli_main(
[
"study",
"tune",
"--spec",
str(study_path),
"--store-root",
str(store_root),
"--skip-baseline",
"--max-trials",
"2",
"--proposal-policy",
"llm-first",
]
)
self.assertEqual(exit_code, 0)
llm_mock.assert_called_once()
proposal_root = store.study_root(study.study_id) / "proposals"
self.assertTrue((proposal_root / "proposal-0002.json").exists())
self.assertFalse((proposal_root / "harness-proposal-0002.json").exists())
self.assertTrue(
(store.study_root(study.study_id) / "harness" / "candidate-set-0002.json").exists()
)
def test_cli_tune_evaluates_baseline_before_llm_proposal(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)