Record failed trial context

This commit is contained in:
2026-04-04 23:35:07 +08:00
parent 8b024c72f1
commit 7632de8dad
5 changed files with 141 additions and 11 deletions

View File

@@ -25,6 +25,8 @@ def build_prompt(
"best_request_rate": trial.best_request_rate, "best_request_rate": trial.best_request_rate,
"best_pass_rate": trial.best_pass_rate, "best_pass_rate": trial.best_pass_rate,
"diagnosis": trial.diagnosis, "diagnosis": trial.diagnosis,
"config_patch": trial.config_patch,
"failure_reason": trial.failure_reason,
} }
) )
sections = [ sections = [
@@ -34,6 +36,7 @@ def build_prompt(
"expected_effects must be a JSON array of short strings, not an object.", "expected_effects must be a JSON array of short strings, not an object.",
"Only use allowed tunable env keys and allowed tunable flag keys.", "Only use allowed tunable env keys and allowed tunable flag keys.",
"Do not wrap the JSON in markdown fences or any extra text.", "Do not wrap the JSON in markdown fences or any extra text.",
"Do not repeat a config that previously failed to launch unless the new patch explicitly removes the failing knob.",
"", "",
"Study stack:", "Study stack:",
json.dumps( json.dumps(

View File

@@ -446,6 +446,8 @@ class TrialSummary:
best_pass_rate: float | None = None best_pass_rate: float | None = None
result_path: str | None = None result_path: str | None = None
diagnosis: str = "" diagnosis: str = ""
config_patch: dict[str, Any] | None = None
failure_reason: str = ""
@dataclass @dataclass

View File

@@ -74,7 +74,12 @@ class StudyStore:
self.write_json(trial_root / "trial_spec.json", to_jsonable(spec)) self.write_json(trial_root / "trial_spec.json", to_jsonable(spec))
next_state = replace(state, next_trial_index=state.next_trial_index + 1) next_state = replace(state, next_trial_index=state.next_trial_index + 1)
next_state.trials.append( next_state.trials.append(
TrialSummary(trial_id=trial_id, status="queued", diagnosis=proposal.diagnosis) TrialSummary(
trial_id=trial_id,
status="queued",
diagnosis=proposal.diagnosis,
config_patch=to_jsonable(proposal.config_patch),
)
) )
self.save_state(next_state) self.save_state(next_state)
return spec, next_state return spec, next_state
@@ -101,6 +106,7 @@ class StudyStore:
summary.best_request_rate = payload.get("best_request_rate") summary.best_request_rate = payload.get("best_request_rate")
summary.best_pass_rate = payload.get("best_pass_rate") summary.best_pass_rate = payload.get("best_pass_rate")
summary.result_path = str(result_path) summary.result_path = str(result_path)
summary.failure_reason = str(payload.get("failure_reason") or "").strip()
if ( if (
isinstance(summary.best_request_rate, (int, float)) isinstance(summary.best_request_rate, (int, float))
and (best_rate is None or summary.best_request_rate > best_rate) and (best_rate is None or summary.best_request_rate > best_rate)

View File

@@ -183,6 +183,28 @@ def _replay_requests(
return ordered, early_stopped, early_stop_reason return ordered, early_stopped, early_stop_reason
def _wait_for_server_or_exit(
process: subprocess.Popen[str],
*,
base_url: str,
healthcheck_path: str,
ready_timeout_s: float,
) -> None:
deadline = time.monotonic() + ready_timeout_s
last_error = "server_not_ready"
while time.monotonic() < deadline:
exit_code = process.poll()
if exit_code is not None:
raise RuntimeError(f"engine_process_exited_before_ready exit_code={exit_code}")
try:
wait_for_server(base_url, healthcheck_path, timeout_s=1.0)
return
except HttpClientError as exc:
last_error = str(exc)
time.sleep(1.0)
raise HttpClientError(f"Timed out waiting for {base_url}{healthcheck_path}: {last_error}")
def run_trial(trial_spec_path: Path) -> dict[str, Any]: def run_trial(trial_spec_path: Path) -> dict[str, Any]:
from .store import StudyStore from .store import StudyStore
@@ -203,9 +225,14 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
text=True, text=True,
) )
try:
wait_for_server(recipe.base_url, recipe.healthcheck_path, recipe.ready_timeout_s)
probe_history: list[dict[str, Any]] = [] probe_history: list[dict[str, Any]] = []
try:
_wait_for_server_or_exit(
process,
base_url=recipe.base_url,
healthcheck_path=recipe.healthcheck_path,
ready_timeout_s=recipe.ready_timeout_s,
)
def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]: def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]:
selected = select_requests_for_threshold(requests, threshold=threshold) selected = select_requests_for_threshold(requests, threshold=threshold)
@@ -297,7 +324,22 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
} }
StudyStore.write_json(Path(trial.result_path), result) StudyStore.write_json(Path(trial.result_path), result)
return result return result
except Exception as exc: # noqa: BLE001
result = {
"study_id": trial.study_id,
"trial_id": trial.trial_id,
"status": "failed",
"best_sampling_u": None,
"best_request_rate": None,
"best_pass_rate": None,
"best_request_count": None,
"failure_reason": str(exc),
"probes": probe_history,
}
StudyStore.write_json(Path(trial.result_path), result)
return result
finally: finally:
if process.poll() is None:
process.terminate() process.terminate()
try: try:
process.wait(timeout=30) process.wait(timeout=30)

View File

@@ -13,10 +13,10 @@ from aituner.job import append_job, build_trial_job
from aituner.llm import build_prompt, parse_proposal_text from aituner.llm import build_prompt, parse_proposal_text
from aituner.search import ThresholdProbe, binary_search_max_feasible from aituner.search import ThresholdProbe, binary_search_max_feasible
from aituner.slo import RequestOutcome, summarize_evaluations from aituner.slo import RequestOutcome, summarize_evaluations
from aituner.spec import Proposal, load_study_spec from aituner.spec import Proposal, StudyState, TrialSummary, load_study_spec
from aituner.store import StudyStore from aituner.store import StudyStore
from aituner.trace import load_trace_requests, summarize_window from aituner.trace import load_trace_requests, summarize_window
from aituner.worker import _replay_requests from aituner.worker import _replay_requests, _wait_for_server_or_exit
from aituner.trace import TraceRequest from aituner.trace import TraceRequest
@@ -159,6 +159,36 @@ class CoreFlowTests(unittest.TestCase):
self.assertIn("queueing_knee_by_bucket", prompt) self.assertIn("queueing_knee_by_bucket", prompt)
self.assertTrue(study_root.exists()) self.assertTrue(study_root.exists())
def test_prompt_includes_failed_trial_context(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
study = load_study_spec(study_path)
window, requests = load_trace_requests(study, study_spec_path=study_path)
prompt = build_prompt(
study=study,
window_summary=summarize_window(requests, window),
state=StudyState(
study_id=study.study_id,
trials=[
TrialSummary(
trial_id="trial-0001",
status="failed",
diagnosis="flashinfer looked promising",
config_patch={
"env_patch": {"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
"flag_patch": {"tensor-parallel-size": 4},
},
failure_reason="engine_process_exited_before_ready exit_code=1",
)
],
),
capability_profile=None,
)
self.assertIn('"status": "failed"', prompt)
self.assertIn('"failure_reason": "engine_process_exited_before_ready exit_code=1"', prompt)
self.assertIn('"VLLM_ATTENTION_BACKEND": "FLASHINFER"', prompt)
def test_length_only_trace_rows_are_synthesized(self) -> None: def test_length_only_trace_rows_are_synthesized(self) -> None:
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp) tmp_path = Path(tmp)
@@ -594,6 +624,42 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(next_state.best_trial_id, trial.trial_id) self.assertEqual(next_state.best_trial_id, trial.trial_id)
self.assertEqual(next_state.best_request_rate, 12.5) self.assertEqual(next_state.best_request_rate, 12.5)
def test_ingest_trial_results_records_failure_reason(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
study = load_study_spec(study_path)
store = StudyStore(tmp_path / ".aituner" / "studies")
store.init_study(spec_path=study_path, study=study)
state = store.load_state(study.study_id)
proposal = Proposal.from_dict(
{
"observation": "Obs",
"diagnosis": "Diag",
"config_patch": {"env_patch": {}, "flag_patch": {"tensor-parallel-size": 4}},
"expected_effects": ["raise rate"]
}
)
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
Path(trial.result_path).write_text(
json.dumps(
{
"study_id": study.study_id,
"trial_id": trial.trial_id,
"status": "failed",
"failure_reason": "engine_process_exited_before_ready exit_code=1",
"probes": []
}
),
encoding="utf-8",
)
next_state = store.ingest_trial_results(study.study_id)
self.assertEqual(next_state.trials[0].status, "failed")
self.assertEqual(
next_state.trials[0].failure_reason,
"engine_process_exited_before_ready exit_code=1",
)
def test_cli_tune_runs_multiple_manual_proposals(self) -> None: def test_cli_tune_runs_multiple_manual_proposals(self) -> None:
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp) tmp_path = Path(tmp)
@@ -746,6 +812,17 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(len(replayed), 3) self.assertEqual(len(replayed), 3)
self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable") self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable")
def test_wait_for_server_or_exit_fails_fast_when_process_exits(self) -> None:
process = mock.Mock()
process.poll.return_value = 17
with self.assertRaisesRegex(RuntimeError, "engine_process_exited_before_ready exit_code=17"):
_wait_for_server_or_exit(
process,
base_url="http://127.0.0.1:8000",
healthcheck_path="/v1/models",
ready_timeout_s=10.0,
)
def test_openai_url_avoids_double_v1(self) -> None: def test_openai_url_avoids_double_v1(self) -> None:
self.assertEqual( self.assertEqual(
_openai_url("http://example.com", "/v1/chat/completions"), _openai_url("http://example.com", "/v1/chat/completions"),