231 lines
8.9 KiB
Python
231 lines
8.9 KiB
Python
import json
|
|
import os
|
|
from dataclasses import asdict
|
|
from pathlib import Path
|
|
|
|
import psutil
|
|
from tqdm.auto import tqdm
|
|
|
|
from .helpers import safe_int
|
|
from .models import MessageEvent, RequestMeta, ToolSpec, TraceRecord, UsageStats
|
|
|
|
|
|
class FormattedAliTraceAdapter:
|
|
name = "formatted"
|
|
|
|
def detect(self, raw):
|
|
if not isinstance(raw.get("meta"), dict):
|
|
return False
|
|
required_keys = {"canonical_prompt", "usage", "message_events", "declared_tools", "role_sequence"}
|
|
if not required_keys.issubset(raw.keys()):
|
|
return False
|
|
schema_version = str(raw.get("schema_version", "")).strip()
|
|
return bool(schema_version) or "request_id" in raw["meta"]
|
|
|
|
def parse_line(self, raw, line_number=0):
|
|
meta_payload = raw.get("meta", {}) if isinstance(raw.get("meta", {}), dict) else {}
|
|
usage_payload = raw.get("usage", {}) if isinstance(raw.get("usage", {}), dict) else {}
|
|
message_events_payload = raw.get("message_events", [])
|
|
declared_tools_payload = raw.get("declared_tools", [])
|
|
|
|
usage = UsageStats(
|
|
input_tokens=safe_int(usage_payload.get("input_tokens")),
|
|
output_tokens=safe_int(usage_payload.get("output_tokens")),
|
|
total_tokens=safe_int(usage_payload.get("total_tokens")),
|
|
reasoning_tokens=safe_int(usage_payload.get("reasoning_tokens")),
|
|
cached_tokens=safe_int(usage_payload.get("cached_tokens")),
|
|
)
|
|
|
|
messages = [
|
|
MessageEvent(
|
|
role=str(message.get("role", "unknown")),
|
|
content_type=str(message.get("content_type", "unknown")),
|
|
text_len=safe_int(message.get("text_len")),
|
|
has_cache_control=bool(message.get("has_cache_control")),
|
|
item_count=safe_int(message.get("item_count")),
|
|
)
|
|
for message in message_events_payload
|
|
if isinstance(message, dict)
|
|
]
|
|
declared_tools = [
|
|
ToolSpec(
|
|
name=str(tool.get("name", "")),
|
|
tool_type=str(tool.get("tool_type", "function")),
|
|
)
|
|
for tool in declared_tools_payload
|
|
if isinstance(tool, dict)
|
|
]
|
|
|
|
inferred_family = str(meta_payload.get("model_family", "")).strip()
|
|
inferred_provider = str(meta_payload.get("provider", "")).strip()
|
|
if not inferred_provider:
|
|
inferred_provider = inferred_family or self.name
|
|
|
|
meta = RequestMeta(
|
|
provider=inferred_provider,
|
|
line_number=line_number,
|
|
request_id=str(meta_payload.get("request_id", "")),
|
|
session_id=str(meta_payload.get("session_id", "")),
|
|
request_model=str(meta_payload.get("request_model", "")),
|
|
time=str(meta_payload.get("time", "")),
|
|
status_code=str(meta_payload.get("status_code", "")),
|
|
status_name=str(meta_payload.get("status_name", "")),
|
|
request_ready_time_ms=safe_int(meta_payload.get("request_ready_time_ms")),
|
|
request_end_time_ms=safe_int(meta_payload.get("request_end_time_ms")),
|
|
total_cost_time_ms=safe_int(meta_payload.get("total_cost_time_ms")),
|
|
backend_first_request_time_ms=safe_int(meta_payload.get("backend_first_request_time_ms")),
|
|
backend_first_response_time_ms=safe_int(meta_payload.get("backend_first_response_time_ms")),
|
|
)
|
|
return TraceRecord(
|
|
meta=meta,
|
|
canonical_prompt=str(raw.get("canonical_prompt", "")),
|
|
messages=messages,
|
|
role_sequence=[
|
|
str(role)
|
|
for role in raw.get("role_sequence", [message.role for message in messages])
|
|
],
|
|
declared_tools=declared_tools,
|
|
usage=usage,
|
|
raw_messages=[
|
|
message
|
|
for message in raw.get("raw_messages", [])
|
|
if isinstance(message, dict)
|
|
],
|
|
)
|
|
|
|
|
|
def _looks_like_release_trace(raw):
|
|
expected_keys = {"chat_id", "parent_chat_id", "timestamp", "input_length", "output_length", "turn", "hash_ids"}
|
|
return expected_keys.issubset(raw.keys())
|
|
|
|
|
|
def path_looks_like_release_trace(path):
|
|
path = Path(path)
|
|
if not path.exists():
|
|
return False
|
|
try:
|
|
with path.open("r", encoding="utf-8") as handle:
|
|
for line in handle:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
return _looks_like_release_trace(json.loads(line))
|
|
except Exception:
|
|
return False
|
|
return False
|
|
|
|
|
|
def get_adapter(raw):
|
|
adapter = FormattedAliTraceAdapter()
|
|
if adapter.detect(raw):
|
|
return adapter
|
|
if _looks_like_release_trace(raw):
|
|
raise ValueError("trace_analyzer currently analyzes formatter-generated *-raw.jsonl, not release hash-id traces.")
|
|
raise ValueError("trace_analyzer only accepts formatter-generated *-raw.jsonl inputs.")
|
|
|
|
|
|
def _estimate_peak_rss_mb(current_rss_mb, peak_rss_mb, fraction_done):
|
|
baseline = max(current_rss_mb, peak_rss_mb)
|
|
headroom = 1.0 + 0.25 * max(0.0, 1.0 - fraction_done)
|
|
return baseline * headroom
|
|
|
|
|
|
def load_records(path, limit=None, show_progress=False, progress_desc="Load trace"):
|
|
records = []
|
|
path = str(path)
|
|
progress = None
|
|
process = psutil.Process(os.getpid()) if show_progress else None
|
|
peak_rss_mb = 0.0
|
|
total_bytes = os.path.getsize(path) if show_progress else 0
|
|
if show_progress:
|
|
progress = tqdm(
|
|
total=total_bytes,
|
|
desc=progress_desc,
|
|
unit="B",
|
|
unit_scale=True,
|
|
dynamic_ncols=True,
|
|
)
|
|
with open(path, "r", encoding="utf-8") as handle:
|
|
for line_number, line in enumerate(handle, start=1):
|
|
if limit is not None and len(records) >= limit:
|
|
break
|
|
raw_line = line
|
|
line = line.strip()
|
|
if not line:
|
|
if progress is not None:
|
|
progress.update(len(raw_line.encode("utf-8")))
|
|
continue
|
|
raw = json.loads(line)
|
|
adapter = get_adapter(raw)
|
|
try:
|
|
record = adapter.parse_line(raw, line_number=line_number)
|
|
except Exception as exc:
|
|
if progress is not None:
|
|
progress.close()
|
|
raise ValueError(f"Failed to parse line {line_number} in {path}: {exc}") from exc
|
|
records.append(record)
|
|
if progress is not None:
|
|
progress.update(len(raw_line.encode("utf-8")))
|
|
current_rss_mb = process.memory_info().rss / (1024 * 1024)
|
|
peak_rss_mb = max(peak_rss_mb, current_rss_mb)
|
|
fraction_done = progress.n / progress.total if progress.total else 0.0
|
|
progress.set_postfix(
|
|
records=len(records),
|
|
rss_mb=f"{current_rss_mb:.0f}",
|
|
est_peak_mb=f"{_estimate_peak_rss_mb(current_rss_mb, peak_rss_mb, fraction_done):.0f}",
|
|
)
|
|
if progress is not None:
|
|
progress.close()
|
|
return records
|
|
|
|
|
|
def flatten_record(record):
|
|
return {
|
|
"provider": record.meta.provider,
|
|
"line_number": record.meta.line_number,
|
|
"request_id": record.meta.request_id,
|
|
"session_id": record.meta.session_id,
|
|
"request_model": record.meta.request_model,
|
|
"time": record.meta.time,
|
|
"status_code": record.meta.status_code,
|
|
"status_name": record.meta.status_name,
|
|
"request_ready_time_ms": record.meta.request_ready_time_ms,
|
|
"request_end_time_ms": record.meta.request_end_time_ms,
|
|
"total_cost_time_ms": record.meta.total_cost_time_ms,
|
|
"backend_first_request_time_ms": record.meta.backend_first_request_time_ms,
|
|
"backend_first_response_time_ms": record.meta.backend_first_response_time_ms,
|
|
"message_count": len(record.messages),
|
|
"role_sequence": ";".join(record.role_sequence),
|
|
"declared_tool_count": len(record.declared_tools),
|
|
"declared_tool_names": ";".join(tool.name for tool in record.declared_tools if tool.name),
|
|
"input_tokens": record.usage.input_tokens,
|
|
"output_tokens": record.usage.output_tokens,
|
|
"total_tokens": record.usage.total_tokens,
|
|
"reasoning_tokens": record.usage.reasoning_tokens,
|
|
"cached_tokens": record.usage.cached_tokens,
|
|
}
|
|
|
|
|
|
def record_to_dict(record):
|
|
return asdict(record)
|
|
|
|
|
|
def infer_analysis_dataset_name(input_path):
|
|
resolved = Path(input_path)
|
|
stem = resolved.stem
|
|
if stem.endswith("-raw"):
|
|
stem = stem[:-4]
|
|
|
|
parent_name = resolved.parent.name
|
|
model_slug = ""
|
|
if parent_name.startswith("trace-") and parent_name.endswith("-formatted"):
|
|
model_slug = parent_name[len("trace-") : -len("-formatted")]
|
|
|
|
if model_slug and not stem.startswith(f"{model_slug}-"):
|
|
return f"{model_slug}-{stem}"
|
|
return stem
|
|
|
|
|
|
def default_output_dir(input_path):
|
|
return Path("outputs") / "analysis" / infer_analysis_dataset_name(input_path)
|