Files
aituner/src/aituner/trace.py

258 lines
9.9 KiB
Python

from __future__ import annotations
import json
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Mapping
from .spec import StudySpec
class TraceError(ValueError):
"""Raised when trace assets are invalid."""
def _percentile(values: list[float], p: float) -> float:
if not values:
return 0.0
ordered = sorted(values)
idx = min(len(ordered) - 1, max(0, math.ceil((p / 100.0) * len(ordered)) - 1))
return ordered[idx]
@dataclass(frozen=True)
class WindowRecord:
window_id: str
trace_path: Path
trace_type: str
window_start: float
window_end: float
source_payload: dict[str, Any]
@dataclass(frozen=True)
class TraceRequest:
row_id: str
arrival_s: float
sampling_u: float
body: dict[str, Any]
prompt_tokens_hint: int | None
completion_tokens_hint: int | None
def resolve_window_record(study: StudySpec, *, study_spec_path: Path) -> WindowRecord:
windows_path = Path(study.trace.windows_path)
if not windows_path.is_absolute():
windows_path = (study_spec_path.parent / windows_path).resolve()
payload = json.loads(windows_path.read_text(encoding="utf-8"))
windows = payload["windows"] if isinstance(payload, Mapping) and "windows" in payload else payload
if not isinstance(windows, list):
raise TraceError(f"windows payload must contain a list: {windows_path}")
for item in windows:
if not isinstance(item, Mapping):
continue
if str(item.get("window_id") or "").strip() != study.trace.window_id:
continue
trace_file = study.trace.trace_file_override or str(item.get("trace_file") or "").strip()
if not trace_file:
raise TraceError(f"window {study.trace.window_id} does not define trace_file")
trace_path = Path(trace_file)
if not trace_path.is_absolute():
candidate = (windows_path.parent / trace_path).resolve()
if candidate.exists():
trace_path = candidate
else:
parts = trace_path.parts
if parts and parts[0] == "trace_windows":
trace_path = (windows_path.parent / Path(*parts[1:])).resolve()
else:
trace_path = candidate
return WindowRecord(
window_id=study.trace.window_id,
trace_path=trace_path,
trace_type=str(item.get("trace_type") or "chat").strip(),
window_start=float(item.get("window_start") or 0.0),
window_end=float(item.get("window_end") or 0.0),
source_payload={str(key): value for key, value in item.items()},
)
raise TraceError(f"window_id not found: {study.trace.window_id}")
def _coerce_messages(row: Mapping[str, Any]) -> list[dict[str, Any]]:
messages = row.get("messages")
if isinstance(messages, list) and messages:
return [dict(item) for item in messages if isinstance(item, Mapping)]
prompt = row.get("prompt") or row.get("input") or row.get("text")
if isinstance(prompt, str) and prompt.strip():
return [{"role": "user", "content": prompt}]
raise TraceError("trace row is missing chat messages/prompt text")
def _synthetic_prompt_from_tokens(token_count: int) -> str:
if token_count <= 0:
return "hello"
# Keep it ASCII and structurally simple so the same trace can be replayed
# on any OpenAI-compatible engine without extra tokenizer assets.
return " ".join(["token"] * token_count)
def _coerce_completion_tokens(row: Mapping[str, Any]) -> int | None:
for key in ("max_completion_tokens", "max_tokens", "output_length", "completion_tokens"):
value = row.get(key)
if isinstance(value, bool):
continue
if isinstance(value, int) and value >= 0:
return value
if isinstance(value, float) and value >= 0:
return int(value)
return None
def _coerce_prompt_tokens(row: Mapping[str, Any]) -> int | None:
for key in ("input_length", "prompt_length", "prompt_len", "input_tokens"):
value = row.get(key)
if isinstance(value, bool):
continue
if isinstance(value, int) and value >= 0:
return value
if isinstance(value, float) and value >= 0:
return int(value)
return None
def _downsample_requests(
requests: list[TraceRequest], *, limit: int
) -> list[TraceRequest]:
if limit <= 0:
return []
if len(requests) <= limit:
return requests
indexes = sorted({(i * len(requests)) // limit for i in range(limit)})
return [requests[idx] for idx in indexes]
def _matches_input_length_filter(study: StudySpec, *, prompt_tokens_hint: int | None) -> bool:
length_filter = study.trace.input_length_filter
if length_filter is None:
return True
if prompt_tokens_hint is None:
return False
if (
length_filter.min_input_tokens is not None
and prompt_tokens_hint < length_filter.min_input_tokens
):
return False
if (
length_filter.max_input_tokens is not None
and prompt_tokens_hint > length_filter.max_input_tokens
):
return False
return True
def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[WindowRecord, list[TraceRequest]]:
window = resolve_window_record(study, study_spec_path=study_spec_path)
time_scale = float(study.trace.replay_time_scale)
if time_scale <= 0:
raise TraceError("trace.replay_time_scale must be > 0")
if time_scale != 1.0:
window = WindowRecord(
window_id=window.window_id,
trace_path=window.trace_path,
trace_type=window.trace_type,
window_start=window.window_start * time_scale,
window_end=window.window_end * time_scale,
source_payload=dict(window.source_payload),
)
requests: list[TraceRequest] = []
with window.trace_path.open("r", encoding="utf-8") as handle:
for idx, raw in enumerate(handle):
if not raw.strip():
continue
row = json.loads(raw)
if not isinstance(row, Mapping):
continue
timestamp = row.get(study.trace.timestamp_field)
if timestamp is None:
timestamp = row.get("arrival_time", row.get("timestamp"))
if isinstance(timestamp, bool) or not isinstance(timestamp, (int, float)):
raise TraceError(f"trace row {idx} is missing numeric timestamp")
sampling_u = row.get(study.trace.u_field, 1.0)
if isinstance(sampling_u, bool) or not isinstance(sampling_u, (int, float)):
raise TraceError(f"trace row {idx} is missing numeric {study.trace.u_field}")
prompt_tokens_hint = _coerce_prompt_tokens(row)
if not _matches_input_length_filter(study, prompt_tokens_hint=prompt_tokens_hint):
continue
try:
messages = _coerce_messages(row)
except TraceError:
capped_prompt_tokens = prompt_tokens_hint or 0
if study.trace.synthetic_prompt_cap_tokens is not None:
capped_prompt_tokens = min(
capped_prompt_tokens, study.trace.synthetic_prompt_cap_tokens
)
messages = [
{
"role": "user",
"content": _synthetic_prompt_from_tokens(capped_prompt_tokens),
}
]
body: dict[str, Any] = {
"model": study.model.served_model_name,
"messages": messages,
"stream": True,
"stream_options": {"include_usage": True},
}
completion_tokens = _coerce_completion_tokens(row)
if completion_tokens is not None:
body["min_tokens"] = completion_tokens
body["max_tokens"] = completion_tokens
temperature = row.get("temperature")
if isinstance(temperature, (int, float)) and not isinstance(temperature, bool):
body["temperature"] = temperature
requests.append(
TraceRequest(
row_id=str(row.get("request_id") or row.get("id") or idx),
arrival_s=float(timestamp) * time_scale,
sampling_u=float(sampling_u),
body=body,
prompt_tokens_hint=prompt_tokens_hint,
completion_tokens_hint=completion_tokens,
)
)
requests.sort(key=lambda item: item.arrival_s)
if study.trace.max_requests_per_probe is not None:
requests = _downsample_requests(
requests,
limit=study.trace.max_requests_per_probe,
)
return window, requests
def summarize_window(requests: list[TraceRequest], window: WindowRecord) -> dict[str, Any]:
prompt_tokens = [float(item.prompt_tokens_hint or 0) for item in requests]
completion_tokens = [float(item.completion_tokens_hint or 0) for item in requests]
duration = max(window.window_end - window.window_start, 0.0) or (
requests[-1].arrival_s - requests[0].arrival_s if len(requests) >= 2 else 0.0
)
qps = (len(requests) / duration) if duration > 0 else 0.0
return {
"window_id": window.window_id,
"trace_path": str(window.trace_path),
"trace_type": window.trace_type,
"request_count": len(requests),
"duration_s": duration,
"request_rate": qps,
"prompt_tokens_p50": _percentile(prompt_tokens, 50.0),
"prompt_tokens_p95": _percentile(prompt_tokens, 95.0),
"completion_tokens_p50": _percentile(completion_tokens, 50.0),
"completion_tokens_p95": _percentile(completion_tokens, 95.0),
}
def select_requests_for_threshold(
requests: list[TraceRequest], *, threshold: float
) -> list[TraceRequest]:
return [item for item in requests if item.sampling_u <= threshold]