Add streaming LLM proposal support

This commit is contained in:
2026-04-09 01:06:45 +08:00
parent 46151512cd
commit 96140b79bb
4 changed files with 90 additions and 1 deletions

View File

@@ -181,6 +181,60 @@ def chat_completion(
raise HttpClientError(f"llm_completion failed: {exc.code} {detail}") from exc raise HttpClientError(f"llm_completion failed: {exc.code} {detail}") from exc
def stream_text_completion(
*,
base_url: str,
api_key_env: str | None,
provider: str = "custom",
wire_api: str = "chat.completions",
model: str,
messages: list[dict[str, Any]],
timeout_s: float,
system_prompt: str = "",
reasoning_effort: str | None = None,
) -> str:
if wire_api != "chat.completions":
raise HttpClientError("stream_text_completion currently supports only chat.completions")
payload: dict[str, Any] = {
"model": model,
"messages": messages,
"stream": True,
}
if system_prompt:
payload["messages"] = [{"role": "system", "content": system_prompt}, *messages]
if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort
data = json.dumps(payload).encode("utf-8")
request = urllib.request.Request(
url=_openai_url(base_url, "/v1/chat/completions"),
headers=_auth_headers(api_key_env, provider),
data=data,
method="POST",
)
parts: list[str] = []
try:
with _urlopen(request, timeout=timeout_s) as response:
for raw in _iter_sse_lines(response):
if raw == "[DONE]":
break
payload = json.loads(raw)
if not isinstance(payload, dict):
continue
choices = payload.get("choices")
if not isinstance(choices, list) or not choices:
continue
delta = choices[0].get("delta", {})
if not isinstance(delta, dict):
continue
content = delta.get("content")
if isinstance(content, str):
parts.append(content)
except urllib.error.HTTPError as exc:
detail = exc.read().decode("utf-8", errors="replace")
raise HttpClientError(f"stream_text_completion failed: {exc.code} {detail}") from exc
return "".join(parts)
@dataclass(frozen=True) @dataclass(frozen=True)
class StreamMetrics: class StreamMetrics:
ttft_ms: float | None ttft_ms: float | None

View File

@@ -4,7 +4,7 @@ import json
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from .http_client import chat_completion from .http_client import chat_completion, stream_text_completion
from .spec import LLMPolicySpec, Proposal, SpecError, StudySpec, StudyState from .spec import LLMPolicySpec, Proposal, SpecError, StudySpec, StudyState
@@ -229,6 +229,18 @@ def call_llm_for_proposal(
) -> str: ) -> str:
if policy.endpoint is None: if policy.endpoint is None:
raise RuntimeError("study.llm.endpoint is not configured") raise RuntimeError("study.llm.endpoint is not configured")
if policy.endpoint.stream:
return stream_text_completion(
base_url=policy.endpoint.base_url,
api_key_env=policy.endpoint.api_key_env,
provider=policy.endpoint.provider,
wire_api=policy.endpoint.wire_api,
model=policy.endpoint.model,
messages=[{"role": "user", "content": prompt}],
timeout_s=policy.endpoint.timeout_s,
system_prompt=policy.system_prompt,
reasoning_effort=policy.endpoint.reasoning_effort,
)
response = chat_completion( response = chat_completion(
base_url=policy.endpoint.base_url, base_url=policy.endpoint.base_url,
api_key_env=policy.endpoint.api_key_env, api_key_env=policy.endpoint.api_key_env,

View File

@@ -36,6 +36,12 @@ def _require_int(value: Any, *, context: str) -> int:
return value return value
def _require_bool(value: Any, *, context: str) -> bool:
if not isinstance(value, bool):
raise SpecError(f"{context} must be a boolean.")
return value
def _coerce_str_map(value: Any, *, context: str) -> dict[str, str]: def _coerce_str_map(value: Any, *, context: str) -> dict[str, str]:
mapping = _require_mapping(value or {}, context=context) mapping = _require_mapping(value or {}, context=context)
return {str(key): str(item) for key, item in mapping.items()} return {str(key): str(item) for key, item in mapping.items()}
@@ -393,6 +399,7 @@ class LLMEndpointSpec:
model: str model: str
provider: str = "custom" provider: str = "custom"
wire_api: str = "chat.completions" wire_api: str = "chat.completions"
stream: bool = False
reasoning_effort: str | None = None reasoning_effort: str | None = None
api_key_env: str = "OPENAI_API_KEY" api_key_env: str = "OPENAI_API_KEY"
timeout_s: float = 120.0 timeout_s: float = 120.0
@@ -402,6 +409,7 @@ class LLMEndpointSpec:
provider = str(data.get("provider") or "custom").strip().lower() provider = str(data.get("provider") or "custom").strip().lower()
base_url = str(data.get("base_url") or "").strip() base_url = str(data.get("base_url") or "").strip()
wire_api = str(data.get("wire_api") or "").strip() wire_api = str(data.get("wire_api") or "").strip()
stream = data.get("stream")
reasoning_effort = str(data.get("reasoning_effort") or "").strip() reasoning_effort = str(data.get("reasoning_effort") or "").strip()
api_key_env = str(data.get("api_key_env") or "").strip() api_key_env = str(data.get("api_key_env") or "").strip()
if provider == "codex": if provider == "codex":
@@ -438,6 +446,7 @@ class LLMEndpointSpec:
model=_require_str(data.get("model"), context="llm.endpoint.model"), model=_require_str(data.get("model"), context="llm.endpoint.model"),
provider=provider, provider=provider,
wire_api=_require_str(wire_api, context="llm.endpoint.wire_api"), wire_api=_require_str(wire_api, context="llm.endpoint.wire_api"),
stream=(_require_bool(stream, context="llm.endpoint.stream") if stream is not None else False),
reasoning_effort=reasoning_effort or None, reasoning_effort=reasoning_effort or None,
api_key_env=_require_str(api_key_env, context="llm.endpoint.api_key_env"), api_key_env=_require_str(api_key_env, context="llm.endpoint.api_key_env"),
timeout_s=_require_float( timeout_s=_require_float(

View File

@@ -253,9 +253,23 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(endpoint.provider, "codex") self.assertEqual(endpoint.provider, "codex")
self.assertEqual(endpoint.base_url, "http://codex.example/v1") self.assertEqual(endpoint.base_url, "http://codex.example/v1")
self.assertEqual(endpoint.wire_api, "responses") self.assertEqual(endpoint.wire_api, "responses")
self.assertFalse(endpoint.stream)
self.assertEqual(endpoint.reasoning_effort, "high") self.assertEqual(endpoint.reasoning_effort, "high")
self.assertEqual(endpoint.api_key_env, "OPENAI_API_KEY") self.assertEqual(endpoint.api_key_env, "OPENAI_API_KEY")
def test_endpoint_stream_flag(self) -> None:
endpoint = LLMEndpointSpec.from_dict(
{
"provider": "custom",
"base_url": "http://example/v1",
"wire_api": "chat.completions",
"stream": True,
"model": "x",
"api_key_env": "OPENAI_API_KEY",
}
)
self.assertTrue(endpoint.stream)
def test_extract_response_text_supports_responses_api_output(self) -> None: def test_extract_response_text_supports_responses_api_output(self) -> None:
text = _extract_response_text( text = _extract_response_text(
{ {