Make early-stop engine relaunch opt-in

This commit is contained in:
2026-04-26 01:26:26 +08:00
parent d76ac49198
commit a53445868e
3 changed files with 48 additions and 26 deletions

View File

@@ -578,27 +578,39 @@ def call_llm_for_proposal(
) -> str: ) -> str:
if policy.endpoint is None: if policy.endpoint is None:
raise RuntimeError("study.llm.endpoint is not configured") raise RuntimeError("study.llm.endpoint is not configured")
if policy.endpoint.stream: last_error: Exception | None = None
return stream_text_completion( for attempt in range(2):
base_url=policy.endpoint.base_url, try:
api_key_env=policy.endpoint.api_key_env, if policy.endpoint.stream:
provider=policy.endpoint.provider, text = stream_text_completion(
wire_api=policy.endpoint.wire_api, base_url=policy.endpoint.base_url,
model=policy.endpoint.model, api_key_env=policy.endpoint.api_key_env,
messages=[{"role": "user", "content": prompt}], provider=policy.endpoint.provider,
timeout_s=policy.endpoint.timeout_s, wire_api=policy.endpoint.wire_api,
system_prompt=policy.system_prompt, model=policy.endpoint.model,
reasoning_effort=policy.endpoint.reasoning_effort, messages=[{"role": "user", "content": prompt}],
) timeout_s=policy.endpoint.timeout_s,
response = chat_completion( system_prompt=policy.system_prompt,
base_url=policy.endpoint.base_url, reasoning_effort=policy.endpoint.reasoning_effort,
api_key_env=policy.endpoint.api_key_env, )
provider=policy.endpoint.provider, else:
wire_api=policy.endpoint.wire_api, response = chat_completion(
model=policy.endpoint.model, base_url=policy.endpoint.base_url,
messages=[{"role": "user", "content": prompt}], api_key_env=policy.endpoint.api_key_env,
timeout_s=policy.endpoint.timeout_s, provider=policy.endpoint.provider,
system_prompt=policy.system_prompt, wire_api=policy.endpoint.wire_api,
reasoning_effort=policy.endpoint.reasoning_effort, model=policy.endpoint.model,
) messages=[{"role": "user", "content": prompt}],
return _extract_response_text(response) timeout_s=policy.endpoint.timeout_s,
system_prompt=policy.system_prompt,
reasoning_effort=policy.endpoint.reasoning_effort,
)
text = _extract_response_text(response)
if text.strip():
return text
last_error = RuntimeError("LLM response content is empty")
except Exception as exc: # noqa: BLE001
last_error = exc
if attempt == 0:
continue
raise RuntimeError(f"LLM proposal failed after retry: {last_error}") from last_error

View File

@@ -327,6 +327,7 @@ class TraceSpec:
replay_time_scale: float = 1.0 replay_time_scale: float = 1.0
early_stop_max_lag_s: float | None = None early_stop_max_lag_s: float | None = None
early_stop_max_elapsed_s: float | None = None early_stop_max_elapsed_s: float | None = None
restart_engine_after_early_stop: bool = False
@classmethod @classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec": def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
@@ -389,6 +390,14 @@ class TraceSpec:
if data.get("early_stop_max_elapsed_s") is not None if data.get("early_stop_max_elapsed_s") is not None
else None else None
), ),
restart_engine_after_early_stop=(
_require_bool(
data.get("restart_engine_after_early_stop"),
context="trace.restart_engine_after_early_stop",
)
if data.get("restart_engine_after_early_stop") is not None
else False
),
) )

View File

@@ -367,6 +367,7 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]: def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]:
nonlocal process nonlocal process
selected = select_requests_for_threshold(requests, threshold=threshold) selected = select_requests_for_threshold(requests, threshold=threshold)
restart_after_early_stop = study.trace.restart_engine_after_early_stop
outcomes, early_stopped, early_stop_reason = _replay_requests( outcomes, early_stopped, early_stop_reason = _replay_requests(
selected, selected,
base_url=recipe.base_url, base_url=recipe.base_url,
@@ -376,7 +377,7 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
max_lag_s=study.trace.early_stop_max_lag_s, max_lag_s=study.trace.early_stop_max_lag_s,
max_elapsed_s=study.trace.early_stop_max_elapsed_s, max_elapsed_s=study.trace.early_stop_max_elapsed_s,
evaluate_outcome=lambda outcome: evaluate_request(outcome, study.slo), evaluate_outcome=lambda outcome: evaluate_request(outcome, study.slo),
drain_inflight_on_early_stop=False, drain_inflight_on_early_stop=not restart_after_early_stop,
) )
evaluations, summary = summarize_evaluations(outcomes, study.slo) evaluations, summary = summarize_evaluations(outcomes, study.slo)
request_rate = ( request_rate = (
@@ -423,7 +424,7 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
} }
probe_history.append(probe_record) probe_history.append(probe_record)
StudyStore.write_json(Path(trial.probe_log_path), probe_history) StudyStore.write_json(Path(trial.probe_log_path), probe_history)
if early_stopped: if early_stopped and restart_after_early_stop:
_terminate_process_tree(process, timeout_s=30.0) _terminate_process_tree(process, timeout_s=30.0)
process = launch_process() process = launch_process()
_wait_for_server_or_exit( _wait_for_server_or_exit(