diff --git a/src/aituner/http_client.py b/src/aituner/http_client.py index e840731..746bba4 100644 --- a/src/aituner/http_client.py +++ b/src/aituner/http_client.py @@ -179,6 +179,9 @@ def chat_completion( except urllib.error.HTTPError as exc: detail = exc.read().decode("utf-8", errors="replace") raise HttpClientError(f"llm_completion failed: {exc.code} {detail}") from exc + except OSError as exc: + # TimeoutError (socket.timeout), URLError, ConnectionError all subclass OSError. + raise HttpClientError(f"llm_completion failed: {exc}") from exc def stream_text_completion( @@ -232,6 +235,8 @@ def stream_text_completion( except urllib.error.HTTPError as exc: detail = exc.read().decode("utf-8", errors="replace") raise HttpClientError(f"stream_text_completion failed: {exc.code} {detail}") from exc + except OSError as exc: + raise HttpClientError(f"stream_text_completion failed: {exc}") from exc return "".join(parts) @@ -293,6 +298,10 @@ def stream_chat_completion( except urllib.error.HTTPError as exc: detail = exc.read().decode("utf-8", errors="replace") raise HttpClientError(f"stream_chat_completion failed: {exc.code} {detail}") from exc + except OSError as exc: + # A request that exceeds request_timeout_s raises TimeoutError mid-stream; + # treat it as a failed request (SLO miss), not a crashed trial. + raise HttpClientError(f"stream_chat_completion failed: {exc}") from exc ttft_ms = None if first_token_at is None else (first_token_at - start) * 1000.0 if completion_tokens is None and chunk_token_count > 0: completion_tokens = chunk_token_count diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index b1cbfe7..2772159 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -16,6 +16,7 @@ from aituner.cli import main as cli_main from aituner.compare import _aggregate_summary, load_compare_spec, run_compare from aituner.engine import build_launch_recipe from aituner.http_client import ( + HttpClientError, StreamMetrics, _auth_headers, _openai_url, @@ -560,6 +561,34 @@ class CoreFlowTests(unittest.TestCase): self.assertTrue(ev(0, 3900).passed) self.assertFalse(ev(0, 4100).passed) + def test_streaming_socket_timeout_is_a_failed_request_not_a_crash(self) -> None: + # A request that exceeds request_timeout_s raises TimeoutError mid-stream; + # it must surface as HttpClientError (a failed request), never escape to + # crash the trial. + with mock.patch( + "aituner.http_client._urlopen", side_effect=TimeoutError("timed out") + ): + with self.assertRaises(HttpClientError): + stream_chat_completion( + base_url="http://127.0.0.1:1/v1", + body={"messages": [{"role": "user", "content": "hi"}], "stream": True}, + timeout_s=0.5, + ) + outcome = _run_one_request( + TraceRequest( + row_id="r", + arrival_s=0.0, + sampling_u=1.0, + body={"messages": [{"role": "user", "content": "hi"}], "stream": True}, + prompt_tokens_hint=10, + completion_tokens_hint=None, + ), + base_url="http://127.0.0.1:1/v1", + timeout_s=0.5, + ) + self.assertFalse(outcome.success) + self.assertIn("timed out", outcome.error) + def test_lca_similarity_matrix_separates_different_profiles(self) -> None: window = WindowRecord( window_id="base",