from collections import Counter from dataclasses import asdict from .helpers import percentile, safe_div from .models import TraceFeatures LONG_CONTEXT_THRESHOLD = 32000 HIGH_CACHE_THRESHOLD = 0.8 TOOL_BURST_THRESHOLD = 4 TOOL_LOOP_THRESHOLD = 3 def _transition_count(roles, left, right): return sum(1 for current, nxt in zip(roles, roles[1:]) if current == left and nxt == right) def _tool_bursts(roles): bursts = [] current = 0 for role in roles: if role == "tool": current += 1 elif current: bursts.append(current) current = 0 if current: bursts.append(current) return bursts def compute_features(records): features = [] for record in records: role_counts = Counter(record.role_sequence) bursts = _tool_bursts(record.role_sequence) input_tokens = record.usage.input_tokens output_tokens = record.usage.output_tokens cached_tokens = record.usage.cached_tokens latency_ms = record.meta.total_cost_time_ms cache_hit_ratio = safe_div(cached_tokens, input_tokens) tool_to_tool_count = _transition_count(record.role_sequence, "tool", "tool") feature = TraceFeatures( request_id=record.meta.request_id, session_id=record.meta.session_id, model=record.meta.request_model, status_code=record.meta.status_code, time=record.meta.time, message_count=len(record.messages), conversation_depth=len(record.messages), declared_tool_count=len(record.declared_tools), assistant_msg_count=role_counts.get("assistant", 0), tool_msg_count=role_counts.get("tool", 0), user_msg_count=role_counts.get("user", 0), system_msg_count=role_counts.get("system", 0), assistant_to_tool_count=_transition_count(record.role_sequence, "assistant", "tool"), tool_to_assistant_count=_transition_count(record.role_sequence, "tool", "assistant"), tool_to_tool_count=tool_to_tool_count, assistant_to_user_count=_transition_count(record.role_sequence, "assistant", "user"), user_to_assistant_count=_transition_count(record.role_sequence, "user", "assistant"), max_consecutive_tool_msgs=max(bursts) if bursts else 0, avg_tool_burst_len=safe_div(sum(bursts), len(bursts)) if bursts else 0.0, has_tool_loop=1 if tool_to_tool_count > 0 else 0, input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=record.usage.total_tokens, reasoning_tokens=record.usage.reasoning_tokens, cached_tokens=cached_tokens, cache_hit_ratio=cache_hit_ratio, uncached_prompt_tokens=max(input_tokens - cached_tokens, 0), output_input_ratio=safe_div(output_tokens, input_tokens), latency_ms=latency_ms, ms_per_input_token=safe_div(latency_ms, input_tokens), ms_per_output_token=safe_div(latency_ms, output_tokens), long_context=1 if input_tokens >= LONG_CONTEXT_THRESHOLD else 0, high_cache=1 if cache_hit_ratio >= HIGH_CACHE_THRESHOLD else 0, tool_burst_alert=1 if (max(bursts) if bursts else 0) >= TOOL_BURST_THRESHOLD else 0, tool_loop_alert=1 if tool_to_tool_count >= TOOL_LOOP_THRESHOLD else 0, ) feature.pattern_labels = base_pattern_labels(feature) features.append(feature) apply_batch_thresholds(features) return features def base_pattern_labels(feature): labels = [] if feature.tool_msg_count == 0 and feature.declared_tool_count == 0: labels.append("single-shot") if feature.tool_msg_count > 0 and feature.tool_msg_count >= feature.assistant_msg_count: labels.append("tool-heavy") if feature.max_consecutive_tool_msgs >= TOOL_BURST_THRESHOLD: labels.append("tool-burst") if feature.cache_hit_ratio >= HIGH_CACHE_THRESHOLD: labels.append("cache-efficient") if feature.cache_hit_ratio <= 0.1: labels.append("cache-cold") return labels def apply_batch_thresholds(features): if not features: return latency_p90 = percentile([feature.latency_ms for feature in features], 0.9) for feature in features: feature.slow_request = 1 if feature.latency_ms >= latency_p90 else 0 if feature.slow_request and feature.high_cache: feature.pattern_labels.append("slow-despite-cache") if feature.input_tokens >= LONG_CONTEXT_THRESHOLD and feature.cache_hit_ratio <= 0.1: feature.pattern_labels.append("long-context-no-cache") feature.pattern_labels = sorted(set(feature.pattern_labels)) def feature_to_row(feature): row = asdict(feature) row["pattern_labels"] = ";".join(feature.pattern_labels) return row