from __future__ import annotations import csv import json import os from pathlib import Path from trace_analyzer.helpers import percentile from trace_analyzer.parser import get_adapter from tqdm.auto import tqdm def stream_prepare(input_path: str | Path, output_dir: str | Path, *, show_progress: bool = False) -> dict: input_file = Path(input_path) output_root = Path(output_dir) output_root.mkdir(parents=True, exist_ok=True) features_path = output_root / "features.csv" total_bytes = os.path.getsize(input_file) if show_progress and input_file.exists() else 0 progress = tqdm( total=total_bytes, desc="Prepare features", unit="B", unit_scale=True, dynamic_ncols=True, disable=not show_progress, ) try: with input_file.open("r", encoding="utf-8") as input_handle, features_path.open( "w", encoding="utf-8", newline="" ) as features_handle: writer: csv.DictWriter | None = None kept_rows = 0 for line_number, line in enumerate(input_handle, start=1): stripped = line.strip() if not stripped: if show_progress: progress.update(len(line.encode("utf-8"))) continue raw = json.loads(stripped) adapter = get_adapter(raw) record = adapter.parse_line(raw, line_number=line_number) role_sequence = record.role_sequence role_pairs = list(zip(role_sequence, role_sequence[1:])) tool_bursts = _tool_bursts(role_sequence) max_tool_burst = max(tool_bursts) if tool_bursts else 0 avg_tool_burst = _safe_div(sum(tool_bursts), len(tool_bursts)) if tool_bursts else 0.0 tool_to_tool_count = sum(1 for current, nxt in role_pairs if current == "tool" and nxt == "tool") tool_msg_count = sum(message.role == "tool" for message in record.messages) assistant_msg_count = sum(message.role == "assistant" for message in record.messages) cache_hit_ratio = _safe_div(record.usage.cached_tokens, record.usage.input_tokens) feature_row = { "request_id": record.meta.request_id, "session_id": record.meta.session_id, "model": record.meta.request_model, "status_code": record.meta.status_code, "time": record.meta.time, "message_count": len(record.messages), "conversation_depth": len(record.messages), "declared_tool_count": len(record.declared_tools), "assistant_msg_count": assistant_msg_count, "tool_msg_count": tool_msg_count, "user_msg_count": sum(message.role == "user" for message in record.messages), "system_msg_count": sum(message.role == "system" for message in record.messages), "assistant_to_tool_count": sum( 1 for current, nxt in role_pairs if current == "assistant" and nxt == "tool" ), "tool_to_assistant_count": sum( 1 for current, nxt in role_pairs if current == "tool" and nxt == "assistant" ), "tool_to_tool_count": tool_to_tool_count, "assistant_to_user_count": sum( 1 for current, nxt in role_pairs if current == "assistant" and nxt == "user" ), "user_to_assistant_count": sum( 1 for current, nxt in role_pairs if current == "user" and nxt == "assistant" ), "max_consecutive_tool_msgs": max_tool_burst, "avg_tool_burst_len": avg_tool_burst, "has_tool_loop": 1 if tool_to_tool_count > 0 else 0, "input_tokens": record.usage.input_tokens, "output_tokens": record.usage.output_tokens, "total_tokens": record.usage.total_tokens, "reasoning_tokens": record.usage.reasoning_tokens, "cached_tokens": record.usage.cached_tokens, "cache_hit_ratio": cache_hit_ratio, "uncached_prompt_tokens": max(record.usage.input_tokens - record.usage.cached_tokens, 0), "output_input_ratio": _safe_div(record.usage.output_tokens, record.usage.input_tokens), "latency_ms": record.meta.total_cost_time_ms, "ms_per_input_token": _safe_div(record.meta.total_cost_time_ms, record.usage.input_tokens), "ms_per_output_token": _safe_div(record.meta.total_cost_time_ms, record.usage.output_tokens), "long_context": 1 if record.usage.input_tokens >= 32000 else 0, "high_cache": 1 if cache_hit_ratio >= 0.8 else 0, "tool_burst_alert": 1 if max_tool_burst >= 4 else 0, "tool_loop_alert": 1 if tool_to_tool_count >= 3 else 0, "slow_request": 0, "pattern_labels": _pattern_labels( record, cache_hit_ratio=cache_hit_ratio, tool_msg_count=tool_msg_count, assistant_msg_count=assistant_msg_count, max_tool_burst=max_tool_burst, ), } if writer is None: writer = csv.DictWriter(features_handle, fieldnames=list(feature_row.keys())) writer.writeheader() writer.writerow(feature_row) kept_rows += 1 if show_progress: progress.update(len(line.encode("utf-8"))) progress.set_postfix( rows=kept_rows, features=features_path.name, ) finally: if show_progress: progress.close() if show_progress: tqdm.write("Finalize features.csv: apply slow_request p90 latency threshold") _apply_slow_request_threshold(features_path) return { "features_path": str(features_path), } def _safe_div(numerator: float, denominator: float) -> float: return (numerator / denominator) if denominator else 0.0 def _tool_bursts(role_sequence: list[str]) -> list[int]: bursts: list[int] = [] current = 0 for role in role_sequence: if role == "tool": current += 1 elif current: bursts.append(current) current = 0 if current: bursts.append(current) return bursts def _max_tool_burst(role_sequence: list[str]) -> int: bursts = _tool_bursts(role_sequence) return max(bursts) if bursts else 0 def _avg_tool_burst(role_sequence: list[str]) -> float: bursts = _tool_bursts(role_sequence) return _safe_div(sum(bursts), len(bursts)) if bursts else 0.0 def _pattern_labels( record, *, cache_hit_ratio: float | None = None, tool_msg_count: int | None = None, assistant_msg_count: int | None = None, max_tool_burst: int | None = None, ) -> str: labels: list[str] = [] if tool_msg_count is None: tool_msg_count = sum(message.role == "tool" for message in record.messages) if assistant_msg_count is None: assistant_msg_count = sum(message.role == "assistant" for message in record.messages) if cache_hit_ratio is None: cache_hit_ratio = _safe_div(record.usage.cached_tokens, record.usage.input_tokens) if max_tool_burst is None: max_tool_burst = _max_tool_burst(record.role_sequence) if tool_msg_count == 0 and len(record.declared_tools) == 0: labels.append("single-shot") if tool_msg_count > 0 and tool_msg_count >= assistant_msg_count: labels.append("tool-heavy") if max_tool_burst >= 4: labels.append("tool-burst") if cache_hit_ratio >= 0.8: labels.append("cache-efficient") if cache_hit_ratio <= 0.1: labels.append("cache-cold") if record.usage.input_tokens >= 32000 and cache_hit_ratio <= 0.1: labels.append("long-context-no-cache") return ";".join(sorted(set(labels))) def _apply_slow_request_threshold(features_path: Path) -> None: with features_path.open("r", encoding="utf-8") as handle: latencies = [int(row["latency_ms"]) for row in csv.DictReader(handle)] if not latencies: return latencies.sort() p90_latency = percentile(latencies, 0.9) temp_path = features_path.with_suffix(features_path.suffix + ".tmp") with features_path.open("r", encoding="utf-8") as input_handle, temp_path.open( "w", encoding="utf-8", newline="" ) as output_handle: reader = csv.DictReader(input_handle) writer = None for row in reader: slow_request = 1 if int(row["latency_ms"]) >= p90_latency else 0 pattern_labels = {label for label in row.get("pattern_labels", "").split(";") if label} row["slow_request"] = str(slow_request) if slow_request and row.get("high_cache") == "1": pattern_labels.add("slow-despite-cache") row["pattern_labels"] = ";".join(sorted(pattern_labels)) if writer is None: writer = csv.DictWriter(output_handle, fieldnames=list(row.keys())) writer.writeheader() writer.writerow(row) temp_path.replace(features_path)