"""HTTP client for OpenAI-compatible /v1/chat/completions. Records per-request: TTFT (time to first content token), TPOT (mean inter-token latency over the decode phase), and end-to-end throughput. We don't care about parsing exact OpenAI envelope semantics, just enough to get the deltas + finish_reason + token counts. """ from __future__ import annotations import asyncio import json import time from dataclasses import dataclass, field from typing import Any import httpx @dataclass class StreamResult: text: str = "" completion_tokens: int = 0 prompt_tokens: int = 0 finish_reason: str | None = None # Timings (seconds; -1 means not measured) ttft_s: float = -1.0 e2e_s: float = -1.0 chunk_times: list[float] = field(default_factory=list) # absolute monotonic times of content chunks error: str | None = None @property def tpot_s(self) -> float: """Mean inter-content-chunk latency after the first chunk (seconds/token).""" if len(self.chunk_times) < 2: return -1.0 deltas = [self.chunk_times[i] - self.chunk_times[i - 1] for i in range(1, len(self.chunk_times))] return sum(deltas) / len(deltas) @property def throughput_tok_s(self) -> float: if self.e2e_s <= 0 or self.completion_tokens <= 0: return -1.0 return self.completion_tokens / self.e2e_s async def chat_stream( client: httpx.AsyncClient, base_url: str, model: str, messages: list[dict[str, str]], *, max_tokens: int, temperature: float = 0.0, api_key: str | None = None, timeout: float = 1800.0, ) -> StreamResult: payload: dict[str, Any] = { "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, "stream": True, } # llama-server returns usage in the final stream chunk when this is set; # xserv ignores unknown fields, so this is harmless there. payload["stream_options"] = {"include_usage": True} headers = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" url = base_url.rstrip("/") + "/v1/chat/completions" res = StreamResult() t_start = time.perf_counter() try: async with client.stream( "POST", url, json=payload, headers=headers, timeout=timeout, ) as resp: if resp.status_code != 200: body = await resp.aread() res.error = f"HTTP {resp.status_code}: {body.decode(errors='replace')[:400]}" res.e2e_s = time.perf_counter() - t_start return res async for line in resp.aiter_lines(): if not line or not line.startswith("data:"): continue data = line[len("data:"):].strip() if data == "[DONE]": break try: chunk = json.loads(data) except json.JSONDecodeError: continue if "usage" in chunk and chunk["usage"]: usage = chunk["usage"] res.prompt_tokens = usage.get("prompt_tokens", res.prompt_tokens) res.completion_tokens = usage.get("completion_tokens", res.completion_tokens) choices = chunk.get("choices") or [] if not choices: continue choice = choices[0] delta = choice.get("delta") or {} content = delta.get("content") if content: now = time.perf_counter() if res.ttft_s < 0: res.ttft_s = now - t_start res.chunk_times.append(now) res.text += content if choice.get("finish_reason"): res.finish_reason = choice["finish_reason"] except Exception as e: # noqa: BLE001 — surface any failure to the report res.error = f"{type(e).__name__}: {e}" res.e2e_s = time.perf_counter() - t_start # Fall back to chunk count when server doesn't report usage (xserv stream path). if res.completion_tokens == 0: res.completion_tokens = len(res.chunk_times) return res async def chat_concurrent( base_url: str, model: str, prompts: list[list[dict[str, str]]], *, max_tokens: int, temperature: float = 0.0, api_key: str | None = None, timeout: float = 1800.0, concurrency: int, ) -> tuple[list[StreamResult], float]: """Fire `concurrency` requests in parallel waves. Returns per-request results plus wall-clock elapsed time of the entire batch.""" sem = asyncio.Semaphore(concurrency) limits = httpx.Limits(max_connections=concurrency * 2, max_keepalive_connections=concurrency) async with httpx.AsyncClient(timeout=timeout, limits=limits) as client: async def one(messages: list[dict[str, str]]) -> StreamResult: async with sem: return await chat_stream( client, base_url, model, messages, max_tokens=max_tokens, temperature=temperature, api_key=api_key, timeout=timeout, ) t0 = time.perf_counter() results = await asyncio.gather(*(one(p) for p in prompts)) wall = time.perf_counter() - t0 return results, wall