Stop tuning when baseline is infeasible
This commit is contained in:
@@ -19,6 +19,43 @@ from .trace import load_trace_requests, summarize_window
|
|||||||
from .worker import run_trial
|
from .worker import run_trial
|
||||||
|
|
||||||
|
|
||||||
|
def _is_empty_config_patch(proposal: Proposal) -> bool:
|
||||||
|
return not proposal.config_patch.env_patch and not proposal.config_patch.flag_patch
|
||||||
|
|
||||||
|
|
||||||
|
def _baseline_all_infeasible_diagnosis(result: dict[str, object]) -> str | None:
|
||||||
|
if result.get("status") != "completed":
|
||||||
|
return None
|
||||||
|
if isinstance(result.get("best_request_rate"), (int, float)):
|
||||||
|
return None
|
||||||
|
probes = result.get("probes")
|
||||||
|
if not isinstance(probes, list) or not probes:
|
||||||
|
return None
|
||||||
|
if any(isinstance(probe, dict) and probe.get("feasible") for probe in probes):
|
||||||
|
return None
|
||||||
|
|
||||||
|
diagnostics = result.get("all_infeasible_diagnostics")
|
||||||
|
if not isinstance(diagnostics, dict):
|
||||||
|
diagnostics = {}
|
||||||
|
lowest_rate = diagnostics.get("request_rate")
|
||||||
|
lowest_threshold = diagnostics.get("threshold")
|
||||||
|
pass_rate = diagnostics.get("pass_rate")
|
||||||
|
early_stop_reason = str(diagnostics.get("early_stop_reason") or "").strip()
|
||||||
|
pieces = [
|
||||||
|
"Baseline configuration has no feasible probe under the current SLO.",
|
||||||
|
"Stopping tuning because even the lowest sampled request rate did not meet the target pass rate.",
|
||||||
|
]
|
||||||
|
if isinstance(lowest_rate, (int, float)):
|
||||||
|
pieces.append(f"lowest_sampled_request_rate={float(lowest_rate):.6g}")
|
||||||
|
if isinstance(lowest_threshold, (int, float)):
|
||||||
|
pieces.append(f"lowest_sampling_u={float(lowest_threshold):.6g}")
|
||||||
|
if isinstance(pass_rate, (int, float)):
|
||||||
|
pieces.append(f"lowest_probe_pass_rate={float(pass_rate):.6g}")
|
||||||
|
if early_stop_reason:
|
||||||
|
pieces.append(f"early_stop_reason={early_stop_reason}")
|
||||||
|
return " ".join(pieces)
|
||||||
|
|
||||||
|
|
||||||
def _study_source_path(study_root: Path) -> Path:
|
def _study_source_path(study_root: Path) -> Path:
|
||||||
return Path((study_root / "study_spec.source").read_text(encoding="utf-8").strip())
|
return Path((study_root / "study_spec.source").read_text(encoding="utf-8").strip())
|
||||||
|
|
||||||
@@ -126,6 +163,18 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
|||||||
executed: list[dict[str, object]] = []
|
executed: list[dict[str, object]] = []
|
||||||
for idx in range(max_trials):
|
for idx in range(max_trials):
|
||||||
state = store.load_state(study.study_id)
|
state = store.load_state(study.study_id)
|
||||||
|
if state.tuning_stop_reason:
|
||||||
|
executed.append(
|
||||||
|
{
|
||||||
|
"trial_id": None,
|
||||||
|
"stopped": True,
|
||||||
|
"reason": state.tuning_stop_reason,
|
||||||
|
"diagnosis": state.tuning_stop_diagnosis,
|
||||||
|
"state_best_trial_id": state.best_trial_id,
|
||||||
|
"state_best_request_rate": state.best_request_rate,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
break
|
||||||
if state.next_trial_index > max_trials:
|
if state.next_trial_index > max_trials:
|
||||||
break
|
break
|
||||||
window, requests = load_trace_requests(study, study_spec_path=spec_path)
|
window, requests = load_trace_requests(study, study_spec_path=spec_path)
|
||||||
@@ -228,6 +277,13 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
is_auto_baseline = (
|
||||||
|
not proposal_files
|
||||||
|
and not args.skip_baseline
|
||||||
|
and state.next_trial_index == 1
|
||||||
|
and not state.trials
|
||||||
|
and _is_empty_config_patch(proposal)
|
||||||
|
)
|
||||||
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||||
trial_spec_path = Path(trial.artifact_dir) / "trial_spec.json"
|
trial_spec_path = Path(trial.artifact_dir) / "trial_spec.json"
|
||||||
result = run_trial(trial_spec_path)
|
result = run_trial(trial_spec_path)
|
||||||
@@ -248,6 +304,23 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
|||||||
"state_best_request_rate": state.best_request_rate,
|
"state_best_request_rate": state.best_request_rate,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
if is_auto_baseline:
|
||||||
|
diagnosis = _baseline_all_infeasible_diagnosis(result)
|
||||||
|
if diagnosis is not None:
|
||||||
|
state.tuning_stop_reason = "baseline_all_infeasible"
|
||||||
|
state.tuning_stop_diagnosis = diagnosis
|
||||||
|
store.save_state(state)
|
||||||
|
executed.append(
|
||||||
|
{
|
||||||
|
"trial_id": None,
|
||||||
|
"stopped": True,
|
||||||
|
"reason": state.tuning_stop_reason,
|
||||||
|
"diagnosis": diagnosis,
|
||||||
|
"state_best_trial_id": state.best_trial_id,
|
||||||
|
"state_best_request_rate": state.best_request_rate,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
final_state = store.load_state(study.study_id)
|
final_state = store.load_state(study.study_id)
|
||||||
print(
|
print(
|
||||||
@@ -257,6 +330,8 @@ def cmd_study_tune(args: argparse.Namespace) -> int:
|
|||||||
"executed_trials": executed,
|
"executed_trials": executed,
|
||||||
"best_trial_id": final_state.best_trial_id,
|
"best_trial_id": final_state.best_trial_id,
|
||||||
"best_request_rate": final_state.best_request_rate,
|
"best_request_rate": final_state.best_request_rate,
|
||||||
|
"tuning_stop_reason": final_state.tuning_stop_reason,
|
||||||
|
"tuning_stop_diagnosis": final_state.tuning_stop_diagnosis,
|
||||||
},
|
},
|
||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -764,6 +764,8 @@ class StudyState:
|
|||||||
best_request_rate: float | None = None
|
best_request_rate: float | None = None
|
||||||
best_request_rate_per_gpu: float | None = None
|
best_request_rate_per_gpu: float | None = None
|
||||||
next_trial_index: int = 1
|
next_trial_index: int = 1
|
||||||
|
tuning_stop_reason: str = ""
|
||||||
|
tuning_stop_diagnosis: str = ""
|
||||||
best_by_parallel_size: dict[str, dict[str, Any]] = field(default_factory=dict)
|
best_by_parallel_size: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||||
trials: list[TrialSummary] = field(default_factory=list)
|
trials: list[TrialSummary] = field(default_factory=list)
|
||||||
|
|
||||||
|
|||||||
@@ -45,6 +45,8 @@ class StudyStore:
|
|||||||
best_request_rate=payload.get("best_request_rate"),
|
best_request_rate=payload.get("best_request_rate"),
|
||||||
best_request_rate_per_gpu=payload.get("best_request_rate_per_gpu"),
|
best_request_rate_per_gpu=payload.get("best_request_rate_per_gpu"),
|
||||||
next_trial_index=int(payload.get("next_trial_index", 1)),
|
next_trial_index=int(payload.get("next_trial_index", 1)),
|
||||||
|
tuning_stop_reason=str(payload.get("tuning_stop_reason") or ""),
|
||||||
|
tuning_stop_diagnosis=str(payload.get("tuning_stop_diagnosis") or ""),
|
||||||
best_by_parallel_size={
|
best_by_parallel_size={
|
||||||
str(key): value
|
str(key): value
|
||||||
for key, value in (payload.get("best_by_parallel_size") or {}).items()
|
for key, value in (payload.get("best_by_parallel_size") or {}).items()
|
||||||
|
|||||||
@@ -2997,6 +2997,97 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
self.assertEqual(state.trials[0].config_patch, {"env_patch": {}, "flag_patch": {}})
|
self.assertEqual(state.trials[0].config_patch, {"env_patch": {}, "flag_patch": {}})
|
||||||
self.assertEqual(state.trials[1].config_patch["flag_patch"], {"max-num-seqs": 64})
|
self.assertEqual(state.trials[1].config_patch["flag_patch"], {"max-num-seqs": 64})
|
||||||
|
|
||||||
|
def test_cli_tune_stops_when_baseline_is_all_infeasible(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
tmp_path = Path(tmp)
|
||||||
|
study_path = _write_study_assets(tmp_path)
|
||||||
|
payload = json.loads(study_path.read_text(encoding="utf-8"))
|
||||||
|
payload["llm"]["endpoint"] = {
|
||||||
|
"provider": "custom",
|
||||||
|
"base_url": "http://llm.example/v1",
|
||||||
|
"wire_api": "chat.completions",
|
||||||
|
"model": "test-model",
|
||||||
|
"api_key_env": "OPENAI_API_KEY",
|
||||||
|
}
|
||||||
|
study_path.write_text(json.dumps(payload), encoding="utf-8")
|
||||||
|
store_root = tmp_path / "store"
|
||||||
|
|
||||||
|
def fake_run_trial(trial_spec_path: Path) -> dict[str, object]:
|
||||||
|
payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
|
||||||
|
trial_root = Path(payload["artifact_dir"])
|
||||||
|
result = {
|
||||||
|
"study_id": payload["study_id"],
|
||||||
|
"trial_id": payload["trial_id"],
|
||||||
|
"status": "completed",
|
||||||
|
"best_sampling_u": None,
|
||||||
|
"best_request_rate": None,
|
||||||
|
"best_pass_rate": None,
|
||||||
|
"best_request_count": None,
|
||||||
|
"probes": [
|
||||||
|
{
|
||||||
|
"threshold": 0.5,
|
||||||
|
"feasible": False,
|
||||||
|
"payload": {"pass_rate": 0.0, "request_rate": 2.0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"threshold": 0.25,
|
||||||
|
"feasible": False,
|
||||||
|
"payload": {"pass_rate": 0.5, "request_rate": 1.0},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"all_infeasible_diagnostics": {
|
||||||
|
"threshold": 0.25,
|
||||||
|
"request_rate": 1.0,
|
||||||
|
"pass_rate": 0.5,
|
||||||
|
"early_stop_reason": "slo_pass_rate_unrecoverable",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
(trial_root / "result.json").write_text(json.dumps(result), encoding="utf-8")
|
||||||
|
return result
|
||||||
|
|
||||||
|
with mock.patch("aituner.cli.run_trial", side_effect=fake_run_trial):
|
||||||
|
with mock.patch("aituner.cli.call_llm_for_proposal") as llm_mock:
|
||||||
|
exit_code = cli_main(
|
||||||
|
[
|
||||||
|
"study",
|
||||||
|
"tune",
|
||||||
|
"--spec",
|
||||||
|
str(study_path),
|
||||||
|
"--store-root",
|
||||||
|
str(store_root),
|
||||||
|
"--max-trials",
|
||||||
|
"3",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(exit_code, 0)
|
||||||
|
llm_mock.assert_not_called()
|
||||||
|
store = StudyStore(store_root)
|
||||||
|
state = store.load_state("study-1")
|
||||||
|
self.assertEqual(state.next_trial_index, 2)
|
||||||
|
self.assertEqual(len(state.trials), 1)
|
||||||
|
self.assertEqual(state.tuning_stop_reason, "baseline_all_infeasible")
|
||||||
|
self.assertIn("lowest_sampled_request_rate=1", state.tuning_stop_diagnosis)
|
||||||
|
|
||||||
|
with mock.patch("aituner.cli.run_trial") as run_trial_mock:
|
||||||
|
with mock.patch("aituner.cli.call_llm_for_proposal") as llm_mock:
|
||||||
|
exit_code = cli_main(
|
||||||
|
[
|
||||||
|
"study",
|
||||||
|
"tune",
|
||||||
|
"--spec",
|
||||||
|
str(study_path),
|
||||||
|
"--store-root",
|
||||||
|
str(store_root),
|
||||||
|
"--max-trials",
|
||||||
|
"3",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(exit_code, 0)
|
||||||
|
run_trial_mock.assert_not_called()
|
||||||
|
llm_mock.assert_not_called()
|
||||||
|
|
||||||
def test_cli_tune_max_trials_is_total_budget_on_resume(self) -> None:
|
def test_cli_tune_max_trials_is_total_budget_on_resume(self) -> None:
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
tmp_path = Path(tmp)
|
tmp_path = Path(tmp)
|
||||||
|
|||||||
Reference in New Issue
Block a user