Initial commit
This commit is contained in:
809
trace_analyzer/figures.py
Normal file
809
trace_analyzer/figures.py
Normal file
@@ -0,0 +1,809 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import json
|
||||
from collections import Counter, defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.ticker import MaxNLocator, MultipleLocator
|
||||
|
||||
from trace_analyzer.helpers import percentile, safe_float, safe_int
|
||||
from trace_analyzer.layout import resolve_details_dir
|
||||
|
||||
|
||||
PALETTE = {
|
||||
"blue": "#2B6CB0",
|
||||
"orange": "#DD6B20",
|
||||
"green": "#2F855A",
|
||||
"red": "#C53030",
|
||||
"purple": "#6B46C1",
|
||||
"gray": "#4A5568",
|
||||
"teal": "#0F766E",
|
||||
"gold": "#B7791F",
|
||||
"pink": "#D53F8C",
|
||||
"grid": "#CBD5E0",
|
||||
}
|
||||
|
||||
FIGURE_STEMS = [
|
||||
"01_input_output_length_cdf",
|
||||
"02_session_turns_cdf",
|
||||
"03_request_length_by_turn",
|
||||
"04_request_trigger_role_pie",
|
||||
"05_tool_call_output_length_cdf",
|
||||
"06_tool_call_latency_cdf",
|
||||
"07_consecutive_tool_call_count_cdf",
|
||||
"08_tool_call_added_context_cdf",
|
||||
"09_kvcache_block_reuse_time_cdf",
|
||||
"10_kvcache_block_lifecycle_cdf",
|
||||
"11_alive_kvcache_blocks_timeline",
|
||||
"12_bucket_kvcache_reuse_ratio",
|
||||
"13_session_cross_bucket_kvcache_miss",
|
||||
]
|
||||
|
||||
|
||||
def _ensure_dir(path: Path) -> None:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _clear_dir_files(path: Path) -> None:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
for child in path.iterdir():
|
||||
if child.is_file():
|
||||
child.unlink()
|
||||
|
||||
|
||||
def _apply_style() -> None:
|
||||
plt.rcParams.update(
|
||||
{
|
||||
"figure.figsize": (8.0, 4.8),
|
||||
"figure.dpi": 600,
|
||||
"savefig.dpi": 600,
|
||||
"font.family": "DejaVu Serif",
|
||||
"font.size": 11,
|
||||
"axes.titlesize": 13,
|
||||
"axes.labelsize": 12,
|
||||
"axes.linewidth": 0.9,
|
||||
"xtick.labelsize": 10,
|
||||
"ytick.labelsize": 10,
|
||||
"legend.fontsize": 10,
|
||||
"legend.frameon": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _finalize_axes(ax: plt.Axes, *, grid_axis: str = "y") -> None:
|
||||
ax.spines["top"].set_visible(False)
|
||||
ax.spines["right"].set_visible(False)
|
||||
ax.grid(axis=grid_axis, color=PALETTE["grid"], alpha=0.5, linewidth=0.8)
|
||||
ax.tick_params(axis="both", which="major", length=4, width=0.8)
|
||||
|
||||
|
||||
def _save(fig: plt.Figure, fig_dir: Path, stem: str) -> None:
|
||||
fig.savefig(fig_dir / f"{stem}.png", bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def _read_json(path: Path) -> dict:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _read_csv_rows(path: Path) -> list[dict]:
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
return list(csv.DictReader(handle))
|
||||
|
||||
|
||||
def _load_request_metrics(path: Path) -> list[dict]:
|
||||
rows = []
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for row in csv.DictReader(handle):
|
||||
rows.append(
|
||||
{
|
||||
"request_id": row.get("request_id", ""),
|
||||
"session_id": row.get("session_id", ""),
|
||||
"turn": safe_int(row.get("turn")),
|
||||
"trigger_group": row.get("trigger_group", "") or "unknown",
|
||||
"input_tokens": safe_int(row.get("input_tokens")),
|
||||
"output_tokens": safe_int(row.get("output_tokens")),
|
||||
"request_ready_time_ms": safe_int(row.get("request_ready_time_ms")),
|
||||
"request_end_time_ms": safe_int(row.get("request_end_time_ms")),
|
||||
"input_length_bucket": row.get("input_length_bucket", ""),
|
||||
"theoretical_prompt_unit_length": safe_int(row.get("theoretical_prompt_unit_length")),
|
||||
"theoretical_prefix_hit_blocks": safe_int(row.get("theoretical_prefix_hit_blocks")),
|
||||
"bucketed_theoretical_prefix_hit_blocks": safe_int(
|
||||
row.get("bucketed_theoretical_prefix_hit_blocks")
|
||||
),
|
||||
}
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def _sort_request_rows(rows: list[dict]) -> list[dict]:
|
||||
return sorted(
|
||||
rows,
|
||||
key=lambda row: (
|
||||
row["request_ready_time_ms"],
|
||||
row["turn"],
|
||||
row["request_id"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_session_sequences(request_rows: list[dict]) -> dict[str, list[dict]]:
|
||||
sessions = defaultdict(list)
|
||||
for row in request_rows:
|
||||
sessions[row["session_id"]].append(row)
|
||||
for session_rows in sessions.values():
|
||||
session_rows.sort(
|
||||
key=lambda row: (
|
||||
row["request_ready_time_ms"],
|
||||
row["turn"],
|
||||
row["request_id"],
|
||||
)
|
||||
)
|
||||
return sessions
|
||||
|
||||
|
||||
def _build_tool_round_edges(session_rows_by_id: dict[str, list[dict]]) -> list[dict]:
|
||||
edges = []
|
||||
for session_id, session_rows in session_rows_by_id.items():
|
||||
for previous, current in zip(session_rows, session_rows[1:]):
|
||||
if current["trigger_group"] != "tool":
|
||||
continue
|
||||
edges.append(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"prev_request_id": previous["request_id"],
|
||||
"next_request_id": current["request_id"],
|
||||
"tool_call_output_tokens": previous["output_tokens"],
|
||||
"tool_call_latency_ms": max(
|
||||
current["request_ready_time_ms"] - previous["request_end_time_ms"],
|
||||
0,
|
||||
),
|
||||
"added_context_tokens": max(
|
||||
current["input_tokens"] - previous["output_tokens"],
|
||||
0,
|
||||
),
|
||||
}
|
||||
)
|
||||
return edges
|
||||
|
||||
|
||||
def _ecdf(values: list[float]) -> tuple[np.ndarray, np.ndarray]:
|
||||
arr = np.asarray([value for value in values if value is not None], dtype=float)
|
||||
arr = np.sort(arr)
|
||||
if arr.size == 0:
|
||||
return arr, arr
|
||||
xs, counts = np.unique(arr, return_counts=True)
|
||||
ys = np.cumsum(counts, dtype=float) / arr.size
|
||||
return xs, ys
|
||||
|
||||
|
||||
def _ecdf_from_weighted_rows(rows: list[dict], *, value_key: str, count_key: str) -> tuple[np.ndarray, np.ndarray]:
|
||||
weighted = sorted(
|
||||
(
|
||||
safe_float(row[value_key]),
|
||||
safe_int(row[count_key]),
|
||||
)
|
||||
for row in rows
|
||||
if safe_int(row.get(count_key)) > 0
|
||||
)
|
||||
total = sum(count for _, count in weighted)
|
||||
if total <= 0:
|
||||
return np.asarray([]), np.asarray([])
|
||||
xs = np.asarray([value for value, _ in weighted], dtype=float)
|
||||
ys = np.asarray(np.cumsum([count for _, count in weighted], dtype=float) / total, dtype=float)
|
||||
return xs, ys
|
||||
|
||||
|
||||
def _stats(values: list[float], labels: tuple[str, ...]) -> dict[str, float]:
|
||||
cleaned = [value for value in values if value is not None]
|
||||
if not cleaned:
|
||||
return {label: 0.0 for label in labels}
|
||||
mapping = {"mean": float(np.mean(cleaned))}
|
||||
for label in labels:
|
||||
if label == "mean":
|
||||
continue
|
||||
mapping[label] = percentile(cleaned, int(label[1:]) / 100)
|
||||
return mapping
|
||||
|
||||
|
||||
def _weighted_stats(rows: list[dict], *, value_key: str, count_key: str, labels: tuple[str, ...]) -> dict[str, float]:
|
||||
weighted = sorted(
|
||||
(
|
||||
safe_float(row[value_key]),
|
||||
safe_int(row[count_key]),
|
||||
)
|
||||
for row in rows
|
||||
if safe_int(row.get(count_key)) > 0
|
||||
)
|
||||
total = sum(count for _, count in weighted)
|
||||
if total <= 0:
|
||||
return {label: 0.0 for label in labels}
|
||||
result = {}
|
||||
weighted_sum = sum(value * count for value, count in weighted)
|
||||
result["mean"] = weighted_sum / total
|
||||
for label in labels:
|
||||
if label == "mean":
|
||||
continue
|
||||
target = int(label[1:]) / 100 * total
|
||||
seen = 0
|
||||
value_at_target = weighted[-1][0]
|
||||
for value, count in weighted:
|
||||
seen += count
|
||||
if seen >= target:
|
||||
value_at_target = value
|
||||
break
|
||||
result[label] = value_at_target
|
||||
return result
|
||||
|
||||
|
||||
def _format_stat_text(title: str, stats: dict[str, float], labels: tuple[str, ...]) -> str:
|
||||
parts = [title]
|
||||
for label in labels:
|
||||
value = stats.get(label, 0.0)
|
||||
if abs(value - round(value)) < 1e-6:
|
||||
parts.append(f"{label}={int(round(value))}")
|
||||
else:
|
||||
parts.append(f"{label}={value:.2f}")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _add_footer(fig: plt.Figure, lines: list[str]) -> None:
|
||||
fig.subplots_adjust(bottom=0.24)
|
||||
y = 0.06
|
||||
for line in lines:
|
||||
fig.text(0.5, y, line, ha="center", va="bottom", fontsize=9.5)
|
||||
y -= 0.035
|
||||
|
||||
|
||||
def _plot_two_series_cdf_with_zoom(
|
||||
fig_dir: Path,
|
||||
*,
|
||||
stem: str,
|
||||
title: str,
|
||||
xlabel: str,
|
||||
first_label: str,
|
||||
first_values: list[float],
|
||||
first_color: str,
|
||||
second_label: str,
|
||||
second_values: list[float],
|
||||
second_color: str,
|
||||
zoom_quantile: float,
|
||||
stats_labels: tuple[str, ...],
|
||||
) -> None:
|
||||
first_xs, first_ys = _ecdf(first_values)
|
||||
second_xs, second_ys = _ecdf(second_values)
|
||||
zoom_max = max(
|
||||
percentile(first_values, zoom_quantile) if first_values else 0.0,
|
||||
percentile(second_values, zoom_quantile) if second_values else 0.0,
|
||||
)
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(12.4, 4.8))
|
||||
for ax, subtitle in zip(axes, ["Full Range", f"Zoom: <= p{int(zoom_quantile * 100)}"]):
|
||||
ax.step(first_xs, first_ys, where="post", linewidth=2.2, color=first_color, label=first_label)
|
||||
ax.step(second_xs, second_ys, where="post", linewidth=2.2, color=second_color, label=second_label)
|
||||
ax.set_title(subtitle)
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_ylabel("CDF")
|
||||
_finalize_axes(ax)
|
||||
axes[1].set_xlim(0, zoom_max if zoom_max > 0 else 1)
|
||||
axes[0].legend(loc="lower right")
|
||||
fig.suptitle(title, y=0.98)
|
||||
_add_footer(
|
||||
fig,
|
||||
[
|
||||
_format_stat_text(first_label, _stats(first_values, stats_labels), stats_labels),
|
||||
_format_stat_text(second_label, _stats(second_values, stats_labels), stats_labels),
|
||||
],
|
||||
)
|
||||
_save(fig, fig_dir, stem)
|
||||
|
||||
|
||||
def _plot_single_cdf(
|
||||
fig_dir: Path,
|
||||
*,
|
||||
stem: str,
|
||||
title: str,
|
||||
xlabel: str,
|
||||
label: str,
|
||||
values: list[float] | None = None,
|
||||
weighted_rows: list[dict] | None = None,
|
||||
weighted_value_key: str | None = None,
|
||||
weighted_count_key: str | None = None,
|
||||
color: str = PALETTE["blue"],
|
||||
zoom_quantile: float | None = None,
|
||||
stats_labels: tuple[str, ...] = ("mean", "p50", "p90", "p95", "p99"),
|
||||
) -> None:
|
||||
values = values or []
|
||||
weighted_rows = weighted_rows or []
|
||||
if weighted_rows:
|
||||
xs, ys = _ecdf_from_weighted_rows(
|
||||
weighted_rows,
|
||||
value_key=weighted_value_key,
|
||||
count_key=weighted_count_key,
|
||||
)
|
||||
stats = _weighted_stats(
|
||||
weighted_rows,
|
||||
value_key=weighted_value_key,
|
||||
count_key=weighted_count_key,
|
||||
labels=stats_labels,
|
||||
)
|
||||
zoom_max = stats.get(f"p{int(zoom_quantile * 100)}", 0.0) if zoom_quantile is not None else 0.0
|
||||
else:
|
||||
xs, ys = _ecdf(values)
|
||||
stats = _stats(values, stats_labels)
|
||||
zoom_max = percentile(values, zoom_quantile) if zoom_quantile is not None and values else 0.0
|
||||
|
||||
panel_count = 2 if zoom_quantile is not None else 1
|
||||
fig, axes = plt.subplots(1, panel_count, figsize=(12.4, 4.8) if panel_count == 2 else (8.2, 4.8))
|
||||
if panel_count == 1:
|
||||
axes = [axes]
|
||||
axes[0].step(xs, ys, where="post", linewidth=2.2, color=color)
|
||||
axes[0].set_title("Full Range")
|
||||
axes[0].set_xlabel(xlabel)
|
||||
axes[0].set_ylabel("CDF")
|
||||
_finalize_axes(axes[0])
|
||||
if panel_count == 2:
|
||||
axes[1].step(xs, ys, where="post", linewidth=2.2, color=color)
|
||||
axes[1].set_title(f"Zoom: <= p{int(zoom_quantile * 100)}")
|
||||
axes[1].set_xlabel(xlabel)
|
||||
axes[1].set_ylabel("CDF")
|
||||
axes[1].set_xlim(0, zoom_max if zoom_max > 0 else 1)
|
||||
_finalize_axes(axes[1])
|
||||
fig.suptitle(title, y=0.98)
|
||||
_add_footer(fig, [_format_stat_text(label, stats, stats_labels)])
|
||||
_save(fig, fig_dir, stem)
|
||||
|
||||
|
||||
def _plot_session_turns_cdf(fig_dir: Path, request_rows: list[dict]) -> None:
|
||||
session_sizes = Counter(row["session_id"] for row in request_rows)
|
||||
values = list(session_sizes.values())
|
||||
xs, ys = _ecdf(values)
|
||||
max_turn = max(values) if values else 1
|
||||
zoom_max = max(int(np.ceil(max_turn * 0.10)), 1)
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(12.4, 4.8))
|
||||
for ax, subtitle in zip(axes, ["Full Range", f"Zoom: <= {zoom_max} turns (first 10% of max turn)"]):
|
||||
ax.step(xs, ys, where="post", linewidth=2.2, color=PALETTE["green"])
|
||||
ax.set_title(subtitle)
|
||||
ax.set_xlabel("Turns per session")
|
||||
ax.set_ylabel("CDF")
|
||||
_finalize_axes(ax)
|
||||
axes[1].set_xlim(0.5, zoom_max + 0.5)
|
||||
fig.suptitle("Session Turns CDF", y=0.98)
|
||||
_add_footer(
|
||||
fig,
|
||||
[
|
||||
_format_stat_text(
|
||||
"Session turns",
|
||||
_stats(values, ("mean", "p50", "p90", "p95", "p99")),
|
||||
("mean", "p50", "p90", "p95", "p99"),
|
||||
)
|
||||
],
|
||||
)
|
||||
_save(fig, fig_dir, "02_session_turns_cdf")
|
||||
|
||||
|
||||
def _plot_request_length_by_turn(fig_dir: Path, request_rows: list[dict]) -> None:
|
||||
values_by_turn = defaultdict(list)
|
||||
for row in request_rows:
|
||||
if row["turn"] > 0:
|
||||
values_by_turn[row["turn"]].append(row["input_tokens"])
|
||||
turns = sorted(values_by_turn)
|
||||
mean_values = [float(np.mean(values_by_turn[turn])) for turn in turns]
|
||||
p50_values = [percentile(values_by_turn[turn], 0.50) for turn in turns]
|
||||
p99_values = [percentile(values_by_turn[turn], 0.99) for turn in turns]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8.6, 4.8))
|
||||
ax.plot(turns, mean_values, color=PALETTE["blue"], linewidth=2.0, label="mean")
|
||||
ax.plot(turns, p50_values, color=PALETTE["orange"], linewidth=2.0, label="p50")
|
||||
ax.plot(turns, p99_values, color=PALETTE["red"], linewidth=2.0, label="p99")
|
||||
ax.set_title("Request Input Length by Turn")
|
||||
ax.set_xlabel("Turn")
|
||||
ax.set_ylabel("Input tokens")
|
||||
ax.legend(loc="upper left")
|
||||
ax.xaxis.set_major_locator(MaxNLocator(nbins=12, integer=True))
|
||||
plt.setp(ax.get_xticklabels(), rotation=20, ha="right")
|
||||
_finalize_axes(ax)
|
||||
fig.tight_layout()
|
||||
_save(fig, fig_dir, "03_request_length_by_turn")
|
||||
|
||||
|
||||
def _plot_trigger_role_pie(fig_dir: Path, request_rows: list[dict]) -> None:
|
||||
label_order = ["user", "tool", "assistant"]
|
||||
color_by_label = {
|
||||
"user": PALETTE["orange"],
|
||||
"tool": PALETTE["green"],
|
||||
"assistant": PALETTE["blue"],
|
||||
}
|
||||
counts = Counter(row["trigger_group"] for row in request_rows)
|
||||
labels = [label for label in label_order if counts[label] > 0]
|
||||
values = [counts[label] for label in labels]
|
||||
colors = [color_by_label[label] for label in labels]
|
||||
|
||||
def _autopct(pct):
|
||||
total = sum(values)
|
||||
count = int(round(pct * total / 100.0))
|
||||
return f"{pct:.1f}%\n({count})"
|
||||
|
||||
fig, ax = plt.subplots(figsize=(9.0, 5.8))
|
||||
wedges, _texts, autotexts = ax.pie(
|
||||
values,
|
||||
autopct=_autopct,
|
||||
startangle=90,
|
||||
colors=colors,
|
||||
wedgeprops={"linewidth": 0.8, "edgecolor": "white"},
|
||||
textprops={"fontsize": 9},
|
||||
)
|
||||
for autotext in autotexts:
|
||||
autotext.set_fontsize(8.5)
|
||||
ax.legend(
|
||||
wedges,
|
||||
[f"{label} ({counts[label]:,})" for label in labels],
|
||||
title="Trigger source",
|
||||
loc="center left",
|
||||
bbox_to_anchor=(1.02, 0.5),
|
||||
)
|
||||
ax.set_title("Request Trigger Role Proportion")
|
||||
fig.tight_layout()
|
||||
_save(fig, fig_dir, "04_request_trigger_role_pie")
|
||||
|
||||
|
||||
def _plot_session_gap_cdf(fig_dir: Path, session_rows_by_id: dict[str, list[dict]]) -> None:
|
||||
ready_gaps = []
|
||||
end_ready_gaps = []
|
||||
for session_rows in session_rows_by_id.values():
|
||||
for previous, current in zip(session_rows, session_rows[1:]):
|
||||
ready_gaps.append(max(current["request_ready_time_ms"] - previous["request_ready_time_ms"], 0))
|
||||
end_ready_gaps.append(max(current["request_ready_time_ms"] - previous["request_end_time_ms"], 0))
|
||||
_plot_two_series_cdf_with_zoom(
|
||||
fig_dir,
|
||||
stem="session_inter_request_gap_cdf",
|
||||
title="Session Inter-Request Gap CDF",
|
||||
xlabel="Milliseconds",
|
||||
first_label="ready->ready",
|
||||
first_values=ready_gaps,
|
||||
first_color=PALETTE["purple"],
|
||||
second_label="end->ready",
|
||||
second_values=end_ready_gaps,
|
||||
second_color=PALETTE["gray"],
|
||||
zoom_quantile=0.90,
|
||||
stats_labels=("mean", "p50", "p90", "p95", "p99"),
|
||||
)
|
||||
|
||||
|
||||
def _plot_consecutive_tool_calls_cdf(fig_dir: Path, session_rows_by_id: dict[str, list[dict]]) -> None:
|
||||
values = []
|
||||
for session_rows in session_rows_by_id.values():
|
||||
for index, row in enumerate(session_rows):
|
||||
if row["trigger_group"] != "user":
|
||||
continue
|
||||
count = 0
|
||||
next_index = index + 1
|
||||
while next_index < len(session_rows) and session_rows[next_index]["trigger_group"] == "tool":
|
||||
count += 1
|
||||
next_index += 1
|
||||
values.append(count)
|
||||
_plot_single_cdf(
|
||||
fig_dir,
|
||||
stem="07_consecutive_tool_call_count_cdf",
|
||||
title="Consecutive Tool Calls After One User Input",
|
||||
xlabel="Consecutive tool-triggered rounds",
|
||||
label="Consecutive tool calls",
|
||||
values=values,
|
||||
color=PALETTE["green"],
|
||||
)
|
||||
|
||||
|
||||
def _plot_alive_kvcache_timeline(fig_dir: Path, timeline_rows: list[dict]) -> None:
|
||||
fig, ax = plt.subplots(figsize=(10.2, 4.8))
|
||||
if timeline_rows:
|
||||
base_ts = safe_int(timeline_rows[0]["timestamp_ms"])
|
||||
else:
|
||||
base_ts = 0
|
||||
xs = [
|
||||
max(safe_int(row["timestamp_ms"]) - base_ts, 0) / 60000.0
|
||||
for row in timeline_rows
|
||||
]
|
||||
ys = [safe_int(row["alive_block_count"]) for row in timeline_rows]
|
||||
ax.step(xs, ys, where="post", color=PALETTE["purple"], linewidth=1.8)
|
||||
ax.set_title("Alive KV-Cache Blocks Over Time")
|
||||
ax.set_xlabel("Elapsed time (minutes)")
|
||||
ax.set_ylabel("Alive block count")
|
||||
ax.xaxis.set_major_locator(MultipleLocator(10))
|
||||
plt.setp(ax.get_xticklabels(), rotation=20, ha="right")
|
||||
_finalize_axes(ax)
|
||||
fig.tight_layout()
|
||||
_save(fig, fig_dir, "11_alive_kvcache_blocks_timeline")
|
||||
|
||||
|
||||
def _plot_bucket_reuse_ratio(fig_dir: Path, request_rows: list[dict]) -> None:
|
||||
by_bucket = defaultdict(lambda: {"prompt_blocks": 0, "reused_blocks": 0})
|
||||
total_prompt_blocks = 0
|
||||
total_reused_blocks = 0
|
||||
for row in request_rows:
|
||||
bucket = row["input_length_bucket"] or "unknown"
|
||||
prompt_blocks = row["theoretical_prompt_unit_length"]
|
||||
reused_blocks = row["bucketed_theoretical_prefix_hit_blocks"]
|
||||
by_bucket[bucket]["prompt_blocks"] += prompt_blocks
|
||||
by_bucket[bucket]["reused_blocks"] += reused_blocks
|
||||
total_prompt_blocks += prompt_blocks
|
||||
total_reused_blocks += row["theoretical_prefix_hit_blocks"]
|
||||
|
||||
labels = list(by_bucket)
|
||||
ratios = [
|
||||
(by_bucket[label]["reused_blocks"] / by_bucket[label]["prompt_blocks"])
|
||||
if by_bucket[label]["prompt_blocks"]
|
||||
else 0.0
|
||||
for label in labels
|
||||
]
|
||||
reused_counts = [by_bucket[label]["reused_blocks"] for label in labels]
|
||||
labels.append("Overall")
|
||||
ratios.append((total_reused_blocks / total_prompt_blocks) if total_prompt_blocks else 0.0)
|
||||
reused_counts.append(total_reused_blocks)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(9.2, 4.8))
|
||||
bars = ax.bar(
|
||||
labels,
|
||||
ratios,
|
||||
color=[PALETTE["blue"], PALETTE["orange"], PALETTE["green"], PALETTE["purple"], PALETTE["teal"]][: len(labels)],
|
||||
width=0.68,
|
||||
edgecolor="white",
|
||||
linewidth=0.8,
|
||||
)
|
||||
for bar, ratio, reused_count in zip(bars, ratios, reused_counts):
|
||||
ax.text(
|
||||
bar.get_x() + bar.get_width() / 2,
|
||||
ratio + max(ratios + [0.0]) * 0.03 + 1e-9,
|
||||
f"{ratio:.2%}\nreused={reused_count:,}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8.8,
|
||||
)
|
||||
ax.set_title("Bucketed KV-Cache Reuse Ratio vs Global Reuse Ratio")
|
||||
ax.set_xlabel("Input-length bucket")
|
||||
ax.set_ylabel("Reuse ratio")
|
||||
ax.set_ylim(0, max(ratios + [0.0]) * 1.25 + 1e-9)
|
||||
_finalize_axes(ax)
|
||||
fig.tight_layout()
|
||||
_save(fig, fig_dir, "12_bucket_kvcache_reuse_ratio")
|
||||
|
||||
|
||||
def _plot_session_cross_bucket_miss(fig_dir: Path, rows: list[dict]) -> None:
|
||||
labels = [row["bucket"] for row in rows]
|
||||
miss_ratios = [safe_float(row["cross_bucket_edge_fraction"]) for row in rows]
|
||||
loss_ratios = [safe_float(row["reduced_reused_blocks_ratio"]) for row in rows]
|
||||
miss_blocks = [safe_int(row["cross_bucket_shared_prefix_units_sum"]) for row in rows]
|
||||
x = np.arange(len(labels))
|
||||
width = 0.36
|
||||
|
||||
fig, ax = plt.subplots(figsize=(9.2, 4.8))
|
||||
left = ax.bar(x - width / 2, miss_ratios, width=width, color=PALETTE["red"], label="cross-bucket miss ratio")
|
||||
right = ax.bar(
|
||||
x + width / 2,
|
||||
loss_ratios,
|
||||
width=width,
|
||||
color=PALETTE["gold"],
|
||||
label="reduced reused blocks / bucket reuse",
|
||||
)
|
||||
y_pad = max(miss_ratios + loss_ratios + [0.0]) * 0.03 + 1e-9
|
||||
for bar, value, count in zip(left, miss_ratios, miss_blocks):
|
||||
ax.text(
|
||||
bar.get_x() + bar.get_width() / 2,
|
||||
value + y_pad,
|
||||
f"{value:.2%}\nmiss={count:,}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8.8,
|
||||
)
|
||||
for bar, value in zip(right, loss_ratios):
|
||||
ax.text(
|
||||
bar.get_x() + bar.get_width() / 2,
|
||||
value + y_pad,
|
||||
f"{value:.2%}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8.8,
|
||||
)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(labels)
|
||||
ax.set_title("Session Cross-Bucket KV-Cache Miss and Reuse Loss")
|
||||
ax.set_xlabel("Child bucket")
|
||||
ax.set_ylabel("Ratio")
|
||||
ax.legend(loc="upper left")
|
||||
ax.set_ylim(0, max(miss_ratios + loss_ratios + [0.0]) * 1.25 + 1e-9)
|
||||
_finalize_axes(ax)
|
||||
fig.tight_layout()
|
||||
_save(fig, fig_dir, "13_session_cross_bucket_kvcache_miss")
|
||||
|
||||
|
||||
def _write_manifest(fig_dir: Path, manifest: dict) -> None:
|
||||
(fig_dir / "manifest.json").write_text(json.dumps(manifest, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
def _write_readme(fig_dir: Path, dataset_title: str) -> None:
|
||||
lines = [
|
||||
f"# {dataset_title}",
|
||||
"",
|
||||
"This directory contains the PNG figures rendered from `details/` data.",
|
||||
"",
|
||||
"Figures:",
|
||||
]
|
||||
for stem in FIGURE_STEMS:
|
||||
lines.append(f"- `{stem}.png`")
|
||||
lines.append("- `session_inter_request_gap_cdf.png`")
|
||||
(fig_dir / "README.md").write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
|
||||
def render_figures(
|
||||
*,
|
||||
analysis_dir: str | Path,
|
||||
fig_dir: str | Path,
|
||||
dataset_title: str,
|
||||
show_progress: bool = False,
|
||||
) -> dict:
|
||||
analysis_root = Path(analysis_dir)
|
||||
fig_root = Path(fig_dir)
|
||||
details_root = resolve_details_dir(analysis_root)
|
||||
_clear_dir_files(fig_root)
|
||||
_apply_style()
|
||||
|
||||
request_rows = _load_request_metrics(details_root / "request_metrics.csv")
|
||||
request_rows = _sort_request_rows(request_rows)
|
||||
session_rows_by_id = _build_session_sequences(request_rows)
|
||||
tool_round_edges = _build_tool_round_edges(session_rows_by_id)
|
||||
reuse_gap_rows = _read_csv_rows(details_root / "theoretical_block_reuse_gaps.csv")
|
||||
block_lifetime_rows = _read_csv_rows(details_root / "theoretical_block_lifetimes.csv")
|
||||
timeline_rows = _read_csv_rows(details_root / "theoretical_alive_block_timeline.csv")
|
||||
session_bucket_rows = _read_csv_rows(details_root / "session_bucket_boundary_miss.csv")
|
||||
details_summary = _read_json(details_root / "details_summary.json")
|
||||
|
||||
progress = tqdm(
|
||||
total=len(FIGURE_STEMS) + 1,
|
||||
desc="Render figures",
|
||||
unit="artifact",
|
||||
dynamic_ncols=True,
|
||||
disable=not show_progress,
|
||||
)
|
||||
|
||||
if show_progress:
|
||||
progress.set_postfix(current="01_input_output_length_cdf")
|
||||
_plot_two_series_cdf_with_zoom(
|
||||
fig_root,
|
||||
stem="01_input_output_length_cdf",
|
||||
title="Input / Output Length CDF",
|
||||
xlabel="Tokens",
|
||||
first_label="Input",
|
||||
first_values=[row["input_tokens"] for row in request_rows],
|
||||
first_color=PALETTE["blue"],
|
||||
second_label="Output",
|
||||
second_values=[row["output_tokens"] for row in request_rows],
|
||||
second_color=PALETTE["orange"],
|
||||
zoom_quantile=0.80,
|
||||
stats_labels=("mean", "p50", "p80", "p90", "p95", "p99"),
|
||||
)
|
||||
if show_progress:
|
||||
progress.update(1)
|
||||
progress.set_postfix(current="02_session_turns_cdf")
|
||||
_plot_session_turns_cdf(fig_root, request_rows)
|
||||
if show_progress:
|
||||
progress.update(1)
|
||||
progress.set_postfix(current="03_request_length_by_turn")
|
||||
_plot_request_length_by_turn(fig_root, request_rows)
|
||||
if show_progress:
|
||||
progress.update(1)
|
||||
progress.set_postfix(current="04_request_trigger_role_pie")
|
||||
_plot_trigger_role_pie(fig_root, request_rows)
|
||||
if show_progress:
|
||||
progress.update(1)
|
||||
progress.set_postfix(current="05_tool_call_output_length_cdf")
|
||||
_plot_single_cdf(
|
||||
fig_root,
|
||||
stem="05_tool_call_output_length_cdf",
|
||||
title="Tool Call Output Length CDF",
|
||||
xlabel="Output tokens",
|
||||
label="Tool-call output length",
|
||||
values=[row["tool_call_output_tokens"] for row in tool_round_edges],
|
||||
color=PALETTE["teal"],
|
||||
zoom_quantile=0.90,
|
||||
)
|
||||
if show_progress:
|
||||
progress.update(1)
|
||||
progress.set_postfix(current="06_tool_call_latency_cdf")
|
||||
_plot_single_cdf(
|
||||
fig_root,
|
||||
stem="06_tool_call_latency_cdf",
|
||||
title="Tool Call Latency CDF",
|
||||
xlabel="Milliseconds",
|
||||
label="Tool-call latency",
|
||||
values=[row["tool_call_latency_ms"] for row in tool_round_edges],
|
||||
color=PALETTE["red"],
|
||||
zoom_quantile=0.90,
|
||||
)
|
||||
if show_progress:
|
||||
progress.update(1)
|
||||
progress.set_postfix(current="07_consecutive_tool_call_count_cdf")
|
||||
_plot_consecutive_tool_calls_cdf(fig_root, session_rows_by_id)
|
||||
if show_progress:
|
||||
progress.update(1)
|
||||
progress.set_postfix(current="08_tool_call_added_context_cdf")
|
||||
_plot_single_cdf(
|
||||
fig_root,
|
||||
stem="08_tool_call_added_context_cdf",
|
||||
title="Added Context After Tool Call CDF",
|
||||
xlabel="Added context tokens",
|
||||
label="Added context",
|
||||
values=[row["added_context_tokens"] for row in tool_round_edges],
|
||||
color=PALETTE["purple"],
|
||||
)
|
||||
if show_progress:
|
||||
progress.update(1)
|
||||
progress.set_postfix(current="09_kvcache_block_reuse_time_cdf")
|
||||
_plot_single_cdf(
|
||||
fig_root,
|
||||
stem="09_kvcache_block_reuse_time_cdf",
|
||||
title="KV-Cache Block Reuse Time CDF",
|
||||
xlabel="Milliseconds",
|
||||
label="Reuse time",
|
||||
weighted_rows=reuse_gap_rows,
|
||||
weighted_value_key="reuse_gap_ms",
|
||||
weighted_count_key="count",
|
||||
color=PALETTE["gold"],
|
||||
zoom_quantile=0.90,
|
||||
)
|
||||
if show_progress:
|
||||
progress.update(1)
|
||||
progress.set_postfix(current="10_kvcache_block_lifecycle_cdf")
|
||||
_plot_single_cdf(
|
||||
fig_root,
|
||||
stem="10_kvcache_block_lifecycle_cdf",
|
||||
title="KV-Cache Block Lifecycle CDF",
|
||||
xlabel="Milliseconds",
|
||||
label="Block lifecycle",
|
||||
values=[safe_int(row["lifetime_ms"]) for row in block_lifetime_rows],
|
||||
color=PALETTE["gray"],
|
||||
)
|
||||
if show_progress:
|
||||
progress.update(1)
|
||||
progress.set_postfix(current="11_alive_kvcache_blocks_timeline")
|
||||
_plot_alive_kvcache_timeline(fig_root, timeline_rows)
|
||||
if show_progress:
|
||||
progress.update(1)
|
||||
progress.set_postfix(current="12_bucket_kvcache_reuse_ratio")
|
||||
_plot_bucket_reuse_ratio(fig_root, request_rows)
|
||||
if show_progress:
|
||||
progress.update(1)
|
||||
progress.set_postfix(current="13_session_cross_bucket_kvcache_miss")
|
||||
_plot_session_cross_bucket_miss(fig_root, session_bucket_rows)
|
||||
_plot_session_gap_cdf(fig_root, session_rows_by_id)
|
||||
|
||||
if show_progress:
|
||||
progress.update(1)
|
||||
progress.set_postfix(current="manifest.json + README.md")
|
||||
|
||||
manifest = {
|
||||
"dataset_title": dataset_title,
|
||||
"figure_count": len(FIGURE_STEMS),
|
||||
"analysis_dir": str(analysis_root),
|
||||
"request_count": details_summary.get("request_count", 0),
|
||||
"global_reuse_ratio": details_summary.get("global_reuse_ratio", 0.0),
|
||||
"figures": [f"{stem}.png" for stem in FIGURE_STEMS],
|
||||
"extra_figures": ["session_inter_request_gap_cdf.png"],
|
||||
}
|
||||
_write_manifest(fig_root, manifest)
|
||||
_write_readme(fig_root, dataset_title)
|
||||
if show_progress:
|
||||
progress.update(1)
|
||||
progress.close()
|
||||
|
||||
return {
|
||||
"fig_dir": str(fig_root),
|
||||
"manifest_path": str(fig_root / "manifest.json"),
|
||||
"readme_path": str(fig_root / "README.md"),
|
||||
}
|
||||
Reference in New Issue
Block a user