Stop waiting on in-flight requests after early stop
This commit is contained in:
@@ -150,21 +150,19 @@ def _replay_requests(
|
|||||||
sleep_for = max(0.0, requests[next_index].arrival_s - elapsed)
|
sleep_for = max(0.0, requests[next_index].arrival_s - elapsed)
|
||||||
if sleep_for > 0:
|
if sleep_for > 0:
|
||||||
time.sleep(min(sleep_for, 0.1))
|
time.sleep(min(sleep_for, 0.1))
|
||||||
for future, request in list(futures_by_request.items()):
|
if early_stopped:
|
||||||
try:
|
for future in list(futures_by_request):
|
||||||
outcome = future.result(timeout=timeout_s)
|
future.cancel()
|
||||||
except Exception: # noqa: BLE001
|
for request in futures_by_request.values():
|
||||||
outcome = RequestOutcome(
|
outcomes_by_id[request.row_id] = RequestOutcome(
|
||||||
request_id=request.row_id,
|
request_id=request.row_id,
|
||||||
success=False,
|
success=False,
|
||||||
ttft_ms=None,
|
ttft_ms=None,
|
||||||
tpot_ms=None,
|
tpot_ms=None,
|
||||||
prompt_tokens=request.prompt_tokens_hint,
|
prompt_tokens=request.prompt_tokens_hint,
|
||||||
completion_tokens=request.completion_tokens_hint,
|
completion_tokens=request.completion_tokens_hint,
|
||||||
error="probe_early_stopped",
|
error=early_stop_reason or "probe_early_stopped",
|
||||||
)
|
)
|
||||||
outcomes_by_id[outcome.request_id] = outcome
|
|
||||||
if early_stopped:
|
|
||||||
for request in requests:
|
for request in requests:
|
||||||
if request.row_id in submitted_ids:
|
if request.row_id in submitted_ids:
|
||||||
continue
|
continue
|
||||||
@@ -177,6 +175,21 @@ def _replay_requests(
|
|||||||
completion_tokens=request.completion_tokens_hint,
|
completion_tokens=request.completion_tokens_hint,
|
||||||
error=early_stop_reason or "probe_early_stopped",
|
error=early_stop_reason or "probe_early_stopped",
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
for future, request in list(futures_by_request.items()):
|
||||||
|
try:
|
||||||
|
outcome = future.result(timeout=timeout_s)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
outcome = RequestOutcome(
|
||||||
|
request_id=request.row_id,
|
||||||
|
success=False,
|
||||||
|
ttft_ms=None,
|
||||||
|
tpot_ms=None,
|
||||||
|
prompt_tokens=request.prompt_tokens_hint,
|
||||||
|
completion_tokens=request.completion_tokens_hint,
|
||||||
|
error="probe_early_stopped",
|
||||||
|
)
|
||||||
|
outcomes_by_id[outcome.request_id] = outcome
|
||||||
finally:
|
finally:
|
||||||
pool.shutdown(wait=False, cancel_futures=True)
|
pool.shutdown(wait=False, cancel_futures=True)
|
||||||
ordered = [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id]
|
ordered = [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id]
|
||||||
|
|||||||
@@ -812,6 +812,93 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
self.assertEqual(len(replayed), 3)
|
self.assertEqual(len(replayed), 3)
|
||||||
self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable")
|
self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable")
|
||||||
|
|
||||||
|
def test_replay_requests_does_not_wait_for_inflight_after_early_stop(self) -> None:
|
||||||
|
requests = [
|
||||||
|
TraceRequest(
|
||||||
|
row_id="r0",
|
||||||
|
arrival_s=0.0,
|
||||||
|
sampling_u=0.1,
|
||||||
|
body={"model": "m", "messages": [{"role": "user", "content": "x"}]},
|
||||||
|
prompt_tokens_hint=8,
|
||||||
|
completion_tokens_hint=4,
|
||||||
|
),
|
||||||
|
TraceRequest(
|
||||||
|
row_id="r1",
|
||||||
|
arrival_s=0.0,
|
||||||
|
sampling_u=0.2,
|
||||||
|
body={"model": "m", "messages": [{"role": "user", "content": "y"}]},
|
||||||
|
prompt_tokens_hint=8,
|
||||||
|
completion_tokens_hint=4,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
class FakeFuture:
|
||||||
|
def __init__(self, outcome=None, *, should_fail_if_waited=False):
|
||||||
|
self._outcome = outcome
|
||||||
|
self._should_fail_if_waited = should_fail_if_waited
|
||||||
|
|
||||||
|
def result(self, timeout=None):
|
||||||
|
if self._should_fail_if_waited:
|
||||||
|
raise AssertionError("in-flight future should not be awaited after early stop")
|
||||||
|
return self._outcome
|
||||||
|
|
||||||
|
def cancel(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
done_future = FakeFuture(
|
||||||
|
RequestOutcome(
|
||||||
|
request_id="r0",
|
||||||
|
success=False,
|
||||||
|
ttft_ms=None,
|
||||||
|
tpot_ms=None,
|
||||||
|
prompt_tokens=8,
|
||||||
|
completion_tokens=4,
|
||||||
|
error="request_failed",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
inflight_future = FakeFuture(should_fail_if_waited=True)
|
||||||
|
|
||||||
|
submitted = []
|
||||||
|
|
||||||
|
class FakeExecutor:
|
||||||
|
def __init__(self, max_workers):
|
||||||
|
self.max_workers = max_workers
|
||||||
|
|
||||||
|
def submit(self, fn, request, **kwargs):
|
||||||
|
submitted.append(request.row_id)
|
||||||
|
if request.row_id == "r0":
|
||||||
|
return done_future
|
||||||
|
return inflight_future
|
||||||
|
|
||||||
|
def shutdown(self, wait=False, cancel_futures=True):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def fake_wait(futures, timeout=None, return_when=None):
|
||||||
|
self.assertEqual(len(futures), 2)
|
||||||
|
return {done_future}, {inflight_future}
|
||||||
|
|
||||||
|
def fake_evaluate(outcome: RequestOutcome):
|
||||||
|
return type("Eval", (), {"passed": outcome.success})()
|
||||||
|
|
||||||
|
with mock.patch("aituner.worker.ThreadPoolExecutor", FakeExecutor):
|
||||||
|
with mock.patch("aituner.worker.wait", side_effect=fake_wait):
|
||||||
|
replayed, early_stopped, reason = _replay_requests(
|
||||||
|
requests,
|
||||||
|
base_url="http://127.0.0.1:8000",
|
||||||
|
timeout_s=30.0,
|
||||||
|
max_concurrency=2,
|
||||||
|
target_pass_rate=0.95,
|
||||||
|
max_lag_s=None,
|
||||||
|
max_elapsed_s=None,
|
||||||
|
evaluate_outcome=fake_evaluate,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(submitted, ["r0", "r1"])
|
||||||
|
self.assertTrue(early_stopped)
|
||||||
|
self.assertEqual(reason, "slo_pass_rate_unrecoverable")
|
||||||
|
self.assertEqual(len(replayed), 2)
|
||||||
|
self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable")
|
||||||
|
|
||||||
def test_wait_for_server_or_exit_fails_fast_when_process_exits(self) -> None:
|
def test_wait_for_server_or_exit_fails_fast_when_process_exits(self) -> None:
|
||||||
process = mock.Mock()
|
process = mock.Mock()
|
||||||
process.poll.return_value = 17
|
process.poll.return_value = 17
|
||||||
|
|||||||
Reference in New Issue
Block a user