118 lines
4.8 KiB
Python
118 lines
4.8 KiB
Python
from collections import Counter
|
|
from dataclasses import asdict
|
|
|
|
from .helpers import percentile, safe_div
|
|
from .models import TraceFeatures
|
|
|
|
LONG_CONTEXT_THRESHOLD = 32000
|
|
HIGH_CACHE_THRESHOLD = 0.8
|
|
TOOL_BURST_THRESHOLD = 4
|
|
TOOL_LOOP_THRESHOLD = 3
|
|
|
|
|
|
def _transition_count(roles, left, right):
|
|
return sum(1 for current, nxt in zip(roles, roles[1:]) if current == left and nxt == right)
|
|
|
|
|
|
def _tool_bursts(roles):
|
|
bursts = []
|
|
current = 0
|
|
for role in roles:
|
|
if role == "tool":
|
|
current += 1
|
|
elif current:
|
|
bursts.append(current)
|
|
current = 0
|
|
if current:
|
|
bursts.append(current)
|
|
return bursts
|
|
|
|
|
|
def compute_features(records):
|
|
features = []
|
|
for record in records:
|
|
role_counts = Counter(record.role_sequence)
|
|
bursts = _tool_bursts(record.role_sequence)
|
|
input_tokens = record.usage.input_tokens
|
|
output_tokens = record.usage.output_tokens
|
|
cached_tokens = record.usage.cached_tokens
|
|
latency_ms = record.meta.total_cost_time_ms
|
|
cache_hit_ratio = safe_div(cached_tokens, input_tokens)
|
|
tool_to_tool_count = _transition_count(record.role_sequence, "tool", "tool")
|
|
feature = TraceFeatures(
|
|
request_id=record.meta.request_id,
|
|
session_id=record.meta.session_id,
|
|
model=record.meta.request_model,
|
|
status_code=record.meta.status_code,
|
|
time=record.meta.time,
|
|
message_count=len(record.messages),
|
|
conversation_depth=len(record.messages),
|
|
declared_tool_count=len(record.declared_tools),
|
|
assistant_msg_count=role_counts.get("assistant", 0),
|
|
tool_msg_count=role_counts.get("tool", 0),
|
|
user_msg_count=role_counts.get("user", 0),
|
|
system_msg_count=role_counts.get("system", 0),
|
|
assistant_to_tool_count=_transition_count(record.role_sequence, "assistant", "tool"),
|
|
tool_to_assistant_count=_transition_count(record.role_sequence, "tool", "assistant"),
|
|
tool_to_tool_count=tool_to_tool_count,
|
|
assistant_to_user_count=_transition_count(record.role_sequence, "assistant", "user"),
|
|
user_to_assistant_count=_transition_count(record.role_sequence, "user", "assistant"),
|
|
max_consecutive_tool_msgs=max(bursts) if bursts else 0,
|
|
avg_tool_burst_len=safe_div(sum(bursts), len(bursts)) if bursts else 0.0,
|
|
has_tool_loop=1 if tool_to_tool_count > 0 else 0,
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
total_tokens=record.usage.total_tokens,
|
|
reasoning_tokens=record.usage.reasoning_tokens,
|
|
cached_tokens=cached_tokens,
|
|
cache_hit_ratio=cache_hit_ratio,
|
|
uncached_prompt_tokens=max(input_tokens - cached_tokens, 0),
|
|
output_input_ratio=safe_div(output_tokens, input_tokens),
|
|
latency_ms=latency_ms,
|
|
ms_per_input_token=safe_div(latency_ms, input_tokens),
|
|
ms_per_output_token=safe_div(latency_ms, output_tokens),
|
|
long_context=1 if input_tokens >= LONG_CONTEXT_THRESHOLD else 0,
|
|
high_cache=1 if cache_hit_ratio >= HIGH_CACHE_THRESHOLD else 0,
|
|
tool_burst_alert=1 if (max(bursts) if bursts else 0) >= TOOL_BURST_THRESHOLD else 0,
|
|
tool_loop_alert=1 if tool_to_tool_count >= TOOL_LOOP_THRESHOLD else 0,
|
|
)
|
|
feature.pattern_labels = base_pattern_labels(feature)
|
|
features.append(feature)
|
|
|
|
apply_batch_thresholds(features)
|
|
return features
|
|
|
|
|
|
def base_pattern_labels(feature):
|
|
labels = []
|
|
if feature.tool_msg_count == 0 and feature.declared_tool_count == 0:
|
|
labels.append("single-shot")
|
|
if feature.tool_msg_count > 0 and feature.tool_msg_count >= feature.assistant_msg_count:
|
|
labels.append("tool-heavy")
|
|
if feature.max_consecutive_tool_msgs >= TOOL_BURST_THRESHOLD:
|
|
labels.append("tool-burst")
|
|
if feature.cache_hit_ratio >= HIGH_CACHE_THRESHOLD:
|
|
labels.append("cache-efficient")
|
|
if feature.cache_hit_ratio <= 0.1:
|
|
labels.append("cache-cold")
|
|
return labels
|
|
|
|
|
|
def apply_batch_thresholds(features):
|
|
if not features:
|
|
return
|
|
latency_p90 = percentile([feature.latency_ms for feature in features], 0.9)
|
|
for feature in features:
|
|
feature.slow_request = 1 if feature.latency_ms >= latency_p90 else 0
|
|
if feature.slow_request and feature.high_cache:
|
|
feature.pattern_labels.append("slow-despite-cache")
|
|
if feature.input_tokens >= LONG_CONTEXT_THRESHOLD and feature.cache_hit_ratio <= 0.1:
|
|
feature.pattern_labels.append("long-context-no-cache")
|
|
feature.pattern_labels = sorted(set(feature.pattern_labels))
|
|
|
|
|
|
def feature_to_row(feature):
|
|
row = asdict(feature)
|
|
row["pattern_labels"] = ";".join(feature.pattern_labels)
|
|
return row
|