Make tune trial budget resumable
This commit is contained in:
@@ -126,6 +126,8 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
|||||||
executed: list[dict[str, object]] = []
|
executed: list[dict[str, object]] = []
|
||||||
for idx in range(max_trials):
|
for idx in range(max_trials):
|
||||||
state = store.load_state(study.study_id)
|
state = store.load_state(study.study_id)
|
||||||
|
if state.next_trial_index > max_trials:
|
||||||
|
break
|
||||||
window, requests = load_trace_requests(study, study_spec_path=spec_path)
|
window, requests = load_trace_requests(study, study_spec_path=spec_path)
|
||||||
window_summary = summarize_window(requests, window)
|
window_summary = summarize_window(requests, window)
|
||||||
harness_context = (
|
harness_context = (
|
||||||
@@ -169,7 +171,10 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
|||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
)
|
)
|
||||||
elif proposal_files:
|
elif proposal_files:
|
||||||
proposal_source = proposal_files[idx]
|
proposal_index = state.next_trial_index - 1
|
||||||
|
if proposal_index >= len(proposal_files):
|
||||||
|
break
|
||||||
|
proposal_source = proposal_files[proposal_index]
|
||||||
proposal_text = proposal_source.read_text(encoding="utf-8")
|
proposal_text = proposal_source.read_text(encoding="utf-8")
|
||||||
proposal_name = proposal_source.stem
|
proposal_name = proposal_source.stem
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -604,7 +605,8 @@ def call_llm_for_proposal(
|
|||||||
if policy.endpoint is None:
|
if policy.endpoint is None:
|
||||||
raise RuntimeError("study.llm.endpoint is not configured")
|
raise RuntimeError("study.llm.endpoint is not configured")
|
||||||
last_error: Exception | None = None
|
last_error: Exception | None = None
|
||||||
for attempt in range(2):
|
max_attempts = 4
|
||||||
|
for attempt in range(max_attempts):
|
||||||
try:
|
try:
|
||||||
if policy.endpoint.stream:
|
if policy.endpoint.stream:
|
||||||
text = stream_text_completion(
|
text = stream_text_completion(
|
||||||
@@ -636,6 +638,7 @@ def call_llm_for_proposal(
|
|||||||
last_error = RuntimeError("LLM response content is empty")
|
last_error = RuntimeError("LLM response content is empty")
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
last_error = exc
|
last_error = exc
|
||||||
if attempt == 0:
|
if attempt < max_attempts - 1:
|
||||||
|
time.sleep(min(30.0, 2.0 * (2**attempt)))
|
||||||
continue
|
continue
|
||||||
raise RuntimeError(f"LLM proposal failed after retry: {last_error}") from last_error
|
raise RuntimeError(f"LLM proposal failed after retry: {last_error}") from last_error
|
||||||
|
|||||||
@@ -2919,7 +2919,7 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
"--store-root",
|
"--store-root",
|
||||||
str(store_root),
|
str(store_root),
|
||||||
"--max-trials",
|
"--max-trials",
|
||||||
"1",
|
"5",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2997,6 +2997,53 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
self.assertEqual(state.trials[0].config_patch, {"env_patch": {}, "flag_patch": {}})
|
self.assertEqual(state.trials[0].config_patch, {"env_patch": {}, "flag_patch": {}})
|
||||||
self.assertEqual(state.trials[1].config_patch["flag_patch"], {"max-num-seqs": 64})
|
self.assertEqual(state.trials[1].config_patch["flag_patch"], {"max-num-seqs": 64})
|
||||||
|
|
||||||
|
def test_cli_tune_max_trials_is_total_budget_on_resume(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
tmp_path = Path(tmp)
|
||||||
|
study_path = _write_study_assets(tmp_path)
|
||||||
|
payload = json.loads(study_path.read_text(encoding="utf-8"))
|
||||||
|
payload["llm"]["endpoint"] = {
|
||||||
|
"provider": "custom",
|
||||||
|
"base_url": "http://llm.example/v1",
|
||||||
|
"wire_api": "chat.completions",
|
||||||
|
"model": "test-model",
|
||||||
|
"api_key_env": "OPENAI_API_KEY",
|
||||||
|
}
|
||||||
|
study_path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
store_root = tmp_path / "store"
|
||||||
|
study = load_study_spec(study_path)
|
||||||
|
store = StudyStore(store_root)
|
||||||
|
store.init_study(spec_path=study_path, study=study)
|
||||||
|
state = StudyState(
|
||||||
|
study_id=study.study_id,
|
||||||
|
next_trial_index=3,
|
||||||
|
trials=[
|
||||||
|
TrialSummary(trial_id="trial-0001", status="completed"),
|
||||||
|
TrialSummary(trial_id="trial-0002", status="completed"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
store.save_state(state)
|
||||||
|
|
||||||
|
with mock.patch("aituner.cli.call_llm_for_proposal") as llm_mock:
|
||||||
|
with mock.patch("aituner.cli.run_trial") as run_trial_mock:
|
||||||
|
exit_code = cli_main(
|
||||||
|
[
|
||||||
|
"study",
|
||||||
|
"tune",
|
||||||
|
"--spec",
|
||||||
|
str(study_path),
|
||||||
|
"--store-root",
|
||||||
|
str(store_root),
|
||||||
|
"--max-trials",
|
||||||
|
"2",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(exit_code, 0)
|
||||||
|
llm_mock.assert_not_called()
|
||||||
|
run_trial_mock.assert_not_called()
|
||||||
|
self.assertEqual(store.load_state(study.study_id).next_trial_index, 3)
|
||||||
|
|
||||||
def test_load_compare_spec_requires_window_selection(self) -> None:
|
def test_load_compare_spec_requires_window_selection(self) -> None:
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
tmp_path = Path(tmp)
|
tmp_path = Path(tmp)
|
||||||
|
|||||||
Reference in New Issue
Block a user