diff --git a/src/aituner/llm.py b/src/aituner/llm.py index 6411972..2ef6814 100644 --- a/src/aituner/llm.py +++ b/src/aituner/llm.py @@ -578,27 +578,39 @@ def call_llm_for_proposal( ) -> str: if policy.endpoint is None: raise RuntimeError("study.llm.endpoint is not configured") - if policy.endpoint.stream: - return stream_text_completion( - base_url=policy.endpoint.base_url, - api_key_env=policy.endpoint.api_key_env, - provider=policy.endpoint.provider, - wire_api=policy.endpoint.wire_api, - model=policy.endpoint.model, - messages=[{"role": "user", "content": prompt}], - timeout_s=policy.endpoint.timeout_s, - system_prompt=policy.system_prompt, - reasoning_effort=policy.endpoint.reasoning_effort, - ) - response = chat_completion( - base_url=policy.endpoint.base_url, - api_key_env=policy.endpoint.api_key_env, - provider=policy.endpoint.provider, - wire_api=policy.endpoint.wire_api, - model=policy.endpoint.model, - messages=[{"role": "user", "content": prompt}], - timeout_s=policy.endpoint.timeout_s, - system_prompt=policy.system_prompt, - reasoning_effort=policy.endpoint.reasoning_effort, - ) - return _extract_response_text(response) + last_error: Exception | None = None + for attempt in range(2): + try: + if policy.endpoint.stream: + text = stream_text_completion( + base_url=policy.endpoint.base_url, + api_key_env=policy.endpoint.api_key_env, + provider=policy.endpoint.provider, + wire_api=policy.endpoint.wire_api, + model=policy.endpoint.model, + messages=[{"role": "user", "content": prompt}], + timeout_s=policy.endpoint.timeout_s, + system_prompt=policy.system_prompt, + reasoning_effort=policy.endpoint.reasoning_effort, + ) + else: + response = chat_completion( + base_url=policy.endpoint.base_url, + api_key_env=policy.endpoint.api_key_env, + provider=policy.endpoint.provider, + wire_api=policy.endpoint.wire_api, + model=policy.endpoint.model, + messages=[{"role": "user", "content": prompt}], + timeout_s=policy.endpoint.timeout_s, + system_prompt=policy.system_prompt, + reasoning_effort=policy.endpoint.reasoning_effort, + ) + text = _extract_response_text(response) + if text.strip(): + return text + last_error = RuntimeError("LLM response content is empty") + except Exception as exc: # noqa: BLE001 + last_error = exc + if attempt == 0: + continue + raise RuntimeError(f"LLM proposal failed after retry: {last_error}") from last_error diff --git a/src/aituner/spec.py b/src/aituner/spec.py index 1f0199e..45f986d 100644 --- a/src/aituner/spec.py +++ b/src/aituner/spec.py @@ -327,6 +327,7 @@ class TraceSpec: replay_time_scale: float = 1.0 early_stop_max_lag_s: float | None = None early_stop_max_elapsed_s: float | None = None + restart_engine_after_early_stop: bool = False @classmethod def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec": @@ -389,6 +390,14 @@ class TraceSpec: if data.get("early_stop_max_elapsed_s") is not None else None ), + restart_engine_after_early_stop=( + _require_bool( + data.get("restart_engine_after_early_stop"), + context="trace.restart_engine_after_early_stop", + ) + if data.get("restart_engine_after_early_stop") is not None + else False + ), ) diff --git a/src/aituner/worker.py b/src/aituner/worker.py index d54bd9b..63b8a78 100644 --- a/src/aituner/worker.py +++ b/src/aituner/worker.py @@ -367,6 +367,7 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]: def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]: nonlocal process selected = select_requests_for_threshold(requests, threshold=threshold) + restart_after_early_stop = study.trace.restart_engine_after_early_stop outcomes, early_stopped, early_stop_reason = _replay_requests( selected, base_url=recipe.base_url, @@ -376,7 +377,7 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]: max_lag_s=study.trace.early_stop_max_lag_s, max_elapsed_s=study.trace.early_stop_max_elapsed_s, evaluate_outcome=lambda outcome: evaluate_request(outcome, study.slo), - drain_inflight_on_early_stop=False, + drain_inflight_on_early_stop=not restart_after_early_stop, ) evaluations, summary = summarize_evaluations(outcomes, study.slo) request_rate = ( @@ -423,7 +424,7 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]: } probe_history.append(probe_record) StudyStore.write_json(Path(trial.probe_log_path), probe_history) - if early_stopped: + if early_stopped and restart_after_early_stop: _terminate_process_tree(process, timeout_s=30.0) process = launch_process() _wait_for_server_or_exit(