Files
ali-trace-tools/trace_analyzer/preparation.py
2026-04-21 15:44:47 +00:00

222 lines
9.7 KiB
Python

from __future__ import annotations
import csv
import json
import os
from pathlib import Path
from trace_analyzer.helpers import percentile
from trace_analyzer.parser import get_adapter
from tqdm.auto import tqdm
def stream_prepare(input_path: str | Path, output_dir: str | Path, *, show_progress: bool = False) -> dict:
input_file = Path(input_path)
output_root = Path(output_dir)
output_root.mkdir(parents=True, exist_ok=True)
features_path = output_root / "features.csv"
total_bytes = os.path.getsize(input_file) if show_progress and input_file.exists() else 0
progress = tqdm(
total=total_bytes,
desc="Prepare features",
unit="B",
unit_scale=True,
dynamic_ncols=True,
disable=not show_progress,
)
try:
with input_file.open("r", encoding="utf-8") as input_handle, features_path.open(
"w", encoding="utf-8", newline=""
) as features_handle:
writer: csv.DictWriter | None = None
kept_rows = 0
for line_number, line in enumerate(input_handle, start=1):
stripped = line.strip()
if not stripped:
if show_progress:
progress.update(len(line.encode("utf-8")))
continue
raw = json.loads(stripped)
adapter = get_adapter(raw)
record = adapter.parse_line(raw, line_number=line_number)
role_sequence = record.role_sequence
role_pairs = list(zip(role_sequence, role_sequence[1:]))
tool_bursts = _tool_bursts(role_sequence)
max_tool_burst = max(tool_bursts) if tool_bursts else 0
avg_tool_burst = _safe_div(sum(tool_bursts), len(tool_bursts)) if tool_bursts else 0.0
tool_to_tool_count = sum(1 for current, nxt in role_pairs if current == "tool" and nxt == "tool")
tool_msg_count = sum(message.role == "tool" for message in record.messages)
assistant_msg_count = sum(message.role == "assistant" for message in record.messages)
cache_hit_ratio = _safe_div(record.usage.cached_tokens, record.usage.input_tokens)
feature_row = {
"request_id": record.meta.request_id,
"session_id": record.meta.session_id,
"model": record.meta.request_model,
"status_code": record.meta.status_code,
"time": record.meta.time,
"message_count": len(record.messages),
"conversation_depth": len(record.messages),
"declared_tool_count": len(record.declared_tools),
"assistant_msg_count": assistant_msg_count,
"tool_msg_count": tool_msg_count,
"user_msg_count": sum(message.role == "user" for message in record.messages),
"system_msg_count": sum(message.role == "system" for message in record.messages),
"assistant_to_tool_count": sum(
1
for current, nxt in role_pairs
if current == "assistant" and nxt == "tool"
),
"tool_to_assistant_count": sum(
1
for current, nxt in role_pairs
if current == "tool" and nxt == "assistant"
),
"tool_to_tool_count": tool_to_tool_count,
"assistant_to_user_count": sum(
1
for current, nxt in role_pairs
if current == "assistant" and nxt == "user"
),
"user_to_assistant_count": sum(
1
for current, nxt in role_pairs
if current == "user" and nxt == "assistant"
),
"max_consecutive_tool_msgs": max_tool_burst,
"avg_tool_burst_len": avg_tool_burst,
"has_tool_loop": 1 if tool_to_tool_count > 0 else 0,
"input_tokens": record.usage.input_tokens,
"output_tokens": record.usage.output_tokens,
"total_tokens": record.usage.total_tokens,
"reasoning_tokens": record.usage.reasoning_tokens,
"cached_tokens": record.usage.cached_tokens,
"cache_hit_ratio": cache_hit_ratio,
"uncached_prompt_tokens": max(record.usage.input_tokens - record.usage.cached_tokens, 0),
"output_input_ratio": _safe_div(record.usage.output_tokens, record.usage.input_tokens),
"latency_ms": record.meta.total_cost_time_ms,
"ms_per_input_token": _safe_div(record.meta.total_cost_time_ms, record.usage.input_tokens),
"ms_per_output_token": _safe_div(record.meta.total_cost_time_ms, record.usage.output_tokens),
"long_context": 1 if record.usage.input_tokens >= 32000 else 0,
"high_cache": 1 if cache_hit_ratio >= 0.8 else 0,
"tool_burst_alert": 1 if max_tool_burst >= 4 else 0,
"tool_loop_alert": 1 if tool_to_tool_count >= 3 else 0,
"slow_request": 0,
"pattern_labels": _pattern_labels(
record,
cache_hit_ratio=cache_hit_ratio,
tool_msg_count=tool_msg_count,
assistant_msg_count=assistant_msg_count,
max_tool_burst=max_tool_burst,
),
}
if writer is None:
writer = csv.DictWriter(features_handle, fieldnames=list(feature_row.keys()))
writer.writeheader()
writer.writerow(feature_row)
kept_rows += 1
if show_progress:
progress.update(len(line.encode("utf-8")))
progress.set_postfix(
rows=kept_rows,
features=features_path.name,
)
finally:
if show_progress:
progress.close()
if show_progress:
tqdm.write("Finalize features.csv: apply slow_request p90 latency threshold")
_apply_slow_request_threshold(features_path)
return {
"features_path": str(features_path),
}
def _safe_div(numerator: float, denominator: float) -> float:
return (numerator / denominator) if denominator else 0.0
def _tool_bursts(role_sequence: list[str]) -> list[int]:
bursts: list[int] = []
current = 0
for role in role_sequence:
if role == "tool":
current += 1
elif current:
bursts.append(current)
current = 0
if current:
bursts.append(current)
return bursts
def _max_tool_burst(role_sequence: list[str]) -> int:
bursts = _tool_bursts(role_sequence)
return max(bursts) if bursts else 0
def _avg_tool_burst(role_sequence: list[str]) -> float:
bursts = _tool_bursts(role_sequence)
return _safe_div(sum(bursts), len(bursts)) if bursts else 0.0
def _pattern_labels(
record,
*,
cache_hit_ratio: float | None = None,
tool_msg_count: int | None = None,
assistant_msg_count: int | None = None,
max_tool_burst: int | None = None,
) -> str:
labels: list[str] = []
if tool_msg_count is None:
tool_msg_count = sum(message.role == "tool" for message in record.messages)
if assistant_msg_count is None:
assistant_msg_count = sum(message.role == "assistant" for message in record.messages)
if cache_hit_ratio is None:
cache_hit_ratio = _safe_div(record.usage.cached_tokens, record.usage.input_tokens)
if max_tool_burst is None:
max_tool_burst = _max_tool_burst(record.role_sequence)
if tool_msg_count == 0 and len(record.declared_tools) == 0:
labels.append("single-shot")
if tool_msg_count > 0 and tool_msg_count >= assistant_msg_count:
labels.append("tool-heavy")
if max_tool_burst >= 4:
labels.append("tool-burst")
if cache_hit_ratio >= 0.8:
labels.append("cache-efficient")
if cache_hit_ratio <= 0.1:
labels.append("cache-cold")
if record.usage.input_tokens >= 32000 and cache_hit_ratio <= 0.1:
labels.append("long-context-no-cache")
return ";".join(sorted(set(labels)))
def _apply_slow_request_threshold(features_path: Path) -> None:
with features_path.open("r", encoding="utf-8") as handle:
latencies = [int(row["latency_ms"]) for row in csv.DictReader(handle)]
if not latencies:
return
latencies.sort()
p90_latency = percentile(latencies, 0.9)
temp_path = features_path.with_suffix(features_path.suffix + ".tmp")
with features_path.open("r", encoding="utf-8") as input_handle, temp_path.open(
"w", encoding="utf-8", newline=""
) as output_handle:
reader = csv.DictReader(input_handle)
writer = None
for row in reader:
slow_request = 1 if int(row["latency_ms"]) >= p90_latency else 0
pattern_labels = {label for label in row.get("pattern_labels", "").split(";") if label}
row["slow_request"] = str(slow_request)
if slow_request and row.get("high_cache") == "1":
pattern_labels.add("slow-despite-cache")
row["pattern_labels"] = ";".join(sorted(pattern_labels))
if writer is None:
writer = csv.DictWriter(output_handle, fieldnames=list(row.keys()))
writer.writeheader()
writer.writerow(row)
temp_path.replace(features_path)