Add probe early stop guards
This commit is contained in:
@@ -15,6 +15,8 @@ from aituner.slo import RequestOutcome, summarize_evaluations
|
||||
from aituner.spec import Proposal, load_study_spec
|
||||
from aituner.store import StudyStore
|
||||
from aituner.trace import load_trace_requests, summarize_window
|
||||
from aituner.worker import _replay_requests
|
||||
from aituner.trace import TraceRequest
|
||||
|
||||
|
||||
def _write_study_assets(tmp_path: Path) -> Path:
|
||||
@@ -668,6 +670,53 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertEqual(state.best_request_rate, 2.0)
|
||||
self.assertEqual(state.next_trial_index, 3)
|
||||
|
||||
def test_replay_requests_early_stops_when_slo_is_unrecoverable(self) -> None:
|
||||
requests = [
|
||||
TraceRequest(
|
||||
row_id=f"r{i}",
|
||||
arrival_s=0.0,
|
||||
sampling_u=0.1 * i,
|
||||
body={"model": "m", "messages": [{"role": "user", "content": "x"}]},
|
||||
prompt_tokens_hint=8,
|
||||
completion_tokens_hint=4,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
outcomes = [
|
||||
RequestOutcome(
|
||||
request_id="r0",
|
||||
success=False,
|
||||
ttft_ms=None,
|
||||
tpot_ms=None,
|
||||
prompt_tokens=8,
|
||||
completion_tokens=4,
|
||||
error="request_failed",
|
||||
)
|
||||
]
|
||||
|
||||
def fake_run_one_request(*args, **kwargs):
|
||||
return outcomes.pop(0)
|
||||
|
||||
def fake_evaluate(outcome: RequestOutcome):
|
||||
return type("Eval", (), {"passed": outcome.success})()
|
||||
|
||||
with mock.patch("aituner.worker._run_one_request", side_effect=fake_run_one_request):
|
||||
replayed, early_stopped, reason = _replay_requests(
|
||||
requests,
|
||||
base_url="http://127.0.0.1:8000",
|
||||
timeout_s=1.0,
|
||||
max_concurrency=1,
|
||||
target_pass_rate=0.95,
|
||||
max_lag_s=None,
|
||||
max_elapsed_s=None,
|
||||
evaluate_outcome=fake_evaluate,
|
||||
)
|
||||
self.assertTrue(early_stopped)
|
||||
self.assertEqual(reason, "slo_pass_rate_unrecoverable")
|
||||
self.assertEqual(len(replayed), 3)
|
||||
self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user