Add probe early stop guards

This commit is contained in:
2026-04-04 22:58:33 +08:00
parent 56fa6747d2
commit 7e8523fdaa
4 changed files with 180 additions and 21 deletions

View File

@@ -55,7 +55,9 @@
"timestamp_field": "timestamp",
"max_concurrency": 2,
"max_requests_per_probe": 24,
"replay_time_scale": 0.02
"replay_time_scale": 0.02,
"early_stop_max_lag_s": 5.0,
"early_stop_max_elapsed_s": 60.0
},
"slo": {
"target_pass_rate": 0.95,

View File

@@ -153,6 +153,8 @@ class TraceSpec:
max_requests_per_probe: int | None = None
synthetic_prompt_cap_tokens: int | None = None
replay_time_scale: float = 1.0
early_stop_max_lag_s: float | None = None
early_stop_max_elapsed_s: float | None = None
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
@@ -176,6 +178,21 @@ class TraceSpec:
replay_time_scale=_require_float(
data.get("replay_time_scale", 1.0), context="trace.replay_time_scale"
),
early_stop_max_lag_s=(
_require_float(
data.get("early_stop_max_lag_s"), context="trace.early_stop_max_lag_s"
)
if data.get("early_stop_max_lag_s") is not None
else None
),
early_stop_max_elapsed_s=(
_require_float(
data.get("early_stop_max_elapsed_s"),
context="trace.early_stop_max_elapsed_s",
)
if data.get("early_stop_max_elapsed_s") is not None
else None
),
)

View File

@@ -1,18 +1,19 @@
from __future__ import annotations
import json
import math
import subprocess
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Any, Callable
from .engine import build_launch_recipe
from .http_client import HttpClientError, stream_chat_completion, wait_for_server
from .search import ThresholdProbe, binary_search_max_feasible
from .slo import RequestOutcome, summarize_evaluations
from .slo import RequestOutcome, evaluate_request, summarize_evaluations
from .spec import ConfigPatch, SamplingSearchSpec, TrialSpec, load_study_spec
from .trace import TraceRequest, load_trace_requests, select_requests_for_threshold
@@ -25,6 +26,8 @@ class ProbePayload:
request_rate: float
feasible: bool
outcomes: list[dict[str, Any]]
early_stopped: bool = False
early_stop_reason: str = ""
def _trial_spec_from_json(path: Path) -> TrialSpec:
payload = json.loads(path.read_text(encoding="utf-8"))
@@ -75,31 +78,109 @@ def _replay_requests(
base_url: str,
timeout_s: float,
max_concurrency: int,
) -> list[RequestOutcome]:
target_pass_rate: float,
max_lag_s: float | None,
max_elapsed_s: float | None,
evaluate_outcome: Callable[[RequestOutcome], Any],
) -> tuple[list[RequestOutcome], bool, str]:
outcomes_by_id: dict[str, RequestOutcome] = {}
lock = threading.Lock()
start = time.monotonic()
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
futures = []
for request in requests:
delay = max(0.0, request.arrival_s)
allowed_failures = max(0, len(requests) - math.ceil(len(requests) * target_pass_rate))
failed_evaluations = 0
early_stopped = False
early_stop_reason = ""
next_index = 0
futures_by_request: dict[Future[RequestOutcome], TraceRequest] = {}
submitted_ids: set[str] = set()
pool = ThreadPoolExecutor(max_workers=max_concurrency)
try:
while next_index < len(requests) or futures_by_request:
now = time.monotonic()
sleep_for = (start + delay) - now
if sleep_for > 0:
time.sleep(sleep_for)
futures.append(
pool.submit(
elapsed = now - start
if max_elapsed_s is not None and elapsed > max_elapsed_s:
early_stopped = True
early_stop_reason = f"probe_elapsed_s>{max_elapsed_s}"
break
while next_index < len(requests):
request = requests[next_index]
lag_s = elapsed - request.arrival_s
if max_lag_s is not None and lag_s > max_lag_s:
early_stopped = True
early_stop_reason = f"arrival_lag_s>{max_lag_s}"
break
if request.arrival_s > elapsed:
break
future = pool.submit(
_run_one_request,
request,
base_url=base_url,
timeout_s=timeout_s,
)
futures_by_request[future] = request
submitted_ids.add(request.row_id)
next_index += 1
if len(futures_by_request) >= max_concurrency:
break
if early_stopped:
break
if futures_by_request:
timeout = None
if next_index < len(requests):
timeout = max(0.0, requests[next_index].arrival_s - elapsed)
done, _ = wait(
list(futures_by_request),
timeout=timeout,
return_when=FIRST_COMPLETED,
)
for future in as_completed(futures):
for future in done:
request = futures_by_request.pop(future)
outcome = future.result()
with lock:
outcomes_by_id[outcome.request_id] = outcome
return [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id]
if not evaluate_outcome(outcome).passed:
failed_evaluations += 1
if failed_evaluations > allowed_failures:
early_stopped = True
early_stop_reason = "slo_pass_rate_unrecoverable"
break
if early_stopped:
break
elif next_index < len(requests):
sleep_for = max(0.0, requests[next_index].arrival_s - elapsed)
if sleep_for > 0:
time.sleep(min(sleep_for, 0.1))
for future, request in list(futures_by_request.items()):
try:
outcome = future.result(timeout=timeout_s)
except Exception: # noqa: BLE001
outcome = RequestOutcome(
request_id=request.row_id,
success=False,
ttft_ms=None,
tpot_ms=None,
prompt_tokens=request.prompt_tokens_hint,
completion_tokens=request.completion_tokens_hint,
error="probe_early_stopped",
)
outcomes_by_id[outcome.request_id] = outcome
if early_stopped:
for request in requests:
if request.row_id in submitted_ids:
continue
outcomes_by_id[request.row_id] = RequestOutcome(
request_id=request.row_id,
success=False,
ttft_ms=None,
tpot_ms=None,
prompt_tokens=request.prompt_tokens_hint,
completion_tokens=request.completion_tokens_hint,
error=early_stop_reason or "probe_early_stopped",
)
finally:
pool.shutdown(wait=False, cancel_futures=True)
ordered = [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id]
return ordered, early_stopped, early_stop_reason
def run_trial(trial_spec_path: Path) -> dict[str, Any]:
@@ -128,11 +209,15 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]:
selected = select_requests_for_threshold(requests, threshold=threshold)
outcomes = _replay_requests(
outcomes, early_stopped, early_stop_reason = _replay_requests(
selected,
base_url=recipe.base_url,
timeout_s=recipe.request_timeout_s,
max_concurrency=study.trace.max_concurrency,
target_pass_rate=study.slo.target_pass_rate,
max_lag_s=study.trace.early_stop_max_lag_s,
max_elapsed_s=study.trace.early_stop_max_elapsed_s,
evaluate_outcome=lambda outcome: evaluate_request(outcome, study.slo),
)
evaluations, summary = summarize_evaluations(outcomes, study.slo)
request_rate = (
@@ -146,6 +231,8 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
pass_rate=float(summary["slo_pass_rate"]),
request_rate=request_rate,
feasible=bool(summary["feasible"]),
early_stopped=early_stopped,
early_stop_reason=early_stop_reason,
outcomes=[
{
"request_id": outcome.request_id,
@@ -166,6 +253,8 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
"pass_rate": payload.pass_rate,
"request_rate": payload.request_rate,
"feasible": payload.feasible,
"early_stopped": payload.early_stopped,
"early_stop_reason": payload.early_stop_reason,
}
probe_history.append(probe_record)
StudyStore.write_json(Path(trial.probe_log_path), probe_history)
@@ -199,6 +288,8 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
"request_count": probe.payload.request_count,
"pass_rate": probe.payload.pass_rate,
"request_rate": probe.payload.request_rate,
"early_stopped": probe.payload.early_stopped,
"early_stop_reason": probe.payload.early_stop_reason,
},
}
for probe in search.probes

View File

@@ -15,6 +15,8 @@ from aituner.slo import RequestOutcome, summarize_evaluations
from aituner.spec import Proposal, load_study_spec
from aituner.store import StudyStore
from aituner.trace import load_trace_requests, summarize_window
from aituner.worker import _replay_requests
from aituner.trace import TraceRequest
def _write_study_assets(tmp_path: Path) -> Path:
@@ -668,6 +670,53 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(state.best_request_rate, 2.0)
self.assertEqual(state.next_trial_index, 3)
def test_replay_requests_early_stops_when_slo_is_unrecoverable(self) -> None:
requests = [
TraceRequest(
row_id=f"r{i}",
arrival_s=0.0,
sampling_u=0.1 * i,
body={"model": "m", "messages": [{"role": "user", "content": "x"}]},
prompt_tokens_hint=8,
completion_tokens_hint=4,
)
for i in range(3)
]
outcomes = [
RequestOutcome(
request_id="r0",
success=False,
ttft_ms=None,
tpot_ms=None,
prompt_tokens=8,
completion_tokens=4,
error="request_failed",
)
]
def fake_run_one_request(*args, **kwargs):
return outcomes.pop(0)
def fake_evaluate(outcome: RequestOutcome):
return type("Eval", (), {"passed": outcome.success})()
with mock.patch("aituner.worker._run_one_request", side_effect=fake_run_one_request):
replayed, early_stopped, reason = _replay_requests(
requests,
base_url="http://127.0.0.1:8000",
timeout_s=1.0,
max_concurrency=1,
target_pass_rate=0.95,
max_lag_s=None,
max_elapsed_s=None,
evaluate_outcome=fake_evaluate,
)
self.assertTrue(early_stopped)
self.assertEqual(reason, "slo_pass_rate_unrecoverable")
self.assertEqual(len(replayed), 3)
self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable")
if __name__ == "__main__":
unittest.main()