Initial commit
This commit is contained in:
221
trace_analyzer/preparation.py
Normal file
221
trace_analyzer/preparation.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from trace_analyzer.helpers import percentile
|
||||
from trace_analyzer.parser import get_adapter
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def stream_prepare(input_path: str | Path, output_dir: str | Path, *, show_progress: bool = False) -> dict:
|
||||
input_file = Path(input_path)
|
||||
output_root = Path(output_dir)
|
||||
output_root.mkdir(parents=True, exist_ok=True)
|
||||
features_path = output_root / "features.csv"
|
||||
total_bytes = os.path.getsize(input_file) if show_progress and input_file.exists() else 0
|
||||
|
||||
progress = tqdm(
|
||||
total=total_bytes,
|
||||
desc="Prepare features",
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
dynamic_ncols=True,
|
||||
disable=not show_progress,
|
||||
)
|
||||
try:
|
||||
with input_file.open("r", encoding="utf-8") as input_handle, features_path.open(
|
||||
"w", encoding="utf-8", newline=""
|
||||
) as features_handle:
|
||||
writer: csv.DictWriter | None = None
|
||||
kept_rows = 0
|
||||
for line_number, line in enumerate(input_handle, start=1):
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
if show_progress:
|
||||
progress.update(len(line.encode("utf-8")))
|
||||
continue
|
||||
raw = json.loads(stripped)
|
||||
adapter = get_adapter(raw)
|
||||
record = adapter.parse_line(raw, line_number=line_number)
|
||||
role_sequence = record.role_sequence
|
||||
role_pairs = list(zip(role_sequence, role_sequence[1:]))
|
||||
tool_bursts = _tool_bursts(role_sequence)
|
||||
max_tool_burst = max(tool_bursts) if tool_bursts else 0
|
||||
avg_tool_burst = _safe_div(sum(tool_bursts), len(tool_bursts)) if tool_bursts else 0.0
|
||||
tool_to_tool_count = sum(1 for current, nxt in role_pairs if current == "tool" and nxt == "tool")
|
||||
tool_msg_count = sum(message.role == "tool" for message in record.messages)
|
||||
assistant_msg_count = sum(message.role == "assistant" for message in record.messages)
|
||||
cache_hit_ratio = _safe_div(record.usage.cached_tokens, record.usage.input_tokens)
|
||||
|
||||
feature_row = {
|
||||
"request_id": record.meta.request_id,
|
||||
"session_id": record.meta.session_id,
|
||||
"model": record.meta.request_model,
|
||||
"status_code": record.meta.status_code,
|
||||
"time": record.meta.time,
|
||||
"message_count": len(record.messages),
|
||||
"conversation_depth": len(record.messages),
|
||||
"declared_tool_count": len(record.declared_tools),
|
||||
"assistant_msg_count": assistant_msg_count,
|
||||
"tool_msg_count": tool_msg_count,
|
||||
"user_msg_count": sum(message.role == "user" for message in record.messages),
|
||||
"system_msg_count": sum(message.role == "system" for message in record.messages),
|
||||
"assistant_to_tool_count": sum(
|
||||
1
|
||||
for current, nxt in role_pairs
|
||||
if current == "assistant" and nxt == "tool"
|
||||
),
|
||||
"tool_to_assistant_count": sum(
|
||||
1
|
||||
for current, nxt in role_pairs
|
||||
if current == "tool" and nxt == "assistant"
|
||||
),
|
||||
"tool_to_tool_count": tool_to_tool_count,
|
||||
"assistant_to_user_count": sum(
|
||||
1
|
||||
for current, nxt in role_pairs
|
||||
if current == "assistant" and nxt == "user"
|
||||
),
|
||||
"user_to_assistant_count": sum(
|
||||
1
|
||||
for current, nxt in role_pairs
|
||||
if current == "user" and nxt == "assistant"
|
||||
),
|
||||
"max_consecutive_tool_msgs": max_tool_burst,
|
||||
"avg_tool_burst_len": avg_tool_burst,
|
||||
"has_tool_loop": 1 if tool_to_tool_count > 0 else 0,
|
||||
"input_tokens": record.usage.input_tokens,
|
||||
"output_tokens": record.usage.output_tokens,
|
||||
"total_tokens": record.usage.total_tokens,
|
||||
"reasoning_tokens": record.usage.reasoning_tokens,
|
||||
"cached_tokens": record.usage.cached_tokens,
|
||||
"cache_hit_ratio": cache_hit_ratio,
|
||||
"uncached_prompt_tokens": max(record.usage.input_tokens - record.usage.cached_tokens, 0),
|
||||
"output_input_ratio": _safe_div(record.usage.output_tokens, record.usage.input_tokens),
|
||||
"latency_ms": record.meta.total_cost_time_ms,
|
||||
"ms_per_input_token": _safe_div(record.meta.total_cost_time_ms, record.usage.input_tokens),
|
||||
"ms_per_output_token": _safe_div(record.meta.total_cost_time_ms, record.usage.output_tokens),
|
||||
"long_context": 1 if record.usage.input_tokens >= 32000 else 0,
|
||||
"high_cache": 1 if cache_hit_ratio >= 0.8 else 0,
|
||||
"tool_burst_alert": 1 if max_tool_burst >= 4 else 0,
|
||||
"tool_loop_alert": 1 if tool_to_tool_count >= 3 else 0,
|
||||
"slow_request": 0,
|
||||
"pattern_labels": _pattern_labels(
|
||||
record,
|
||||
cache_hit_ratio=cache_hit_ratio,
|
||||
tool_msg_count=tool_msg_count,
|
||||
assistant_msg_count=assistant_msg_count,
|
||||
max_tool_burst=max_tool_burst,
|
||||
),
|
||||
}
|
||||
if writer is None:
|
||||
writer = csv.DictWriter(features_handle, fieldnames=list(feature_row.keys()))
|
||||
writer.writeheader()
|
||||
writer.writerow(feature_row)
|
||||
kept_rows += 1
|
||||
if show_progress:
|
||||
progress.update(len(line.encode("utf-8")))
|
||||
progress.set_postfix(
|
||||
rows=kept_rows,
|
||||
features=features_path.name,
|
||||
)
|
||||
finally:
|
||||
if show_progress:
|
||||
progress.close()
|
||||
|
||||
if show_progress:
|
||||
tqdm.write("Finalize features.csv: apply slow_request p90 latency threshold")
|
||||
_apply_slow_request_threshold(features_path)
|
||||
return {
|
||||
"features_path": str(features_path),
|
||||
}
|
||||
|
||||
|
||||
def _safe_div(numerator: float, denominator: float) -> float:
|
||||
return (numerator / denominator) if denominator else 0.0
|
||||
|
||||
|
||||
def _tool_bursts(role_sequence: list[str]) -> list[int]:
|
||||
bursts: list[int] = []
|
||||
current = 0
|
||||
for role in role_sequence:
|
||||
if role == "tool":
|
||||
current += 1
|
||||
elif current:
|
||||
bursts.append(current)
|
||||
current = 0
|
||||
if current:
|
||||
bursts.append(current)
|
||||
return bursts
|
||||
|
||||
|
||||
def _max_tool_burst(role_sequence: list[str]) -> int:
|
||||
bursts = _tool_bursts(role_sequence)
|
||||
return max(bursts) if bursts else 0
|
||||
|
||||
|
||||
def _avg_tool_burst(role_sequence: list[str]) -> float:
|
||||
bursts = _tool_bursts(role_sequence)
|
||||
return _safe_div(sum(bursts), len(bursts)) if bursts else 0.0
|
||||
|
||||
|
||||
def _pattern_labels(
|
||||
record,
|
||||
*,
|
||||
cache_hit_ratio: float | None = None,
|
||||
tool_msg_count: int | None = None,
|
||||
assistant_msg_count: int | None = None,
|
||||
max_tool_burst: int | None = None,
|
||||
) -> str:
|
||||
labels: list[str] = []
|
||||
if tool_msg_count is None:
|
||||
tool_msg_count = sum(message.role == "tool" for message in record.messages)
|
||||
if assistant_msg_count is None:
|
||||
assistant_msg_count = sum(message.role == "assistant" for message in record.messages)
|
||||
if cache_hit_ratio is None:
|
||||
cache_hit_ratio = _safe_div(record.usage.cached_tokens, record.usage.input_tokens)
|
||||
if max_tool_burst is None:
|
||||
max_tool_burst = _max_tool_burst(record.role_sequence)
|
||||
if tool_msg_count == 0 and len(record.declared_tools) == 0:
|
||||
labels.append("single-shot")
|
||||
if tool_msg_count > 0 and tool_msg_count >= assistant_msg_count:
|
||||
labels.append("tool-heavy")
|
||||
if max_tool_burst >= 4:
|
||||
labels.append("tool-burst")
|
||||
if cache_hit_ratio >= 0.8:
|
||||
labels.append("cache-efficient")
|
||||
if cache_hit_ratio <= 0.1:
|
||||
labels.append("cache-cold")
|
||||
if record.usage.input_tokens >= 32000 and cache_hit_ratio <= 0.1:
|
||||
labels.append("long-context-no-cache")
|
||||
return ";".join(sorted(set(labels)))
|
||||
|
||||
|
||||
def _apply_slow_request_threshold(features_path: Path) -> None:
|
||||
with features_path.open("r", encoding="utf-8") as handle:
|
||||
latencies = [int(row["latency_ms"]) for row in csv.DictReader(handle)]
|
||||
if not latencies:
|
||||
return
|
||||
latencies.sort()
|
||||
p90_latency = percentile(latencies, 0.9)
|
||||
temp_path = features_path.with_suffix(features_path.suffix + ".tmp")
|
||||
with features_path.open("r", encoding="utf-8") as input_handle, temp_path.open(
|
||||
"w", encoding="utf-8", newline=""
|
||||
) as output_handle:
|
||||
reader = csv.DictReader(input_handle)
|
||||
writer = None
|
||||
for row in reader:
|
||||
slow_request = 1 if int(row["latency_ms"]) >= p90_latency else 0
|
||||
pattern_labels = {label for label in row.get("pattern_labels", "").split(";") if label}
|
||||
row["slow_request"] = str(slow_request)
|
||||
if slow_request and row.get("high_cache") == "1":
|
||||
pattern_labels.add("slow-despite-cache")
|
||||
row["pattern_labels"] = ";".join(sorted(pattern_labels))
|
||||
if writer is None:
|
||||
writer = csv.DictWriter(output_handle, fieldnames=list(row.keys()))
|
||||
writer.writeheader()
|
||||
writer.writerow(row)
|
||||
temp_path.replace(features_path)
|
||||
Reference in New Issue
Block a user