import json import os from dataclasses import asdict from pathlib import Path import psutil from tqdm.auto import tqdm from .helpers import safe_int from .models import MessageEvent, RequestMeta, ToolSpec, TraceRecord, UsageStats class FormattedAliTraceAdapter: name = "formatted" def detect(self, raw): if not isinstance(raw.get("meta"), dict): return False required_keys = {"canonical_prompt", "usage", "message_events", "declared_tools", "role_sequence"} if not required_keys.issubset(raw.keys()): return False schema_version = str(raw.get("schema_version", "")).strip() return bool(schema_version) or "request_id" in raw["meta"] def parse_line(self, raw, line_number=0): meta_payload = raw.get("meta", {}) if isinstance(raw.get("meta", {}), dict) else {} usage_payload = raw.get("usage", {}) if isinstance(raw.get("usage", {}), dict) else {} message_events_payload = raw.get("message_events", []) declared_tools_payload = raw.get("declared_tools", []) usage = UsageStats( input_tokens=safe_int(usage_payload.get("input_tokens")), output_tokens=safe_int(usage_payload.get("output_tokens")), total_tokens=safe_int(usage_payload.get("total_tokens")), reasoning_tokens=safe_int(usage_payload.get("reasoning_tokens")), cached_tokens=safe_int(usage_payload.get("cached_tokens")), ) messages = [ MessageEvent( role=str(message.get("role", "unknown")), content_type=str(message.get("content_type", "unknown")), text_len=safe_int(message.get("text_len")), has_cache_control=bool(message.get("has_cache_control")), item_count=safe_int(message.get("item_count")), ) for message in message_events_payload if isinstance(message, dict) ] declared_tools = [ ToolSpec( name=str(tool.get("name", "")), tool_type=str(tool.get("tool_type", "function")), ) for tool in declared_tools_payload if isinstance(tool, dict) ] inferred_family = str(meta_payload.get("model_family", "")).strip() inferred_provider = str(meta_payload.get("provider", "")).strip() if not inferred_provider: inferred_provider = inferred_family or self.name meta = RequestMeta( provider=inferred_provider, line_number=line_number, request_id=str(meta_payload.get("request_id", "")), session_id=str(meta_payload.get("session_id", "")), request_model=str(meta_payload.get("request_model", "")), time=str(meta_payload.get("time", "")), status_code=str(meta_payload.get("status_code", "")), status_name=str(meta_payload.get("status_name", "")), request_ready_time_ms=safe_int(meta_payload.get("request_ready_time_ms")), request_end_time_ms=safe_int(meta_payload.get("request_end_time_ms")), total_cost_time_ms=safe_int(meta_payload.get("total_cost_time_ms")), backend_first_request_time_ms=safe_int(meta_payload.get("backend_first_request_time_ms")), backend_first_response_time_ms=safe_int(meta_payload.get("backend_first_response_time_ms")), ) return TraceRecord( meta=meta, canonical_prompt=str(raw.get("canonical_prompt", "")), messages=messages, role_sequence=[ str(role) for role in raw.get("role_sequence", [message.role for message in messages]) ], declared_tools=declared_tools, usage=usage, raw_messages=[ message for message in raw.get("raw_messages", []) if isinstance(message, dict) ], ) def _looks_like_release_trace(raw): expected_keys = {"chat_id", "parent_chat_id", "timestamp", "input_length", "output_length", "turn", "hash_ids"} return expected_keys.issubset(raw.keys()) def path_looks_like_release_trace(path): path = Path(path) if not path.exists(): return False try: with path.open("r", encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue return _looks_like_release_trace(json.loads(line)) except Exception: return False return False def get_adapter(raw): adapter = FormattedAliTraceAdapter() if adapter.detect(raw): return adapter if _looks_like_release_trace(raw): raise ValueError("trace_analyzer currently analyzes formatter-generated *-raw.jsonl, not release hash-id traces.") raise ValueError("trace_analyzer only accepts formatter-generated *-raw.jsonl inputs.") def _estimate_peak_rss_mb(current_rss_mb, peak_rss_mb, fraction_done): baseline = max(current_rss_mb, peak_rss_mb) headroom = 1.0 + 0.25 * max(0.0, 1.0 - fraction_done) return baseline * headroom def load_records(path, limit=None, show_progress=False, progress_desc="Load trace"): records = [] path = str(path) progress = None process = psutil.Process(os.getpid()) if show_progress else None peak_rss_mb = 0.0 total_bytes = os.path.getsize(path) if show_progress else 0 if show_progress: progress = tqdm( total=total_bytes, desc=progress_desc, unit="B", unit_scale=True, dynamic_ncols=True, ) with open(path, "r", encoding="utf-8") as handle: for line_number, line in enumerate(handle, start=1): if limit is not None and len(records) >= limit: break raw_line = line line = line.strip() if not line: if progress is not None: progress.update(len(raw_line.encode("utf-8"))) continue raw = json.loads(line) adapter = get_adapter(raw) try: record = adapter.parse_line(raw, line_number=line_number) except Exception as exc: if progress is not None: progress.close() raise ValueError(f"Failed to parse line {line_number} in {path}: {exc}") from exc records.append(record) if progress is not None: progress.update(len(raw_line.encode("utf-8"))) current_rss_mb = process.memory_info().rss / (1024 * 1024) peak_rss_mb = max(peak_rss_mb, current_rss_mb) fraction_done = progress.n / progress.total if progress.total else 0.0 progress.set_postfix( records=len(records), rss_mb=f"{current_rss_mb:.0f}", est_peak_mb=f"{_estimate_peak_rss_mb(current_rss_mb, peak_rss_mb, fraction_done):.0f}", ) if progress is not None: progress.close() return records def flatten_record(record): return { "provider": record.meta.provider, "line_number": record.meta.line_number, "request_id": record.meta.request_id, "session_id": record.meta.session_id, "request_model": record.meta.request_model, "time": record.meta.time, "status_code": record.meta.status_code, "status_name": record.meta.status_name, "request_ready_time_ms": record.meta.request_ready_time_ms, "request_end_time_ms": record.meta.request_end_time_ms, "total_cost_time_ms": record.meta.total_cost_time_ms, "backend_first_request_time_ms": record.meta.backend_first_request_time_ms, "backend_first_response_time_ms": record.meta.backend_first_response_time_ms, "message_count": len(record.messages), "role_sequence": ";".join(record.role_sequence), "declared_tool_count": len(record.declared_tools), "declared_tool_names": ";".join(tool.name for tool in record.declared_tools if tool.name), "input_tokens": record.usage.input_tokens, "output_tokens": record.usage.output_tokens, "total_tokens": record.usage.total_tokens, "reasoning_tokens": record.usage.reasoning_tokens, "cached_tokens": record.usage.cached_tokens, } def record_to_dict(record): return asdict(record) def infer_analysis_dataset_name(input_path): resolved = Path(input_path) stem = resolved.stem if stem.endswith("-raw"): stem = stem[:-4] parent_name = resolved.parent.name model_slug = "" if parent_name.startswith("trace-") and parent_name.endswith("-formatted"): model_slug = parent_name[len("trace-") : -len("-formatted")] if model_slug and not stem.startswith(f"{model_slug}-"): return f"{model_slug}-{stem}" return stem def default_output_dir(input_path): return Path("outputs") / "analysis" / infer_analysis_dataset_name(input_path)