Add llm-first tuning proposal policy
This commit is contained in:
@@ -288,6 +288,7 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
|||||||
capability_profile = load_capability_profile(study, study_spec_path=spec_path)
|
capability_profile = load_capability_profile(study, study_spec_path=spec_path)
|
||||||
proposal_files = [Path(item).resolve() for item in (args.proposal_file or [])]
|
proposal_files = [Path(item).resolve() for item in (args.proposal_file or [])]
|
||||||
max_trials = args.max_trials or (len(proposal_files) if proposal_files else 2)
|
max_trials = args.max_trials or (len(proposal_files) if proposal_files else 2)
|
||||||
|
proposal_policy = args.proposal_policy
|
||||||
if max_trials <= 0:
|
if max_trials <= 0:
|
||||||
raise SpecError("max_trials must be positive")
|
raise SpecError("max_trials must be positive")
|
||||||
if proposal_files and max_trials > len(proposal_files):
|
if proposal_files and max_trials > len(proposal_files):
|
||||||
@@ -387,7 +388,7 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
|||||||
else:
|
else:
|
||||||
guided_proposal = (
|
guided_proposal = (
|
||||||
build_harness_guided_proposal(harness_context)
|
build_harness_guided_proposal(harness_context)
|
||||||
if harness_context is not None
|
if harness_context is not None and proposal_policy == "harness-first"
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
if guided_proposal is not None:
|
if guided_proposal is not None:
|
||||||
@@ -782,6 +783,15 @@ def build_parser() -> argparse.ArgumentParser:
|
|||||||
tune.add_argument("--store-root")
|
tune.add_argument("--store-root")
|
||||||
tune.add_argument("--proposal-file", action="append")
|
tune.add_argument("--proposal-file", action="append")
|
||||||
tune.add_argument("--max-trials", type=int)
|
tune.add_argument("--max-trials", type=int)
|
||||||
|
tune.add_argument(
|
||||||
|
"--proposal-policy",
|
||||||
|
choices=("harness-first", "llm-first"),
|
||||||
|
default="harness-first",
|
||||||
|
help=(
|
||||||
|
"Choose whether deterministic harness proposals are tried before the LLM "
|
||||||
|
"or whether the LLM proposes directly from the harness prompt/context."
|
||||||
|
),
|
||||||
|
)
|
||||||
tune.add_argument(
|
tune.add_argument(
|
||||||
"--skip-baseline",
|
"--skip-baseline",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -6323,6 +6323,99 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(state.tuning_stop_diagnosis)
|
self.assertTrue(state.tuning_stop_diagnosis)
|
||||||
|
|
||||||
|
def test_cli_tune_llm_first_skips_deterministic_harness_proposal(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
tmp_path = Path(tmp)
|
||||||
|
study_path = _write_study_assets(tmp_path)
|
||||||
|
payload = json.loads(study_path.read_text(encoding="utf-8"))
|
||||||
|
payload["llm"]["endpoint"] = {
|
||||||
|
"provider": "custom",
|
||||||
|
"base_url": "http://llm.example/v1",
|
||||||
|
"wire_api": "chat.completions",
|
||||||
|
"model": "test-model",
|
||||||
|
"api_key_env": "OPENAI_API_KEY",
|
||||||
|
}
|
||||||
|
study_path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
study = load_study_spec(study_path)
|
||||||
|
store_root = tmp_path / "store"
|
||||||
|
store = StudyStore(store_root)
|
||||||
|
store.init_study(spec_path=study_path, study=study)
|
||||||
|
store.save_state(
|
||||||
|
StudyState(
|
||||||
|
study_id=study.study_id,
|
||||||
|
best_trial_id="trial-0001",
|
||||||
|
best_parallel_size=8,
|
||||||
|
best_sampling_u=0.25,
|
||||||
|
best_request_rate=1.0,
|
||||||
|
best_request_rate_per_gpu=0.125,
|
||||||
|
next_trial_index=2,
|
||||||
|
trials=[
|
||||||
|
TrialSummary(
|
||||||
|
trial_id="trial-0001",
|
||||||
|
status="completed",
|
||||||
|
parallel_size=8,
|
||||||
|
best_request_rate=1.0,
|
||||||
|
best_request_rate_per_gpu=0.125,
|
||||||
|
config_patch={"env_patch": {}, "flag_patch": {}},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_payload = json.dumps(
|
||||||
|
{
|
||||||
|
"observation": "Use harness evidence but let the LLM choose.",
|
||||||
|
"diagnosis": "Try higher admission concurrency.",
|
||||||
|
"config_patch": {"env_patch": {}, "flag_patch": {"max-num-seqs": 64}},
|
||||||
|
"expected_effects": ["measure admission concurrency"],
|
||||||
|
"why_not_previous_failures": "does not repeat a prior full config",
|
||||||
|
"should_stop": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_run_trial(trial_spec_path: Path) -> dict[str, object]:
|
||||||
|
payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
|
||||||
|
trial_root = Path(payload["artifact_dir"])
|
||||||
|
result = {
|
||||||
|
"study_id": payload["study_id"],
|
||||||
|
"trial_id": payload["trial_id"],
|
||||||
|
"status": "completed",
|
||||||
|
"best_sampling_u": 0.5,
|
||||||
|
"best_request_rate": 2.0,
|
||||||
|
"best_pass_rate": 1.0,
|
||||||
|
"best_request_count": 2,
|
||||||
|
"probes": [],
|
||||||
|
}
|
||||||
|
(trial_root / "result.json").write_text(json.dumps(result), encoding="utf-8")
|
||||||
|
return result
|
||||||
|
|
||||||
|
with mock.patch("aituner.cli.call_llm_for_proposal", return_value=llm_payload) as llm_mock:
|
||||||
|
with mock.patch("aituner.cli.run_trial", side_effect=fake_run_trial):
|
||||||
|
exit_code = cli_main(
|
||||||
|
[
|
||||||
|
"study",
|
||||||
|
"tune",
|
||||||
|
"--spec",
|
||||||
|
str(study_path),
|
||||||
|
"--store-root",
|
||||||
|
str(store_root),
|
||||||
|
"--skip-baseline",
|
||||||
|
"--max-trials",
|
||||||
|
"2",
|
||||||
|
"--proposal-policy",
|
||||||
|
"llm-first",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(exit_code, 0)
|
||||||
|
llm_mock.assert_called_once()
|
||||||
|
proposal_root = store.study_root(study.study_id) / "proposals"
|
||||||
|
self.assertTrue((proposal_root / "proposal-0002.json").exists())
|
||||||
|
self.assertFalse((proposal_root / "harness-proposal-0002.json").exists())
|
||||||
|
self.assertTrue(
|
||||||
|
(store.study_root(study.study_id) / "harness" / "candidate-set-0002.json").exists()
|
||||||
|
)
|
||||||
|
|
||||||
def test_cli_tune_evaluates_baseline_before_llm_proposal(self) -> None:
|
def test_cli_tune_evaluates_baseline_before_llm_proposal(self) -> None:
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
tmp_path = Path(tmp)
|
tmp_path = Path(tmp)
|
||||||
|
|||||||
Reference in New Issue
Block a user