From 96140b79bb9ae147b3a1e3eb5d1e9559e3cf8abb Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Thu, 9 Apr 2026 01:06:45 +0800 Subject: [PATCH] Add streaming LLM proposal support --- src/aituner/http_client.py | 54 ++++++++++++++++++++++++++++++++++++++ src/aituner/llm.py | 14 +++++++++- src/aituner/spec.py | 9 +++++++ tests/test_core_flow.py | 14 ++++++++++ 4 files changed, 90 insertions(+), 1 deletion(-) diff --git a/src/aituner/http_client.py b/src/aituner/http_client.py index 248877e..77c03db 100644 --- a/src/aituner/http_client.py +++ b/src/aituner/http_client.py @@ -181,6 +181,60 @@ def chat_completion( raise HttpClientError(f"llm_completion failed: {exc.code} {detail}") from exc +def stream_text_completion( + *, + base_url: str, + api_key_env: str | None, + provider: str = "custom", + wire_api: str = "chat.completions", + model: str, + messages: list[dict[str, Any]], + timeout_s: float, + system_prompt: str = "", + reasoning_effort: str | None = None, +) -> str: + if wire_api != "chat.completions": + raise HttpClientError("stream_text_completion currently supports only chat.completions") + payload: dict[str, Any] = { + "model": model, + "messages": messages, + "stream": True, + } + if system_prompt: + payload["messages"] = [{"role": "system", "content": system_prompt}, *messages] + if reasoning_effort: + payload["reasoning_effort"] = reasoning_effort + data = json.dumps(payload).encode("utf-8") + request = urllib.request.Request( + url=_openai_url(base_url, "/v1/chat/completions"), + headers=_auth_headers(api_key_env, provider), + data=data, + method="POST", + ) + parts: list[str] = [] + try: + with _urlopen(request, timeout=timeout_s) as response: + for raw in _iter_sse_lines(response): + if raw == "[DONE]": + break + payload = json.loads(raw) + if not isinstance(payload, dict): + continue + choices = payload.get("choices") + if not isinstance(choices, list) or not choices: + continue + delta = choices[0].get("delta", {}) + if not isinstance(delta, dict): + continue + content = delta.get("content") + if isinstance(content, str): + parts.append(content) + except urllib.error.HTTPError as exc: + detail = exc.read().decode("utf-8", errors="replace") + raise HttpClientError(f"stream_text_completion failed: {exc.code} {detail}") from exc + return "".join(parts) + + @dataclass(frozen=True) class StreamMetrics: ttft_ms: float | None diff --git a/src/aituner/llm.py b/src/aituner/llm.py index ff39434..2603690 100644 --- a/src/aituner/llm.py +++ b/src/aituner/llm.py @@ -4,7 +4,7 @@ import json from pathlib import Path from typing import Any -from .http_client import chat_completion +from .http_client import chat_completion, stream_text_completion from .spec import LLMPolicySpec, Proposal, SpecError, StudySpec, StudyState @@ -229,6 +229,18 @@ def call_llm_for_proposal( ) -> str: if policy.endpoint is None: raise RuntimeError("study.llm.endpoint is not configured") + if policy.endpoint.stream: + return stream_text_completion( + base_url=policy.endpoint.base_url, + api_key_env=policy.endpoint.api_key_env, + provider=policy.endpoint.provider, + wire_api=policy.endpoint.wire_api, + model=policy.endpoint.model, + messages=[{"role": "user", "content": prompt}], + timeout_s=policy.endpoint.timeout_s, + system_prompt=policy.system_prompt, + reasoning_effort=policy.endpoint.reasoning_effort, + ) response = chat_completion( base_url=policy.endpoint.base_url, api_key_env=policy.endpoint.api_key_env, diff --git a/src/aituner/spec.py b/src/aituner/spec.py index 511615a..7182e4f 100644 --- a/src/aituner/spec.py +++ b/src/aituner/spec.py @@ -36,6 +36,12 @@ def _require_int(value: Any, *, context: str) -> int: return value +def _require_bool(value: Any, *, context: str) -> bool: + if not isinstance(value, bool): + raise SpecError(f"{context} must be a boolean.") + return value + + def _coerce_str_map(value: Any, *, context: str) -> dict[str, str]: mapping = _require_mapping(value or {}, context=context) return {str(key): str(item) for key, item in mapping.items()} @@ -393,6 +399,7 @@ class LLMEndpointSpec: model: str provider: str = "custom" wire_api: str = "chat.completions" + stream: bool = False reasoning_effort: str | None = None api_key_env: str = "OPENAI_API_KEY" timeout_s: float = 120.0 @@ -402,6 +409,7 @@ class LLMEndpointSpec: provider = str(data.get("provider") or "custom").strip().lower() base_url = str(data.get("base_url") or "").strip() wire_api = str(data.get("wire_api") or "").strip() + stream = data.get("stream") reasoning_effort = str(data.get("reasoning_effort") or "").strip() api_key_env = str(data.get("api_key_env") or "").strip() if provider == "codex": @@ -438,6 +446,7 @@ class LLMEndpointSpec: model=_require_str(data.get("model"), context="llm.endpoint.model"), provider=provider, wire_api=_require_str(wire_api, context="llm.endpoint.wire_api"), + stream=(_require_bool(stream, context="llm.endpoint.stream") if stream is not None else False), reasoning_effort=reasoning_effort or None, api_key_env=_require_str(api_key_env, context="llm.endpoint.api_key_env"), timeout_s=_require_float( diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index 9083b94..28472e8 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -253,9 +253,23 @@ class CoreFlowTests(unittest.TestCase): self.assertEqual(endpoint.provider, "codex") self.assertEqual(endpoint.base_url, "http://codex.example/v1") self.assertEqual(endpoint.wire_api, "responses") + self.assertFalse(endpoint.stream) self.assertEqual(endpoint.reasoning_effort, "high") self.assertEqual(endpoint.api_key_env, "OPENAI_API_KEY") + def test_endpoint_stream_flag(self) -> None: + endpoint = LLMEndpointSpec.from_dict( + { + "provider": "custom", + "base_url": "http://example/v1", + "wire_api": "chat.completions", + "stream": True, + "model": "x", + "api_key_env": "OPENAI_API_KEY", + } + ) + self.assertTrue(endpoint.stream) + def test_extract_response_text_supports_responses_api_output(self) -> None: text = _extract_response_text( {