Add streaming LLM proposal support
This commit is contained in:
@@ -181,6 +181,60 @@ def chat_completion(
|
||||
raise HttpClientError(f"llm_completion failed: {exc.code} {detail}") from exc
|
||||
|
||||
|
||||
def stream_text_completion(
|
||||
*,
|
||||
base_url: str,
|
||||
api_key_env: str | None,
|
||||
provider: str = "custom",
|
||||
wire_api: str = "chat.completions",
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
timeout_s: float,
|
||||
system_prompt: str = "",
|
||||
reasoning_effort: str | None = None,
|
||||
) -> str:
|
||||
if wire_api != "chat.completions":
|
||||
raise HttpClientError("stream_text_completion currently supports only chat.completions")
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
}
|
||||
if system_prompt:
|
||||
payload["messages"] = [{"role": "system", "content": system_prompt}, *messages]
|
||||
if reasoning_effort:
|
||||
payload["reasoning_effort"] = reasoning_effort
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
request = urllib.request.Request(
|
||||
url=_openai_url(base_url, "/v1/chat/completions"),
|
||||
headers=_auth_headers(api_key_env, provider),
|
||||
data=data,
|
||||
method="POST",
|
||||
)
|
||||
parts: list[str] = []
|
||||
try:
|
||||
with _urlopen(request, timeout=timeout_s) as response:
|
||||
for raw in _iter_sse_lines(response):
|
||||
if raw == "[DONE]":
|
||||
break
|
||||
payload = json.loads(raw)
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
choices = payload.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
continue
|
||||
delta = choices[0].get("delta", {})
|
||||
if not isinstance(delta, dict):
|
||||
continue
|
||||
content = delta.get("content")
|
||||
if isinstance(content, str):
|
||||
parts.append(content)
|
||||
except urllib.error.HTTPError as exc:
|
||||
detail = exc.read().decode("utf-8", errors="replace")
|
||||
raise HttpClientError(f"stream_text_completion failed: {exc.code} {detail}") from exc
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StreamMetrics:
|
||||
ttft_ms: float | None
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .http_client import chat_completion
|
||||
from .http_client import chat_completion, stream_text_completion
|
||||
from .spec import LLMPolicySpec, Proposal, SpecError, StudySpec, StudyState
|
||||
|
||||
|
||||
@@ -229,6 +229,18 @@ def call_llm_for_proposal(
|
||||
) -> str:
|
||||
if policy.endpoint is None:
|
||||
raise RuntimeError("study.llm.endpoint is not configured")
|
||||
if policy.endpoint.stream:
|
||||
return stream_text_completion(
|
||||
base_url=policy.endpoint.base_url,
|
||||
api_key_env=policy.endpoint.api_key_env,
|
||||
provider=policy.endpoint.provider,
|
||||
wire_api=policy.endpoint.wire_api,
|
||||
model=policy.endpoint.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
timeout_s=policy.endpoint.timeout_s,
|
||||
system_prompt=policy.system_prompt,
|
||||
reasoning_effort=policy.endpoint.reasoning_effort,
|
||||
)
|
||||
response = chat_completion(
|
||||
base_url=policy.endpoint.base_url,
|
||||
api_key_env=policy.endpoint.api_key_env,
|
||||
|
||||
@@ -36,6 +36,12 @@ def _require_int(value: Any, *, context: str) -> int:
|
||||
return value
|
||||
|
||||
|
||||
def _require_bool(value: Any, *, context: str) -> bool:
|
||||
if not isinstance(value, bool):
|
||||
raise SpecError(f"{context} must be a boolean.")
|
||||
return value
|
||||
|
||||
|
||||
def _coerce_str_map(value: Any, *, context: str) -> dict[str, str]:
|
||||
mapping = _require_mapping(value or {}, context=context)
|
||||
return {str(key): str(item) for key, item in mapping.items()}
|
||||
@@ -393,6 +399,7 @@ class LLMEndpointSpec:
|
||||
model: str
|
||||
provider: str = "custom"
|
||||
wire_api: str = "chat.completions"
|
||||
stream: bool = False
|
||||
reasoning_effort: str | None = None
|
||||
api_key_env: str = "OPENAI_API_KEY"
|
||||
timeout_s: float = 120.0
|
||||
@@ -402,6 +409,7 @@ class LLMEndpointSpec:
|
||||
provider = str(data.get("provider") or "custom").strip().lower()
|
||||
base_url = str(data.get("base_url") or "").strip()
|
||||
wire_api = str(data.get("wire_api") or "").strip()
|
||||
stream = data.get("stream")
|
||||
reasoning_effort = str(data.get("reasoning_effort") or "").strip()
|
||||
api_key_env = str(data.get("api_key_env") or "").strip()
|
||||
if provider == "codex":
|
||||
@@ -438,6 +446,7 @@ class LLMEndpointSpec:
|
||||
model=_require_str(data.get("model"), context="llm.endpoint.model"),
|
||||
provider=provider,
|
||||
wire_api=_require_str(wire_api, context="llm.endpoint.wire_api"),
|
||||
stream=(_require_bool(stream, context="llm.endpoint.stream") if stream is not None else False),
|
||||
reasoning_effort=reasoning_effort or None,
|
||||
api_key_env=_require_str(api_key_env, context="llm.endpoint.api_key_env"),
|
||||
timeout_s=_require_float(
|
||||
|
||||
@@ -253,9 +253,23 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertEqual(endpoint.provider, "codex")
|
||||
self.assertEqual(endpoint.base_url, "http://codex.example/v1")
|
||||
self.assertEqual(endpoint.wire_api, "responses")
|
||||
self.assertFalse(endpoint.stream)
|
||||
self.assertEqual(endpoint.reasoning_effort, "high")
|
||||
self.assertEqual(endpoint.api_key_env, "OPENAI_API_KEY")
|
||||
|
||||
def test_endpoint_stream_flag(self) -> None:
|
||||
endpoint = LLMEndpointSpec.from_dict(
|
||||
{
|
||||
"provider": "custom",
|
||||
"base_url": "http://example/v1",
|
||||
"wire_api": "chat.completions",
|
||||
"stream": True,
|
||||
"model": "x",
|
||||
"api_key_env": "OPENAI_API_KEY",
|
||||
}
|
||||
)
|
||||
self.assertTrue(endpoint.stream)
|
||||
|
||||
def test_extract_response_text_supports_responses_api_output(self) -> None:
|
||||
text = _extract_response_text(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user