258 lines
9.9 KiB
Python
258 lines
9.9 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import math
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Mapping
|
|
|
|
from .spec import StudySpec
|
|
|
|
|
|
class TraceError(ValueError):
|
|
"""Raised when trace assets are invalid."""
|
|
|
|
|
|
def _percentile(values: list[float], p: float) -> float:
|
|
if not values:
|
|
return 0.0
|
|
ordered = sorted(values)
|
|
idx = min(len(ordered) - 1, max(0, math.ceil((p / 100.0) * len(ordered)) - 1))
|
|
return ordered[idx]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class WindowRecord:
|
|
window_id: str
|
|
trace_path: Path
|
|
trace_type: str
|
|
window_start: float
|
|
window_end: float
|
|
source_payload: dict[str, Any]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TraceRequest:
|
|
row_id: str
|
|
arrival_s: float
|
|
sampling_u: float
|
|
body: dict[str, Any]
|
|
prompt_tokens_hint: int | None
|
|
completion_tokens_hint: int | None
|
|
|
|
|
|
def resolve_window_record(study: StudySpec, *, study_spec_path: Path) -> WindowRecord:
|
|
windows_path = Path(study.trace.windows_path)
|
|
if not windows_path.is_absolute():
|
|
windows_path = (study_spec_path.parent / windows_path).resolve()
|
|
payload = json.loads(windows_path.read_text(encoding="utf-8"))
|
|
windows = payload["windows"] if isinstance(payload, Mapping) and "windows" in payload else payload
|
|
if not isinstance(windows, list):
|
|
raise TraceError(f"windows payload must contain a list: {windows_path}")
|
|
for item in windows:
|
|
if not isinstance(item, Mapping):
|
|
continue
|
|
if str(item.get("window_id") or "").strip() != study.trace.window_id:
|
|
continue
|
|
trace_file = study.trace.trace_file_override or str(item.get("trace_file") or "").strip()
|
|
if not trace_file:
|
|
raise TraceError(f"window {study.trace.window_id} does not define trace_file")
|
|
trace_path = Path(trace_file)
|
|
if not trace_path.is_absolute():
|
|
candidate = (windows_path.parent / trace_path).resolve()
|
|
if candidate.exists():
|
|
trace_path = candidate
|
|
else:
|
|
parts = trace_path.parts
|
|
if parts and parts[0] == "trace_windows":
|
|
trace_path = (windows_path.parent / Path(*parts[1:])).resolve()
|
|
else:
|
|
trace_path = candidate
|
|
return WindowRecord(
|
|
window_id=study.trace.window_id,
|
|
trace_path=trace_path,
|
|
trace_type=str(item.get("trace_type") or "chat").strip(),
|
|
window_start=float(item.get("window_start") or 0.0),
|
|
window_end=float(item.get("window_end") or 0.0),
|
|
source_payload={str(key): value for key, value in item.items()},
|
|
)
|
|
raise TraceError(f"window_id not found: {study.trace.window_id}")
|
|
|
|
|
|
def _coerce_messages(row: Mapping[str, Any]) -> list[dict[str, Any]]:
|
|
messages = row.get("messages")
|
|
if isinstance(messages, list) and messages:
|
|
return [dict(item) for item in messages if isinstance(item, Mapping)]
|
|
prompt = row.get("prompt") or row.get("input") or row.get("text")
|
|
if isinstance(prompt, str) and prompt.strip():
|
|
return [{"role": "user", "content": prompt}]
|
|
raise TraceError("trace row is missing chat messages/prompt text")
|
|
|
|
|
|
def _synthetic_prompt_from_tokens(token_count: int) -> str:
|
|
if token_count <= 0:
|
|
return "hello"
|
|
# Keep it ASCII and structurally simple so the same trace can be replayed
|
|
# on any OpenAI-compatible engine without extra tokenizer assets.
|
|
return " ".join(["token"] * token_count)
|
|
|
|
|
|
def _coerce_completion_tokens(row: Mapping[str, Any]) -> int | None:
|
|
for key in ("max_completion_tokens", "max_tokens", "output_length", "completion_tokens"):
|
|
value = row.get(key)
|
|
if isinstance(value, bool):
|
|
continue
|
|
if isinstance(value, int) and value >= 0:
|
|
return value
|
|
if isinstance(value, float) and value >= 0:
|
|
return int(value)
|
|
return None
|
|
|
|
|
|
def _coerce_prompt_tokens(row: Mapping[str, Any]) -> int | None:
|
|
for key in ("input_length", "prompt_length", "prompt_len", "input_tokens"):
|
|
value = row.get(key)
|
|
if isinstance(value, bool):
|
|
continue
|
|
if isinstance(value, int) and value >= 0:
|
|
return value
|
|
if isinstance(value, float) and value >= 0:
|
|
return int(value)
|
|
return None
|
|
|
|
|
|
def _downsample_requests(
|
|
requests: list[TraceRequest], *, limit: int
|
|
) -> list[TraceRequest]:
|
|
if limit <= 0:
|
|
return []
|
|
if len(requests) <= limit:
|
|
return requests
|
|
indexes = sorted({(i * len(requests)) // limit for i in range(limit)})
|
|
return [requests[idx] for idx in indexes]
|
|
|
|
|
|
def _matches_input_length_filter(study: StudySpec, *, prompt_tokens_hint: int | None) -> bool:
|
|
length_filter = study.trace.input_length_filter
|
|
if length_filter is None:
|
|
return True
|
|
if prompt_tokens_hint is None:
|
|
return False
|
|
if (
|
|
length_filter.min_input_tokens is not None
|
|
and prompt_tokens_hint < length_filter.min_input_tokens
|
|
):
|
|
return False
|
|
if (
|
|
length_filter.max_input_tokens is not None
|
|
and prompt_tokens_hint > length_filter.max_input_tokens
|
|
):
|
|
return False
|
|
return True
|
|
|
|
|
|
def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[WindowRecord, list[TraceRequest]]:
|
|
window = resolve_window_record(study, study_spec_path=study_spec_path)
|
|
time_scale = float(study.trace.replay_time_scale)
|
|
if time_scale <= 0:
|
|
raise TraceError("trace.replay_time_scale must be > 0")
|
|
if time_scale != 1.0:
|
|
window = WindowRecord(
|
|
window_id=window.window_id,
|
|
trace_path=window.trace_path,
|
|
trace_type=window.trace_type,
|
|
window_start=window.window_start * time_scale,
|
|
window_end=window.window_end * time_scale,
|
|
source_payload=dict(window.source_payload),
|
|
)
|
|
requests: list[TraceRequest] = []
|
|
with window.trace_path.open("r", encoding="utf-8") as handle:
|
|
for idx, raw in enumerate(handle):
|
|
if not raw.strip():
|
|
continue
|
|
row = json.loads(raw)
|
|
if not isinstance(row, Mapping):
|
|
continue
|
|
timestamp = row.get(study.trace.timestamp_field)
|
|
if timestamp is None:
|
|
timestamp = row.get("arrival_time", row.get("timestamp"))
|
|
if isinstance(timestamp, bool) or not isinstance(timestamp, (int, float)):
|
|
raise TraceError(f"trace row {idx} is missing numeric timestamp")
|
|
sampling_u = row.get(study.trace.u_field, 1.0)
|
|
if isinstance(sampling_u, bool) or not isinstance(sampling_u, (int, float)):
|
|
raise TraceError(f"trace row {idx} is missing numeric {study.trace.u_field}")
|
|
prompt_tokens_hint = _coerce_prompt_tokens(row)
|
|
if not _matches_input_length_filter(study, prompt_tokens_hint=prompt_tokens_hint):
|
|
continue
|
|
try:
|
|
messages = _coerce_messages(row)
|
|
except TraceError:
|
|
capped_prompt_tokens = prompt_tokens_hint or 0
|
|
if study.trace.synthetic_prompt_cap_tokens is not None:
|
|
capped_prompt_tokens = min(
|
|
capped_prompt_tokens, study.trace.synthetic_prompt_cap_tokens
|
|
)
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": _synthetic_prompt_from_tokens(capped_prompt_tokens),
|
|
}
|
|
]
|
|
body: dict[str, Any] = {
|
|
"model": study.model.served_model_name,
|
|
"messages": messages,
|
|
"stream": True,
|
|
"stream_options": {"include_usage": True},
|
|
}
|
|
completion_tokens = _coerce_completion_tokens(row)
|
|
if completion_tokens is not None:
|
|
body["min_tokens"] = completion_tokens
|
|
body["max_tokens"] = completion_tokens
|
|
temperature = row.get("temperature")
|
|
if isinstance(temperature, (int, float)) and not isinstance(temperature, bool):
|
|
body["temperature"] = temperature
|
|
requests.append(
|
|
TraceRequest(
|
|
row_id=str(row.get("request_id") or row.get("id") or idx),
|
|
arrival_s=float(timestamp) * time_scale,
|
|
sampling_u=float(sampling_u),
|
|
body=body,
|
|
prompt_tokens_hint=prompt_tokens_hint,
|
|
completion_tokens_hint=completion_tokens,
|
|
)
|
|
)
|
|
requests.sort(key=lambda item: item.arrival_s)
|
|
if study.trace.max_requests_per_probe is not None:
|
|
requests = _downsample_requests(
|
|
requests,
|
|
limit=study.trace.max_requests_per_probe,
|
|
)
|
|
return window, requests
|
|
|
|
|
|
def summarize_window(requests: list[TraceRequest], window: WindowRecord) -> dict[str, Any]:
|
|
prompt_tokens = [float(item.prompt_tokens_hint or 0) for item in requests]
|
|
completion_tokens = [float(item.completion_tokens_hint or 0) for item in requests]
|
|
duration = max(window.window_end - window.window_start, 0.0) or (
|
|
requests[-1].arrival_s - requests[0].arrival_s if len(requests) >= 2 else 0.0
|
|
)
|
|
qps = (len(requests) / duration) if duration > 0 else 0.0
|
|
return {
|
|
"window_id": window.window_id,
|
|
"trace_path": str(window.trace_path),
|
|
"trace_type": window.trace_type,
|
|
"request_count": len(requests),
|
|
"duration_s": duration,
|
|
"request_rate": qps,
|
|
"prompt_tokens_p50": _percentile(prompt_tokens, 50.0),
|
|
"prompt_tokens_p95": _percentile(prompt_tokens, 95.0),
|
|
"completion_tokens_p50": _percentile(completion_tokens, 50.0),
|
|
"completion_tokens_p95": _percentile(completion_tokens, 95.0),
|
|
}
|
|
|
|
|
|
def select_requests_for_threshold(
|
|
requests: list[TraceRequest], *, threshold: float
|
|
) -> list[TraceRequest]:
|
|
return [item for item in requests if item.sampling_u <= threshold]
|