Stop waiting on in-flight requests after early stop
This commit is contained in:
@@ -150,6 +150,32 @@ def _replay_requests(
|
||||
sleep_for = max(0.0, requests[next_index].arrival_s - elapsed)
|
||||
if sleep_for > 0:
|
||||
time.sleep(min(sleep_for, 0.1))
|
||||
if early_stopped:
|
||||
for future in list(futures_by_request):
|
||||
future.cancel()
|
||||
for request in futures_by_request.values():
|
||||
outcomes_by_id[request.row_id] = RequestOutcome(
|
||||
request_id=request.row_id,
|
||||
success=False,
|
||||
ttft_ms=None,
|
||||
tpot_ms=None,
|
||||
prompt_tokens=request.prompt_tokens_hint,
|
||||
completion_tokens=request.completion_tokens_hint,
|
||||
error=early_stop_reason or "probe_early_stopped",
|
||||
)
|
||||
for request in requests:
|
||||
if request.row_id in submitted_ids:
|
||||
continue
|
||||
outcomes_by_id[request.row_id] = RequestOutcome(
|
||||
request_id=request.row_id,
|
||||
success=False,
|
||||
ttft_ms=None,
|
||||
tpot_ms=None,
|
||||
prompt_tokens=request.prompt_tokens_hint,
|
||||
completion_tokens=request.completion_tokens_hint,
|
||||
error=early_stop_reason or "probe_early_stopped",
|
||||
)
|
||||
else:
|
||||
for future, request in list(futures_by_request.items()):
|
||||
try:
|
||||
outcome = future.result(timeout=timeout_s)
|
||||
@@ -164,19 +190,6 @@ def _replay_requests(
|
||||
error="probe_early_stopped",
|
||||
)
|
||||
outcomes_by_id[outcome.request_id] = outcome
|
||||
if early_stopped:
|
||||
for request in requests:
|
||||
if request.row_id in submitted_ids:
|
||||
continue
|
||||
outcomes_by_id[request.row_id] = RequestOutcome(
|
||||
request_id=request.row_id,
|
||||
success=False,
|
||||
ttft_ms=None,
|
||||
tpot_ms=None,
|
||||
prompt_tokens=request.prompt_tokens_hint,
|
||||
completion_tokens=request.completion_tokens_hint,
|
||||
error=early_stop_reason or "probe_early_stopped",
|
||||
)
|
||||
finally:
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
ordered = [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id]
|
||||
|
||||
@@ -812,6 +812,93 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertEqual(len(replayed), 3)
|
||||
self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable")
|
||||
|
||||
def test_replay_requests_does_not_wait_for_inflight_after_early_stop(self) -> None:
|
||||
requests = [
|
||||
TraceRequest(
|
||||
row_id="r0",
|
||||
arrival_s=0.0,
|
||||
sampling_u=0.1,
|
||||
body={"model": "m", "messages": [{"role": "user", "content": "x"}]},
|
||||
prompt_tokens_hint=8,
|
||||
completion_tokens_hint=4,
|
||||
),
|
||||
TraceRequest(
|
||||
row_id="r1",
|
||||
arrival_s=0.0,
|
||||
sampling_u=0.2,
|
||||
body={"model": "m", "messages": [{"role": "user", "content": "y"}]},
|
||||
prompt_tokens_hint=8,
|
||||
completion_tokens_hint=4,
|
||||
),
|
||||
]
|
||||
|
||||
class FakeFuture:
|
||||
def __init__(self, outcome=None, *, should_fail_if_waited=False):
|
||||
self._outcome = outcome
|
||||
self._should_fail_if_waited = should_fail_if_waited
|
||||
|
||||
def result(self, timeout=None):
|
||||
if self._should_fail_if_waited:
|
||||
raise AssertionError("in-flight future should not be awaited after early stop")
|
||||
return self._outcome
|
||||
|
||||
def cancel(self):
|
||||
return True
|
||||
|
||||
done_future = FakeFuture(
|
||||
RequestOutcome(
|
||||
request_id="r0",
|
||||
success=False,
|
||||
ttft_ms=None,
|
||||
tpot_ms=None,
|
||||
prompt_tokens=8,
|
||||
completion_tokens=4,
|
||||
error="request_failed",
|
||||
)
|
||||
)
|
||||
inflight_future = FakeFuture(should_fail_if_waited=True)
|
||||
|
||||
submitted = []
|
||||
|
||||
class FakeExecutor:
|
||||
def __init__(self, max_workers):
|
||||
self.max_workers = max_workers
|
||||
|
||||
def submit(self, fn, request, **kwargs):
|
||||
submitted.append(request.row_id)
|
||||
if request.row_id == "r0":
|
||||
return done_future
|
||||
return inflight_future
|
||||
|
||||
def shutdown(self, wait=False, cancel_futures=True):
|
||||
return None
|
||||
|
||||
def fake_wait(futures, timeout=None, return_when=None):
|
||||
self.assertEqual(len(futures), 2)
|
||||
return {done_future}, {inflight_future}
|
||||
|
||||
def fake_evaluate(outcome: RequestOutcome):
|
||||
return type("Eval", (), {"passed": outcome.success})()
|
||||
|
||||
with mock.patch("aituner.worker.ThreadPoolExecutor", FakeExecutor):
|
||||
with mock.patch("aituner.worker.wait", side_effect=fake_wait):
|
||||
replayed, early_stopped, reason = _replay_requests(
|
||||
requests,
|
||||
base_url="http://127.0.0.1:8000",
|
||||
timeout_s=30.0,
|
||||
max_concurrency=2,
|
||||
target_pass_rate=0.95,
|
||||
max_lag_s=None,
|
||||
max_elapsed_s=None,
|
||||
evaluate_outcome=fake_evaluate,
|
||||
)
|
||||
|
||||
self.assertEqual(submitted, ["r0", "r1"])
|
||||
self.assertTrue(early_stopped)
|
||||
self.assertEqual(reason, "slo_pass_rate_unrecoverable")
|
||||
self.assertEqual(len(replayed), 2)
|
||||
self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable")
|
||||
|
||||
def test_wait_for_server_or_exit_fails_fast_when_process_exits(self) -> None:
|
||||
process = mock.Mock()
|
||||
process.poll.return_value = 17
|
||||
|
||||
Reference in New Issue
Block a user