Add probe early stop guards
This commit is contained in:
@@ -55,7 +55,9 @@
|
|||||||
"timestamp_field": "timestamp",
|
"timestamp_field": "timestamp",
|
||||||
"max_concurrency": 2,
|
"max_concurrency": 2,
|
||||||
"max_requests_per_probe": 24,
|
"max_requests_per_probe": 24,
|
||||||
"replay_time_scale": 0.02
|
"replay_time_scale": 0.02,
|
||||||
|
"early_stop_max_lag_s": 5.0,
|
||||||
|
"early_stop_max_elapsed_s": 60.0
|
||||||
},
|
},
|
||||||
"slo": {
|
"slo": {
|
||||||
"target_pass_rate": 0.95,
|
"target_pass_rate": 0.95,
|
||||||
|
|||||||
@@ -153,6 +153,8 @@ class TraceSpec:
|
|||||||
max_requests_per_probe: int | None = None
|
max_requests_per_probe: int | None = None
|
||||||
synthetic_prompt_cap_tokens: int | None = None
|
synthetic_prompt_cap_tokens: int | None = None
|
||||||
replay_time_scale: float = 1.0
|
replay_time_scale: float = 1.0
|
||||||
|
early_stop_max_lag_s: float | None = None
|
||||||
|
early_stop_max_elapsed_s: float | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
|
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
|
||||||
@@ -176,6 +178,21 @@ class TraceSpec:
|
|||||||
replay_time_scale=_require_float(
|
replay_time_scale=_require_float(
|
||||||
data.get("replay_time_scale", 1.0), context="trace.replay_time_scale"
|
data.get("replay_time_scale", 1.0), context="trace.replay_time_scale"
|
||||||
),
|
),
|
||||||
|
early_stop_max_lag_s=(
|
||||||
|
_require_float(
|
||||||
|
data.get("early_stop_max_lag_s"), context="trace.early_stop_max_lag_s"
|
||||||
|
)
|
||||||
|
if data.get("early_stop_max_lag_s") is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
early_stop_max_elapsed_s=(
|
||||||
|
_require_float(
|
||||||
|
data.get("early_stop_max_elapsed_s"),
|
||||||
|
context="trace.early_stop_max_elapsed_s",
|
||||||
|
)
|
||||||
|
if data.get("early_stop_max_elapsed_s") is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Callable
|
||||||
|
|
||||||
from .engine import build_launch_recipe
|
from .engine import build_launch_recipe
|
||||||
from .http_client import HttpClientError, stream_chat_completion, wait_for_server
|
from .http_client import HttpClientError, stream_chat_completion, wait_for_server
|
||||||
from .search import ThresholdProbe, binary_search_max_feasible
|
from .search import ThresholdProbe, binary_search_max_feasible
|
||||||
from .slo import RequestOutcome, summarize_evaluations
|
from .slo import RequestOutcome, evaluate_request, summarize_evaluations
|
||||||
from .spec import ConfigPatch, SamplingSearchSpec, TrialSpec, load_study_spec
|
from .spec import ConfigPatch, SamplingSearchSpec, TrialSpec, load_study_spec
|
||||||
from .trace import TraceRequest, load_trace_requests, select_requests_for_threshold
|
from .trace import TraceRequest, load_trace_requests, select_requests_for_threshold
|
||||||
|
|
||||||
@@ -25,6 +26,8 @@ class ProbePayload:
|
|||||||
request_rate: float
|
request_rate: float
|
||||||
feasible: bool
|
feasible: bool
|
||||||
outcomes: list[dict[str, Any]]
|
outcomes: list[dict[str, Any]]
|
||||||
|
early_stopped: bool = False
|
||||||
|
early_stop_reason: str = ""
|
||||||
|
|
||||||
def _trial_spec_from_json(path: Path) -> TrialSpec:
|
def _trial_spec_from_json(path: Path) -> TrialSpec:
|
||||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||||
@@ -75,31 +78,109 @@ def _replay_requests(
|
|||||||
base_url: str,
|
base_url: str,
|
||||||
timeout_s: float,
|
timeout_s: float,
|
||||||
max_concurrency: int,
|
max_concurrency: int,
|
||||||
) -> list[RequestOutcome]:
|
target_pass_rate: float,
|
||||||
|
max_lag_s: float | None,
|
||||||
|
max_elapsed_s: float | None,
|
||||||
|
evaluate_outcome: Callable[[RequestOutcome], Any],
|
||||||
|
) -> tuple[list[RequestOutcome], bool, str]:
|
||||||
outcomes_by_id: dict[str, RequestOutcome] = {}
|
outcomes_by_id: dict[str, RequestOutcome] = {}
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
start = time.monotonic()
|
start = time.monotonic()
|
||||||
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
|
allowed_failures = max(0, len(requests) - math.ceil(len(requests) * target_pass_rate))
|
||||||
futures = []
|
failed_evaluations = 0
|
||||||
for request in requests:
|
early_stopped = False
|
||||||
delay = max(0.0, request.arrival_s)
|
early_stop_reason = ""
|
||||||
|
next_index = 0
|
||||||
|
futures_by_request: dict[Future[RequestOutcome], TraceRequest] = {}
|
||||||
|
submitted_ids: set[str] = set()
|
||||||
|
pool = ThreadPoolExecutor(max_workers=max_concurrency)
|
||||||
|
try:
|
||||||
|
while next_index < len(requests) or futures_by_request:
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
sleep_for = (start + delay) - now
|
elapsed = now - start
|
||||||
if sleep_for > 0:
|
if max_elapsed_s is not None and elapsed > max_elapsed_s:
|
||||||
time.sleep(sleep_for)
|
early_stopped = True
|
||||||
futures.append(
|
early_stop_reason = f"probe_elapsed_s>{max_elapsed_s}"
|
||||||
pool.submit(
|
break
|
||||||
|
while next_index < len(requests):
|
||||||
|
request = requests[next_index]
|
||||||
|
lag_s = elapsed - request.arrival_s
|
||||||
|
if max_lag_s is not None and lag_s > max_lag_s:
|
||||||
|
early_stopped = True
|
||||||
|
early_stop_reason = f"arrival_lag_s>{max_lag_s}"
|
||||||
|
break
|
||||||
|
if request.arrival_s > elapsed:
|
||||||
|
break
|
||||||
|
future = pool.submit(
|
||||||
_run_one_request,
|
_run_one_request,
|
||||||
request,
|
request,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
timeout_s=timeout_s,
|
timeout_s=timeout_s,
|
||||||
)
|
)
|
||||||
|
futures_by_request[future] = request
|
||||||
|
submitted_ids.add(request.row_id)
|
||||||
|
next_index += 1
|
||||||
|
if len(futures_by_request) >= max_concurrency:
|
||||||
|
break
|
||||||
|
if early_stopped:
|
||||||
|
break
|
||||||
|
if futures_by_request:
|
||||||
|
timeout = None
|
||||||
|
if next_index < len(requests):
|
||||||
|
timeout = max(0.0, requests[next_index].arrival_s - elapsed)
|
||||||
|
done, _ = wait(
|
||||||
|
list(futures_by_request),
|
||||||
|
timeout=timeout,
|
||||||
|
return_when=FIRST_COMPLETED,
|
||||||
)
|
)
|
||||||
for future in as_completed(futures):
|
for future in done:
|
||||||
|
request = futures_by_request.pop(future)
|
||||||
outcome = future.result()
|
outcome = future.result()
|
||||||
with lock:
|
with lock:
|
||||||
outcomes_by_id[outcome.request_id] = outcome
|
outcomes_by_id[outcome.request_id] = outcome
|
||||||
return [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id]
|
if not evaluate_outcome(outcome).passed:
|
||||||
|
failed_evaluations += 1
|
||||||
|
if failed_evaluations > allowed_failures:
|
||||||
|
early_stopped = True
|
||||||
|
early_stop_reason = "slo_pass_rate_unrecoverable"
|
||||||
|
break
|
||||||
|
if early_stopped:
|
||||||
|
break
|
||||||
|
elif next_index < len(requests):
|
||||||
|
sleep_for = max(0.0, requests[next_index].arrival_s - elapsed)
|
||||||
|
if sleep_for > 0:
|
||||||
|
time.sleep(min(sleep_for, 0.1))
|
||||||
|
for future, request in list(futures_by_request.items()):
|
||||||
|
try:
|
||||||
|
outcome = future.result(timeout=timeout_s)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
outcome = RequestOutcome(
|
||||||
|
request_id=request.row_id,
|
||||||
|
success=False,
|
||||||
|
ttft_ms=None,
|
||||||
|
tpot_ms=None,
|
||||||
|
prompt_tokens=request.prompt_tokens_hint,
|
||||||
|
completion_tokens=request.completion_tokens_hint,
|
||||||
|
error="probe_early_stopped",
|
||||||
|
)
|
||||||
|
outcomes_by_id[outcome.request_id] = outcome
|
||||||
|
if early_stopped:
|
||||||
|
for request in requests:
|
||||||
|
if request.row_id in submitted_ids:
|
||||||
|
continue
|
||||||
|
outcomes_by_id[request.row_id] = RequestOutcome(
|
||||||
|
request_id=request.row_id,
|
||||||
|
success=False,
|
||||||
|
ttft_ms=None,
|
||||||
|
tpot_ms=None,
|
||||||
|
prompt_tokens=request.prompt_tokens_hint,
|
||||||
|
completion_tokens=request.completion_tokens_hint,
|
||||||
|
error=early_stop_reason or "probe_early_stopped",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
pool.shutdown(wait=False, cancel_futures=True)
|
||||||
|
ordered = [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id]
|
||||||
|
return ordered, early_stopped, early_stop_reason
|
||||||
|
|
||||||
|
|
||||||
def run_trial(trial_spec_path: Path) -> dict[str, Any]:
|
def run_trial(trial_spec_path: Path) -> dict[str, Any]:
|
||||||
@@ -128,11 +209,15 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
|
|||||||
|
|
||||||
def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]:
|
def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]:
|
||||||
selected = select_requests_for_threshold(requests, threshold=threshold)
|
selected = select_requests_for_threshold(requests, threshold=threshold)
|
||||||
outcomes = _replay_requests(
|
outcomes, early_stopped, early_stop_reason = _replay_requests(
|
||||||
selected,
|
selected,
|
||||||
base_url=recipe.base_url,
|
base_url=recipe.base_url,
|
||||||
timeout_s=recipe.request_timeout_s,
|
timeout_s=recipe.request_timeout_s,
|
||||||
max_concurrency=study.trace.max_concurrency,
|
max_concurrency=study.trace.max_concurrency,
|
||||||
|
target_pass_rate=study.slo.target_pass_rate,
|
||||||
|
max_lag_s=study.trace.early_stop_max_lag_s,
|
||||||
|
max_elapsed_s=study.trace.early_stop_max_elapsed_s,
|
||||||
|
evaluate_outcome=lambda outcome: evaluate_request(outcome, study.slo),
|
||||||
)
|
)
|
||||||
evaluations, summary = summarize_evaluations(outcomes, study.slo)
|
evaluations, summary = summarize_evaluations(outcomes, study.slo)
|
||||||
request_rate = (
|
request_rate = (
|
||||||
@@ -146,6 +231,8 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
|
|||||||
pass_rate=float(summary["slo_pass_rate"]),
|
pass_rate=float(summary["slo_pass_rate"]),
|
||||||
request_rate=request_rate,
|
request_rate=request_rate,
|
||||||
feasible=bool(summary["feasible"]),
|
feasible=bool(summary["feasible"]),
|
||||||
|
early_stopped=early_stopped,
|
||||||
|
early_stop_reason=early_stop_reason,
|
||||||
outcomes=[
|
outcomes=[
|
||||||
{
|
{
|
||||||
"request_id": outcome.request_id,
|
"request_id": outcome.request_id,
|
||||||
@@ -166,6 +253,8 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
|
|||||||
"pass_rate": payload.pass_rate,
|
"pass_rate": payload.pass_rate,
|
||||||
"request_rate": payload.request_rate,
|
"request_rate": payload.request_rate,
|
||||||
"feasible": payload.feasible,
|
"feasible": payload.feasible,
|
||||||
|
"early_stopped": payload.early_stopped,
|
||||||
|
"early_stop_reason": payload.early_stop_reason,
|
||||||
}
|
}
|
||||||
probe_history.append(probe_record)
|
probe_history.append(probe_record)
|
||||||
StudyStore.write_json(Path(trial.probe_log_path), probe_history)
|
StudyStore.write_json(Path(trial.probe_log_path), probe_history)
|
||||||
@@ -199,6 +288,8 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
|
|||||||
"request_count": probe.payload.request_count,
|
"request_count": probe.payload.request_count,
|
||||||
"pass_rate": probe.payload.pass_rate,
|
"pass_rate": probe.payload.pass_rate,
|
||||||
"request_rate": probe.payload.request_rate,
|
"request_rate": probe.payload.request_rate,
|
||||||
|
"early_stopped": probe.payload.early_stopped,
|
||||||
|
"early_stop_reason": probe.payload.early_stop_reason,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for probe in search.probes
|
for probe in search.probes
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ from aituner.slo import RequestOutcome, summarize_evaluations
|
|||||||
from aituner.spec import Proposal, load_study_spec
|
from aituner.spec import Proposal, load_study_spec
|
||||||
from aituner.store import StudyStore
|
from aituner.store import StudyStore
|
||||||
from aituner.trace import load_trace_requests, summarize_window
|
from aituner.trace import load_trace_requests, summarize_window
|
||||||
|
from aituner.worker import _replay_requests
|
||||||
|
from aituner.trace import TraceRequest
|
||||||
|
|
||||||
|
|
||||||
def _write_study_assets(tmp_path: Path) -> Path:
|
def _write_study_assets(tmp_path: Path) -> Path:
|
||||||
@@ -668,6 +670,53 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
self.assertEqual(state.best_request_rate, 2.0)
|
self.assertEqual(state.best_request_rate, 2.0)
|
||||||
self.assertEqual(state.next_trial_index, 3)
|
self.assertEqual(state.next_trial_index, 3)
|
||||||
|
|
||||||
|
def test_replay_requests_early_stops_when_slo_is_unrecoverable(self) -> None:
|
||||||
|
requests = [
|
||||||
|
TraceRequest(
|
||||||
|
row_id=f"r{i}",
|
||||||
|
arrival_s=0.0,
|
||||||
|
sampling_u=0.1 * i,
|
||||||
|
body={"model": "m", "messages": [{"role": "user", "content": "x"}]},
|
||||||
|
prompt_tokens_hint=8,
|
||||||
|
completion_tokens_hint=4,
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
outcomes = [
|
||||||
|
RequestOutcome(
|
||||||
|
request_id="r0",
|
||||||
|
success=False,
|
||||||
|
ttft_ms=None,
|
||||||
|
tpot_ms=None,
|
||||||
|
prompt_tokens=8,
|
||||||
|
completion_tokens=4,
|
||||||
|
error="request_failed",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def fake_run_one_request(*args, **kwargs):
|
||||||
|
return outcomes.pop(0)
|
||||||
|
|
||||||
|
def fake_evaluate(outcome: RequestOutcome):
|
||||||
|
return type("Eval", (), {"passed": outcome.success})()
|
||||||
|
|
||||||
|
with mock.patch("aituner.worker._run_one_request", side_effect=fake_run_one_request):
|
||||||
|
replayed, early_stopped, reason = _replay_requests(
|
||||||
|
requests,
|
||||||
|
base_url="http://127.0.0.1:8000",
|
||||||
|
timeout_s=1.0,
|
||||||
|
max_concurrency=1,
|
||||||
|
target_pass_rate=0.95,
|
||||||
|
max_lag_s=None,
|
||||||
|
max_elapsed_s=None,
|
||||||
|
evaluate_outcome=fake_evaluate,
|
||||||
|
)
|
||||||
|
self.assertTrue(early_stopped)
|
||||||
|
self.assertEqual(reason, "slo_pass_rate_unrecoverable")
|
||||||
|
self.assertEqual(len(replayed), 3)
|
||||||
|
self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user