From e00bedb4665819d411228671c949f78c4e831ac0 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Sun, 5 Apr 2026 00:56:26 +0800 Subject: [PATCH] Stop waiting on in-flight requests after early stop --- src/aituner/worker.py | 29 ++++++++++---- tests/test_core_flow.py | 87 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 8 deletions(-) diff --git a/src/aituner/worker.py b/src/aituner/worker.py index 6fd73c1..fef9901 100644 --- a/src/aituner/worker.py +++ b/src/aituner/worker.py @@ -150,21 +150,19 @@ def _replay_requests( sleep_for = max(0.0, requests[next_index].arrival_s - elapsed) if sleep_for > 0: time.sleep(min(sleep_for, 0.1)) - for future, request in list(futures_by_request.items()): - try: - outcome = future.result(timeout=timeout_s) - except Exception: # noqa: BLE001 - outcome = RequestOutcome( + if early_stopped: + for future in list(futures_by_request): + future.cancel() + for request in futures_by_request.values(): + outcomes_by_id[request.row_id] = RequestOutcome( request_id=request.row_id, success=False, ttft_ms=None, tpot_ms=None, prompt_tokens=request.prompt_tokens_hint, completion_tokens=request.completion_tokens_hint, - error="probe_early_stopped", + error=early_stop_reason or "probe_early_stopped", ) - outcomes_by_id[outcome.request_id] = outcome - if early_stopped: for request in requests: if request.row_id in submitted_ids: continue @@ -177,6 +175,21 @@ def _replay_requests( completion_tokens=request.completion_tokens_hint, error=early_stop_reason or "probe_early_stopped", ) + else: + for future, request in list(futures_by_request.items()): + try: + outcome = future.result(timeout=timeout_s) + except Exception: # noqa: BLE001 + outcome = RequestOutcome( + request_id=request.row_id, + success=False, + ttft_ms=None, + tpot_ms=None, + prompt_tokens=request.prompt_tokens_hint, + completion_tokens=request.completion_tokens_hint, + error="probe_early_stopped", + ) + outcomes_by_id[outcome.request_id] = outcome finally: pool.shutdown(wait=False, cancel_futures=True) ordered = [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id] diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index 4b7fa7b..9e45ede 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -812,6 +812,93 @@ class CoreFlowTests(unittest.TestCase): self.assertEqual(len(replayed), 3) self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable") + def test_replay_requests_does_not_wait_for_inflight_after_early_stop(self) -> None: + requests = [ + TraceRequest( + row_id="r0", + arrival_s=0.0, + sampling_u=0.1, + body={"model": "m", "messages": [{"role": "user", "content": "x"}]}, + prompt_tokens_hint=8, + completion_tokens_hint=4, + ), + TraceRequest( + row_id="r1", + arrival_s=0.0, + sampling_u=0.2, + body={"model": "m", "messages": [{"role": "user", "content": "y"}]}, + prompt_tokens_hint=8, + completion_tokens_hint=4, + ), + ] + + class FakeFuture: + def __init__(self, outcome=None, *, should_fail_if_waited=False): + self._outcome = outcome + self._should_fail_if_waited = should_fail_if_waited + + def result(self, timeout=None): + if self._should_fail_if_waited: + raise AssertionError("in-flight future should not be awaited after early stop") + return self._outcome + + def cancel(self): + return True + + done_future = FakeFuture( + RequestOutcome( + request_id="r0", + success=False, + ttft_ms=None, + tpot_ms=None, + prompt_tokens=8, + completion_tokens=4, + error="request_failed", + ) + ) + inflight_future = FakeFuture(should_fail_if_waited=True) + + submitted = [] + + class FakeExecutor: + def __init__(self, max_workers): + self.max_workers = max_workers + + def submit(self, fn, request, **kwargs): + submitted.append(request.row_id) + if request.row_id == "r0": + return done_future + return inflight_future + + def shutdown(self, wait=False, cancel_futures=True): + return None + + def fake_wait(futures, timeout=None, return_when=None): + self.assertEqual(len(futures), 2) + return {done_future}, {inflight_future} + + def fake_evaluate(outcome: RequestOutcome): + return type("Eval", (), {"passed": outcome.success})() + + with mock.patch("aituner.worker.ThreadPoolExecutor", FakeExecutor): + with mock.patch("aituner.worker.wait", side_effect=fake_wait): + replayed, early_stopped, reason = _replay_requests( + requests, + base_url="http://127.0.0.1:8000", + timeout_s=30.0, + max_concurrency=2, + target_pass_rate=0.95, + max_lag_s=None, + max_elapsed_s=None, + evaluate_outcome=fake_evaluate, + ) + + self.assertEqual(submitted, ["r0", "r1"]) + self.assertTrue(early_stopped) + self.assertEqual(reason, "slo_pass_rate_unrecoverable") + self.assertEqual(len(replayed), 2) + self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable") + def test_wait_for_server_or_exit_fails_fast_when_process_exits(self) -> None: process = mock.Mock() process.poll.return_value = 17