Stop waiting on in-flight requests after early stop

This commit is contained in:
2026-04-05 00:56:26 +08:00
parent 75a9842f1a
commit e00bedb466
2 changed files with 108 additions and 8 deletions

View File

@@ -150,21 +150,19 @@ def _replay_requests(
sleep_for = max(0.0, requests[next_index].arrival_s - elapsed)
if sleep_for > 0:
time.sleep(min(sleep_for, 0.1))
for future, request in list(futures_by_request.items()):
try:
outcome = future.result(timeout=timeout_s)
except Exception: # noqa: BLE001
outcome = RequestOutcome(
if early_stopped:
for future in list(futures_by_request):
future.cancel()
for request in futures_by_request.values():
outcomes_by_id[request.row_id] = RequestOutcome(
request_id=request.row_id,
success=False,
ttft_ms=None,
tpot_ms=None,
prompt_tokens=request.prompt_tokens_hint,
completion_tokens=request.completion_tokens_hint,
error="probe_early_stopped",
error=early_stop_reason or "probe_early_stopped",
)
outcomes_by_id[outcome.request_id] = outcome
if early_stopped:
for request in requests:
if request.row_id in submitted_ids:
continue
@@ -177,6 +175,21 @@ def _replay_requests(
completion_tokens=request.completion_tokens_hint,
error=early_stop_reason or "probe_early_stopped",
)
else:
for future, request in list(futures_by_request.items()):
try:
outcome = future.result(timeout=timeout_s)
except Exception: # noqa: BLE001
outcome = RequestOutcome(
request_id=request.row_id,
success=False,
ttft_ms=None,
tpot_ms=None,
prompt_tokens=request.prompt_tokens_hint,
completion_tokens=request.completion_tokens_hint,
error="probe_early_stopped",
)
outcomes_by_id[outcome.request_id] = outcome
finally:
pool.shutdown(wait=False, cancel_futures=True)
ordered = [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id]

View File

@@ -812,6 +812,93 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(len(replayed), 3)
self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable")
def test_replay_requests_does_not_wait_for_inflight_after_early_stop(self) -> None:
requests = [
TraceRequest(
row_id="r0",
arrival_s=0.0,
sampling_u=0.1,
body={"model": "m", "messages": [{"role": "user", "content": "x"}]},
prompt_tokens_hint=8,
completion_tokens_hint=4,
),
TraceRequest(
row_id="r1",
arrival_s=0.0,
sampling_u=0.2,
body={"model": "m", "messages": [{"role": "user", "content": "y"}]},
prompt_tokens_hint=8,
completion_tokens_hint=4,
),
]
class FakeFuture:
def __init__(self, outcome=None, *, should_fail_if_waited=False):
self._outcome = outcome
self._should_fail_if_waited = should_fail_if_waited
def result(self, timeout=None):
if self._should_fail_if_waited:
raise AssertionError("in-flight future should not be awaited after early stop")
return self._outcome
def cancel(self):
return True
done_future = FakeFuture(
RequestOutcome(
request_id="r0",
success=False,
ttft_ms=None,
tpot_ms=None,
prompt_tokens=8,
completion_tokens=4,
error="request_failed",
)
)
inflight_future = FakeFuture(should_fail_if_waited=True)
submitted = []
class FakeExecutor:
def __init__(self, max_workers):
self.max_workers = max_workers
def submit(self, fn, request, **kwargs):
submitted.append(request.row_id)
if request.row_id == "r0":
return done_future
return inflight_future
def shutdown(self, wait=False, cancel_futures=True):
return None
def fake_wait(futures, timeout=None, return_when=None):
self.assertEqual(len(futures), 2)
return {done_future}, {inflight_future}
def fake_evaluate(outcome: RequestOutcome):
return type("Eval", (), {"passed": outcome.success})()
with mock.patch("aituner.worker.ThreadPoolExecutor", FakeExecutor):
with mock.patch("aituner.worker.wait", side_effect=fake_wait):
replayed, early_stopped, reason = _replay_requests(
requests,
base_url="http://127.0.0.1:8000",
timeout_s=30.0,
max_concurrency=2,
target_pass_rate=0.95,
max_lag_s=None,
max_elapsed_s=None,
evaluate_outcome=fake_evaluate,
)
self.assertEqual(submitted, ["r0", "r1"])
self.assertTrue(early_stopped)
self.assertEqual(reason, "slo_pass_rate_unrecoverable")
self.assertEqual(len(replayed), 2)
self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable")
def test_wait_for_server_or_exit_fails_fast_when_process_exits(self) -> None:
process = mock.Mock()
process.poll.return_value = 17