Files
ali-trace-tools/trace_analyzer/features.py
2026-04-21 15:44:47 +00:00

118 lines
4.8 KiB
Python

from collections import Counter
from dataclasses import asdict
from .helpers import percentile, safe_div
from .models import TraceFeatures
LONG_CONTEXT_THRESHOLD = 32000
HIGH_CACHE_THRESHOLD = 0.8
TOOL_BURST_THRESHOLD = 4
TOOL_LOOP_THRESHOLD = 3
def _transition_count(roles, left, right):
return sum(1 for current, nxt in zip(roles, roles[1:]) if current == left and nxt == right)
def _tool_bursts(roles):
bursts = []
current = 0
for role in roles:
if role == "tool":
current += 1
elif current:
bursts.append(current)
current = 0
if current:
bursts.append(current)
return bursts
def compute_features(records):
features = []
for record in records:
role_counts = Counter(record.role_sequence)
bursts = _tool_bursts(record.role_sequence)
input_tokens = record.usage.input_tokens
output_tokens = record.usage.output_tokens
cached_tokens = record.usage.cached_tokens
latency_ms = record.meta.total_cost_time_ms
cache_hit_ratio = safe_div(cached_tokens, input_tokens)
tool_to_tool_count = _transition_count(record.role_sequence, "tool", "tool")
feature = TraceFeatures(
request_id=record.meta.request_id,
session_id=record.meta.session_id,
model=record.meta.request_model,
status_code=record.meta.status_code,
time=record.meta.time,
message_count=len(record.messages),
conversation_depth=len(record.messages),
declared_tool_count=len(record.declared_tools),
assistant_msg_count=role_counts.get("assistant", 0),
tool_msg_count=role_counts.get("tool", 0),
user_msg_count=role_counts.get("user", 0),
system_msg_count=role_counts.get("system", 0),
assistant_to_tool_count=_transition_count(record.role_sequence, "assistant", "tool"),
tool_to_assistant_count=_transition_count(record.role_sequence, "tool", "assistant"),
tool_to_tool_count=tool_to_tool_count,
assistant_to_user_count=_transition_count(record.role_sequence, "assistant", "user"),
user_to_assistant_count=_transition_count(record.role_sequence, "user", "assistant"),
max_consecutive_tool_msgs=max(bursts) if bursts else 0,
avg_tool_burst_len=safe_div(sum(bursts), len(bursts)) if bursts else 0.0,
has_tool_loop=1 if tool_to_tool_count > 0 else 0,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=record.usage.total_tokens,
reasoning_tokens=record.usage.reasoning_tokens,
cached_tokens=cached_tokens,
cache_hit_ratio=cache_hit_ratio,
uncached_prompt_tokens=max(input_tokens - cached_tokens, 0),
output_input_ratio=safe_div(output_tokens, input_tokens),
latency_ms=latency_ms,
ms_per_input_token=safe_div(latency_ms, input_tokens),
ms_per_output_token=safe_div(latency_ms, output_tokens),
long_context=1 if input_tokens >= LONG_CONTEXT_THRESHOLD else 0,
high_cache=1 if cache_hit_ratio >= HIGH_CACHE_THRESHOLD else 0,
tool_burst_alert=1 if (max(bursts) if bursts else 0) >= TOOL_BURST_THRESHOLD else 0,
tool_loop_alert=1 if tool_to_tool_count >= TOOL_LOOP_THRESHOLD else 0,
)
feature.pattern_labels = base_pattern_labels(feature)
features.append(feature)
apply_batch_thresholds(features)
return features
def base_pattern_labels(feature):
labels = []
if feature.tool_msg_count == 0 and feature.declared_tool_count == 0:
labels.append("single-shot")
if feature.tool_msg_count > 0 and feature.tool_msg_count >= feature.assistant_msg_count:
labels.append("tool-heavy")
if feature.max_consecutive_tool_msgs >= TOOL_BURST_THRESHOLD:
labels.append("tool-burst")
if feature.cache_hit_ratio >= HIGH_CACHE_THRESHOLD:
labels.append("cache-efficient")
if feature.cache_hit_ratio <= 0.1:
labels.append("cache-cold")
return labels
def apply_batch_thresholds(features):
if not features:
return
latency_p90 = percentile([feature.latency_ms for feature in features], 0.9)
for feature in features:
feature.slow_request = 1 if feature.latency_ms >= latency_p90 else 0
if feature.slow_request and feature.high_cache:
feature.pattern_labels.append("slow-despite-cache")
if feature.input_tokens >= LONG_CONTEXT_THRESHOLD and feature.cache_hit_ratio <= 0.1:
feature.pattern_labels.append("long-context-no-cache")
feature.pattern_labels = sorted(set(feature.pattern_labels))
def feature_to_row(feature):
row = asdict(feature)
row["pattern_labels"] = ";".join(feature.pattern_labels)
return row