Files
ali-trace-tools/trace_analyzer/parser.py
2026-04-21 15:44:47 +00:00

231 lines
8.9 KiB
Python

import json
import os
from dataclasses import asdict
from pathlib import Path
import psutil
from tqdm.auto import tqdm
from .helpers import safe_int
from .models import MessageEvent, RequestMeta, ToolSpec, TraceRecord, UsageStats
class FormattedAliTraceAdapter:
name = "formatted"
def detect(self, raw):
if not isinstance(raw.get("meta"), dict):
return False
required_keys = {"canonical_prompt", "usage", "message_events", "declared_tools", "role_sequence"}
if not required_keys.issubset(raw.keys()):
return False
schema_version = str(raw.get("schema_version", "")).strip()
return bool(schema_version) or "request_id" in raw["meta"]
def parse_line(self, raw, line_number=0):
meta_payload = raw.get("meta", {}) if isinstance(raw.get("meta", {}), dict) else {}
usage_payload = raw.get("usage", {}) if isinstance(raw.get("usage", {}), dict) else {}
message_events_payload = raw.get("message_events", [])
declared_tools_payload = raw.get("declared_tools", [])
usage = UsageStats(
input_tokens=safe_int(usage_payload.get("input_tokens")),
output_tokens=safe_int(usage_payload.get("output_tokens")),
total_tokens=safe_int(usage_payload.get("total_tokens")),
reasoning_tokens=safe_int(usage_payload.get("reasoning_tokens")),
cached_tokens=safe_int(usage_payload.get("cached_tokens")),
)
messages = [
MessageEvent(
role=str(message.get("role", "unknown")),
content_type=str(message.get("content_type", "unknown")),
text_len=safe_int(message.get("text_len")),
has_cache_control=bool(message.get("has_cache_control")),
item_count=safe_int(message.get("item_count")),
)
for message in message_events_payload
if isinstance(message, dict)
]
declared_tools = [
ToolSpec(
name=str(tool.get("name", "")),
tool_type=str(tool.get("tool_type", "function")),
)
for tool in declared_tools_payload
if isinstance(tool, dict)
]
inferred_family = str(meta_payload.get("model_family", "")).strip()
inferred_provider = str(meta_payload.get("provider", "")).strip()
if not inferred_provider:
inferred_provider = inferred_family or self.name
meta = RequestMeta(
provider=inferred_provider,
line_number=line_number,
request_id=str(meta_payload.get("request_id", "")),
session_id=str(meta_payload.get("session_id", "")),
request_model=str(meta_payload.get("request_model", "")),
time=str(meta_payload.get("time", "")),
status_code=str(meta_payload.get("status_code", "")),
status_name=str(meta_payload.get("status_name", "")),
request_ready_time_ms=safe_int(meta_payload.get("request_ready_time_ms")),
request_end_time_ms=safe_int(meta_payload.get("request_end_time_ms")),
total_cost_time_ms=safe_int(meta_payload.get("total_cost_time_ms")),
backend_first_request_time_ms=safe_int(meta_payload.get("backend_first_request_time_ms")),
backend_first_response_time_ms=safe_int(meta_payload.get("backend_first_response_time_ms")),
)
return TraceRecord(
meta=meta,
canonical_prompt=str(raw.get("canonical_prompt", "")),
messages=messages,
role_sequence=[
str(role)
for role in raw.get("role_sequence", [message.role for message in messages])
],
declared_tools=declared_tools,
usage=usage,
raw_messages=[
message
for message in raw.get("raw_messages", [])
if isinstance(message, dict)
],
)
def _looks_like_release_trace(raw):
expected_keys = {"chat_id", "parent_chat_id", "timestamp", "input_length", "output_length", "turn", "hash_ids"}
return expected_keys.issubset(raw.keys())
def path_looks_like_release_trace(path):
path = Path(path)
if not path.exists():
return False
try:
with path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
return _looks_like_release_trace(json.loads(line))
except Exception:
return False
return False
def get_adapter(raw):
adapter = FormattedAliTraceAdapter()
if adapter.detect(raw):
return adapter
if _looks_like_release_trace(raw):
raise ValueError("trace_analyzer currently analyzes formatter-generated *-raw.jsonl, not release hash-id traces.")
raise ValueError("trace_analyzer only accepts formatter-generated *-raw.jsonl inputs.")
def _estimate_peak_rss_mb(current_rss_mb, peak_rss_mb, fraction_done):
baseline = max(current_rss_mb, peak_rss_mb)
headroom = 1.0 + 0.25 * max(0.0, 1.0 - fraction_done)
return baseline * headroom
def load_records(path, limit=None, show_progress=False, progress_desc="Load trace"):
records = []
path = str(path)
progress = None
process = psutil.Process(os.getpid()) if show_progress else None
peak_rss_mb = 0.0
total_bytes = os.path.getsize(path) if show_progress else 0
if show_progress:
progress = tqdm(
total=total_bytes,
desc=progress_desc,
unit="B",
unit_scale=True,
dynamic_ncols=True,
)
with open(path, "r", encoding="utf-8") as handle:
for line_number, line in enumerate(handle, start=1):
if limit is not None and len(records) >= limit:
break
raw_line = line
line = line.strip()
if not line:
if progress is not None:
progress.update(len(raw_line.encode("utf-8")))
continue
raw = json.loads(line)
adapter = get_adapter(raw)
try:
record = adapter.parse_line(raw, line_number=line_number)
except Exception as exc:
if progress is not None:
progress.close()
raise ValueError(f"Failed to parse line {line_number} in {path}: {exc}") from exc
records.append(record)
if progress is not None:
progress.update(len(raw_line.encode("utf-8")))
current_rss_mb = process.memory_info().rss / (1024 * 1024)
peak_rss_mb = max(peak_rss_mb, current_rss_mb)
fraction_done = progress.n / progress.total if progress.total else 0.0
progress.set_postfix(
records=len(records),
rss_mb=f"{current_rss_mb:.0f}",
est_peak_mb=f"{_estimate_peak_rss_mb(current_rss_mb, peak_rss_mb, fraction_done):.0f}",
)
if progress is not None:
progress.close()
return records
def flatten_record(record):
return {
"provider": record.meta.provider,
"line_number": record.meta.line_number,
"request_id": record.meta.request_id,
"session_id": record.meta.session_id,
"request_model": record.meta.request_model,
"time": record.meta.time,
"status_code": record.meta.status_code,
"status_name": record.meta.status_name,
"request_ready_time_ms": record.meta.request_ready_time_ms,
"request_end_time_ms": record.meta.request_end_time_ms,
"total_cost_time_ms": record.meta.total_cost_time_ms,
"backend_first_request_time_ms": record.meta.backend_first_request_time_ms,
"backend_first_response_time_ms": record.meta.backend_first_response_time_ms,
"message_count": len(record.messages),
"role_sequence": ";".join(record.role_sequence),
"declared_tool_count": len(record.declared_tools),
"declared_tool_names": ";".join(tool.name for tool in record.declared_tools if tool.name),
"input_tokens": record.usage.input_tokens,
"output_tokens": record.usage.output_tokens,
"total_tokens": record.usage.total_tokens,
"reasoning_tokens": record.usage.reasoning_tokens,
"cached_tokens": record.usage.cached_tokens,
}
def record_to_dict(record):
return asdict(record)
def infer_analysis_dataset_name(input_path):
resolved = Path(input_path)
stem = resolved.stem
if stem.endswith("-raw"):
stem = stem[:-4]
parent_name = resolved.parent.name
model_slug = ""
if parent_name.startswith("trace-") and parent_name.endswith("-formatted"):
model_slug = parent_name[len("trace-") : -len("-formatted")]
if model_slug and not stem.startswith(f"{model_slug}-"):
return f"{model_slug}-{stem}"
return stem
def default_output_dir(input_path):
return Path("outputs") / "analysis" / infer_analysis_dataset_name(input_path)