diff --git a/configs/examples/dash0_smoke_study.json b/configs/examples/dash0_smoke_study.json index 868ae70..0f745a5 100644 --- a/configs/examples/dash0_smoke_study.json +++ b/configs/examples/dash0_smoke_study.json @@ -55,7 +55,9 @@ "timestamp_field": "timestamp", "max_concurrency": 2, "max_requests_per_probe": 24, - "replay_time_scale": 0.02 + "replay_time_scale": 0.02, + "early_stop_max_lag_s": 5.0, + "early_stop_max_elapsed_s": 60.0 }, "slo": { "target_pass_rate": 0.95, diff --git a/src/aituner/spec.py b/src/aituner/spec.py index 895aa91..dffe693 100644 --- a/src/aituner/spec.py +++ b/src/aituner/spec.py @@ -153,6 +153,8 @@ class TraceSpec: max_requests_per_probe: int | None = None synthetic_prompt_cap_tokens: int | None = None replay_time_scale: float = 1.0 + early_stop_max_lag_s: float | None = None + early_stop_max_elapsed_s: float | None = None @classmethod def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec": @@ -176,6 +178,21 @@ class TraceSpec: replay_time_scale=_require_float( data.get("replay_time_scale", 1.0), context="trace.replay_time_scale" ), + early_stop_max_lag_s=( + _require_float( + data.get("early_stop_max_lag_s"), context="trace.early_stop_max_lag_s" + ) + if data.get("early_stop_max_lag_s") is not None + else None + ), + early_stop_max_elapsed_s=( + _require_float( + data.get("early_stop_max_elapsed_s"), + context="trace.early_stop_max_elapsed_s", + ) + if data.get("early_stop_max_elapsed_s") is not None + else None + ), ) diff --git a/src/aituner/worker.py b/src/aituner/worker.py index d02470f..33682dd 100644 --- a/src/aituner/worker.py +++ b/src/aituner/worker.py @@ -1,18 +1,19 @@ from __future__ import annotations import json +import math import subprocess import threading import time -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, Callable from .engine import build_launch_recipe from .http_client import HttpClientError, stream_chat_completion, wait_for_server from .search import ThresholdProbe, binary_search_max_feasible -from .slo import RequestOutcome, summarize_evaluations +from .slo import RequestOutcome, evaluate_request, summarize_evaluations from .spec import ConfigPatch, SamplingSearchSpec, TrialSpec, load_study_spec from .trace import TraceRequest, load_trace_requests, select_requests_for_threshold @@ -25,6 +26,8 @@ class ProbePayload: request_rate: float feasible: bool outcomes: list[dict[str, Any]] + early_stopped: bool = False + early_stop_reason: str = "" def _trial_spec_from_json(path: Path) -> TrialSpec: payload = json.loads(path.read_text(encoding="utf-8")) @@ -75,31 +78,109 @@ def _replay_requests( base_url: str, timeout_s: float, max_concurrency: int, -) -> list[RequestOutcome]: + target_pass_rate: float, + max_lag_s: float | None, + max_elapsed_s: float | None, + evaluate_outcome: Callable[[RequestOutcome], Any], +) -> tuple[list[RequestOutcome], bool, str]: outcomes_by_id: dict[str, RequestOutcome] = {} lock = threading.Lock() start = time.monotonic() - with ThreadPoolExecutor(max_workers=max_concurrency) as pool: - futures = [] - for request in requests: - delay = max(0.0, request.arrival_s) + allowed_failures = max(0, len(requests) - math.ceil(len(requests) * target_pass_rate)) + failed_evaluations = 0 + early_stopped = False + early_stop_reason = "" + next_index = 0 + futures_by_request: dict[Future[RequestOutcome], TraceRequest] = {} + submitted_ids: set[str] = set() + pool = ThreadPoolExecutor(max_workers=max_concurrency) + try: + while next_index < len(requests) or futures_by_request: now = time.monotonic() - sleep_for = (start + delay) - now - if sleep_for > 0: - time.sleep(sleep_for) - futures.append( - pool.submit( + elapsed = now - start + if max_elapsed_s is not None and elapsed > max_elapsed_s: + early_stopped = True + early_stop_reason = f"probe_elapsed_s>{max_elapsed_s}" + break + while next_index < len(requests): + request = requests[next_index] + lag_s = elapsed - request.arrival_s + if max_lag_s is not None and lag_s > max_lag_s: + early_stopped = True + early_stop_reason = f"arrival_lag_s>{max_lag_s}" + break + if request.arrival_s > elapsed: + break + future = pool.submit( _run_one_request, request, base_url=base_url, timeout_s=timeout_s, ) - ) - for future in as_completed(futures): - outcome = future.result() - with lock: - outcomes_by_id[outcome.request_id] = outcome - return [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id] + futures_by_request[future] = request + submitted_ids.add(request.row_id) + next_index += 1 + if len(futures_by_request) >= max_concurrency: + break + if early_stopped: + break + if futures_by_request: + timeout = None + if next_index < len(requests): + timeout = max(0.0, requests[next_index].arrival_s - elapsed) + done, _ = wait( + list(futures_by_request), + timeout=timeout, + return_when=FIRST_COMPLETED, + ) + for future in done: + request = futures_by_request.pop(future) + outcome = future.result() + with lock: + outcomes_by_id[outcome.request_id] = outcome + if not evaluate_outcome(outcome).passed: + failed_evaluations += 1 + if failed_evaluations > allowed_failures: + early_stopped = True + early_stop_reason = "slo_pass_rate_unrecoverable" + break + if early_stopped: + break + elif next_index < len(requests): + sleep_for = max(0.0, requests[next_index].arrival_s - elapsed) + if sleep_for > 0: + time.sleep(min(sleep_for, 0.1)) + for future, request in list(futures_by_request.items()): + try: + outcome = future.result(timeout=timeout_s) + except Exception: # noqa: BLE001 + outcome = RequestOutcome( + request_id=request.row_id, + success=False, + ttft_ms=None, + tpot_ms=None, + prompt_tokens=request.prompt_tokens_hint, + completion_tokens=request.completion_tokens_hint, + error="probe_early_stopped", + ) + outcomes_by_id[outcome.request_id] = outcome + if early_stopped: + for request in requests: + if request.row_id in submitted_ids: + continue + outcomes_by_id[request.row_id] = RequestOutcome( + request_id=request.row_id, + success=False, + ttft_ms=None, + tpot_ms=None, + prompt_tokens=request.prompt_tokens_hint, + completion_tokens=request.completion_tokens_hint, + error=early_stop_reason or "probe_early_stopped", + ) + finally: + pool.shutdown(wait=False, cancel_futures=True) + ordered = [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id] + return ordered, early_stopped, early_stop_reason def run_trial(trial_spec_path: Path) -> dict[str, Any]: @@ -128,11 +209,15 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]: def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]: selected = select_requests_for_threshold(requests, threshold=threshold) - outcomes = _replay_requests( + outcomes, early_stopped, early_stop_reason = _replay_requests( selected, base_url=recipe.base_url, timeout_s=recipe.request_timeout_s, max_concurrency=study.trace.max_concurrency, + target_pass_rate=study.slo.target_pass_rate, + max_lag_s=study.trace.early_stop_max_lag_s, + max_elapsed_s=study.trace.early_stop_max_elapsed_s, + evaluate_outcome=lambda outcome: evaluate_request(outcome, study.slo), ) evaluations, summary = summarize_evaluations(outcomes, study.slo) request_rate = ( @@ -146,6 +231,8 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]: pass_rate=float(summary["slo_pass_rate"]), request_rate=request_rate, feasible=bool(summary["feasible"]), + early_stopped=early_stopped, + early_stop_reason=early_stop_reason, outcomes=[ { "request_id": outcome.request_id, @@ -166,6 +253,8 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]: "pass_rate": payload.pass_rate, "request_rate": payload.request_rate, "feasible": payload.feasible, + "early_stopped": payload.early_stopped, + "early_stop_reason": payload.early_stop_reason, } probe_history.append(probe_record) StudyStore.write_json(Path(trial.probe_log_path), probe_history) @@ -199,6 +288,8 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]: "request_count": probe.payload.request_count, "pass_rate": probe.payload.pass_rate, "request_rate": probe.payload.request_rate, + "early_stopped": probe.payload.early_stopped, + "early_stop_reason": probe.payload.early_stop_reason, }, } for probe in search.probes diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index b514ee6..ed81292 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -15,6 +15,8 @@ from aituner.slo import RequestOutcome, summarize_evaluations from aituner.spec import Proposal, load_study_spec from aituner.store import StudyStore from aituner.trace import load_trace_requests, summarize_window +from aituner.worker import _replay_requests +from aituner.trace import TraceRequest def _write_study_assets(tmp_path: Path) -> Path: @@ -668,6 +670,53 @@ class CoreFlowTests(unittest.TestCase): self.assertEqual(state.best_request_rate, 2.0) self.assertEqual(state.next_trial_index, 3) + def test_replay_requests_early_stops_when_slo_is_unrecoverable(self) -> None: + requests = [ + TraceRequest( + row_id=f"r{i}", + arrival_s=0.0, + sampling_u=0.1 * i, + body={"model": "m", "messages": [{"role": "user", "content": "x"}]}, + prompt_tokens_hint=8, + completion_tokens_hint=4, + ) + for i in range(3) + ] + + outcomes = [ + RequestOutcome( + request_id="r0", + success=False, + ttft_ms=None, + tpot_ms=None, + prompt_tokens=8, + completion_tokens=4, + error="request_failed", + ) + ] + + def fake_run_one_request(*args, **kwargs): + return outcomes.pop(0) + + def fake_evaluate(outcome: RequestOutcome): + return type("Eval", (), {"passed": outcome.success})() + + with mock.patch("aituner.worker._run_one_request", side_effect=fake_run_one_request): + replayed, early_stopped, reason = _replay_requests( + requests, + base_url="http://127.0.0.1:8000", + timeout_s=1.0, + max_concurrency=1, + target_pass_rate=0.95, + max_lag_s=None, + max_elapsed_s=None, + evaluate_outcome=fake_evaluate, + ) + self.assertTrue(early_stopped) + self.assertEqual(reason, "slo_pass_rate_unrecoverable") + self.assertEqual(len(replayed), 3) + self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable") + if __name__ == "__main__": unittest.main()