Initial commit

This commit is contained in:
2026-04-21 15:44:47 +00:00
commit bce3fe1395
40 changed files with 1758724 additions and 0 deletions

12
.gitignore vendored Normal file
View File

@@ -0,0 +1,12 @@
.venv/
__pycache__/
*.py[cod]
*.so
*.egg-info/
outputs/
scripts/
logs/
tmp/
trace-*

233
README.md Normal file
View File

@@ -0,0 +1,233 @@
# ali-trace
这个仓库现在按两阶段组织:
1. `trace_formatter` 负责把 provider 原始 trace 整理成统一的 formatter 输出。
2. `trace_analyzer` 只分析 formatter 产出的 `*-raw.jsonl`,以及同源 build-release 产出的带 `hash_ids` 的 release jsonl。
主约束:
- `trace_analyzer` 不再直接处理 provider-native 原始日志。
- `details/` 只放画图所需数据。
- `figures/` 只放最终图片。
- 当前主分析入口只维护 13 张图,对应的数据也只保留为这 13 张图服务的最小集合。
朴素 wall-clock 时间字段如 `time``UTC+8` 解释。
## Install
```bash
uv sync
```
## Step 1: Format Raw Trace
把原始 trace 格式化成一个统一、按时间排序的 `*-raw.jsonl`
格式化整个目录:
```bash
python -m trace_formatter format trace-glm5
```
格式化单个原始 shard
```bash
python -m trace_formatter format trace-qwen3-coder/0419-1500-1700.jsonl
```
默认输出:
```text
trace-*-formatted/<mmddhh-mmddhh>-raw.jsonl
```
formatter 会做这些事:
1. 提取 provider 原始 trace输出统一 schema 的 `*-raw.jsonl`
2. 基于 `user_id + messages` 重建 logical session并更新 `session_id / parent_request_id / turn / chat_id`
3. 保留 analyzer 需要的字段,包括:
- `canonical_prompt`
- `usage`
- 原始 `raw_messages`
- `meta.total_cost_time_ms`
- `meta.backend_first_request_time_ms`
- `meta.backend_first_response_time_ms`
4. 对像 `Qwen3-Coder-480B-A35B-Instruct` 这种自带 `tool_parser.py` 的模型,优先使用模型自带 parser 做 `messages -> canonical_prompt`
对于像 `0417-1500-1530.jsonl` 这种文件名formatter 还会推断时间窗口,并按 ready-time 截断,避免把窗口外请求拼进结果。
示例:
- `trace-glm5` -> `trace-glm5-formatted/041715-041717-raw.jsonl`
- `trace-qwen3-coder/0419-1500-1700.jsonl` -> `trace-qwen3-coder-formatted/041915-041917-raw.jsonl`
## Step 2: Build Release Trace
`*-raw.jsonl` 构建 open-source-ready release jsonl。
```bash
python -m trace_formatter build-release trace-glm5-formatted/041715-041717-raw.jsonl --jobs 16
```
默认输出:
```text
trace-*-formatted/<mmddhh-mmddhh>.jsonl
```
这个 release 文件包含 tokenizer/block 化后的 `hash_ids`,供 kvcache 理论复用分析使用。
## Step 3: Analyze Formatted Trace
主分析入口:
```bash
python -m trace_analyzer analyze trace-glm5-formatted/041715-041717-raw.jsonl
```
默认输出布局:
```text
outputs/
analysis/
<model>-<dataset>/
features.csv
summary.json
report.md
analysis_snapshot.json
details/
figures/
```
例如:
- `trace-glm5-formatted/041715-041717-raw.jsonl` -> `outputs/analysis/glm5-041715-041717/`
- `trace-qwen3-coder-formatted/041915-041917-raw.jsonl` -> `outputs/analysis/qwen3-coder-041915-041917/`
`analyze` 会自动寻找同目录 sibling release 文件,也就是把 `*-raw.jsonl``-raw` 去掉后的 `.jsonl`。如果没有,需要先执行 `trace_formatter build-release`
### `details/` 当前只保留的数据
`details/` 是数据层,不再放 PNG。
- `request_metrics.csv`
请求级主表。用于:
- input/output length
- session turns / turn 维度请求长度
- trigger role 占比
- tool-call output length / latency / added context
- bucket reuse ratio
- `theoretical_block_reuse_gaps.csv`
相同 kvcache block 每次 reuse 与上次 reuse/首次出现的时间间隔分布。
- `theoretical_block_lifetimes.csv`
每个 block 从首次出现到最后一次 reuse 的生命周期。
- `theoretical_alive_block_timeline.csv`
存活 block 数量随时间变化的事件时间线。
- `session_bucket_boundary_miss.csv`
session 前后轮请求跨分桶带来的 miss / reused-block loss 按 child bucket 的汇总。
- `details_summary.json`
details schema 版本、分桶定义、全局 reuse 汇总。
- `progress.json`
大 trace 运行时的进度与资源占用信息。
### `figures/` 当前只保留的 13 张图
`figures/` 是展示层,只放图片和轻量说明文件。
1. `01_input_output_length_cdf.png`
2. `02_session_turns_cdf.png`
3. `03_request_length_by_turn.png`
4. `04_request_trigger_role_pie.png`
5. `05_tool_call_output_length_cdf.png`
6. `06_tool_call_latency_cdf.png`
7. `07_consecutive_tool_call_count_cdf.png`
8. `08_tool_call_added_context_cdf.png`
9. `09_kvcache_block_reuse_time_cdf.png`
10. `10_kvcache_block_lifecycle_cdf.png`
11. `11_alive_kvcache_blocks_timeline.png`
12. `12_bucket_kvcache_reuse_ratio.png`
13. `13_session_cross_bucket_kvcache_miss.png`
另外还有:
- `manifest.json`
- `README.md`
所有图片都输出为 `600dpi` 的 PNG。
## Repo Structure
下面是当前推荐关注的代码结构,方便判断哪些文件是主路径、哪些是兼容/旧路径。
### Top Level
- `pyproject.toml`: 包配置、依赖、CLI entrypoints。
- `README.md`: 使用说明和当前代码结构。
- `trace_formatter/`: formatter 和 release builder。
- `trace_analyzer/`: analyzer。
- `trace_model_meta/`: tokenizer / chat template / model-specific parser 资源。
- `tests/`: 回归测试。
### `trace_formatter/`
- `trace_formatter/cli.py`: `format` / `build-release` CLI。
- `trace_formatter/formatting.py`: formatter 主逻辑和 release builder 主逻辑。
- `trace_formatter/raw_parser.py`: provider/model-specific message 解析,以及 `messages -> canonical_prompt`
- `trace_formatter/sessionization.py`: logical session 重建。
- `trace_formatter/time_windows.py`: shard 时间窗口解析和 ready-time 截断。
### `trace_analyzer/`
- `trace_analyzer/cli.py`
主 CLI。`analyze` 走当前主路径:
`preparation.py -> resume_advanced.py -> reporting.py -> figures.py`
- `trace_analyzer/parser.py`
只负责读取 formatter 生成的 `*-raw.jsonl`,反序列化成 `TraceRecord`
它不是 stale`parse / features / report / study` 这些子命令都在用。
- `trace_analyzer/models.py`
analyzer 内部 dataclass schema 边界。
它不是 stale`parser.py``features.py``report.py` 都依赖它。
- `trace_analyzer/preparation.py`
流式读取 `*-raw.jsonl`,生成 `features.csv`
它是 `analyze` 主路径的第一步,不是 stale。
- `trace_analyzer/features.py`
`TraceRecord` 提取请求级 feature。
- `trace_analyzer/resume_advanced.py`
当前 `analyze` 主路径的数据层明细分析器。
现在只负责生成 `details/` 下的画图数据,不再生成 PNG也不再写大量 transition/context-change/agentic 的旧产物。
- `trace_analyzer/figures.py`
当前唯一的绘图入口。
现在统一从 `details/` 读数据,输出 `figures/` 下的 13 张图。
- `trace_analyzer/reporting.py`
`features.csv``details_summary.json` 生成 `summary.json / report.md / analysis_snapshot.json`
- `trace_analyzer/report.py`
旧的 record-oriented report path主要服务 `parse / features / report / study` 这些 in-memory 子命令。
- `trace_analyzer/study.py`
兼容保留的旧 in-memory study 路径与辅助函数集合,主要给 `study` 子命令使用。
它不再是 `analyze` 主路径的明细产物生成入口。
- `trace_analyzer/helpers.py`
小工具函数。
- `trace_analyzer/layout.py`
`details/` / 旧 `advanced/` 的路径兼容和存在性判断。
### `trace_model_meta/`
- `trace_model_meta/registry.py`: model family 推断、tokenizer/chat-template 解析。
- `trace_model_meta/Qwen/Qwen3-Coder-480B-A35B-Instruct/qwen3coder_tool_parser.py`:
Qwen3-Coder 自带 tool parserformatter 会优先使用它来做 tool-call prompt 序列化。
### `tests/`
- `tests/test_ali_trace_pipeline.py`: formatter + analyzer 主流程回归。
- `tests/test_trace_analyzer.py`: analyzer 单元测试和子命令测试。
## Design Notes
当前代码清理后的边界是:
- trace format 工作集中在 `trace_formatter/`
- trace analyze 工作集中在 `trace_analyzer/`
- `trace_analyzer` 分析输入是 formatter 产出的 `*-raw.jsonl`,而不是 provider-native raw
- `details/` 只放数据
- `figures/` 只放图片

29
pyproject.toml Normal file
View File

@@ -0,0 +1,29 @@
[build-system]
requires = ["setuptools>=68"]
build-backend = "setuptools.build_meta"
[project]
name = "ali-trace"
version = "0.1.0"
description = "Two-step trace formatting and analysis pipeline for coding-agent traces."
requires-python = ">=3.11"
dependencies = [
"jinja2",
"matplotlib",
"psutil",
"tokenizers",
"tqdm",
"transformers",
]
[project.scripts]
trace-formatter = "trace_formatter.cli:main"
trace-analyzer = "trace_analyzer.cli:main"
[tool.setuptools.packages.find]
include = ["trace_analyzer", "trace_formatter", "trace_model_meta"]
[dependency-groups]
dev = [
"pytest>=9.0.3",
]

7
tests/conftest.py Normal file
View File

@@ -0,0 +1,7 @@
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))

View File

@@ -0,0 +1,728 @@
import json
import tempfile
import unittest
from datetime import datetime, timezone
from pathlib import Path
from trace_analyzer.cli import main as analyzer_main
from trace_analyzer.parser import load_records
from trace_formatter.cli import main as formatter_main
from trace_formatter.formatting import build_unified_row, discover_source_files, format_and_sort_trace
def utc_ms(value: str) -> int:
return int(datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=timezone.utc).timestamp() * 1000)
def wall_clock_ms_with_offset(value: str, offset_hours: int) -> int:
return utc_ms(value) - offset_hours * 60 * 60 * 1000
def make_raw_row(
request_id: str,
ready_ms: int,
tool_role: bool = False,
*,
raw_session_id: str = "sess-1",
user_id: str = "user-1",
messages: list[dict] | None = None,
time_text: str = "2026-04-17 15:00:00.000",
) -> dict:
if messages is None:
messages = [{"role": "user", "content": "hello"}]
if tool_role:
messages.extend(
[
{"role": "assistant", "content": "calling"},
{"role": "tool", "content": "tool-output"},
]
)
return {
"__time__": str(ready_ms // 1000),
"request_id": request_id,
"session_id": raw_session_id,
"request_model": "glm-5",
"time": time_text,
"status_code": "1000",
"status_name": "ok",
"total_cost_time": "250",
"request_params": json.dumps(
{
"header": {
"attributes": {
"x-dashscope-inner-requestreadytime": str(ready_ms),
"user_id": user_id,
}
},
"payload": {
"input": {"messages": messages},
"parameters": {
"tools": [{"type": "function", "function": {"name": "read_file"}}],
},
},
}
),
"response_params": json.dumps(
{
"header": {
"attributes": {
"x-ds-backend-first-request-time": str(ready_ms - 100),
"x-ds-backend-first-response-time": str(ready_ms + 150),
}
},
"payload": {
"output": {
"choices": [
{
"message": {
"role": "assistant",
"content": "done",
}
}
]
},
"usage": {
"input_tokens": 20,
"output_tokens": 5,
"total_tokens": 25,
"output_tokens_details": {"reasoning_tokens": 1},
"prompt_tokens_details": {"cached_tokens": 10},
},
}
}
),
}
def make_qwen_raw_row(
request_id: str,
ready_ms: int,
*,
raw_session_id: str = "qwen-sess-1",
user_id: str = "qwen-user-1",
time_text: str = "2026-04-19 15:00:00.000",
) -> dict:
return {
"__time__": str(ready_ms // 1000),
"request_id": request_id,
"session_id": raw_session_id,
"request_model": "qwen3-coder-plus-2025-09-23",
"time": time_text,
"status_code": "200",
"status_name": "OK",
"total_cost_time": "800",
"request_params": json.dumps(
{
"header": {
"attributes": {
"x-dashscope-inner-requestreadytime": str(ready_ms),
"user_id": user_id,
}
},
"payload": {
"input": {
"messages": [
{"role": "system", "content": "You are Qwen."},
{"role": "user", "content": "Write a function."},
]
},
"parameters": {
"tools": [
{
"type": "function",
"function": {
"name": "run_command",
"parameters": {
"type": "object",
"properties": {"cmd": {"type": "string"}},
},
},
}
]
},
},
},
ensure_ascii=False,
),
"response_params": json.dumps(
{
"header": {
"attributes": {
"x-ds-backend-first-request-time": str(ready_ms - 200),
"x-ds-backend-first-response-time": str(ready_ms + 300),
}
},
"payload": {
"output": {
"choices": [
{
"delta": {"role": "assistant", "content": "Sure."},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 128,
"completion_tokens": 16,
"total_tokens": 144,
"prompt_tokens_details": {"cached_tokens": 32},
},
}
}
},
ensure_ascii=False,
),
}
class AliTracePipelineTest(unittest.TestCase):
def test_build_unified_row_keeps_analysis_fields(self):
ready_ms = utc_ms("2026-04-17 15:00:00.321")
row = build_unified_row(make_raw_row("req-1", ready_ms), source_file="a.jsonl", source_line=3)
self.assertEqual(row["meta"]["request_id"], "req-1")
self.assertEqual(row["sort_time_ms"], ready_ms)
self.assertEqual(row["usage"]["cached_tokens"], 10)
self.assertEqual(row["meta"]["raw_session_id"], "sess-1")
self.assertEqual(row["meta"]["user_id"], "user-1")
self.assertEqual(row["meta"]["model_family"], "glm5")
self.assertEqual(row["meta"]["backend_first_request_time_ms"], ready_ms - 100)
self.assertEqual(row["meta"]["backend_first_response_time_ms"], ready_ms + 150)
self.assertEqual(row["meta"]["total_cost_time_ms"], 250)
self.assertEqual(row["declared_tools"][0]["name"], "read_file")
self.assertEqual(row["message_events"][0]["role"], "user")
self.assertEqual(row["raw_messages"][0]["content"], "hello")
self.assertIn("[gMASK]<sop>", row["canonical_prompt"])
def test_build_unified_row_supports_qwen_raw_trace_defaults(self):
ready_ms = utc_ms("2026-04-19 15:00:00.321")
row = build_unified_row(make_qwen_raw_row("qwen-req-1", ready_ms), source_file="qwen.jsonl", source_line=5)
self.assertEqual(row["meta"]["model_family"], "qwen3-coder")
self.assertEqual(row["usage"]["input_tokens"], 128)
self.assertEqual(row["usage"]["output_tokens"], 16)
self.assertEqual(row["usage"]["cached_tokens"], 32)
self.assertEqual(row["meta"]["backend_first_request_time_ms"], ready_ms - 200)
self.assertEqual(row["meta"]["backend_first_response_time_ms"], ready_ms + 300)
self.assertEqual(row["declared_tools"][0]["name"], "run_command")
self.assertIn("<|im_start|>system", row["canonical_prompt"])
def test_format_and_sort_trace_outputs_time_sorted_unified_rows(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
input_dir = root / "raw"
input_dir.mkdir()
output_path = root / "formatted.jsonl"
rows_by_file = {
"0417-1530-1600.jsonl": [
make_raw_row(
"req-late",
utc_ms("2026-04-17 15:00:03.000"),
time_text="2026-04-17 15:30:30.000",
),
make_raw_row(
"req-middle",
utc_ms("2026-04-17 15:00:02.000"),
tool_role=True,
time_text="2026-04-17 15:30:20.000",
),
],
"0417-1500-1530.jsonl": [
make_raw_row(
"req-first",
utc_ms("2026-04-17 15:00:01.000"),
time_text="2026-04-17 15:00:10.000",
),
],
}
for filename, rows in rows_by_file.items():
with (input_dir / filename).open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row) + "\n")
discovered = discover_source_files(input_dir)
self.assertEqual([path.name for path in discovered], sorted(rows_by_file))
stats = format_and_sort_trace(input_dir=input_dir, output_path=output_path, chunk_bytes=512)
self.assertEqual(stats["row_count"], 3)
records = load_records(output_path)
self.assertEqual([record.meta.request_id for record in records], ["req-first", "req-middle", "req-late"])
self.assertEqual(records[1].usage.cached_tokens, 10)
def test_format_and_sort_trace_reconstructs_logical_sessions_from_message_history(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
input_dir = root / "raw"
input_dir.mkdir()
output_path = root / "formatted.jsonl"
rows = [
make_raw_row(
"req-root-a",
utc_ms("2026-04-17 15:00:01.000"),
raw_session_id="111",
user_id="user-a",
messages=[{"role": "user", "content": "hello"}],
),
make_raw_row(
"req-root-b",
utc_ms("2026-04-17 15:00:01.500"),
raw_session_id="111",
user_id="user-b",
messages=[{"role": "user", "content": "hello"}],
),
make_raw_row(
"req-turn-2-a",
utc_ms("2026-04-17 15:00:02.000"),
raw_session_id="999",
user_id="user-a",
messages=[
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
{"role": "user", "content": "continue"},
],
),
]
with (input_dir / "0417-1500-1530.jsonl").open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row) + "\n")
format_and_sort_trace(input_dir=input_dir, output_path=output_path, chunk_bytes=256)
formatted_rows = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()]
self.assertEqual(formatted_rows[0]["meta"]["turn"], 1)
self.assertEqual(formatted_rows[1]["meta"]["turn"], 1)
self.assertEqual(formatted_rows[2]["meta"]["turn"], 2)
self.assertNotEqual(formatted_rows[0]["meta"]["session_id"], formatted_rows[1]["meta"]["session_id"])
self.assertEqual(formatted_rows[0]["meta"]["session_id"], formatted_rows[2]["meta"]["session_id"])
self.assertEqual(formatted_rows[2]["meta"]["parent_request_id"], "req-root-a")
def test_format_and_sort_trace_does_not_merge_sessions_on_shared_system_prompt_only(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
input_dir = root / "raw"
input_dir.mkdir()
output_path = root / "formatted.jsonl"
rows = [
make_raw_row(
"req-a",
utc_ms("2026-04-17 15:00:01.000"),
user_id="user-a",
messages=[
{"role": "system", "content": "shared system prompt"},
{"role": "user", "content": "question a"},
],
),
make_raw_row(
"req-b",
utc_ms("2026-04-17 15:00:02.000"),
user_id="user-a",
messages=[
{"role": "system", "content": "shared system prompt"},
{"role": "user", "content": "question b"},
],
),
]
with (input_dir / "0417-1500-1530.jsonl").open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row) + "\n")
format_and_sort_trace(input_dir=input_dir, output_path=output_path, chunk_bytes=256)
formatted_rows = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()]
self.assertEqual([row["meta"]["turn"] for row in formatted_rows], [1, 1])
self.assertNotEqual(formatted_rows[0]["meta"]["session_id"], formatted_rows[1]["meta"]["session_id"])
self.assertEqual([row["meta"]["parent_request_id"] for row in formatted_rows], ["", ""])
def test_format_and_sort_trace_reconstructs_sessions_when_child_drops_parent_suffix(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
input_dir = root / "raw"
input_dir.mkdir()
output_path = root / "formatted.jsonl"
rows = [
make_raw_row(
"req-root",
utc_ms("2026-04-17 15:00:01.000"),
raw_session_id="111",
user_id="user-a",
messages=[
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "draft answer"},
{"role": "user", "content": "continue"},
{"role": "assistant", "content": "long hidden reasoning that may be evicted"},
],
),
make_raw_row(
"req-turn-2",
utc_ms("2026-04-17 15:00:02.000"),
raw_session_id="999",
user_id="user-a",
messages=[
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "draft answer"},
{"role": "user", "content": "continue"},
],
),
]
with (input_dir / "0417-1500-1530.jsonl").open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row) + "\n")
format_and_sort_trace(input_dir=input_dir, output_path=output_path, chunk_bytes=256)
formatted_rows = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()]
self.assertEqual(formatted_rows[0]["meta"]["turn"], 1)
self.assertEqual(formatted_rows[1]["meta"]["turn"], 2)
self.assertEqual(formatted_rows[0]["meta"]["session_id"], formatted_rows[1]["meta"]["session_id"])
self.assertEqual(formatted_rows[1]["meta"]["parent_request_id"], "req-root")
def test_format_and_sort_trace_reconstructs_sessions_when_last_non_user_messages_change(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
input_dir = root / "raw"
input_dir.mkdir()
output_path = root / "formatted.jsonl"
rows = [
make_raw_row(
"req-root",
utc_ms("2026-04-17 15:00:01.000"),
user_id="user-a",
messages=[
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "draft answer"},
],
),
make_raw_row(
"req-regenerated",
utc_ms("2026-04-17 15:00:02.000"),
user_id="user-a",
messages=[
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "regenerated answer"},
],
),
]
with (input_dir / "0417-1500-1530.jsonl").open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row) + "\n")
format_and_sort_trace(input_dir=input_dir, output_path=output_path, chunk_bytes=256)
formatted_rows = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()]
self.assertEqual([row["meta"]["turn"] for row in formatted_rows], [1, 2])
self.assertEqual(formatted_rows[0]["meta"]["session_id"], formatted_rows[1]["meta"]["session_id"])
self.assertEqual(formatted_rows[1]["meta"]["parent_request_id"], "req-root")
def test_format_and_sort_trace_truncates_requests_before_inferred_window_start_by_ready_time(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
input_dir = root / "raw"
input_dir.mkdir()
output_path = root / "formatted.jsonl"
rows = [
make_raw_row(
"req-too-early",
utc_ms("2026-04-17 14:59:59.900"),
time_text="2026-04-17 15:00:00.100",
),
make_raw_row(
"req-kept",
utc_ms("2026-04-17 15:00:00.100"),
time_text="2026-04-17 15:00:00.100",
),
]
with (input_dir / "0417-1500-1530.jsonl").open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row) + "\n")
stats = format_and_sort_trace(input_dir=input_dir, output_path=output_path, chunk_bytes=256)
self.assertEqual(stats["row_count"], 1)
self.assertEqual(stats["truncated_row_count"], 1)
formatted_rows = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()]
self.assertEqual([row["meta"]["request_id"] for row in formatted_rows], ["req-kept"])
def test_format_and_sort_trace_filters_empty_messages_and_empty_response_params(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
input_dir = root / "raw"
input_dir.mkdir()
output_path = root / "formatted.jsonl"
empty_messages_row = make_raw_row("req-empty-messages", utc_ms("2026-04-17 15:00:01.000"))
empty_messages_request = json.loads(empty_messages_row["request_params"])
empty_messages_request["payload"]["input"]["messages"] = []
empty_messages_row["request_params"] = json.dumps(empty_messages_request)
empty_response_row = make_raw_row("req-empty-response", utc_ms("2026-04-17 15:00:02.000"))
empty_response_row["response_params"] = None
kept_row = make_raw_row("req-kept", utc_ms("2026-04-17 15:00:03.000"))
with (input_dir / "0417-1500-1530.jsonl").open("w", encoding="utf-8") as handle:
for row in [empty_messages_row, empty_response_row, kept_row]:
handle.write(json.dumps(row) + "\n")
stats = format_and_sort_trace(input_dir=input_dir, output_path=output_path, chunk_bytes=256)
self.assertEqual(stats["row_count"], 1)
self.assertEqual(stats["filtered_row_count"], 2)
self.assertEqual(stats["filtered_empty_messages_row_count"], 1)
self.assertEqual(stats["filtered_empty_response_params_row_count"], 1)
formatted_rows = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()]
self.assertEqual([row["meta"]["request_id"] for row in formatted_rows], ["req-kept"])
def test_trace_formatter_cli_formats_one_raw_jsonl_file(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
raw_path = root / "trace-glm5.jsonl"
output_root = root / "outputs" / "formatted"
with raw_path.open("w", encoding="utf-8") as handle:
handle.write(json.dumps(make_raw_row("req-1", utc_ms("2026-04-17 15:00:01.000"))) + "\n")
handle.write(json.dumps(make_raw_row("req-2", utc_ms("2026-04-17 15:00:02.000"))) + "\n")
exit_code = formatter_main(
["format", str(raw_path), "--output-root", str(output_root), "--chunk-bytes", "256"]
)
self.assertEqual(exit_code, 0)
raw_formatted_path = output_root / "trace-glm5-raw.jsonl"
self.assertTrue(raw_formatted_path.exists())
records = load_records(raw_formatted_path)
self.assertEqual([record.meta.request_id for record in records], ["req-1", "req-2"])
def test_trace_formatter_cli_defaults_to_trace_formatted_dir_for_trace_file(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
trace_dir = root / "trace-demo"
trace_dir.mkdir()
raw_path = trace_dir / "0417-1500-1530.jsonl"
with raw_path.open("w", encoding="utf-8") as handle:
handle.write(json.dumps(make_raw_row("req-1", utc_ms("2026-04-17 15:00:01.000"))) + "\n")
exit_code = formatter_main(["format", str(raw_path)])
self.assertEqual(exit_code, 0)
raw_formatted_path = root / "trace-demo-formatted" / "041715-041715-raw.jsonl"
self.assertTrue(raw_formatted_path.exists())
records = load_records(raw_formatted_path)
self.assertEqual([record.meta.request_id for record in records], ["req-1"])
def test_trace_formatter_cli_build_release_from_raw_in_second_stage(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
raw_path = root / "trace-glm5.jsonl"
output_root = root / "outputs" / "formatted"
with raw_path.open("w", encoding="utf-8") as handle:
handle.write(json.dumps(make_raw_row("req-1", utc_ms("2026-04-17 15:00:01.000"))) + "\n")
handle.write(json.dumps(make_raw_row("req-2", utc_ms("2026-04-17 15:00:02.000"))) + "\n")
format_exit_code = formatter_main(
["format", str(raw_path), "--output-root", str(output_root), "--chunk-bytes", "256"]
)
self.assertEqual(format_exit_code, 0)
raw_formatted_path = output_root / "trace-glm5-raw.jsonl"
release_formatted_path = output_root / "trace-glm5.jsonl"
release_exit_code = formatter_main(
["build-release", str(raw_formatted_path), "--jobs", "2", "--block-size", "8"]
)
self.assertEqual(release_exit_code, 0)
self.assertTrue(release_formatted_path.exists())
release_rows = [json.loads(line) for line in release_formatted_path.read_text(encoding="utf-8").splitlines()]
self.assertEqual(len(release_rows), 2)
self.assertEqual(sorted(release_rows[0]), ["chat_id", "hash_ids", "input_length", "output_length", "parent_chat_id", "timestamp", "turn", "type"])
self.assertEqual([row["chat_id"] for row in release_rows], [0, 1])
def test_trace_formatter_cli_can_write_progress_log(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
raw_path = root / "trace-glm5.jsonl"
output_root = root / "formatted"
log_path = root / "tmp" / "format.log"
with raw_path.open("w", encoding="utf-8") as handle:
handle.write(json.dumps(make_raw_row("req-1", utc_ms("2026-04-17 15:00:01.000"))) + "\n")
exit_code = formatter_main(
[
"format",
str(raw_path),
"--output-root",
str(output_root),
"--chunk-bytes",
"256",
"--log-file",
str(log_path),
]
)
self.assertEqual(exit_code, 0)
self.assertTrue(log_path.exists())
self.assertIn("Scan raw trace", log_path.read_text(encoding="utf-8"))
def test_format_and_sort_trace_infers_window_in_ready_time_scale_when_wall_clock_has_timezone_offset(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
input_dir = root / "raw"
input_dir.mkdir()
output_path = root / "formatted.jsonl"
rows = [
make_raw_row(
"req-too-early",
wall_clock_ms_with_offset("2026-04-19 14:59:59.900", 8),
time_text="2026-04-19 14:59:59.900",
),
make_raw_row(
"req-kept-start",
wall_clock_ms_with_offset("2026-04-19 15:00:00.100", 8),
time_text="2026-04-19 15:00:00.100",
),
make_raw_row(
"req-kept-end",
wall_clock_ms_with_offset("2026-04-19 16:59:59.900", 8),
time_text="2026-04-19 16:59:59.900",
),
make_raw_row(
"req-too-late",
wall_clock_ms_with_offset("2026-04-19 17:00:00.000", 8),
time_text="2026-04-19 17:00:00.000",
),
]
with (input_dir / "0419-1500-1700.jsonl").open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row) + "\n")
stats = format_and_sort_trace(input_dir=input_dir, output_path=output_path, chunk_bytes=256)
self.assertEqual(stats["row_count"], 2)
self.assertEqual(stats["truncated_row_count"], 2)
formatted_rows = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()]
self.assertEqual([row["meta"]["request_id"] for row in formatted_rows], ["req-kept-start", "req-kept-end"])
def test_format_and_sort_trace_normalizes_mixed_ready_time_scales_before_sorting(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
input_dir = root / "raw"
input_dir.mkdir()
output_path = root / "formatted.jsonl"
standard_row = make_raw_row(
"req-standard",
wall_clock_ms_with_offset("2026-04-19 15:01:00.000", 8),
time_text="2026-04-19 15:01:00.000",
)
anomalous_row = make_raw_row(
"req-anomalous",
utc_ms("2026-04-19 15:00:10.000"),
time_text="2026-04-19 15:00:10.000",
)
with (input_dir / "0419-1500-1700.jsonl").open("w", encoding="utf-8") as handle:
handle.write(json.dumps(standard_row) + "\n")
handle.write(json.dumps(anomalous_row) + "\n")
stats = format_and_sort_trace(input_dir=input_dir, output_path=output_path, chunk_bytes=256)
self.assertEqual(stats["row_count"], 1)
self.assertEqual(stats["truncated_row_count"], 1)
formatted_rows = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()]
self.assertEqual([row["meta"]["request_id"] for row in formatted_rows], ["req-standard"])
def test_trace_analyzer_analyze_writes_reports_and_figures_under_outputs(self):
with tempfile.TemporaryDirectory() as temp_dir:
root = Path(temp_dir)
raw_dir = root / "trace-glm5"
raw_dir.mkdir()
formatted_root = root / "outputs" / "formatted"
analysis_root = root / "outputs" / "analysis"
raw_path = raw_dir / "0417-1500-1530.jsonl"
with raw_path.open("w", encoding="utf-8") as handle:
handle.write(
json.dumps(
make_raw_row(
"req-1",
utc_ms("2026-04-17 15:00:01.000"),
user_id="user-1",
)
)
+ "\n"
)
handle.write(
json.dumps(
make_raw_row(
"req-2",
utc_ms("2026-04-17 15:00:02.000"),
user_id="user-1",
messages=[
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
{"role": "user", "content": "continue"},
],
)
)
+ "\n"
)
formatter_exit_code = formatter_main(["format", str(raw_dir), "--output-root", str(formatted_root)])
self.assertEqual(formatter_exit_code, 0)
formatted_path = formatted_root / "041715-041715-raw.jsonl"
release_exit_code = formatter_main(["build-release", str(formatted_path), "--jobs", "1", "--block-size", "8"])
self.assertEqual(release_exit_code, 0)
analyzer_exit_code = analyzer_main(
[
"analyze",
str(formatted_path),
"--output-root",
str(analysis_root),
"--segment-mode",
"bytes",
"--block-size",
"8",
]
)
self.assertEqual(analyzer_exit_code, 0)
analysis_dir = analysis_root / "041715-041715"
self.assertFalse((analysis_dir / "normalized.jsonl").exists())
self.assertTrue((analysis_dir / "features.csv").exists())
self.assertTrue((analysis_dir / "summary.json").exists())
self.assertTrue((analysis_dir / "report.md").exists())
self.assertTrue((analysis_dir / "details" / "details_summary.json").exists())
self.assertFalse(any((analysis_dir / "details").glob("*.sqlite*")))
self.assertTrue((analysis_dir / "details" / "request_metrics.csv").exists())
self.assertTrue((analysis_dir / "details" / "theoretical_block_reuse_gaps.csv").exists())
self.assertTrue((analysis_dir / "details" / "session_bucket_boundary_miss.csv").exists())
self.assertTrue((analysis_dir / "details" / "theoretical_alive_block_timeline.csv").exists())
self.assertTrue((analysis_dir / "figures" / "01_input_output_length_cdf.png").exists())
self.assertTrue((analysis_dir / "figures" / "02_session_turns_cdf.png").exists())
self.assertTrue((analysis_dir / "figures" / "13_session_cross_bucket_kvcache_miss.png").exists())
self.assertTrue((analysis_dir / "figures" / "manifest.json").exists())
self.assertFalse((root / "figs").exists())
features_mtime_ns = (analysis_dir / "features.csv").stat().st_mtime_ns
details_summary_mtime_ns = (analysis_dir / "details" / "details_summary.json").stat().st_mtime_ns
analyzer_exit_code = analyzer_main(
[
"analyze",
str(formatted_path),
"--output-root",
str(analysis_root),
"--segment-mode",
"bytes",
"--block-size",
"8",
]
)
self.assertEqual(analyzer_exit_code, 0)
self.assertEqual((analysis_dir / "features.csv").stat().st_mtime_ns, features_mtime_ns)
self.assertEqual(
(analysis_dir / "details" / "details_summary.json").stat().st_mtime_ns,
details_summary_mtime_ns,
)

View File

@@ -0,0 +1,772 @@
import csv
import json
import subprocess
import sys
import tempfile
import unittest
from pathlib import Path
from trace_analyzer.features import compute_features
from trace_analyzer.parser import default_output_dir, infer_analysis_dataset_name, load_records
from trace_analyzer.report import build_summary
from trace_analyzer.study import (
build_alive_block_timeline,
build_input_length_bucket_defs,
compute_theoretical_cache,
parse_input_length_bucket_thresholds,
summarize_cache_reuse_by_input_length_bucket,
summarize_session_bucket_boundary_miss,
)
from trace_formatter.formatting import format_and_sort_trace
def make_record(
request_id,
session_id,
messages,
tools,
usage,
total_cost_time,
status_code="1000",
model="glm-5",
):
return {
"request_id": request_id,
"session_id": session_id,
"request_model": model,
"time": "2026-04-09 09:00:00.000",
"status_code": status_code,
"status_name": "ok",
"total_cost_time": str(total_cost_time),
"request_params": json.dumps(
{
"payload": {
"input": {"messages": messages},
"parameters": {"tools": tools},
}
},
ensure_ascii=False,
),
"response_params": json.dumps(
{
"header": {
"attributes": {
"x-ds-backend-first-request-time": "123",
"x-ds-backend-first-response-time": "456",
}
},
"payload": {
"usage": usage,
}
},
ensure_ascii=False,
),
}
def make_qwen_record(request_id="qwen-1", session_id="sess-qwen-1", total_cost_time=900):
return {
"request_id": request_id,
"session_id": session_id,
"request_model": "qwen3-coder-plus-2025-09-23",
"time": "2026-04-19 15:00:00.000",
"status_code": "200",
"status_name": "OK",
"total_cost_time": str(total_cost_time),
"request_params": json.dumps(
{
"header": {
"attributes": {
"x-dashscope-inner-requestreadytime": "1776582000000",
}
},
"payload": {
"input": {
"messages": [
{"role": "system", "content": "You are Qwen."},
{"role": "user", "content": "List files"},
]
},
"parameters": {
"tools": [
{
"type": "function",
"function": {
"name": "run_command",
"parameters": {
"type": "object",
"properties": {"cmd": {"type": "string"}},
},
},
}
]
},
},
},
ensure_ascii=False,
),
"response_params": json.dumps(
{
"header": {
"attributes": {
"x-ds-backend-first-request-time": "1776581999083",
"x-ds-backend-first-response-time": "1776581999918",
}
},
"payload": {
"output": {
"choices": [
{
"delta": {"role": "assistant", "content": "ls"},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 90,
"completion_tokens": 12,
"total_tokens": 102,
"prompt_tokens_details": {"cached_tokens": 24},
},
}
}
},
ensure_ascii=False,
),
}
class TraceAnalyzerTest(unittest.TestCase):
def write_raw_fixture(self, rows):
temp_dir = tempfile.TemporaryDirectory()
path = Path(temp_dir.name) / "trace.jsonl"
with open(path, "w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
self.addCleanup(temp_dir.cleanup)
return path
def format_fixture(self, rows):
raw_path = self.write_raw_fixture(rows)
formatted_path = raw_path.parent / "trace-raw.jsonl"
format_and_sort_trace(
input_dir=raw_path,
output_path=formatted_path,
chunk_bytes=256,
truncate_to_window=False,
)
return formatted_path
def test_infer_analysis_dataset_name_includes_model_slug_from_formatted_parent(self):
path = Path("trace-qwen3-coder-formatted/041915-041917-raw.jsonl")
self.assertEqual(
infer_analysis_dataset_name(path),
"qwen3-coder-041915-041917",
)
self.assertEqual(
default_output_dir(path),
Path("outputs/analysis/qwen3-coder-041915-041917"),
)
def test_load_records_parses_formatter_output(self):
path = self.format_fixture(
[
make_record(
"req-1",
"sess-1",
[
{
"role": "system",
"content": [
{
"text": "sys",
"cache_control": {"type": "ephemeral"},
}
],
},
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "working"},
{"role": "tool", "content": "tool-output"},
],
[{"type": "function", "function": {"name": "read"}}],
{
"input_tokens": 100,
"output_tokens": 20,
"total_tokens": 120,
"output_tokens_details": {"reasoning_tokens": 7},
"prompt_tokens_details": {"cached_tokens": 60},
},
total_cost_time=500,
)
]
)
records = load_records(path)
self.assertEqual(len(records), 1)
record = records[0]
self.assertEqual(record.meta.request_id, "req-1")
self.assertEqual(record.meta.line_number, 1)
self.assertEqual(record.messages[0].has_cache_control, True)
self.assertEqual(record.declared_tools[0].name, "read")
self.assertEqual(record.usage.cached_tokens, 60)
self.assertEqual(record.usage.reasoning_tokens, 7)
self.assertEqual(record.meta.backend_first_request_time_ms, 123)
self.assertEqual(record.meta.backend_first_response_time_ms, 456)
self.assertEqual(record.meta.total_cost_time_ms, 500)
self.assertEqual(record.raw_messages[0]["role"], "system")
self.assertTrue(record.canonical_prompt)
self.assertIn("[gMASK]<sop>", record.canonical_prompt)
self.assertIn("<|assistant|><think>", record.canonical_prompt)
self.assertIn("<tools>", record.canonical_prompt)
def test_load_records_keeps_formatter_normalized_tool_calls_and_textual_tool_content(self):
path = self.format_fixture(
[
make_record(
"req-tool-shape",
"sess-tool-shape",
[
{"role": "user", "content": "hello"},
{
"role": "assistant",
"content": "calling tool",
"tool_calls": [
{
"id": "call-1",
"type": "function",
"name": "read_file",
"arguments": "{\"path\": \"/tmp/a.txt\"}",
}
],
},
{
"role": "tool",
"content": [
{
"text": "tool-output",
"cache_control": {"type": "ephemeral"},
}
],
},
],
[{"type": "function", "function": {"name": "read_file"}}],
{
"input_tokens": 50,
"output_tokens": 10,
"total_tokens": 60,
"prompt_tokens_details": {"cached_tokens": 0},
},
total_cost_time=100,
)
]
)
record = load_records(path)[0]
self.assertIn("<tool_call>read_file", record.canonical_prompt)
self.assertIn("<arg_key>path</arg_key>", record.canonical_prompt)
self.assertIn("<tool_response>tool-output</tool_response>", record.canonical_prompt)
def test_load_records_supports_qwen_formatter_output(self):
path = self.format_fixture([make_qwen_record()])
record = load_records(path)[0]
self.assertEqual(record.meta.provider, "qwen3-coder")
self.assertEqual(record.usage.input_tokens, 90)
self.assertEqual(record.usage.output_tokens, 12)
self.assertEqual(record.usage.cached_tokens, 24)
self.assertIn("<|im_start|>system", record.canonical_prompt)
self.assertIn("run_command", record.canonical_prompt)
def test_qwen_formatter_serializes_assistant_tool_calls_with_model_tool_parser_shape(self):
row = make_qwen_record()
request_params = json.loads(row["request_params"])
request_params["payload"]["input"]["messages"].append(
{
"role": "assistant",
"content": "calling tool",
"tool_calls": [
{
"id": "call-1",
"type": "function",
"function": {
"name": "run_command",
"arguments": "{\"cmd\": \"ls -la\"}",
},
}
],
}
)
row["request_params"] = json.dumps(request_params, ensure_ascii=False)
path = self.format_fixture([row])
record = load_records(path)[0]
self.assertIn("<tool_call>", record.canonical_prompt)
self.assertIn("<function=run_command>", record.canonical_prompt)
self.assertIn("<parameter=cmd>", record.canonical_prompt)
self.assertIn("ls -la", record.canonical_prompt)
def test_load_records_rejects_provider_raw_trace(self):
path = self.write_raw_fixture([make_qwen_record()])
with self.assertRaisesRegex(ValueError, r"formatter-generated \*-raw\.jsonl"):
load_records(path)
def test_input_length_bucket_cache_reuse_summary(self):
summary, bucket_rows = summarize_cache_reuse_by_input_length_bucket(
[
{
"input_tokens": 100,
"cached_tokens": 25,
"cache_hit_ratio": 0.25,
"theoretical_prompt_unit_length": 120,
"theoretical_prefix_hit_units": 60,
"theoretical_prefix_hit_ratio": 0.5,
},
{
"input_tokens": 40000,
"cached_tokens": 20000,
"cache_hit_ratio": 0.5,
"theoretical_prompt_unit_length": 42000,
"theoretical_prefix_hit_units": 31500,
"theoretical_prefix_hit_ratio": 0.75,
},
]
)
self.assertEqual(summary["request_count"], 2)
bucket_by_name = {row["bucket"]: row for row in bucket_rows}
self.assertEqual(bucket_by_name["0-32Ki"]["request_count"], 1)
self.assertEqual(bucket_by_name["32-85Ki"]["request_count"], 1)
self.assertAlmostEqual(bucket_by_name["0-32Ki"]["weighted_actual_cache_hit_ratio"], 0.25)
self.assertAlmostEqual(bucket_by_name["32-85Ki"]["weighted_theoretical_cache_hit_ratio"], 0.75)
def test_input_length_bucket_cache_reuse_summary_supports_custom_buckets(self):
summary, bucket_rows = summarize_cache_reuse_by_input_length_bucket(
[
{
"input_tokens": 40,
"cached_tokens": 10,
"cache_hit_ratio": 0.25,
"theoretical_prompt_unit_length": 50,
"theoretical_prefix_hit_units": 20,
"theoretical_prefix_hit_ratio": 0.4,
},
{
"input_tokens": 60,
"cached_tokens": 30,
"cache_hit_ratio": 0.5,
"theoretical_prompt_unit_length": 70,
"theoretical_prefix_hit_units": 35,
"theoretical_prefix_hit_ratio": 0.5,
},
],
bucket_defs=build_input_length_bucket_defs([50]),
)
self.assertEqual(
summary["bucket_definition"]["buckets"],
[
{
"bucket": "0-50",
"input_tokens_min_inclusive": 0,
"input_tokens_max_exclusive": 50,
},
{
"bucket": "50+",
"input_tokens_min_inclusive": 50,
"input_tokens_max_exclusive": None,
},
],
)
bucket_by_name = {row["bucket"]: row for row in bucket_rows}
self.assertEqual(bucket_by_name["0-50"]["request_count"], 1)
self.assertEqual(bucket_by_name["50+"]["request_count"], 1)
def test_input_length_bucket_cache_reuse_summary_tracks_bucketed_theoretical_upper_bound(self):
summary, bucket_rows = summarize_cache_reuse_by_input_length_bucket(
[
{
"input_tokens": 100,
"cached_tokens": 20,
"cache_hit_ratio": 0.2,
"theoretical_prompt_unit_length": 100,
"theoretical_prefix_hit_units": 60,
"theoretical_prefix_hit_ratio": 0.6,
"bucketed_theoretical_prefix_hit_units": 40,
"bucketed_theoretical_prefix_hit_ratio": 0.4,
}
],
bucket_defs=build_input_length_bucket_defs([200]),
)
self.assertEqual(summary["request_count"], 1)
row = bucket_rows[0]
self.assertEqual(row["bucket"], "0-200")
self.assertAlmostEqual(row["weighted_theoretical_cache_hit_ratio"], 0.6)
self.assertAlmostEqual(row["weighted_bucketed_theoretical_cache_hit_ratio"], 0.4)
self.assertAlmostEqual(row["weighted_bucket_boundary_loss_ratio"], 0.2)
self.assertAlmostEqual(row["bucketed_theoretical_reused_request_fraction"], 1.0)
def test_session_bucket_boundary_miss_summary_counts_cross_bucket_shared_prefix_loss(self):
summary, bucket_rows = summarize_session_bucket_boundary_miss(
[
{
"child_bucket": "0-32Ki",
"child_input_tokens": 100,
"shared_prefix_units": 8,
"is_cross_bucket": 1,
},
{
"child_bucket": "0-32Ki",
"child_input_tokens": 120,
"shared_prefix_units": 4,
"is_cross_bucket": 0,
},
],
bucket_defs=build_input_length_bucket_defs(),
)
self.assertEqual(summary["edge_count"], 2)
self.assertEqual(summary["cross_bucket_edge_count"], 1)
self.assertAlmostEqual(summary["cross_bucket_shared_prefix_unit_fraction"], 8 / 12)
bucket_by_name = {row["bucket"]: row for row in bucket_rows}
self.assertEqual(bucket_by_name["0-32Ki"]["edge_count"], 2)
self.assertAlmostEqual(
bucket_by_name["0-32Ki"]["cross_bucket_shared_prefix_unit_fraction"],
8 / 12,
)
def test_build_alive_block_timeline_counts_live_blocks_from_first_seen_to_last_reuse(self):
summary, rows = build_alive_block_timeline(
[
{"first_seen_ms": 10, "span_end_ms": 20},
{"first_seen_ms": 15, "span_end_ms": 15},
]
)
self.assertEqual(summary["peak_alive_blocks"], 2)
self.assertEqual(rows[0]["timestamp_ms"], 10)
self.assertEqual(rows[0]["alive_block_count"], 1)
row_by_ts = {row["timestamp_ms"]: row for row in rows}
self.assertEqual(row_by_ts[15]["alive_block_count"], 2)
self.assertEqual(row_by_ts[16]["alive_block_count"], 1)
self.assertEqual(row_by_ts[21]["alive_block_count"], 0)
def test_parse_input_length_bucket_thresholds_supports_ki_units(self):
self.assertEqual(
parse_input_length_bucket_thresholds("32Ki;85Ki;128Ki"),
[32 * 1024, 85 * 1024, 128 * 1024],
)
def test_compute_features_detects_bursts_and_cache(self):
path = self.format_fixture(
[
make_record(
"req-2",
"sess-2",
[
{"role": "user", "content": "u"},
{"role": "assistant", "content": "a"},
{"role": "tool", "content": "t1"},
{"role": "tool", "content": "t2"},
{"role": "tool", "content": "t3"},
{"role": "assistant", "content": "done"},
],
[{"type": "function", "function": {"name": "exec"}}],
{
"input_tokens": 40000,
"output_tokens": 200,
"total_tokens": 40200,
"prompt_tokens_details": {"cached_tokens": 1000},
},
total_cost_time=9000,
)
]
)
features = compute_features(load_records(path))
feature = features[0]
self.assertEqual(feature.assistant_to_tool_count, 1)
self.assertEqual(feature.tool_to_tool_count, 2)
self.assertEqual(feature.max_consecutive_tool_msgs, 3)
self.assertAlmostEqual(feature.cache_hit_ratio, 0.025)
self.assertIn("cache-cold", feature.pattern_labels)
self.assertIn("long-context-no-cache", feature.pattern_labels)
def test_compute_theoretical_cache_detects_prefix_reuse(self):
rows = [
make_record(
"req-a",
"sess-a",
[{"role": "user", "content": "prefix shared"}],
[],
{
"input_tokens": 10,
"output_tokens": 1,
"total_tokens": 11,
"prompt_tokens_details": {"cached_tokens": 0},
},
total_cost_time=10,
),
make_record(
"req-b",
"sess-a",
[{"role": "user", "content": "prefix shared and more"}],
[],
{
"input_tokens": 20,
"output_tokens": 1,
"total_tokens": 21,
"prompt_tokens_details": {"cached_tokens": 0},
},
total_cost_time=10,
),
]
first = json.loads(rows[0]["request_params"])
second = json.loads(rows[1]["request_params"])
first["header"] = {"attributes": {"x-dashscope-inner-requestreadytime": "1000"}}
second["header"] = {"attributes": {"x-dashscope-inner-requestreadytime": "2000"}}
rows[0]["request_params"] = json.dumps(first, ensure_ascii=False)
rows[1]["request_params"] = json.dumps(second, ensure_ascii=False)
theoretical = compute_theoretical_cache(
load_records(self.format_fixture(rows)),
block_size=8,
segment_mode="bytes",
)
request_rows = {row["request_id"]: row for row in theoretical["request_rows"]}
self.assertEqual(request_rows["req-a"]["theoretical_prefix_hit_ratio"], 0.0)
self.assertGreater(request_rows["req-b"]["theoretical_prefix_hit_ratio"], 0.0)
self.assertTrue(theoretical["reuse_gap_rows"])
reused_blocks = [row for row in theoretical["block_rows"] if row["reuse_count"] > 0]
self.assertTrue(reused_blocks)
self.assertIn("last_reuse_ms", reused_blocks[0])
self.assertIn("span_ms", reused_blocks[0])
self.assertGreaterEqual(reused_blocks[0]["lifetime_ms"], 0)
def test_report_cli_writes_outputs(self):
path = self.format_fixture(
[
make_record(
"req-3",
"sess-3",
[
{"role": "user", "content": "u"},
{"role": "assistant", "content": "a"},
],
[],
{
"input_tokens": 20,
"output_tokens": 5,
"total_tokens": 25,
"prompt_tokens_details": {"cached_tokens": 0},
},
total_cost_time=30,
),
make_record(
"req-4",
"sess-3",
[
{"role": "user", "content": "u"},
{"role": "assistant", "content": "a"},
{"role": "tool", "content": "t"},
{"role": "assistant", "content": "done"},
],
[{"type": "function", "function": {"name": "read"}}],
{
"input_tokens": 200,
"output_tokens": 50,
"total_tokens": 250,
"prompt_tokens_details": {"cached_tokens": 150},
},
total_cost_time=300,
),
]
)
with tempfile.TemporaryDirectory() as temp_dir:
completed = subprocess.run(
[
sys.executable,
"-m",
"trace_analyzer",
"report",
str(path),
"--output-dir",
temp_dir,
"--limit",
"2",
],
cwd=Path(__file__).resolve().parents[1],
check=True,
capture_output=True,
text=True,
)
self.assertIn("report.md", completed.stdout)
summary_path = Path(temp_dir) / "summary.json"
report_path = Path(temp_dir) / "report.md"
features_path = Path(temp_dir) / "features.csv"
self.assertTrue(summary_path.exists())
self.assertTrue(report_path.exists())
self.assertTrue(features_path.exists())
summary = json.loads(summary_path.read_text(encoding="utf-8"))
self.assertIn("tool_patterns", summary)
self.assertIn("cache_patterns", summary)
with open(features_path, "r", encoding="utf-8") as handle:
rows = list(csv.DictReader(handle))
self.assertEqual(len(rows), 2)
def test_study_cli_writes_advanced_outputs(self):
raw_rows = [
make_record(
"req-6",
"sess-6",
[{"role": "user", "content": "hello world"}],
[{"type": "function", "function": {"name": "read"}}],
{
"input_tokens": 40,
"output_tokens": 2,
"total_tokens": 42,
"prompt_tokens_details": {"cached_tokens": 10},
},
total_cost_time=100,
),
make_record(
"req-7",
"sess-6",
[
{"role": "assistant", "content": "a"},
{"role": "tool", "content": "result"},
{"role": "user", "content": "hello world again"},
],
[{"type": "function", "function": {"name": "read"}}],
{
"input_tokens": 60,
"output_tokens": 3,
"total_tokens": 63,
"prompt_tokens_details": {"cached_tokens": 20},
},
total_cost_time=150,
),
]
path = self.format_fixture(raw_rows)
with tempfile.TemporaryDirectory() as temp_dir:
subprocess.run(
[
sys.executable,
"-m",
"trace_analyzer",
"study",
str(path),
"--output-dir",
temp_dir,
"--block-size",
"8",
"--segment-mode",
"bytes",
"--input-length-buckets",
"50",
],
cwd=Path(__file__).resolve().parents[1],
check=True,
capture_output=True,
text=True,
)
self.assertTrue((Path(temp_dir) / "details" / "request_metrics.csv").exists())
self.assertTrue((Path(temp_dir) / "details" / "cdf_lengths.png").exists())
self.assertTrue((Path(temp_dir) / "details" / "tools_catalog.csv").exists())
bucket_summary = json.loads(
(Path(temp_dir) / "details" / "input_length_bucket_cache_reuse_summary.json").read_text(
encoding="utf-8"
)
)
self.assertEqual(
[row["bucket"] for row in bucket_summary["bucket_definition"]["buckets"]],
["0-50", "50+"],
)
def test_study_cli_reuses_existing_base_outputs(self):
raw_rows = [
make_record(
"req-8",
"sess-8",
[{"role": "user", "content": "prefix shared"}],
[{"type": "function", "function": {"name": "read"}}],
{
"input_tokens": 20,
"output_tokens": 2,
"total_tokens": 22,
"prompt_tokens_details": {"cached_tokens": 0},
},
total_cost_time=80,
),
make_record(
"req-9",
"sess-8",
[{"role": "user", "content": "prefix shared again"}],
[{"type": "function", "function": {"name": "read"}}],
{
"input_tokens": 30,
"output_tokens": 3,
"total_tokens": 33,
"prompt_tokens_details": {"cached_tokens": 10},
},
total_cost_time=120,
),
]
path = self.format_fixture(raw_rows)
with tempfile.TemporaryDirectory() as temp_dir:
subprocess.run(
[
sys.executable,
"-m",
"trace_analyzer",
"report",
str(path),
"--output-dir",
temp_dir,
],
cwd=Path(__file__).resolve().parents[1],
check=True,
capture_output=True,
text=True,
)
completed = subprocess.run(
[
sys.executable,
"-m",
"trace_analyzer",
"study",
str(path),
"--output-dir",
temp_dir,
"--block-size",
"8",
"--segment-mode",
"bytes",
],
cwd=Path(__file__).resolve().parents[1],
check=True,
capture_output=True,
text=True,
)
self.assertIn("details_summary.json", completed.stdout)
self.assertTrue((Path(temp_dir) / "details" / "progress.json").exists())
report_text = (Path(temp_dir) / "report.md").read_text(encoding="utf-8")
self.assertIn("Study Outputs", report_text)
def test_build_summary_contains_expected_keys(self):
path = self.format_fixture(
[
make_record(
"req-5",
"sess-5",
[{"role": "user", "content": "u"}],
[],
{
"input_tokens": 10,
"output_tokens": 1,
"total_tokens": 11,
"prompt_tokens_details": {"cached_tokens": 0},
},
total_cost_time=5,
)
]
)
records = load_records(path)
features = compute_features(records)
summary = build_summary(records, features)
self.assertIn("record_count", summary)
self.assertIn("tool_patterns", summary)
self.assertIn("cache_patterns", summary)
self.assertIn("anomalies", summary)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,10 @@
"""Trace analysis toolkit for coding-agent request logs."""
def main(argv=None):
from .cli import main as cli_main
return cli_main(argv)
__all__ = ["main"]

View File

@@ -0,0 +1,5 @@
from .cli import main
if __name__ == "__main__":
raise SystemExit(main())

477
trace_analyzer/cli.py Normal file
View File

@@ -0,0 +1,477 @@
import argparse
import json
from pathlib import Path
from tqdm.auto import tqdm
from .figures import render_figures
from .features import compute_features
from .layout import details_outputs_exist
from .parser import default_output_dir, infer_analysis_dataset_name, load_records, path_looks_like_release_trace
from .preparation import stream_prepare
from .report import write_features, write_normalized, write_report
from .reporting import write_reports
from .resume_advanced import collect_existing_detail_paths, run_advanced_from_existing
from .study import parse_input_length_bucket_thresholds, run_study
def build_parser():
parser = argparse.ArgumentParser(description="Analyze coding-agent trace patterns.")
subparsers = parser.add_subparsers(dest="command", required=True)
analyze_parser = subparsers.add_parser(
"analyze",
help="Run the full analysis workflow from one formatter-generated *-raw.jsonl trace.",
)
analyze_parser.add_argument("input", help="Path to the formatter-generated *-raw.jsonl trace.")
analyze_parser.add_argument(
"--release-input",
default=None,
help="Path to the formatter-generated release .jsonl with hash_ids. Defaults to the sibling file without the `-raw` suffix.",
)
analyze_parser.add_argument(
"--dataset-name",
default=None,
help="Dataset name used for output paths and figure titles. Defaults to the formatted trace stem.",
)
analyze_parser.add_argument(
"--output-dir",
default=None,
help="Explicit analysis output directory. Defaults to outputs/analysis/<dataset>/",
)
analyze_parser.add_argument("--output-root", default="outputs/analysis")
analyze_parser.add_argument(
"--figure-dir",
default=None,
help="Explicit figure directory. Defaults to <output-dir>/figures/.",
)
analyze_parser.add_argument(
"--block-size",
type=int,
default=256,
help="Block size for theoretical cache analysis.",
)
analyze_parser.add_argument(
"--segment-mode",
default="tokenizer",
choices=["bytes", "tokenizer"],
help="How to segment prompts for theoretical cache analysis.",
)
analyze_parser.add_argument(
"--tokenizer-path",
default=None,
help="Local path or model id for tokenizer mode. Defaults to the local resolved tokenizer path.",
)
analyze_parser.add_argument(
"--tokenizer-batch-size",
type=int,
default=64,
help="Batch size used by tokenizer-based theoretical cache analysis.",
)
analyze_parser.add_argument(
"--model-family",
default="auto",
help="Model family for tokenizer/chat-template metadata. Defaults to auto-detect.",
)
analyze_parser.add_argument(
"--model-meta-dir",
default=None,
help="Override the base directory that contains model_meta/<provider>/<model>/.",
)
analyze_parser.add_argument(
"--input-length-buckets",
default=None,
help="Semicolon-separated input-length bucket thresholds in tokens, such as `32768;87040;131072` or `32Ki;85Ki;128Ki`.",
)
parse_parser = subparsers.add_parser("parse", help="Normalize a formatter-generated *-raw.jsonl trace.")
_add_common_args(parse_parser)
parse_parser.add_argument(
"--format",
default="jsonl",
choices=["jsonl", "csv", "parquet"],
help="Normalized output format.",
)
features_parser = subparsers.add_parser("features", help="Extract request-level features.")
_add_common_args(features_parser)
report_parser = subparsers.add_parser("report", help="Generate markdown and json summary reports.")
_add_common_args(report_parser)
report_parser.add_argument(
"--normalized-format",
default="jsonl",
choices=["jsonl", "csv", "parquet"],
help="Also emit normalized records in this format.",
)
study_parser = subparsers.add_parser(
"study",
help="Generate data tables and CDF plots for lengths, cache reuse, and tool timing.",
)
_add_common_args(study_parser)
study_parser.add_argument(
"--normalized-format",
default="jsonl",
choices=["jsonl", "csv", "parquet"],
help="Normalized output format.",
)
study_parser.add_argument(
"--block-size",
type=int,
default=256,
help="Block size for theoretical cache analysis.",
)
study_parser.add_argument(
"--segment-mode",
default="tokenizer",
choices=["bytes", "tokenizer"],
help="How to segment prompts for theoretical cache analysis.",
)
study_parser.add_argument(
"--tokenizer-path",
default=None,
help="Local path or model id for tokenizer mode. Defaults to the local resolved tokenizer path.",
)
study_parser.add_argument(
"--tokenizer-batch-size",
type=int,
default=64,
help="Batch size used by tokenizer-based theoretical cache analysis.",
)
study_parser.add_argument(
"--model-family",
default="auto",
help="Model family for tokenizer/chat-template metadata. Defaults to auto-detect.",
)
study_parser.add_argument(
"--model-meta-dir",
default=None,
help="Override the base directory that contains model_meta/<provider>/<model>/.",
)
study_parser.add_argument(
"--input-length-buckets",
default=None,
help="Semicolon-separated input-length bucket thresholds in tokens, such as `32768;87040;131072` or `32Ki;85Ki;128Ki`.",
)
resume_parser = subparsers.add_parser(
"resume-details",
aliases=["resume-advanced"],
help="Reuse existing source trace (*-raw.jsonl or legacy normalized.jsonl) + features.csv and compute only detailed analysis outputs.",
)
resume_parser.add_argument("input", help="Path to formatter-generated *-raw.jsonl")
resume_parser.add_argument("features", help="Path to existing features.csv")
resume_parser.add_argument(
"--release-input",
default=None,
help="Path to the formatter-generated release .jsonl with hash_ids. Defaults to the sibling file without the `-raw` suffix.",
)
resume_parser.add_argument(
"--output-dir",
required=True,
help="Existing output directory to receive detailed analysis outputs.",
)
resume_parser.add_argument(
"--block-size",
type=int,
default=256,
help="Block size for theoretical cache analysis.",
)
resume_parser.add_argument(
"--segment-mode",
default="tokenizer",
choices=["bytes", "tokenizer"],
help="How to segment prompts for theoretical cache analysis.",
)
resume_parser.add_argument(
"--tokenizer-path",
default=None,
help="Local path or model id for tokenizer mode. Defaults to the local resolved tokenizer path.",
)
resume_parser.add_argument(
"--tokenizer-batch-size",
type=int,
default=64,
help="Batch size used by tokenizer-based theoretical cache analysis.",
)
resume_parser.add_argument(
"--model-family",
default="auto",
help="Model family for tokenizer/chat-template metadata. Defaults to auto-detect.",
)
resume_parser.add_argument(
"--model-meta-dir",
default=None,
help="Override the base directory that contains model_meta/<provider>/<model>/.",
)
resume_parser.add_argument(
"--limit",
type=int,
default=None,
help="Only process the first N source/features rows. Useful for throughput benchmarking.",
)
resume_parser.add_argument(
"--input-length-buckets",
default=None,
help="Semicolon-separated input-length bucket thresholds in tokens, such as `32768;87040;131072` or `32Ki;85Ki;128Ki`.",
)
return parser
def _add_common_args(parser):
parser.add_argument("input", help="Path to the formatter-generated *-raw.jsonl trace.")
parser.add_argument("--limit", type=int, default=None, help="Limit number of input lines.")
parser.add_argument(
"--output-dir",
default=None,
help="Output directory. Defaults to outputs/analysis/<input_stem>/",
)
def resolve_output_dir(input_path, output_dir):
return Path(output_dir) if output_dir else default_output_dir(input_path)
def _normalize_dataset_name(name: str) -> str:
text = str(name)
return text[:-4] if text.endswith("-raw") else text
def _resolve_analysis_output_dir(args):
dataset_name = args.dataset_name or _normalize_dataset_name(infer_analysis_dataset_name(args.input))
output_dir = Path(args.output_dir) if args.output_dir else Path(args.output_root) / dataset_name
figure_dir = Path(args.figure_dir) if args.figure_dir else output_dir / "figures"
return dataset_name, output_dir, figure_dir
def _resolve_release_input_path(raw_input: str, release_input: str | None) -> Path:
if release_input:
return Path(release_input)
raw_path = Path(raw_input)
name = raw_path.name
if name.endswith("-raw.jsonl"):
candidate = raw_path.with_name(name[:-len("-raw.jsonl")] + ".jsonl")
else:
raise ValueError("Expected a formatter-generated *-raw.jsonl input, or pass --release-input explicitly.")
return candidate
def _resolve_existing_release_input_path(raw_input: str, release_input: str | None) -> Path | None:
candidate = _resolve_release_input_path(raw_input, release_input)
if path_looks_like_release_trace(candidate):
return candidate
return None
def _existing_base_outputs(output_dir):
features = output_dir / "features.csv"
report = output_dir / "report.md"
if features.exists():
return {
"features": features,
"report": report if report.exists() else None,
}
return None
def _existing_detail_outputs(output_dir):
if not details_outputs_exist(output_dir):
return None
return collect_existing_detail_paths(output_dir)
def _stage_message(progress, step: int, total_steps: int, message: str) -> None:
tqdm.write(f"Stage {step}/{total_steps}: {message}")
progress.update(1)
progress.set_postfix(current=message)
def main(argv=None):
parser = build_parser()
args = parser.parse_args(argv)
if args.command == "analyze":
dataset_name, output_dir, figure_dir = _resolve_analysis_output_dir(args)
input_length_bucket_thresholds = parse_input_length_bucket_thresholds(args.input_length_buckets)
release_input_path = _resolve_existing_release_input_path(args.input, args.release_input)
if release_input_path is None:
raise FileNotFoundError(
f"Release trace not found for raw trace {args.input}. "
"Run `python -m trace_formatter build-release <raw-trace>` first, or pass --release-input."
)
total_steps = 4
progress = tqdm(
total=total_steps,
desc="Analyze trace",
unit="stage",
dynamic_ncols=True,
)
try:
prepare_result = None
reusable_base = _existing_base_outputs(output_dir)
if reusable_base:
_stage_message(progress, 1, total_steps, "reuse existing features.csv")
prepare_result = {
"features_path": str(reusable_base["features"]),
"reused": True,
}
else:
_stage_message(progress, 1, total_steps, "prepare features.csv")
prepare_result = stream_prepare(args.input, output_dir, show_progress=True)
reusable_details = _existing_detail_outputs(output_dir)
if reusable_details:
_stage_message(progress, 2, total_steps, "reuse existing details/")
advanced_paths = reusable_details
else:
_stage_message(
progress,
2,
total_steps,
"detailed analysis: request metrics, tool/session stats, kvcache stats",
)
advanced_paths = run_advanced_from_existing(
args.input,
release_input_path,
prepare_result["features_path"],
output_dir,
input_length_bucket_thresholds=input_length_bucket_thresholds,
show_progress=True,
)
_stage_message(progress, 3, total_steps, "reporting: summary.json, report.md, analysis_snapshot.json")
report_result = write_reports(
features_path=prepare_result["features_path"],
output_dir=output_dir,
pipeline_summary={
"dataset_name": dataset_name,
"formatted_path": str(Path(args.input)),
"release_path": str(release_input_path),
**{key: str(value) for key, value in advanced_paths.items()},
},
)
_stage_message(
progress,
4,
total_steps,
"figures: 13 approved request/session/tool/kvcache plots",
)
figure_result = render_figures(
analysis_dir=output_dir,
fig_dir=figure_dir,
dataset_title=dataset_name,
show_progress=True,
)
finally:
progress.close()
print(
json.dumps(
{
"dataset_name": dataset_name,
"formatted_path": str(Path(args.input)),
"output_dir": str(output_dir),
"prepare": prepare_result,
"details": {key: str(value) for key, value in advanced_paths.items()},
"report": report_result,
"figures": figure_result,
"release_path": str(release_input_path),
},
ensure_ascii=False,
indent=2,
)
)
return 0
if args.command in {"resume-details", "resume-advanced"}:
input_length_bucket_thresholds = parse_input_length_bucket_thresholds(args.input_length_buckets)
release_input_path = _resolve_existing_release_input_path(args.input, args.release_input)
if release_input_path is None:
raise FileNotFoundError(
f"Release trace not found for raw trace {args.input}. "
"Run `python -m trace_formatter build-release <raw-trace>` first, or pass --release-input."
)
paths = run_advanced_from_existing(
args.input,
release_input_path,
args.features,
args.output_dir,
input_length_bucket_thresholds=input_length_bucket_thresholds,
show_progress=True,
limit=args.limit,
)
for path in paths.values():
print(path)
return 0
output_dir = resolve_output_dir(args.input, args.output_dir)
if args.command == "study" and args.limit is None:
input_length_bucket_thresholds = parse_input_length_bucket_thresholds(args.input_length_buckets)
reusable = _existing_base_outputs(output_dir)
if reusable:
release_input_path = _resolve_existing_release_input_path(args.input, None)
if release_input_path is not None:
paths = _existing_detail_outputs(output_dir)
if paths is None:
paths = run_advanced_from_existing(
args.input,
release_input_path,
reusable["features"],
output_dir,
input_length_bucket_thresholds=input_length_bucket_thresholds,
show_progress=True,
)
for path in paths.values():
print(path)
return 0
show_progress = args.command == "study"
records = load_records(
args.input,
limit=args.limit,
show_progress=show_progress,
progress_desc="Load trace",
)
if args.command == "parse":
path = write_normalized(records, output_dir, output_format=args.format)
print(path)
return 0
features = compute_features(records)
if args.command == "features":
path = write_features(features, output_dir)
print(path)
return 0
if args.command == "study":
input_length_bucket_thresholds = parse_input_length_bucket_thresholds(args.input_length_buckets)
paths = run_study(
records,
output_dir,
normalized_format=args.normalized_format,
source_path=args.input,
block_size=args.block_size,
segment_mode=args.segment_mode,
tokenizer_path=args.tokenizer_path,
model_family=args.model_family,
model_meta_dir=args.model_meta_dir,
input_length_bucket_thresholds=input_length_bucket_thresholds,
show_progress=show_progress,
tokenizer_batch_size=args.tokenizer_batch_size,
)
for path in paths.values():
print(path)
return 0
normalized_path = write_normalized(records, output_dir, output_format=args.normalized_format)
features_path = write_features(features, output_dir)
summary_path, report_path = write_report(records, features, output_dir)
print(normalized_path)
print(features_path)
print(summary_path)
print(report_path)
return 0
if __name__ == "__main__":
raise SystemExit(main())

117
trace_analyzer/features.py Normal file
View File

@@ -0,0 +1,117 @@
from collections import Counter
from dataclasses import asdict
from .helpers import percentile, safe_div
from .models import TraceFeatures
LONG_CONTEXT_THRESHOLD = 32000
HIGH_CACHE_THRESHOLD = 0.8
TOOL_BURST_THRESHOLD = 4
TOOL_LOOP_THRESHOLD = 3
def _transition_count(roles, left, right):
return sum(1 for current, nxt in zip(roles, roles[1:]) if current == left and nxt == right)
def _tool_bursts(roles):
bursts = []
current = 0
for role in roles:
if role == "tool":
current += 1
elif current:
bursts.append(current)
current = 0
if current:
bursts.append(current)
return bursts
def compute_features(records):
features = []
for record in records:
role_counts = Counter(record.role_sequence)
bursts = _tool_bursts(record.role_sequence)
input_tokens = record.usage.input_tokens
output_tokens = record.usage.output_tokens
cached_tokens = record.usage.cached_tokens
latency_ms = record.meta.total_cost_time_ms
cache_hit_ratio = safe_div(cached_tokens, input_tokens)
tool_to_tool_count = _transition_count(record.role_sequence, "tool", "tool")
feature = TraceFeatures(
request_id=record.meta.request_id,
session_id=record.meta.session_id,
model=record.meta.request_model,
status_code=record.meta.status_code,
time=record.meta.time,
message_count=len(record.messages),
conversation_depth=len(record.messages),
declared_tool_count=len(record.declared_tools),
assistant_msg_count=role_counts.get("assistant", 0),
tool_msg_count=role_counts.get("tool", 0),
user_msg_count=role_counts.get("user", 0),
system_msg_count=role_counts.get("system", 0),
assistant_to_tool_count=_transition_count(record.role_sequence, "assistant", "tool"),
tool_to_assistant_count=_transition_count(record.role_sequence, "tool", "assistant"),
tool_to_tool_count=tool_to_tool_count,
assistant_to_user_count=_transition_count(record.role_sequence, "assistant", "user"),
user_to_assistant_count=_transition_count(record.role_sequence, "user", "assistant"),
max_consecutive_tool_msgs=max(bursts) if bursts else 0,
avg_tool_burst_len=safe_div(sum(bursts), len(bursts)) if bursts else 0.0,
has_tool_loop=1 if tool_to_tool_count > 0 else 0,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=record.usage.total_tokens,
reasoning_tokens=record.usage.reasoning_tokens,
cached_tokens=cached_tokens,
cache_hit_ratio=cache_hit_ratio,
uncached_prompt_tokens=max(input_tokens - cached_tokens, 0),
output_input_ratio=safe_div(output_tokens, input_tokens),
latency_ms=latency_ms,
ms_per_input_token=safe_div(latency_ms, input_tokens),
ms_per_output_token=safe_div(latency_ms, output_tokens),
long_context=1 if input_tokens >= LONG_CONTEXT_THRESHOLD else 0,
high_cache=1 if cache_hit_ratio >= HIGH_CACHE_THRESHOLD else 0,
tool_burst_alert=1 if (max(bursts) if bursts else 0) >= TOOL_BURST_THRESHOLD else 0,
tool_loop_alert=1 if tool_to_tool_count >= TOOL_LOOP_THRESHOLD else 0,
)
feature.pattern_labels = base_pattern_labels(feature)
features.append(feature)
apply_batch_thresholds(features)
return features
def base_pattern_labels(feature):
labels = []
if feature.tool_msg_count == 0 and feature.declared_tool_count == 0:
labels.append("single-shot")
if feature.tool_msg_count > 0 and feature.tool_msg_count >= feature.assistant_msg_count:
labels.append("tool-heavy")
if feature.max_consecutive_tool_msgs >= TOOL_BURST_THRESHOLD:
labels.append("tool-burst")
if feature.cache_hit_ratio >= HIGH_CACHE_THRESHOLD:
labels.append("cache-efficient")
if feature.cache_hit_ratio <= 0.1:
labels.append("cache-cold")
return labels
def apply_batch_thresholds(features):
if not features:
return
latency_p90 = percentile([feature.latency_ms for feature in features], 0.9)
for feature in features:
feature.slow_request = 1 if feature.latency_ms >= latency_p90 else 0
if feature.slow_request and feature.high_cache:
feature.pattern_labels.append("slow-despite-cache")
if feature.input_tokens >= LONG_CONTEXT_THRESHOLD and feature.cache_hit_ratio <= 0.1:
feature.pattern_labels.append("long-context-no-cache")
feature.pattern_labels = sorted(set(feature.pattern_labels))
def feature_to_row(feature):
row = asdict(feature)
row["pattern_labels"] = ";".join(feature.pattern_labels)
return row

809
trace_analyzer/figures.py Normal file
View File

@@ -0,0 +1,809 @@
from __future__ import annotations
import csv
import json
from collections import Counter, defaultdict
from pathlib import Path
import matplotlib
import numpy as np
from tqdm.auto import tqdm
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator, MultipleLocator
from trace_analyzer.helpers import percentile, safe_float, safe_int
from trace_analyzer.layout import resolve_details_dir
PALETTE = {
"blue": "#2B6CB0",
"orange": "#DD6B20",
"green": "#2F855A",
"red": "#C53030",
"purple": "#6B46C1",
"gray": "#4A5568",
"teal": "#0F766E",
"gold": "#B7791F",
"pink": "#D53F8C",
"grid": "#CBD5E0",
}
FIGURE_STEMS = [
"01_input_output_length_cdf",
"02_session_turns_cdf",
"03_request_length_by_turn",
"04_request_trigger_role_pie",
"05_tool_call_output_length_cdf",
"06_tool_call_latency_cdf",
"07_consecutive_tool_call_count_cdf",
"08_tool_call_added_context_cdf",
"09_kvcache_block_reuse_time_cdf",
"10_kvcache_block_lifecycle_cdf",
"11_alive_kvcache_blocks_timeline",
"12_bucket_kvcache_reuse_ratio",
"13_session_cross_bucket_kvcache_miss",
]
def _ensure_dir(path: Path) -> None:
path.mkdir(parents=True, exist_ok=True)
def _clear_dir_files(path: Path) -> None:
path.mkdir(parents=True, exist_ok=True)
for child in path.iterdir():
if child.is_file():
child.unlink()
def _apply_style() -> None:
plt.rcParams.update(
{
"figure.figsize": (8.0, 4.8),
"figure.dpi": 600,
"savefig.dpi": 600,
"font.family": "DejaVu Serif",
"font.size": 11,
"axes.titlesize": 13,
"axes.labelsize": 12,
"axes.linewidth": 0.9,
"xtick.labelsize": 10,
"ytick.labelsize": 10,
"legend.fontsize": 10,
"legend.frameon": False,
}
)
def _finalize_axes(ax: plt.Axes, *, grid_axis: str = "y") -> None:
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.grid(axis=grid_axis, color=PALETTE["grid"], alpha=0.5, linewidth=0.8)
ax.tick_params(axis="both", which="major", length=4, width=0.8)
def _save(fig: plt.Figure, fig_dir: Path, stem: str) -> None:
fig.savefig(fig_dir / f"{stem}.png", bbox_inches="tight")
plt.close(fig)
def _read_json(path: Path) -> dict:
return json.loads(path.read_text(encoding="utf-8"))
def _read_csv_rows(path: Path) -> list[dict]:
with path.open("r", encoding="utf-8") as handle:
return list(csv.DictReader(handle))
def _load_request_metrics(path: Path) -> list[dict]:
rows = []
with path.open("r", encoding="utf-8") as handle:
for row in csv.DictReader(handle):
rows.append(
{
"request_id": row.get("request_id", ""),
"session_id": row.get("session_id", ""),
"turn": safe_int(row.get("turn")),
"trigger_group": row.get("trigger_group", "") or "unknown",
"input_tokens": safe_int(row.get("input_tokens")),
"output_tokens": safe_int(row.get("output_tokens")),
"request_ready_time_ms": safe_int(row.get("request_ready_time_ms")),
"request_end_time_ms": safe_int(row.get("request_end_time_ms")),
"input_length_bucket": row.get("input_length_bucket", ""),
"theoretical_prompt_unit_length": safe_int(row.get("theoretical_prompt_unit_length")),
"theoretical_prefix_hit_blocks": safe_int(row.get("theoretical_prefix_hit_blocks")),
"bucketed_theoretical_prefix_hit_blocks": safe_int(
row.get("bucketed_theoretical_prefix_hit_blocks")
),
}
)
return rows
def _sort_request_rows(rows: list[dict]) -> list[dict]:
return sorted(
rows,
key=lambda row: (
row["request_ready_time_ms"],
row["turn"],
row["request_id"],
),
)
def _build_session_sequences(request_rows: list[dict]) -> dict[str, list[dict]]:
sessions = defaultdict(list)
for row in request_rows:
sessions[row["session_id"]].append(row)
for session_rows in sessions.values():
session_rows.sort(
key=lambda row: (
row["request_ready_time_ms"],
row["turn"],
row["request_id"],
)
)
return sessions
def _build_tool_round_edges(session_rows_by_id: dict[str, list[dict]]) -> list[dict]:
edges = []
for session_id, session_rows in session_rows_by_id.items():
for previous, current in zip(session_rows, session_rows[1:]):
if current["trigger_group"] != "tool":
continue
edges.append(
{
"session_id": session_id,
"prev_request_id": previous["request_id"],
"next_request_id": current["request_id"],
"tool_call_output_tokens": previous["output_tokens"],
"tool_call_latency_ms": max(
current["request_ready_time_ms"] - previous["request_end_time_ms"],
0,
),
"added_context_tokens": max(
current["input_tokens"] - previous["output_tokens"],
0,
),
}
)
return edges
def _ecdf(values: list[float]) -> tuple[np.ndarray, np.ndarray]:
arr = np.asarray([value for value in values if value is not None], dtype=float)
arr = np.sort(arr)
if arr.size == 0:
return arr, arr
xs, counts = np.unique(arr, return_counts=True)
ys = np.cumsum(counts, dtype=float) / arr.size
return xs, ys
def _ecdf_from_weighted_rows(rows: list[dict], *, value_key: str, count_key: str) -> tuple[np.ndarray, np.ndarray]:
weighted = sorted(
(
safe_float(row[value_key]),
safe_int(row[count_key]),
)
for row in rows
if safe_int(row.get(count_key)) > 0
)
total = sum(count for _, count in weighted)
if total <= 0:
return np.asarray([]), np.asarray([])
xs = np.asarray([value for value, _ in weighted], dtype=float)
ys = np.asarray(np.cumsum([count for _, count in weighted], dtype=float) / total, dtype=float)
return xs, ys
def _stats(values: list[float], labels: tuple[str, ...]) -> dict[str, float]:
cleaned = [value for value in values if value is not None]
if not cleaned:
return {label: 0.0 for label in labels}
mapping = {"mean": float(np.mean(cleaned))}
for label in labels:
if label == "mean":
continue
mapping[label] = percentile(cleaned, int(label[1:]) / 100)
return mapping
def _weighted_stats(rows: list[dict], *, value_key: str, count_key: str, labels: tuple[str, ...]) -> dict[str, float]:
weighted = sorted(
(
safe_float(row[value_key]),
safe_int(row[count_key]),
)
for row in rows
if safe_int(row.get(count_key)) > 0
)
total = sum(count for _, count in weighted)
if total <= 0:
return {label: 0.0 for label in labels}
result = {}
weighted_sum = sum(value * count for value, count in weighted)
result["mean"] = weighted_sum / total
for label in labels:
if label == "mean":
continue
target = int(label[1:]) / 100 * total
seen = 0
value_at_target = weighted[-1][0]
for value, count in weighted:
seen += count
if seen >= target:
value_at_target = value
break
result[label] = value_at_target
return result
def _format_stat_text(title: str, stats: dict[str, float], labels: tuple[str, ...]) -> str:
parts = [title]
for label in labels:
value = stats.get(label, 0.0)
if abs(value - round(value)) < 1e-6:
parts.append(f"{label}={int(round(value))}")
else:
parts.append(f"{label}={value:.2f}")
return " ".join(parts)
def _add_footer(fig: plt.Figure, lines: list[str]) -> None:
fig.subplots_adjust(bottom=0.24)
y = 0.06
for line in lines:
fig.text(0.5, y, line, ha="center", va="bottom", fontsize=9.5)
y -= 0.035
def _plot_two_series_cdf_with_zoom(
fig_dir: Path,
*,
stem: str,
title: str,
xlabel: str,
first_label: str,
first_values: list[float],
first_color: str,
second_label: str,
second_values: list[float],
second_color: str,
zoom_quantile: float,
stats_labels: tuple[str, ...],
) -> None:
first_xs, first_ys = _ecdf(first_values)
second_xs, second_ys = _ecdf(second_values)
zoom_max = max(
percentile(first_values, zoom_quantile) if first_values else 0.0,
percentile(second_values, zoom_quantile) if second_values else 0.0,
)
fig, axes = plt.subplots(1, 2, figsize=(12.4, 4.8))
for ax, subtitle in zip(axes, ["Full Range", f"Zoom: <= p{int(zoom_quantile * 100)}"]):
ax.step(first_xs, first_ys, where="post", linewidth=2.2, color=first_color, label=first_label)
ax.step(second_xs, second_ys, where="post", linewidth=2.2, color=second_color, label=second_label)
ax.set_title(subtitle)
ax.set_xlabel(xlabel)
ax.set_ylabel("CDF")
_finalize_axes(ax)
axes[1].set_xlim(0, zoom_max if zoom_max > 0 else 1)
axes[0].legend(loc="lower right")
fig.suptitle(title, y=0.98)
_add_footer(
fig,
[
_format_stat_text(first_label, _stats(first_values, stats_labels), stats_labels),
_format_stat_text(second_label, _stats(second_values, stats_labels), stats_labels),
],
)
_save(fig, fig_dir, stem)
def _plot_single_cdf(
fig_dir: Path,
*,
stem: str,
title: str,
xlabel: str,
label: str,
values: list[float] | None = None,
weighted_rows: list[dict] | None = None,
weighted_value_key: str | None = None,
weighted_count_key: str | None = None,
color: str = PALETTE["blue"],
zoom_quantile: float | None = None,
stats_labels: tuple[str, ...] = ("mean", "p50", "p90", "p95", "p99"),
) -> None:
values = values or []
weighted_rows = weighted_rows or []
if weighted_rows:
xs, ys = _ecdf_from_weighted_rows(
weighted_rows,
value_key=weighted_value_key,
count_key=weighted_count_key,
)
stats = _weighted_stats(
weighted_rows,
value_key=weighted_value_key,
count_key=weighted_count_key,
labels=stats_labels,
)
zoom_max = stats.get(f"p{int(zoom_quantile * 100)}", 0.0) if zoom_quantile is not None else 0.0
else:
xs, ys = _ecdf(values)
stats = _stats(values, stats_labels)
zoom_max = percentile(values, zoom_quantile) if zoom_quantile is not None and values else 0.0
panel_count = 2 if zoom_quantile is not None else 1
fig, axes = plt.subplots(1, panel_count, figsize=(12.4, 4.8) if panel_count == 2 else (8.2, 4.8))
if panel_count == 1:
axes = [axes]
axes[0].step(xs, ys, where="post", linewidth=2.2, color=color)
axes[0].set_title("Full Range")
axes[0].set_xlabel(xlabel)
axes[0].set_ylabel("CDF")
_finalize_axes(axes[0])
if panel_count == 2:
axes[1].step(xs, ys, where="post", linewidth=2.2, color=color)
axes[1].set_title(f"Zoom: <= p{int(zoom_quantile * 100)}")
axes[1].set_xlabel(xlabel)
axes[1].set_ylabel("CDF")
axes[1].set_xlim(0, zoom_max if zoom_max > 0 else 1)
_finalize_axes(axes[1])
fig.suptitle(title, y=0.98)
_add_footer(fig, [_format_stat_text(label, stats, stats_labels)])
_save(fig, fig_dir, stem)
def _plot_session_turns_cdf(fig_dir: Path, request_rows: list[dict]) -> None:
session_sizes = Counter(row["session_id"] for row in request_rows)
values = list(session_sizes.values())
xs, ys = _ecdf(values)
max_turn = max(values) if values else 1
zoom_max = max(int(np.ceil(max_turn * 0.10)), 1)
fig, axes = plt.subplots(1, 2, figsize=(12.4, 4.8))
for ax, subtitle in zip(axes, ["Full Range", f"Zoom: <= {zoom_max} turns (first 10% of max turn)"]):
ax.step(xs, ys, where="post", linewidth=2.2, color=PALETTE["green"])
ax.set_title(subtitle)
ax.set_xlabel("Turns per session")
ax.set_ylabel("CDF")
_finalize_axes(ax)
axes[1].set_xlim(0.5, zoom_max + 0.5)
fig.suptitle("Session Turns CDF", y=0.98)
_add_footer(
fig,
[
_format_stat_text(
"Session turns",
_stats(values, ("mean", "p50", "p90", "p95", "p99")),
("mean", "p50", "p90", "p95", "p99"),
)
],
)
_save(fig, fig_dir, "02_session_turns_cdf")
def _plot_request_length_by_turn(fig_dir: Path, request_rows: list[dict]) -> None:
values_by_turn = defaultdict(list)
for row in request_rows:
if row["turn"] > 0:
values_by_turn[row["turn"]].append(row["input_tokens"])
turns = sorted(values_by_turn)
mean_values = [float(np.mean(values_by_turn[turn])) for turn in turns]
p50_values = [percentile(values_by_turn[turn], 0.50) for turn in turns]
p99_values = [percentile(values_by_turn[turn], 0.99) for turn in turns]
fig, ax = plt.subplots(figsize=(8.6, 4.8))
ax.plot(turns, mean_values, color=PALETTE["blue"], linewidth=2.0, label="mean")
ax.plot(turns, p50_values, color=PALETTE["orange"], linewidth=2.0, label="p50")
ax.plot(turns, p99_values, color=PALETTE["red"], linewidth=2.0, label="p99")
ax.set_title("Request Input Length by Turn")
ax.set_xlabel("Turn")
ax.set_ylabel("Input tokens")
ax.legend(loc="upper left")
ax.xaxis.set_major_locator(MaxNLocator(nbins=12, integer=True))
plt.setp(ax.get_xticklabels(), rotation=20, ha="right")
_finalize_axes(ax)
fig.tight_layout()
_save(fig, fig_dir, "03_request_length_by_turn")
def _plot_trigger_role_pie(fig_dir: Path, request_rows: list[dict]) -> None:
label_order = ["user", "tool", "assistant"]
color_by_label = {
"user": PALETTE["orange"],
"tool": PALETTE["green"],
"assistant": PALETTE["blue"],
}
counts = Counter(row["trigger_group"] for row in request_rows)
labels = [label for label in label_order if counts[label] > 0]
values = [counts[label] for label in labels]
colors = [color_by_label[label] for label in labels]
def _autopct(pct):
total = sum(values)
count = int(round(pct * total / 100.0))
return f"{pct:.1f}%\n({count})"
fig, ax = plt.subplots(figsize=(9.0, 5.8))
wedges, _texts, autotexts = ax.pie(
values,
autopct=_autopct,
startangle=90,
colors=colors,
wedgeprops={"linewidth": 0.8, "edgecolor": "white"},
textprops={"fontsize": 9},
)
for autotext in autotexts:
autotext.set_fontsize(8.5)
ax.legend(
wedges,
[f"{label} ({counts[label]:,})" for label in labels],
title="Trigger source",
loc="center left",
bbox_to_anchor=(1.02, 0.5),
)
ax.set_title("Request Trigger Role Proportion")
fig.tight_layout()
_save(fig, fig_dir, "04_request_trigger_role_pie")
def _plot_session_gap_cdf(fig_dir: Path, session_rows_by_id: dict[str, list[dict]]) -> None:
ready_gaps = []
end_ready_gaps = []
for session_rows in session_rows_by_id.values():
for previous, current in zip(session_rows, session_rows[1:]):
ready_gaps.append(max(current["request_ready_time_ms"] - previous["request_ready_time_ms"], 0))
end_ready_gaps.append(max(current["request_ready_time_ms"] - previous["request_end_time_ms"], 0))
_plot_two_series_cdf_with_zoom(
fig_dir,
stem="session_inter_request_gap_cdf",
title="Session Inter-Request Gap CDF",
xlabel="Milliseconds",
first_label="ready->ready",
first_values=ready_gaps,
first_color=PALETTE["purple"],
second_label="end->ready",
second_values=end_ready_gaps,
second_color=PALETTE["gray"],
zoom_quantile=0.90,
stats_labels=("mean", "p50", "p90", "p95", "p99"),
)
def _plot_consecutive_tool_calls_cdf(fig_dir: Path, session_rows_by_id: dict[str, list[dict]]) -> None:
values = []
for session_rows in session_rows_by_id.values():
for index, row in enumerate(session_rows):
if row["trigger_group"] != "user":
continue
count = 0
next_index = index + 1
while next_index < len(session_rows) and session_rows[next_index]["trigger_group"] == "tool":
count += 1
next_index += 1
values.append(count)
_plot_single_cdf(
fig_dir,
stem="07_consecutive_tool_call_count_cdf",
title="Consecutive Tool Calls After One User Input",
xlabel="Consecutive tool-triggered rounds",
label="Consecutive tool calls",
values=values,
color=PALETTE["green"],
)
def _plot_alive_kvcache_timeline(fig_dir: Path, timeline_rows: list[dict]) -> None:
fig, ax = plt.subplots(figsize=(10.2, 4.8))
if timeline_rows:
base_ts = safe_int(timeline_rows[0]["timestamp_ms"])
else:
base_ts = 0
xs = [
max(safe_int(row["timestamp_ms"]) - base_ts, 0) / 60000.0
for row in timeline_rows
]
ys = [safe_int(row["alive_block_count"]) for row in timeline_rows]
ax.step(xs, ys, where="post", color=PALETTE["purple"], linewidth=1.8)
ax.set_title("Alive KV-Cache Blocks Over Time")
ax.set_xlabel("Elapsed time (minutes)")
ax.set_ylabel("Alive block count")
ax.xaxis.set_major_locator(MultipleLocator(10))
plt.setp(ax.get_xticklabels(), rotation=20, ha="right")
_finalize_axes(ax)
fig.tight_layout()
_save(fig, fig_dir, "11_alive_kvcache_blocks_timeline")
def _plot_bucket_reuse_ratio(fig_dir: Path, request_rows: list[dict]) -> None:
by_bucket = defaultdict(lambda: {"prompt_blocks": 0, "reused_blocks": 0})
total_prompt_blocks = 0
total_reused_blocks = 0
for row in request_rows:
bucket = row["input_length_bucket"] or "unknown"
prompt_blocks = row["theoretical_prompt_unit_length"]
reused_blocks = row["bucketed_theoretical_prefix_hit_blocks"]
by_bucket[bucket]["prompt_blocks"] += prompt_blocks
by_bucket[bucket]["reused_blocks"] += reused_blocks
total_prompt_blocks += prompt_blocks
total_reused_blocks += row["theoretical_prefix_hit_blocks"]
labels = list(by_bucket)
ratios = [
(by_bucket[label]["reused_blocks"] / by_bucket[label]["prompt_blocks"])
if by_bucket[label]["prompt_blocks"]
else 0.0
for label in labels
]
reused_counts = [by_bucket[label]["reused_blocks"] for label in labels]
labels.append("Overall")
ratios.append((total_reused_blocks / total_prompt_blocks) if total_prompt_blocks else 0.0)
reused_counts.append(total_reused_blocks)
fig, ax = plt.subplots(figsize=(9.2, 4.8))
bars = ax.bar(
labels,
ratios,
color=[PALETTE["blue"], PALETTE["orange"], PALETTE["green"], PALETTE["purple"], PALETTE["teal"]][: len(labels)],
width=0.68,
edgecolor="white",
linewidth=0.8,
)
for bar, ratio, reused_count in zip(bars, ratios, reused_counts):
ax.text(
bar.get_x() + bar.get_width() / 2,
ratio + max(ratios + [0.0]) * 0.03 + 1e-9,
f"{ratio:.2%}\nreused={reused_count:,}",
ha="center",
va="bottom",
fontsize=8.8,
)
ax.set_title("Bucketed KV-Cache Reuse Ratio vs Global Reuse Ratio")
ax.set_xlabel("Input-length bucket")
ax.set_ylabel("Reuse ratio")
ax.set_ylim(0, max(ratios + [0.0]) * 1.25 + 1e-9)
_finalize_axes(ax)
fig.tight_layout()
_save(fig, fig_dir, "12_bucket_kvcache_reuse_ratio")
def _plot_session_cross_bucket_miss(fig_dir: Path, rows: list[dict]) -> None:
labels = [row["bucket"] for row in rows]
miss_ratios = [safe_float(row["cross_bucket_edge_fraction"]) for row in rows]
loss_ratios = [safe_float(row["reduced_reused_blocks_ratio"]) for row in rows]
miss_blocks = [safe_int(row["cross_bucket_shared_prefix_units_sum"]) for row in rows]
x = np.arange(len(labels))
width = 0.36
fig, ax = plt.subplots(figsize=(9.2, 4.8))
left = ax.bar(x - width / 2, miss_ratios, width=width, color=PALETTE["red"], label="cross-bucket miss ratio")
right = ax.bar(
x + width / 2,
loss_ratios,
width=width,
color=PALETTE["gold"],
label="reduced reused blocks / bucket reuse",
)
y_pad = max(miss_ratios + loss_ratios + [0.0]) * 0.03 + 1e-9
for bar, value, count in zip(left, miss_ratios, miss_blocks):
ax.text(
bar.get_x() + bar.get_width() / 2,
value + y_pad,
f"{value:.2%}\nmiss={count:,}",
ha="center",
va="bottom",
fontsize=8.8,
)
for bar, value in zip(right, loss_ratios):
ax.text(
bar.get_x() + bar.get_width() / 2,
value + y_pad,
f"{value:.2%}",
ha="center",
va="bottom",
fontsize=8.8,
)
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_title("Session Cross-Bucket KV-Cache Miss and Reuse Loss")
ax.set_xlabel("Child bucket")
ax.set_ylabel("Ratio")
ax.legend(loc="upper left")
ax.set_ylim(0, max(miss_ratios + loss_ratios + [0.0]) * 1.25 + 1e-9)
_finalize_axes(ax)
fig.tight_layout()
_save(fig, fig_dir, "13_session_cross_bucket_kvcache_miss")
def _write_manifest(fig_dir: Path, manifest: dict) -> None:
(fig_dir / "manifest.json").write_text(json.dumps(manifest, ensure_ascii=False, indent=2), encoding="utf-8")
def _write_readme(fig_dir: Path, dataset_title: str) -> None:
lines = [
f"# {dataset_title}",
"",
"This directory contains the PNG figures rendered from `details/` data.",
"",
"Figures:",
]
for stem in FIGURE_STEMS:
lines.append(f"- `{stem}.png`")
lines.append("- `session_inter_request_gap_cdf.png`")
(fig_dir / "README.md").write_text("\n".join(lines) + "\n", encoding="utf-8")
def render_figures(
*,
analysis_dir: str | Path,
fig_dir: str | Path,
dataset_title: str,
show_progress: bool = False,
) -> dict:
analysis_root = Path(analysis_dir)
fig_root = Path(fig_dir)
details_root = resolve_details_dir(analysis_root)
_clear_dir_files(fig_root)
_apply_style()
request_rows = _load_request_metrics(details_root / "request_metrics.csv")
request_rows = _sort_request_rows(request_rows)
session_rows_by_id = _build_session_sequences(request_rows)
tool_round_edges = _build_tool_round_edges(session_rows_by_id)
reuse_gap_rows = _read_csv_rows(details_root / "theoretical_block_reuse_gaps.csv")
block_lifetime_rows = _read_csv_rows(details_root / "theoretical_block_lifetimes.csv")
timeline_rows = _read_csv_rows(details_root / "theoretical_alive_block_timeline.csv")
session_bucket_rows = _read_csv_rows(details_root / "session_bucket_boundary_miss.csv")
details_summary = _read_json(details_root / "details_summary.json")
progress = tqdm(
total=len(FIGURE_STEMS) + 1,
desc="Render figures",
unit="artifact",
dynamic_ncols=True,
disable=not show_progress,
)
if show_progress:
progress.set_postfix(current="01_input_output_length_cdf")
_plot_two_series_cdf_with_zoom(
fig_root,
stem="01_input_output_length_cdf",
title="Input / Output Length CDF",
xlabel="Tokens",
first_label="Input",
first_values=[row["input_tokens"] for row in request_rows],
first_color=PALETTE["blue"],
second_label="Output",
second_values=[row["output_tokens"] for row in request_rows],
second_color=PALETTE["orange"],
zoom_quantile=0.80,
stats_labels=("mean", "p50", "p80", "p90", "p95", "p99"),
)
if show_progress:
progress.update(1)
progress.set_postfix(current="02_session_turns_cdf")
_plot_session_turns_cdf(fig_root, request_rows)
if show_progress:
progress.update(1)
progress.set_postfix(current="03_request_length_by_turn")
_plot_request_length_by_turn(fig_root, request_rows)
if show_progress:
progress.update(1)
progress.set_postfix(current="04_request_trigger_role_pie")
_plot_trigger_role_pie(fig_root, request_rows)
if show_progress:
progress.update(1)
progress.set_postfix(current="05_tool_call_output_length_cdf")
_plot_single_cdf(
fig_root,
stem="05_tool_call_output_length_cdf",
title="Tool Call Output Length CDF",
xlabel="Output tokens",
label="Tool-call output length",
values=[row["tool_call_output_tokens"] for row in tool_round_edges],
color=PALETTE["teal"],
zoom_quantile=0.90,
)
if show_progress:
progress.update(1)
progress.set_postfix(current="06_tool_call_latency_cdf")
_plot_single_cdf(
fig_root,
stem="06_tool_call_latency_cdf",
title="Tool Call Latency CDF",
xlabel="Milliseconds",
label="Tool-call latency",
values=[row["tool_call_latency_ms"] for row in tool_round_edges],
color=PALETTE["red"],
zoom_quantile=0.90,
)
if show_progress:
progress.update(1)
progress.set_postfix(current="07_consecutive_tool_call_count_cdf")
_plot_consecutive_tool_calls_cdf(fig_root, session_rows_by_id)
if show_progress:
progress.update(1)
progress.set_postfix(current="08_tool_call_added_context_cdf")
_plot_single_cdf(
fig_root,
stem="08_tool_call_added_context_cdf",
title="Added Context After Tool Call CDF",
xlabel="Added context tokens",
label="Added context",
values=[row["added_context_tokens"] for row in tool_round_edges],
color=PALETTE["purple"],
)
if show_progress:
progress.update(1)
progress.set_postfix(current="09_kvcache_block_reuse_time_cdf")
_plot_single_cdf(
fig_root,
stem="09_kvcache_block_reuse_time_cdf",
title="KV-Cache Block Reuse Time CDF",
xlabel="Milliseconds",
label="Reuse time",
weighted_rows=reuse_gap_rows,
weighted_value_key="reuse_gap_ms",
weighted_count_key="count",
color=PALETTE["gold"],
zoom_quantile=0.90,
)
if show_progress:
progress.update(1)
progress.set_postfix(current="10_kvcache_block_lifecycle_cdf")
_plot_single_cdf(
fig_root,
stem="10_kvcache_block_lifecycle_cdf",
title="KV-Cache Block Lifecycle CDF",
xlabel="Milliseconds",
label="Block lifecycle",
values=[safe_int(row["lifetime_ms"]) for row in block_lifetime_rows],
color=PALETTE["gray"],
)
if show_progress:
progress.update(1)
progress.set_postfix(current="11_alive_kvcache_blocks_timeline")
_plot_alive_kvcache_timeline(fig_root, timeline_rows)
if show_progress:
progress.update(1)
progress.set_postfix(current="12_bucket_kvcache_reuse_ratio")
_plot_bucket_reuse_ratio(fig_root, request_rows)
if show_progress:
progress.update(1)
progress.set_postfix(current="13_session_cross_bucket_kvcache_miss")
_plot_session_cross_bucket_miss(fig_root, session_bucket_rows)
_plot_session_gap_cdf(fig_root, session_rows_by_id)
if show_progress:
progress.update(1)
progress.set_postfix(current="manifest.json + README.md")
manifest = {
"dataset_title": dataset_title,
"figure_count": len(FIGURE_STEMS),
"analysis_dir": str(analysis_root),
"request_count": details_summary.get("request_count", 0),
"global_reuse_ratio": details_summary.get("global_reuse_ratio", 0.0),
"figures": [f"{stem}.png" for stem in FIGURE_STEMS],
"extra_figures": ["session_inter_request_gap_cdf.png"],
}
_write_manifest(fig_root, manifest)
_write_readme(fig_root, dataset_title)
if show_progress:
progress.update(1)
progress.close()
return {
"fig_dir": str(fig_root),
"manifest_path": str(fig_root / "manifest.json"),
"readme_path": str(fig_root / "README.md"),
}

78
trace_analyzer/helpers.py Normal file
View File

@@ -0,0 +1,78 @@
import json
from statistics import mean, median
def parse_jsonish(value):
"""Parse nested JSON strings until a non-string value is reached."""
current = value
while isinstance(current, str):
text = current.strip()
if not text:
return current
try:
current = json.loads(text)
except json.JSONDecodeError:
return current
return current
def safe_int(value, default=0):
if value is None or value == "":
return default
try:
return int(value)
except (TypeError, ValueError):
return default
def safe_float(value, default=0.0):
if value is None or value == "":
return default
try:
return float(value)
except (TypeError, ValueError):
return default
def percentile(values, pct):
if not values:
return 0.0
ordered = sorted(values)
if len(ordered) == 1:
return float(ordered[0])
rank = pct * (len(ordered) - 1)
low = int(rank)
high = min(low + 1, len(ordered) - 1)
fraction = rank - low
return ordered[low] + (ordered[high] - ordered[low]) * fraction
def series_stats(values):
cleaned = [v for v in values if v is not None]
if not cleaned:
return {
"count": 0,
"min": 0,
"max": 0,
"mean": 0.0,
"median": 0.0,
"p90": 0.0,
}
return {
"count": len(cleaned),
"min": min(cleaned),
"max": max(cleaned),
"mean": mean(cleaned),
"median": median(cleaned),
"p90": percentile(cleaned, 0.9),
}
def safe_div(numerator, denominator):
if not denominator:
return 0.0
return numerator / denominator
def compact_json(data):
return json.dumps(data, ensure_ascii=False, separators=(",", ":"))

76
trace_analyzer/layout.py Normal file
View File

@@ -0,0 +1,76 @@
from __future__ import annotations
from pathlib import Path
import json
DETAILS_DIR_NAME = "details"
LEGACY_DETAILS_DIR_NAME = "advanced"
DETAILS_SUMMARY_FILENAME = "details_summary.json"
LEGACY_DETAILS_SUMMARY_FILENAME = "advanced_summary.json"
def preferred_details_dir(output_dir: str | Path) -> Path:
return Path(output_dir) / DETAILS_DIR_NAME
def legacy_details_dir(output_dir: str | Path) -> Path:
return Path(output_dir) / LEGACY_DETAILS_DIR_NAME
def resolve_existing_details_dir(output_dir: str | Path) -> Path | None:
preferred = preferred_details_dir(output_dir)
if _details_dir_has_outputs(preferred):
return preferred
legacy = legacy_details_dir(output_dir)
if _details_dir_has_outputs(legacy):
return legacy
if preferred.exists():
return preferred
if legacy.exists():
return legacy
return None
def resolve_details_dir(output_dir: str | Path) -> Path:
existing = resolve_existing_details_dir(output_dir)
if existing is not None:
return existing
return preferred_details_dir(output_dir)
def resolve_details_summary_path(output_dir: str | Path) -> Path | None:
for details_dir in [preferred_details_dir(output_dir), legacy_details_dir(output_dir)]:
for filename in [DETAILS_SUMMARY_FILENAME, LEGACY_DETAILS_SUMMARY_FILENAME]:
path = details_dir / filename
if path.exists():
return path
return None
def details_outputs_exist(output_dir: str | Path) -> bool:
return _details_dir_has_outputs(preferred_details_dir(output_dir)) or _details_dir_has_outputs(
legacy_details_dir(output_dir)
)
def _details_dir_has_outputs(details_dir: Path) -> bool:
if not details_dir.exists():
return False
required_files = [
details_dir / "request_metrics.csv",
details_dir / "theoretical_block_reuse_gaps.csv",
details_dir / "theoretical_block_lifetimes.csv",
details_dir / "theoretical_alive_block_timeline.csv",
details_dir / "session_bucket_boundary_miss.csv",
]
if not all(path.exists() for path in required_files):
return False
summary_path = details_dir / DETAILS_SUMMARY_FILENAME
if not summary_path.exists():
return False
try:
payload = json.loads(summary_path.read_text(encoding="utf-8"))
except Exception:
return False
return int(payload.get("schema_version", 0) or 0) >= 3

94
trace_analyzer/models.py Normal file
View File

@@ -0,0 +1,94 @@
from dataclasses import dataclass, field
@dataclass
class MessageEvent:
role: str
content_type: str
text_len: int
has_cache_control: bool = False
item_count: int = 0
@dataclass
class ToolSpec:
name: str
tool_type: str
@dataclass
class UsageStats:
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
reasoning_tokens: int = 0
cached_tokens: int = 0
@dataclass
class RequestMeta:
provider: str
line_number: int
request_id: str
session_id: str
request_model: str
time: str
status_code: str
status_name: str
request_ready_time_ms: int
request_end_time_ms: int
total_cost_time_ms: int
backend_first_request_time_ms: int = 0
backend_first_response_time_ms: int = 0
@dataclass
class TraceRecord:
meta: RequestMeta
canonical_prompt: str = ""
messages: list[MessageEvent] = field(default_factory=list)
role_sequence: list[str] = field(default_factory=list)
declared_tools: list[ToolSpec] = field(default_factory=list)
usage: UsageStats = field(default_factory=UsageStats)
raw_messages: list[dict] = field(default_factory=list)
@dataclass
class TraceFeatures:
request_id: str
session_id: str
model: str
status_code: str
time: str
message_count: int
conversation_depth: int
declared_tool_count: int
assistant_msg_count: int
tool_msg_count: int
user_msg_count: int
system_msg_count: int
assistant_to_tool_count: int
tool_to_assistant_count: int
tool_to_tool_count: int
assistant_to_user_count: int
user_to_assistant_count: int
max_consecutive_tool_msgs: int
avg_tool_burst_len: float
has_tool_loop: int
input_tokens: int
output_tokens: int
total_tokens: int
reasoning_tokens: int
cached_tokens: int
cache_hit_ratio: float
uncached_prompt_tokens: int
output_input_ratio: float
latency_ms: int
ms_per_input_token: float
ms_per_output_token: float
long_context: int
high_cache: int
tool_burst_alert: int
tool_loop_alert: int
slow_request: int = 0
pattern_labels: list[str] = field(default_factory=list)

230
trace_analyzer/parser.py Normal file
View File

@@ -0,0 +1,230 @@
import json
import os
from dataclasses import asdict
from pathlib import Path
import psutil
from tqdm.auto import tqdm
from .helpers import safe_int
from .models import MessageEvent, RequestMeta, ToolSpec, TraceRecord, UsageStats
class FormattedAliTraceAdapter:
name = "formatted"
def detect(self, raw):
if not isinstance(raw.get("meta"), dict):
return False
required_keys = {"canonical_prompt", "usage", "message_events", "declared_tools", "role_sequence"}
if not required_keys.issubset(raw.keys()):
return False
schema_version = str(raw.get("schema_version", "")).strip()
return bool(schema_version) or "request_id" in raw["meta"]
def parse_line(self, raw, line_number=0):
meta_payload = raw.get("meta", {}) if isinstance(raw.get("meta", {}), dict) else {}
usage_payload = raw.get("usage", {}) if isinstance(raw.get("usage", {}), dict) else {}
message_events_payload = raw.get("message_events", [])
declared_tools_payload = raw.get("declared_tools", [])
usage = UsageStats(
input_tokens=safe_int(usage_payload.get("input_tokens")),
output_tokens=safe_int(usage_payload.get("output_tokens")),
total_tokens=safe_int(usage_payload.get("total_tokens")),
reasoning_tokens=safe_int(usage_payload.get("reasoning_tokens")),
cached_tokens=safe_int(usage_payload.get("cached_tokens")),
)
messages = [
MessageEvent(
role=str(message.get("role", "unknown")),
content_type=str(message.get("content_type", "unknown")),
text_len=safe_int(message.get("text_len")),
has_cache_control=bool(message.get("has_cache_control")),
item_count=safe_int(message.get("item_count")),
)
for message in message_events_payload
if isinstance(message, dict)
]
declared_tools = [
ToolSpec(
name=str(tool.get("name", "")),
tool_type=str(tool.get("tool_type", "function")),
)
for tool in declared_tools_payload
if isinstance(tool, dict)
]
inferred_family = str(meta_payload.get("model_family", "")).strip()
inferred_provider = str(meta_payload.get("provider", "")).strip()
if not inferred_provider:
inferred_provider = inferred_family or self.name
meta = RequestMeta(
provider=inferred_provider,
line_number=line_number,
request_id=str(meta_payload.get("request_id", "")),
session_id=str(meta_payload.get("session_id", "")),
request_model=str(meta_payload.get("request_model", "")),
time=str(meta_payload.get("time", "")),
status_code=str(meta_payload.get("status_code", "")),
status_name=str(meta_payload.get("status_name", "")),
request_ready_time_ms=safe_int(meta_payload.get("request_ready_time_ms")),
request_end_time_ms=safe_int(meta_payload.get("request_end_time_ms")),
total_cost_time_ms=safe_int(meta_payload.get("total_cost_time_ms")),
backend_first_request_time_ms=safe_int(meta_payload.get("backend_first_request_time_ms")),
backend_first_response_time_ms=safe_int(meta_payload.get("backend_first_response_time_ms")),
)
return TraceRecord(
meta=meta,
canonical_prompt=str(raw.get("canonical_prompt", "")),
messages=messages,
role_sequence=[
str(role)
for role in raw.get("role_sequence", [message.role for message in messages])
],
declared_tools=declared_tools,
usage=usage,
raw_messages=[
message
for message in raw.get("raw_messages", [])
if isinstance(message, dict)
],
)
def _looks_like_release_trace(raw):
expected_keys = {"chat_id", "parent_chat_id", "timestamp", "input_length", "output_length", "turn", "hash_ids"}
return expected_keys.issubset(raw.keys())
def path_looks_like_release_trace(path):
path = Path(path)
if not path.exists():
return False
try:
with path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
return _looks_like_release_trace(json.loads(line))
except Exception:
return False
return False
def get_adapter(raw):
adapter = FormattedAliTraceAdapter()
if adapter.detect(raw):
return adapter
if _looks_like_release_trace(raw):
raise ValueError("trace_analyzer currently analyzes formatter-generated *-raw.jsonl, not release hash-id traces.")
raise ValueError("trace_analyzer only accepts formatter-generated *-raw.jsonl inputs.")
def _estimate_peak_rss_mb(current_rss_mb, peak_rss_mb, fraction_done):
baseline = max(current_rss_mb, peak_rss_mb)
headroom = 1.0 + 0.25 * max(0.0, 1.0 - fraction_done)
return baseline * headroom
def load_records(path, limit=None, show_progress=False, progress_desc="Load trace"):
records = []
path = str(path)
progress = None
process = psutil.Process(os.getpid()) if show_progress else None
peak_rss_mb = 0.0
total_bytes = os.path.getsize(path) if show_progress else 0
if show_progress:
progress = tqdm(
total=total_bytes,
desc=progress_desc,
unit="B",
unit_scale=True,
dynamic_ncols=True,
)
with open(path, "r", encoding="utf-8") as handle:
for line_number, line in enumerate(handle, start=1):
if limit is not None and len(records) >= limit:
break
raw_line = line
line = line.strip()
if not line:
if progress is not None:
progress.update(len(raw_line.encode("utf-8")))
continue
raw = json.loads(line)
adapter = get_adapter(raw)
try:
record = adapter.parse_line(raw, line_number=line_number)
except Exception as exc:
if progress is not None:
progress.close()
raise ValueError(f"Failed to parse line {line_number} in {path}: {exc}") from exc
records.append(record)
if progress is not None:
progress.update(len(raw_line.encode("utf-8")))
current_rss_mb = process.memory_info().rss / (1024 * 1024)
peak_rss_mb = max(peak_rss_mb, current_rss_mb)
fraction_done = progress.n / progress.total if progress.total else 0.0
progress.set_postfix(
records=len(records),
rss_mb=f"{current_rss_mb:.0f}",
est_peak_mb=f"{_estimate_peak_rss_mb(current_rss_mb, peak_rss_mb, fraction_done):.0f}",
)
if progress is not None:
progress.close()
return records
def flatten_record(record):
return {
"provider": record.meta.provider,
"line_number": record.meta.line_number,
"request_id": record.meta.request_id,
"session_id": record.meta.session_id,
"request_model": record.meta.request_model,
"time": record.meta.time,
"status_code": record.meta.status_code,
"status_name": record.meta.status_name,
"request_ready_time_ms": record.meta.request_ready_time_ms,
"request_end_time_ms": record.meta.request_end_time_ms,
"total_cost_time_ms": record.meta.total_cost_time_ms,
"backend_first_request_time_ms": record.meta.backend_first_request_time_ms,
"backend_first_response_time_ms": record.meta.backend_first_response_time_ms,
"message_count": len(record.messages),
"role_sequence": ";".join(record.role_sequence),
"declared_tool_count": len(record.declared_tools),
"declared_tool_names": ";".join(tool.name for tool in record.declared_tools if tool.name),
"input_tokens": record.usage.input_tokens,
"output_tokens": record.usage.output_tokens,
"total_tokens": record.usage.total_tokens,
"reasoning_tokens": record.usage.reasoning_tokens,
"cached_tokens": record.usage.cached_tokens,
}
def record_to_dict(record):
return asdict(record)
def infer_analysis_dataset_name(input_path):
resolved = Path(input_path)
stem = resolved.stem
if stem.endswith("-raw"):
stem = stem[:-4]
parent_name = resolved.parent.name
model_slug = ""
if parent_name.startswith("trace-") and parent_name.endswith("-formatted"):
model_slug = parent_name[len("trace-") : -len("-formatted")]
if model_slug and not stem.startswith(f"{model_slug}-"):
return f"{model_slug}-{stem}"
return stem
def default_output_dir(input_path):
return Path("outputs") / "analysis" / infer_analysis_dataset_name(input_path)

View File

@@ -0,0 +1,221 @@
from __future__ import annotations
import csv
import json
import os
from pathlib import Path
from trace_analyzer.helpers import percentile
from trace_analyzer.parser import get_adapter
from tqdm.auto import tqdm
def stream_prepare(input_path: str | Path, output_dir: str | Path, *, show_progress: bool = False) -> dict:
input_file = Path(input_path)
output_root = Path(output_dir)
output_root.mkdir(parents=True, exist_ok=True)
features_path = output_root / "features.csv"
total_bytes = os.path.getsize(input_file) if show_progress and input_file.exists() else 0
progress = tqdm(
total=total_bytes,
desc="Prepare features",
unit="B",
unit_scale=True,
dynamic_ncols=True,
disable=not show_progress,
)
try:
with input_file.open("r", encoding="utf-8") as input_handle, features_path.open(
"w", encoding="utf-8", newline=""
) as features_handle:
writer: csv.DictWriter | None = None
kept_rows = 0
for line_number, line in enumerate(input_handle, start=1):
stripped = line.strip()
if not stripped:
if show_progress:
progress.update(len(line.encode("utf-8")))
continue
raw = json.loads(stripped)
adapter = get_adapter(raw)
record = adapter.parse_line(raw, line_number=line_number)
role_sequence = record.role_sequence
role_pairs = list(zip(role_sequence, role_sequence[1:]))
tool_bursts = _tool_bursts(role_sequence)
max_tool_burst = max(tool_bursts) if tool_bursts else 0
avg_tool_burst = _safe_div(sum(tool_bursts), len(tool_bursts)) if tool_bursts else 0.0
tool_to_tool_count = sum(1 for current, nxt in role_pairs if current == "tool" and nxt == "tool")
tool_msg_count = sum(message.role == "tool" for message in record.messages)
assistant_msg_count = sum(message.role == "assistant" for message in record.messages)
cache_hit_ratio = _safe_div(record.usage.cached_tokens, record.usage.input_tokens)
feature_row = {
"request_id": record.meta.request_id,
"session_id": record.meta.session_id,
"model": record.meta.request_model,
"status_code": record.meta.status_code,
"time": record.meta.time,
"message_count": len(record.messages),
"conversation_depth": len(record.messages),
"declared_tool_count": len(record.declared_tools),
"assistant_msg_count": assistant_msg_count,
"tool_msg_count": tool_msg_count,
"user_msg_count": sum(message.role == "user" for message in record.messages),
"system_msg_count": sum(message.role == "system" for message in record.messages),
"assistant_to_tool_count": sum(
1
for current, nxt in role_pairs
if current == "assistant" and nxt == "tool"
),
"tool_to_assistant_count": sum(
1
for current, nxt in role_pairs
if current == "tool" and nxt == "assistant"
),
"tool_to_tool_count": tool_to_tool_count,
"assistant_to_user_count": sum(
1
for current, nxt in role_pairs
if current == "assistant" and nxt == "user"
),
"user_to_assistant_count": sum(
1
for current, nxt in role_pairs
if current == "user" and nxt == "assistant"
),
"max_consecutive_tool_msgs": max_tool_burst,
"avg_tool_burst_len": avg_tool_burst,
"has_tool_loop": 1 if tool_to_tool_count > 0 else 0,
"input_tokens": record.usage.input_tokens,
"output_tokens": record.usage.output_tokens,
"total_tokens": record.usage.total_tokens,
"reasoning_tokens": record.usage.reasoning_tokens,
"cached_tokens": record.usage.cached_tokens,
"cache_hit_ratio": cache_hit_ratio,
"uncached_prompt_tokens": max(record.usage.input_tokens - record.usage.cached_tokens, 0),
"output_input_ratio": _safe_div(record.usage.output_tokens, record.usage.input_tokens),
"latency_ms": record.meta.total_cost_time_ms,
"ms_per_input_token": _safe_div(record.meta.total_cost_time_ms, record.usage.input_tokens),
"ms_per_output_token": _safe_div(record.meta.total_cost_time_ms, record.usage.output_tokens),
"long_context": 1 if record.usage.input_tokens >= 32000 else 0,
"high_cache": 1 if cache_hit_ratio >= 0.8 else 0,
"tool_burst_alert": 1 if max_tool_burst >= 4 else 0,
"tool_loop_alert": 1 if tool_to_tool_count >= 3 else 0,
"slow_request": 0,
"pattern_labels": _pattern_labels(
record,
cache_hit_ratio=cache_hit_ratio,
tool_msg_count=tool_msg_count,
assistant_msg_count=assistant_msg_count,
max_tool_burst=max_tool_burst,
),
}
if writer is None:
writer = csv.DictWriter(features_handle, fieldnames=list(feature_row.keys()))
writer.writeheader()
writer.writerow(feature_row)
kept_rows += 1
if show_progress:
progress.update(len(line.encode("utf-8")))
progress.set_postfix(
rows=kept_rows,
features=features_path.name,
)
finally:
if show_progress:
progress.close()
if show_progress:
tqdm.write("Finalize features.csv: apply slow_request p90 latency threshold")
_apply_slow_request_threshold(features_path)
return {
"features_path": str(features_path),
}
def _safe_div(numerator: float, denominator: float) -> float:
return (numerator / denominator) if denominator else 0.0
def _tool_bursts(role_sequence: list[str]) -> list[int]:
bursts: list[int] = []
current = 0
for role in role_sequence:
if role == "tool":
current += 1
elif current:
bursts.append(current)
current = 0
if current:
bursts.append(current)
return bursts
def _max_tool_burst(role_sequence: list[str]) -> int:
bursts = _tool_bursts(role_sequence)
return max(bursts) if bursts else 0
def _avg_tool_burst(role_sequence: list[str]) -> float:
bursts = _tool_bursts(role_sequence)
return _safe_div(sum(bursts), len(bursts)) if bursts else 0.0
def _pattern_labels(
record,
*,
cache_hit_ratio: float | None = None,
tool_msg_count: int | None = None,
assistant_msg_count: int | None = None,
max_tool_burst: int | None = None,
) -> str:
labels: list[str] = []
if tool_msg_count is None:
tool_msg_count = sum(message.role == "tool" for message in record.messages)
if assistant_msg_count is None:
assistant_msg_count = sum(message.role == "assistant" for message in record.messages)
if cache_hit_ratio is None:
cache_hit_ratio = _safe_div(record.usage.cached_tokens, record.usage.input_tokens)
if max_tool_burst is None:
max_tool_burst = _max_tool_burst(record.role_sequence)
if tool_msg_count == 0 and len(record.declared_tools) == 0:
labels.append("single-shot")
if tool_msg_count > 0 and tool_msg_count >= assistant_msg_count:
labels.append("tool-heavy")
if max_tool_burst >= 4:
labels.append("tool-burst")
if cache_hit_ratio >= 0.8:
labels.append("cache-efficient")
if cache_hit_ratio <= 0.1:
labels.append("cache-cold")
if record.usage.input_tokens >= 32000 and cache_hit_ratio <= 0.1:
labels.append("long-context-no-cache")
return ";".join(sorted(set(labels)))
def _apply_slow_request_threshold(features_path: Path) -> None:
with features_path.open("r", encoding="utf-8") as handle:
latencies = [int(row["latency_ms"]) for row in csv.DictReader(handle)]
if not latencies:
return
latencies.sort()
p90_latency = percentile(latencies, 0.9)
temp_path = features_path.with_suffix(features_path.suffix + ".tmp")
with features_path.open("r", encoding="utf-8") as input_handle, temp_path.open(
"w", encoding="utf-8", newline=""
) as output_handle:
reader = csv.DictReader(input_handle)
writer = None
for row in reader:
slow_request = 1 if int(row["latency_ms"]) >= p90_latency else 0
pattern_labels = {label for label in row.get("pattern_labels", "").split(";") if label}
row["slow_request"] = str(slow_request)
if slow_request and row.get("high_cache") == "1":
pattern_labels.add("slow-despite-cache")
row["pattern_labels"] = ";".join(sorted(pattern_labels))
if writer is None:
writer = csv.DictWriter(output_handle, fieldnames=list(row.keys()))
writer.writeheader()
writer.writerow(row)
temp_path.replace(features_path)

271
trace_analyzer/report.py Normal file
View File

@@ -0,0 +1,271 @@
import csv
import json
from collections import Counter
from pathlib import Path
from .features import feature_to_row
from .helpers import series_stats
from .parser import flatten_record, record_to_dict
def ensure_output_dir(path):
path.mkdir(parents=True, exist_ok=True)
return path
def write_jsonl(path, rows):
with open(path, "w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
def write_csv(path, rows):
if not rows:
with open(path, "w", encoding="utf-8", newline="") as handle:
handle.write("")
return
fieldnames = list(rows[0].keys())
with open(path, "w", encoding="utf-8", newline="") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
def write_parquet(path, rows):
try:
import pyarrow as pa
import pyarrow.parquet as pq
except ImportError as exc:
raise RuntimeError("Parquet output requires pyarrow to be installed.") from exc
table = pa.Table.from_pylist(rows)
pq.write_table(table, path)
def write_normalized(records, output_dir, output_format="jsonl"):
output_dir = ensure_output_dir(output_dir)
rows = [record_to_dict(record) for record in records]
if output_format == "jsonl":
path = output_dir / "normalized.jsonl"
write_jsonl(path, rows)
return path
if output_format == "csv":
path = output_dir / "normalized.csv"
write_csv(path, [flatten_record(record) for record in records])
return path
if output_format == "parquet":
path = output_dir / "normalized.parquet"
write_parquet(path, rows)
return path
raise ValueError(f"Unsupported format: {output_format}")
def write_features(features, output_dir):
output_dir = ensure_output_dir(output_dir)
path = output_dir / "features.csv"
write_csv(path, [feature_to_row(feature) for feature in features])
return path
def build_summary(records, features):
model_counts = Counter(feature.model or "unknown" for feature in features)
status_code_counts = Counter(feature.status_code or "unknown" for feature in features)
role_transition_counts = Counter()
for feature in features:
role_transition_counts["assistant->tool"] += feature.assistant_to_tool_count
role_transition_counts["tool->assistant"] += feature.tool_to_assistant_count
role_transition_counts["tool->tool"] += feature.tool_to_tool_count
role_transition_counts["assistant->user"] += feature.assistant_to_user_count
role_transition_counts["user->assistant"] += feature.user_to_assistant_count
latency_stats = series_stats([feature.latency_ms for feature in features])
cache_ratio_stats = series_stats([feature.cache_hit_ratio for feature in features])
cached_token_stats = series_stats([feature.cached_tokens for feature in features])
declared_tool_stats = series_stats([feature.declared_tool_count for feature in features])
burst_stats = series_stats([feature.max_consecutive_tool_msgs for feature in features])
high_burst_requests = sorted(
[
{
"request_id": feature.request_id,
"session_id": feature.session_id,
"max_consecutive_tool_msgs": feature.max_consecutive_tool_msgs,
"tool_to_tool_count": feature.tool_to_tool_count,
}
for feature in features
if feature.tool_burst_alert
],
key=lambda item: (item["max_consecutive_tool_msgs"], item["tool_to_tool_count"]),
reverse=True,
)[:10]
slow_despite_cache = sorted(
[
{
"request_id": feature.request_id,
"session_id": feature.session_id,
"latency_ms": feature.latency_ms,
"cache_hit_ratio": feature.cache_hit_ratio,
}
for feature in features
if "slow-despite-cache" in feature.pattern_labels
],
key=lambda item: item["latency_ms"],
reverse=True,
)[:10]
long_context_no_cache = sorted(
[
{
"request_id": feature.request_id,
"session_id": feature.session_id,
"input_tokens": feature.input_tokens,
"cache_hit_ratio": feature.cache_hit_ratio,
}
for feature in features
if "long-context-no-cache" in feature.pattern_labels
],
key=lambda item: item["input_tokens"],
reverse=True,
)[:10]
cache_buckets = []
for label, low, high in [
("lt_0_2", 0.0, 0.2),
("0_2_to_0_8", 0.2, 0.8),
("ge_0_8", 0.8, 1.01),
]:
bucket = [feature for feature in features if low <= feature.cache_hit_ratio < high]
cache_buckets.append(
{
"bucket": label,
"count": len(bucket),
"avg_latency_ms": series_stats([feature.latency_ms for feature in bucket])["mean"],
"avg_cache_hit_ratio": series_stats([feature.cache_hit_ratio for feature in bucket])["mean"],
}
)
return {
"record_count": len(records),
"success_count": sum(1 for feature in features if feature.status_code in {"1000", "200"}),
"session_count": len({record.meta.session_id for record in records if record.meta.session_id}),
"model_counts": dict(model_counts),
"status_code_counts": dict(status_code_counts),
"thresholds": {
"long_context": 32000,
"high_cache": 0.8,
"tool_burst_alert": 4,
"tool_loop_alert": 3,
"slow_request_p90_latency_ms": latency_stats["p90"],
},
"tool_patterns": {
"role_transitions": dict(role_transition_counts),
"declared_tool_count": declared_tool_stats,
"max_consecutive_tool_msgs": burst_stats,
"tool_burst_alert_count": sum(feature.tool_burst_alert for feature in features),
"tool_loop_alert_count": sum(feature.tool_loop_alert for feature in features),
"high_burst_requests": high_burst_requests,
},
"cache_patterns": {
"cached_tokens": cached_token_stats,
"cache_hit_ratio": cache_ratio_stats,
"latency_ms": latency_stats,
"cache_buckets": cache_buckets,
},
"anomalies": {
"slow_despite_cache": slow_despite_cache,
"long_context_no_cache": long_context_no_cache,
},
}
def _format_top_requests(rows, columns):
if not rows:
return "_none_"
header = "| " + " | ".join(columns) + " |"
divider = "| " + " | ".join(["---"] * len(columns)) + " |"
lines = [header, divider]
for row in rows:
lines.append("| " + " | ".join(_render_value(row.get(column, "")) for column in columns) + " |")
return "\n".join(lines)
def _render_value(value):
if isinstance(value, float):
return f"{value:.4f}".rstrip("0").rstrip(".")
return str(value)
def _render_mapping(mapping):
if isinstance(mapping, dict):
rendered = {key: _render_mapping(value) for key, value in mapping.items()}
return json.dumps(rendered, ensure_ascii=False)
if isinstance(mapping, list):
return [_render_mapping(value) for value in mapping]
if isinstance(mapping, float):
return float(f"{mapping:.4f}")
return mapping
def build_markdown_report(summary):
lines = [
"# Trace Analysis Report",
"",
"## Data Overview",
f"- Records: {summary['record_count']}",
f"- Success count: {summary['success_count']}",
f"- Session count: {summary['session_count']}",
f"- Models: {_render_mapping(summary['model_counts'])}",
f"- Status codes: {_render_mapping(summary['status_code_counts'])}",
"",
"## Tool Patterns",
f"- Role transitions: {_render_mapping(summary['tool_patterns']['role_transitions'])}",
f"- Declared tool count stats: {_render_mapping(summary['tool_patterns']['declared_tool_count'])}",
f"- Max consecutive tool msg stats: {_render_mapping(summary['tool_patterns']['max_consecutive_tool_msgs'])}",
f"- Tool burst alerts: {summary['tool_patterns']['tool_burst_alert_count']}",
f"- Tool loop alerts: {summary['tool_patterns']['tool_loop_alert_count']}",
"",
"High burst requests:",
_format_top_requests(
summary["tool_patterns"]["high_burst_requests"],
["request_id", "session_id", "max_consecutive_tool_msgs", "tool_to_tool_count"],
),
"",
"## Cache Patterns",
f"- Cached token stats: {_render_mapping(summary['cache_patterns']['cached_tokens'])}",
f"- Cache hit ratio stats: {_render_mapping(summary['cache_patterns']['cache_hit_ratio'])}",
f"- Latency stats: {_render_mapping(summary['cache_patterns']['latency_ms'])}",
"",
"Cache buckets:",
_format_top_requests(
summary["cache_patterns"]["cache_buckets"],
["bucket", "count", "avg_latency_ms", "avg_cache_hit_ratio"],
),
"",
"## Anomalies",
"Slow despite cache:",
_format_top_requests(
summary["anomalies"]["slow_despite_cache"],
["request_id", "session_id", "latency_ms", "cache_hit_ratio"],
),
"",
"Long context no cache:",
_format_top_requests(
summary["anomalies"]["long_context_no_cache"],
["request_id", "session_id", "input_tokens", "cache_hit_ratio"],
),
"",
]
return "\n".join(lines)
def write_report(records, features, output_dir):
output_dir = ensure_output_dir(output_dir)
summary = build_summary(records, features)
summary_path = output_dir / "summary.json"
with open(summary_path, "w", encoding="utf-8") as handle:
json.dump(summary, handle, ensure_ascii=False, indent=2)
report_path = output_dir / "report.md"
with open(report_path, "w", encoding="utf-8") as handle:
handle.write(build_markdown_report(summary))
return summary_path, report_path

228
trace_analyzer/reporting.py Normal file
View File

@@ -0,0 +1,228 @@
from __future__ import annotations
import csv
import json
from collections import Counter
from pathlib import Path
from trace_analyzer.helpers import safe_float, safe_int, series_stats
from trace_analyzer.layout import resolve_details_summary_path
from trace_analyzer.report import build_markdown_report
def _iter_feature_rows(features_path: str | Path):
with Path(features_path).open("r", encoding="utf-8") as handle:
for row in csv.DictReader(handle):
row["message_count"] = safe_int(row.get("message_count"))
row["conversation_depth"] = safe_int(row.get("conversation_depth"))
row["declared_tool_count"] = safe_int(row.get("declared_tool_count"))
row["assistant_msg_count"] = safe_int(row.get("assistant_msg_count"))
row["tool_msg_count"] = safe_int(row.get("tool_msg_count"))
row["user_msg_count"] = safe_int(row.get("user_msg_count"))
row["system_msg_count"] = safe_int(row.get("system_msg_count"))
row["assistant_to_tool_count"] = safe_int(row.get("assistant_to_tool_count"))
row["tool_to_assistant_count"] = safe_int(row.get("tool_to_assistant_count"))
row["tool_to_tool_count"] = safe_int(row.get("tool_to_tool_count"))
row["assistant_to_user_count"] = safe_int(row.get("assistant_to_user_count"))
row["user_to_assistant_count"] = safe_int(row.get("user_to_assistant_count"))
row["max_consecutive_tool_msgs"] = safe_int(row.get("max_consecutive_tool_msgs"))
row["avg_tool_burst_len"] = safe_float(row.get("avg_tool_burst_len"))
row["has_tool_loop"] = safe_int(row.get("has_tool_loop"))
row["input_tokens"] = safe_int(row.get("input_tokens"))
row["output_tokens"] = safe_int(row.get("output_tokens"))
row["total_tokens"] = safe_int(row.get("total_tokens"))
row["reasoning_tokens"] = safe_int(row.get("reasoning_tokens"))
row["cached_tokens"] = safe_int(row.get("cached_tokens"))
row["cache_hit_ratio"] = safe_float(row.get("cache_hit_ratio"))
row["uncached_prompt_tokens"] = safe_int(row.get("uncached_prompt_tokens"))
row["output_input_ratio"] = safe_float(row.get("output_input_ratio"))
row["latency_ms"] = safe_int(row.get("latency_ms"))
row["ms_per_input_token"] = safe_float(row.get("ms_per_input_token"))
row["ms_per_output_token"] = safe_float(row.get("ms_per_output_token"))
row["long_context"] = safe_int(row.get("long_context"))
row["high_cache"] = safe_int(row.get("high_cache"))
row["tool_burst_alert"] = safe_int(row.get("tool_burst_alert"))
row["tool_loop_alert"] = safe_int(row.get("tool_loop_alert"))
row["slow_request"] = safe_int(row.get("slow_request"))
row["pattern_labels"] = [label for label in str(row.get("pattern_labels", "")).split(";") if label]
yield row
def build_summary_from_features(features_path: str | Path) -> dict:
model_counts = Counter()
status_code_counts = Counter()
role_transition_counts = Counter()
session_ids: set[str] = set()
latencies: list[int] = []
cache_ratios: list[float] = []
cached_tokens_list: list[int] = []
declared_tool_counts: list[int] = []
burst_values: list[int] = []
record_count = 0
success_count = 0
high_burst_requests: list[dict] = []
slow_despite_cache: list[dict] = []
long_context_no_cache: list[dict] = []
tool_burst_alert_count = 0
tool_loop_alert_count = 0
cache_bucket_input = {
"lt_0_2": {"latencies": [], "ratios": [], "count": 0},
"0_2_to_0_8": {"latencies": [], "ratios": [], "count": 0},
"ge_0_8": {"latencies": [], "ratios": [], "count": 0},
}
for row in _iter_feature_rows(features_path):
record_count += 1
model_counts[row.get("model") or "unknown"] += 1
status_code_counts[row.get("status_code") or "unknown"] += 1
if row.get("session_id"):
session_ids.add(row["session_id"])
if row.get("status_code") in {"1000", "200"}:
success_count += 1
role_transition_counts["assistant->tool"] += row["assistant_to_tool_count"]
role_transition_counts["tool->assistant"] += row["tool_to_assistant_count"]
role_transition_counts["tool->tool"] += row["tool_to_tool_count"]
role_transition_counts["assistant->user"] += row["assistant_to_user_count"]
role_transition_counts["user->assistant"] += row["user_to_assistant_count"]
latencies.append(row["latency_ms"])
cache_ratios.append(row["cache_hit_ratio"])
cached_tokens_list.append(row["cached_tokens"])
declared_tool_counts.append(row["declared_tool_count"])
burst_values.append(row["max_consecutive_tool_msgs"])
tool_burst_alert_count += row["tool_burst_alert"]
tool_loop_alert_count += row["tool_loop_alert"]
if row["tool_burst_alert"]:
high_burst_requests.append(
{
"request_id": row["request_id"],
"session_id": row["session_id"],
"max_consecutive_tool_msgs": row["max_consecutive_tool_msgs"],
"tool_to_tool_count": row["tool_to_tool_count"],
}
)
high_burst_requests.sort(
key=lambda item: (item["max_consecutive_tool_msgs"], item["tool_to_tool_count"]),
reverse=True,
)
del high_burst_requests[10:]
if "slow-despite-cache" in row["pattern_labels"]:
slow_despite_cache.append(
{
"request_id": row["request_id"],
"session_id": row["session_id"],
"latency_ms": row["latency_ms"],
"cache_hit_ratio": row["cache_hit_ratio"],
}
)
slow_despite_cache.sort(key=lambda item: item["latency_ms"], reverse=True)
del slow_despite_cache[10:]
if "long-context-no-cache" in row["pattern_labels"]:
long_context_no_cache.append(
{
"request_id": row["request_id"],
"session_id": row["session_id"],
"input_tokens": row["input_tokens"],
"cache_hit_ratio": row["cache_hit_ratio"],
}
)
long_context_no_cache.sort(key=lambda item: item["input_tokens"], reverse=True)
del long_context_no_cache[10:]
ratio = row["cache_hit_ratio"]
if ratio < 0.2:
bucket_name = "lt_0_2"
elif ratio < 0.8:
bucket_name = "0_2_to_0_8"
else:
bucket_name = "ge_0_8"
cache_bucket_input[bucket_name]["count"] += 1
cache_bucket_input[bucket_name]["latencies"].append(row["latency_ms"])
cache_bucket_input[bucket_name]["ratios"].append(row["cache_hit_ratio"])
latency_stats = series_stats(latencies)
cache_ratio_stats = series_stats(cache_ratios)
cached_token_stats = series_stats(cached_tokens_list)
declared_tool_stats = series_stats(declared_tool_counts)
burst_stats = series_stats(burst_values)
cache_buckets = []
for label in ["lt_0_2", "0_2_to_0_8", "ge_0_8"]:
bucket = cache_bucket_input[label]
cache_buckets.append(
{
"bucket": label,
"count": bucket["count"],
"avg_latency_ms": series_stats(bucket["latencies"])["mean"],
"avg_cache_hit_ratio": series_stats(bucket["ratios"])["mean"],
}
)
return {
"record_count": record_count,
"success_count": success_count,
"session_count": len(session_ids),
"model_counts": dict(model_counts),
"status_code_counts": dict(status_code_counts),
"thresholds": {
"long_context": 32000,
"high_cache": 0.8,
"tool_burst_alert": 4,
"tool_loop_alert": 3,
"slow_request_p90_latency_ms": latency_stats["p90"],
},
"tool_patterns": {
"role_transitions": dict(role_transition_counts),
"declared_tool_count": declared_tool_stats,
"max_consecutive_tool_msgs": burst_stats,
"tool_burst_alert_count": tool_burst_alert_count,
"tool_loop_alert_count": tool_loop_alert_count,
"high_burst_requests": high_burst_requests,
},
"cache_patterns": {
"cached_tokens": cached_token_stats,
"cache_hit_ratio": cache_ratio_stats,
"latency_ms": latency_stats,
"cache_buckets": cache_buckets,
},
"anomalies": {
"slow_despite_cache": slow_despite_cache,
"long_context_no_cache": long_context_no_cache,
},
}
def write_reports(
*,
features_path: str | Path,
output_dir: str | Path,
pipeline_summary: dict | None = None,
) -> dict:
output_root = Path(output_dir)
output_root.mkdir(parents=True, exist_ok=True)
summary = build_summary_from_features(features_path)
summary_path = output_root / "summary.json"
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
report_path = output_root / "report.md"
report_path.write_text(build_markdown_report(summary), encoding="utf-8")
combined = {
"summary": summary,
"pipeline": pipeline_summary or {},
}
details_summary_path = resolve_details_summary_path(output_root)
if details_summary_path is not None:
combined["details_summary"] = json.loads(details_summary_path.read_text(encoding="utf-8"))
combined_path = output_root / "analysis_snapshot.json"
combined_path.write_text(json.dumps(combined, ensure_ascii=False, indent=2), encoding="utf-8")
return {
"summary_path": str(summary_path),
"report_path": str(report_path),
"analysis_snapshot_path": str(combined_path),
}

View File

@@ -0,0 +1,801 @@
from __future__ import annotations
import csv
import json
import time
from collections import Counter
from itertools import islice
from pathlib import Path
import psutil
from tqdm.auto import tqdm
from .layout import DETAILS_SUMMARY_FILENAME, preferred_details_dir, resolve_details_dir
PROGRESS_FLUSH_INTERVAL_S = 5.0
PROGRESS_REFRESH_INTERVAL_S = 0.5
PROGRESS_REFRESH_INTERVAL_REQ = 256
DEFAULT_INPUT_LENGTH_BUCKET_THRESHOLDS = [32 * 1024, 85 * 1024, 128 * 1024]
FIRST_SEEN_MS = 0
LAST_SEEN_MS = 1
LAST_REUSE_MS = 2
FIRST_REQUEST_ID = 3
LAST_REQUEST_ID = 4
LAST_REUSE_REQUEST_ID = 5
REUSE_COUNT = 6
def _format_bucket_boundary(value: int) -> str:
if value == 0:
return "0"
if value % (1024 * 1024) == 0:
return f"{value // (1024 * 1024)}Mi"
if value % 1024 == 0:
return f"{value // 1024}Ki"
return str(value)
def build_input_length_bucket_defs(thresholds=None):
parsed_thresholds = (
list(DEFAULT_INPUT_LENGTH_BUCKET_THRESHOLDS)
if thresholds is None
else sorted(set(int(value) for value in thresholds))
)
if not parsed_thresholds:
raise ValueError("At least one input-length bucket threshold is required.")
if any(value <= 0 for value in parsed_thresholds):
raise ValueError("Input-length bucket thresholds must be positive integers.")
if parsed_thresholds == DEFAULT_INPUT_LENGTH_BUCKET_THRESHOLDS:
return [
("0-32Ki", 0, 32 * 1024),
("32-85Ki", 32 * 1024, 85 * 1024),
("85-128Ki", 85 * 1024, 128 * 1024),
("128Ki+", 128 * 1024, None),
]
bucket_defs = []
lower_bound = 0
for upper_bound in parsed_thresholds:
bucket_defs.append(
(
f"{_format_bucket_boundary(lower_bound)}-{_format_bucket_boundary(upper_bound)}",
lower_bound,
upper_bound,
)
)
lower_bound = upper_bound
bucket_defs.append((f"{_format_bucket_boundary(lower_bound)}+", lower_bound, None))
return bucket_defs
def assign_input_length_bucket(input_tokens: int, bucket_defs=None) -> str:
bucket_defs = bucket_defs or build_input_length_bucket_defs()
for bucket_label, lower_bound, upper_bound in bucket_defs:
if input_tokens >= lower_bound and (upper_bound is None or input_tokens < upper_bound):
return bucket_label
return bucket_defs[-1][0]
def write_csv(path: Path, rows: list[dict], fieldnames: list[str] | None = None) -> Path:
path.parent.mkdir(parents=True, exist_ok=True)
if fieldnames is None and rows:
fieldnames = list(rows[0].keys())
fieldnames = fieldnames or []
with path.open("w", encoding="utf-8", newline="") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames)
if fieldnames:
writer.writeheader()
if rows:
writer.writerows(rows)
return path
def _estimate_peak_rss_mb(current_rss_mb, peak_rss_mb, fraction_done):
baseline = max(current_rss_mb, peak_rss_mb)
headroom = 1.0 + 0.25 * max(0.0, 1.0 - fraction_done)
return baseline * headroom
def _progress_postfix(process, peak_rss_mb, fraction_done, **extra):
current_rss_mb = process.memory_info().rss / (1024 * 1024)
peak_rss_mb = max(peak_rss_mb, current_rss_mb)
postfix = {
"rss_mb": f"{current_rss_mb:.0f}",
"est_peak_mb": f"{_estimate_peak_rss_mb(current_rss_mb, peak_rss_mb, fraction_done):.0f}",
}
postfix.update(extra)
return postfix, peak_rss_mb
def _format_duration(seconds):
if seconds is None or seconds < 0:
return "?"
if seconds < 60:
return f"{seconds:.0f}s"
if seconds < 3600:
return f"{seconds / 60:.1f}m"
return f"{seconds / 3600:.2f}h"
def _write_progress_state(
path,
*,
total_requests,
processed_requests,
started_at,
current_rss_mb,
peak_rss_mb,
est_peak_mb,
source_path,
features_path,
last_request_id,
block_state_count,
bucket_state_count,
):
elapsed_s = max(time.monotonic() - started_at, 1e-9)
req_per_s = processed_requests / elapsed_s
eta_s = ((total_requests - processed_requests) / req_per_s) if req_per_s > 0 and processed_requests < total_requests else 0.0
payload = {
"source_path": str(source_path),
"features_path": str(features_path),
"total_requests": total_requests,
"processed_requests": processed_requests,
"fraction_done": (processed_requests / total_requests) if total_requests else 1.0,
"elapsed_s": elapsed_s,
"req_per_s": req_per_s,
"eta_s": eta_s,
"eta_human": _format_duration(eta_s),
"rss_mb": current_rss_mb,
"peak_rss_mb": peak_rss_mb,
"est_peak_mb": est_peak_mb,
"block_state_count": block_state_count,
"bucket_state_count": bucket_state_count,
"last_request_id": last_request_id,
"updated_at_epoch_s": time.time(),
}
tmp_path = path.with_suffix(path.suffix + ".tmp")
tmp_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
tmp_path.replace(path)
def _count_lines(path):
with open(path, "r", encoding="utf-8") as handle:
return sum(1 for _ in handle)
def _count_feature_rows(path):
total_lines = _count_lines(path)
return max(total_lines - 1, 0)
class InMemoryBlockCache:
def __init__(self):
self.state = {}
def get(self, block_id):
return self.state.get(block_id)
def put(self, block_id, meta):
self.state[block_id] = meta
def iter_blocks(self):
for block_id, meta in self.state.items():
yield (
block_id,
meta[FIRST_SEEN_MS],
meta[LAST_SEEN_MS],
meta[LAST_REUSE_MS],
meta[FIRST_REQUEST_ID],
meta[LAST_REQUEST_ID],
meta[LAST_REUSE_REQUEST_ID],
meta[REUSE_COUNT],
)
def __len__(self):
return len(self.state)
def _normalize_source_row(row):
meta = row.get("meta", {}) if isinstance(row.get("meta", {}), dict) else {}
declared_tools = row.get("declared_tools", [])
raw_messages = row.get("raw_messages", [])
return {
"meta": meta,
"declared_tools": [tool for tool in declared_tools if isinstance(tool, dict)],
"raw_messages": [message for message in raw_messages if isinstance(message, dict)],
}
def _read_source_minimal(path):
with open(path, "r", encoding="utf-8") as handle:
for line in handle:
row = _normalize_source_row(json.loads(line))
meta = row["meta"]
yield {
"request_id": meta["request_id"],
"session_id": meta["session_id"],
"request_ready_time_ms": meta["request_ready_time_ms"],
"request_end_time_ms": meta["request_end_time_ms"],
"declared_tool_names": [
tool["name"] for tool in row.get("declared_tools", []) if tool.get("name")
],
"raw_messages": row["raw_messages"],
}
def _count_child_refs_by_chat_id(path, limit=None):
counts = Counter()
for index, row in enumerate(_iter_release_rows(path), start=1):
if limit is not None and index > limit:
break
parent_chat_id = int(row.get("parent_chat_id", -1) or -1)
if parent_chat_id != -1:
counts[parent_chat_id] += 1
return counts
def _new_block_meta(request_id, ready_ms):
return [ready_ms, ready_ms, 0, request_id, request_id, "", 0]
def _build_alive_block_timeline_from_events(events):
alive_rows = []
alive_count = 0
peak_alive_blocks = 0
for timestamp_ms in sorted(events):
alive_count += events[timestamp_ms]
peak_alive_blocks = max(peak_alive_blocks, alive_count)
alive_rows.append(
{
"timestamp_ms": timestamp_ms,
"delta_alive_blocks": events[timestamp_ms],
"alive_block_count": alive_count,
}
)
return {
"peak_alive_blocks": peak_alive_blocks,
"event_count": len(alive_rows),
}, alive_rows
def _compute_prefix_hits(
global_store,
bucket_store,
*,
hash_ids,
request_id,
ready_ms,
reuse_gap_counts=None,
):
global_prefix_active = True
bucket_prefix_active = True
global_prefix_match_blocks = 0
bucket_prefix_match_blocks = 0
global_source_request_id = ""
bucket_source_request_id = ""
for block_id in hash_ids:
global_meta = global_store.get(block_id)
if global_meta is not None and global_prefix_active:
global_prefix_match_blocks += 1
global_source_request_id = global_meta[LAST_REQUEST_ID]
if reuse_gap_counts is not None:
reuse_gap_counts[max(ready_ms - global_meta[LAST_SEEN_MS], 0)] += 1
global_meta[LAST_REUSE_MS] = ready_ms
global_meta[LAST_REUSE_REQUEST_ID] = request_id
global_meta[REUSE_COUNT] += 1
elif global_meta is None:
global_prefix_active = False
global_meta = _new_block_meta(request_id, ready_ms)
else:
global_prefix_active = False
global_meta[LAST_SEEN_MS] = ready_ms
global_meta[LAST_REQUEST_ID] = request_id
global_store.put(block_id, global_meta)
bucket_meta = bucket_store.get(block_id)
if bucket_meta is not None and bucket_prefix_active:
bucket_prefix_match_blocks += 1
bucket_source_request_id = bucket_meta[LAST_REQUEST_ID]
bucket_meta[LAST_REUSE_MS] = ready_ms
bucket_meta[LAST_REUSE_REQUEST_ID] = request_id
bucket_meta[REUSE_COUNT] += 1
elif bucket_meta is None:
bucket_prefix_active = False
bucket_meta = _new_block_meta(request_id, ready_ms)
else:
bucket_prefix_active = False
bucket_meta[LAST_SEEN_MS] = ready_ms
bucket_meta[LAST_REQUEST_ID] = request_id
bucket_store.put(block_id, bucket_meta)
return (
global_prefix_match_blocks,
global_source_request_id,
bucket_prefix_match_blocks,
bucket_source_request_id,
)
def _iter_release_rows(path):
with open(path, "r", encoding="utf-8") as handle:
for line in handle:
row = json.loads(line)
yield {
"chat_id": int(row.get("chat_id", -1) or -1),
"parent_chat_id": int(row.get("parent_chat_id", -1) or -1),
"timestamp": row.get("timestamp"),
"turn": int(row.get("turn", 0) or 0),
"type": row.get("type", ""),
"input_length": int(row.get("input_length", 0) or 0),
"output_length": int(row.get("output_length", 0) or 0),
"hash_ids": [int(value) for value in row.get("hash_ids", [])],
}
def _message_signature(message: dict) -> str:
return str(message.get("role", ""))
def _common_prefix_message_count(previous_messages, current_messages):
count = 0
for previous, current in zip(previous_messages, current_messages):
if _message_signature(previous) != _message_signature(current):
break
count += 1
return count
def _classify_trigger(previous_messages, current_messages):
common_prefix_count = _common_prefix_message_count(previous_messages, current_messages)
appended_messages = current_messages[common_prefix_count:]
appended_message_count = len(appended_messages)
last_role = str(current_messages[-1].get("role", "unknown")) if current_messages else "unknown"
trigger_group = last_role
trigger_detail = f"last_message_role={last_role}"
return {
"common_prefix_message_count": common_prefix_count,
"appended_message_count": appended_message_count,
"first_new_role": str(appended_messages[0].get("role", "unknown")) if appended_messages else "",
"trigger_group": trigger_group,
"trigger_detail": trigger_detail,
}
def _bucket_definition_rows(bucket_defs):
rows = []
for bucket, lower_bound, upper_bound in bucket_defs:
rows.append(
{
"bucket": bucket,
"input_tokens_min_inclusive": lower_bound,
"input_tokens_max_exclusive": upper_bound,
}
)
return rows
def _clear_details_dir(details_dir: Path) -> None:
details_dir.mkdir(parents=True, exist_ok=True)
for path in details_dir.iterdir():
if path.is_file():
path.unlink()
def collect_existing_detail_paths(output_dir):
details_dir = resolve_details_dir(output_dir)
return {
"details_dir": details_dir,
"progress": details_dir / "progress.json",
"request_metrics": details_dir / "request_metrics.csv",
"theoretical_block_reuse_gaps": details_dir / "theoretical_block_reuse_gaps.csv",
"theoretical_block_lifetimes": details_dir / "theoretical_block_lifetimes.csv",
"theoretical_alive_block_timeline": details_dir / "theoretical_alive_block_timeline.csv",
"session_bucket_boundary_miss": details_dir / "session_bucket_boundary_miss.csv",
"details_summary": details_dir / DETAILS_SUMMARY_FILENAME,
}
def run_advanced_from_existing(
source_path,
release_path,
features_path,
output_dir,
input_length_bucket_thresholds=None,
show_progress=True,
limit=None,
):
output_dir = Path(output_dir)
details_dir = preferred_details_dir(output_dir)
_clear_details_dir(details_dir)
source_path = Path(source_path)
release_path = Path(release_path)
features_path = Path(features_path)
total_requests = limit if limit is not None else _count_feature_rows(features_path)
release_request_count = _count_lines(release_path)
if limit is None and release_request_count != total_requests:
raise ValueError(
f"release/features row count mismatch: release={release_request_count} vs features={total_requests}"
)
process = psutil.Process()
peak_rss_mb = 0.0
started_at = time.monotonic()
bucket_defs = build_input_length_bucket_defs(input_length_bucket_thresholds)
child_ref_counts = _count_child_refs_by_chat_id(release_path, limit=limit)
store = InMemoryBlockCache()
bucket_stores = {bucket_label: InMemoryBlockCache() for bucket_label, _, _ in bucket_defs}
progress_state_path = details_dir / "progress.json"
next_progress_flush_at = started_at + PROGRESS_FLUSH_INTERVAL_S
request_metrics_path = details_dir / "request_metrics.csv"
processed_requests = 0
last_request_id = ""
reuse_gap_counts = Counter()
bucket_reused_block_totals = Counter()
total_prompt_blocks = 0
total_global_reused_blocks = 0
session_last = {}
chat_state_for_children = {}
session_bucket_totals = {
bucket_label: {
"edge_count": 0,
"reusable_edge_count": 0,
"cross_bucket_edge_count": 0,
"shared_prefix_units_sum": 0,
"cross_bucket_shared_prefix_units_sum": 0,
}
for bucket_label, _, _ in bucket_defs
}
with request_metrics_path.open("w", encoding="utf-8", newline="") as request_metrics_handle, features_path.open(
"r", encoding="utf-8"
) as features_handle:
feature_reader = csv.DictReader(features_handle)
source_iter = _read_source_minimal(source_path)
release_iter = _iter_release_rows(release_path)
if limit is not None:
feature_reader = islice(feature_reader, limit)
source_iter = islice(source_iter, limit)
release_iter = islice(release_iter, limit)
request_metrics_writer = None
progress = tqdm(
total=total_requests,
desc="Build details",
unit="req",
dynamic_ncols=True,
disable=not show_progress,
)
last_progress_refresh_at = started_at
try:
for source_row, feature_row, release_row in zip(source_iter, feature_reader, release_iter):
request_id = source_row["request_id"]
session_id = source_row["session_id"]
ready_ms = int(source_row["request_ready_time_ms"])
end_ms = int(source_row["request_end_time_ms"])
tool_names = source_row["declared_tool_names"]
raw_messages = source_row["raw_messages"]
hash_ids = release_row["hash_ids"]
release_input_length = int(release_row["input_length"])
release_output_length = int(release_row["output_length"])
feature_input_tokens = int(feature_row["input_tokens"])
feature_output_tokens = int(feature_row["output_tokens"])
if feature_input_tokens != release_input_length:
raise ValueError(
f"release/raw mismatch at request {request_id}: "
f"features.input_tokens={feature_row['input_tokens']} vs release.input_length={release_input_length}"
)
if feature_output_tokens != release_output_length:
raise ValueError(
f"release/raw mismatch at request {request_id}: "
f"features.output_tokens={feature_row['output_tokens']} vs release.output_length={release_output_length}"
)
input_tokens = feature_input_tokens
bucket_label = assign_input_length_bucket(input_tokens, bucket_defs)
bucket_store = bucket_stores[bucket_label]
(
prefix_match_blocks,
global_source_request_id,
bucketed_prefix_match_blocks,
bucketed_source_request_id,
) = _compute_prefix_hits(
store,
bucket_store,
hash_ids=hash_ids,
request_id=request_id,
ready_ms=ready_ms,
reuse_gap_counts=reuse_gap_counts,
)
prompt_block_count = len(hash_ids)
theoretical_prefix_hit_ratio = prefix_match_blocks / prompt_block_count if prompt_block_count else 0.0
bucketed_theoretical_prefix_hit_ratio = (
bucketed_prefix_match_blocks / prompt_block_count if prompt_block_count else 0.0
)
previous_session_state = session_last.get(session_id)
trigger = _classify_trigger(
previous_session_state["raw_messages"] if previous_session_state is not None else [],
raw_messages,
)
feature_row["request_ready_time_ms"] = ready_ms
feature_row["request_end_time_ms"] = end_ms
feature_row["turn"] = release_row["turn"]
feature_row["chat_id"] = release_row["chat_id"]
feature_row["parent_chat_id"] = release_row["parent_chat_id"]
feature_row["trigger_group"] = trigger["trigger_group"]
feature_row["trigger_detail"] = trigger["trigger_detail"]
feature_row["first_new_role"] = trigger["first_new_role"]
feature_row["common_prefix_message_count"] = trigger["common_prefix_message_count"]
feature_row["appended_message_count"] = trigger["appended_message_count"]
feature_row["input_length_bucket"] = bucket_label
feature_row["declared_tool_names"] = ";".join(tool_names)
feature_row["theoretical_prompt_unit_length"] = prompt_block_count
feature_row["theoretical_prefix_hit_blocks"] = prefix_match_blocks
feature_row["theoretical_prefix_hit_ratio"] = theoretical_prefix_hit_ratio
feature_row["theoretical_source_request_id"] = global_source_request_id
feature_row["bucketed_theoretical_prefix_hit_blocks"] = bucketed_prefix_match_blocks
feature_row["bucketed_theoretical_prefix_hit_ratio"] = bucketed_theoretical_prefix_hit_ratio
feature_row["bucketed_theoretical_source_request_id"] = bucketed_source_request_id
feature_row["theoretical_bucket_boundary_loss_blocks"] = max(
prefix_match_blocks - bucketed_prefix_match_blocks,
0,
)
feature_row["theoretical_bucket_boundary_loss_ratio"] = (
feature_row["theoretical_bucket_boundary_loss_blocks"] / prompt_block_count
if prompt_block_count
else 0.0
)
if request_metrics_writer is None:
request_metrics_writer = csv.DictWriter(
request_metrics_handle,
fieldnames=list(feature_row.keys()),
)
request_metrics_writer.writeheader()
request_metrics_writer.writerow(feature_row)
chat_id = release_row["chat_id"]
parent_chat_id = release_row["parent_chat_id"]
if parent_chat_id != -1:
parent_state = chat_state_for_children.get(parent_chat_id)
if parent_state is not None:
shared_prefix_units = 0
for parent_block_id, child_block_id in zip(parent_state["hash_ids"], hash_ids):
if parent_block_id != child_block_id:
break
shared_prefix_units += 1
bucket_totals = session_bucket_totals[bucket_label]
bucket_totals["edge_count"] += 1
if shared_prefix_units > 0:
bucket_totals["reusable_edge_count"] += 1
if parent_state["bucket_label"] != bucket_label:
bucket_totals["cross_bucket_edge_count"] += 1
bucket_totals["cross_bucket_shared_prefix_units_sum"] += shared_prefix_units
bucket_totals["shared_prefix_units_sum"] += shared_prefix_units
remaining_children = child_ref_counts.get(parent_chat_id, 0) - 1
if remaining_children > 0:
child_ref_counts[parent_chat_id] = remaining_children
else:
child_ref_counts.pop(parent_chat_id, None)
chat_state_for_children.pop(parent_chat_id, None)
if chat_id != -1 and child_ref_counts.get(chat_id, 0) > 0:
chat_state_for_children[chat_id] = {
"bucket_label": bucket_label,
"hash_ids": hash_ids,
}
total_prompt_blocks += prompt_block_count
total_global_reused_blocks += prefix_match_blocks
bucket_reused_block_totals[bucket_label] += bucketed_prefix_match_blocks
session_last[session_id] = {
"request_id": request_id,
"request_ready_time_ms": ready_ms,
"request_end_time_ms": end_ms,
"raw_messages": raw_messages,
}
processed_requests += 1
last_request_id = request_id
progress.update(1)
now = time.monotonic()
should_refresh_progress = (
processed_requests == 1
or processed_requests % PROGRESS_REFRESH_INTERVAL_REQ == 0
or now - last_progress_refresh_at >= PROGRESS_REFRESH_INTERVAL_S
or processed_requests == total_requests
)
if should_refresh_progress:
fraction_done = progress.n / progress.total if progress.total else 0.0
elapsed_s = max(now - started_at, 1e-9)
req_per_s = progress.n / elapsed_s
eta_s = ((progress.total - progress.n) / req_per_s) if req_per_s > 0 and progress.total else 0.0
total_bucket_state_count = sum(len(each_store) for each_store in bucket_stores.values())
postfix, peak_rss_mb = _progress_postfix(
process,
peak_rss_mb,
fraction_done,
req_s=f"{req_per_s:.1f}",
eta=_format_duration(eta_s),
blocks=len(store),
bucket_blocks=total_bucket_state_count,
sessions=len(session_last),
)
progress.set_postfix(postfix)
last_progress_refresh_at = now
if processed_requests and now >= next_progress_flush_at:
current_rss_mb = process.memory_info().rss / (1024 * 1024)
peak_rss_mb = max(peak_rss_mb, current_rss_mb)
est_peak_mb = _estimate_peak_rss_mb(
current_rss_mb,
peak_rss_mb,
(processed_requests / total_requests) if total_requests else 1.0,
)
_write_progress_state(
progress_state_path,
total_requests=total_requests,
processed_requests=processed_requests,
started_at=started_at,
current_rss_mb=current_rss_mb,
peak_rss_mb=peak_rss_mb,
est_peak_mb=est_peak_mb,
source_path=f"{source_path} + {release_path}",
features_path=features_path,
last_request_id=last_request_id,
block_state_count=len(store),
bucket_state_count=total_bucket_state_count,
)
next_progress_flush_at = now + PROGRESS_FLUSH_INTERVAL_S
finally:
progress.close()
theoretical_block_reuse_gaps_path = details_dir / "theoretical_block_reuse_gaps.csv"
write_csv(
theoretical_block_reuse_gaps_path,
[
{"reuse_gap_ms": reuse_gap_ms, "count": count}
for reuse_gap_ms, count in sorted(reuse_gap_counts.items())
],
fieldnames=["reuse_gap_ms", "count"],
)
theoretical_block_lifetimes_path = details_dir / "theoretical_block_lifetimes.csv"
alive_block_events = Counter()
block_lifetime_rows = []
for (
block_hash,
first_seen_ms,
last_seen_ms,
last_reuse_ms,
first_request_id,
last_request_id_for_block,
last_reuse_request_id,
reuse_count,
) in store.iter_blocks():
lifecycle_end_ms = last_reuse_ms if reuse_count > 0 else first_seen_ms
lifetime_ms = max(lifecycle_end_ms - first_seen_ms, 0)
block_lifetime_rows.append(
{
"hash": block_hash,
"first_request_id": first_request_id,
"last_request_id": last_request_id_for_block,
"first_seen_ms": first_seen_ms,
"last_seen_ms": last_seen_ms,
"last_reuse_ms": last_reuse_ms,
"last_reuse_request_id": last_reuse_request_id,
"reuse_count": reuse_count,
"lifetime_ms": lifetime_ms,
"span_end_ms": lifecycle_end_ms,
"reused": 1 if reuse_count > 0 else 0,
}
)
alive_block_events[first_seen_ms] += 1
alive_block_events[lifecycle_end_ms + 1] -= 1
write_csv(theoretical_block_lifetimes_path, block_lifetime_rows)
alive_block_timeline_summary, alive_block_timeline_rows = _build_alive_block_timeline_from_events(alive_block_events)
theoretical_alive_block_timeline_path = details_dir / "theoretical_alive_block_timeline.csv"
write_csv(theoretical_alive_block_timeline_path, alive_block_timeline_rows)
session_bucket_boundary_rows = []
for bucket_label, _, _ in bucket_defs:
bucket_totals = session_bucket_totals[bucket_label]
total_bucket_reused_blocks = bucket_reused_block_totals[bucket_label]
session_bucket_boundary_rows.append(
{
"bucket": bucket_label,
"edge_count": bucket_totals["edge_count"],
"reusable_edge_count": bucket_totals["reusable_edge_count"],
"cross_bucket_edge_count": bucket_totals["cross_bucket_edge_count"],
"cross_bucket_edge_fraction": (
bucket_totals["cross_bucket_edge_count"] / bucket_totals["edge_count"]
if bucket_totals["edge_count"]
else 0.0
),
"shared_prefix_units_sum": bucket_totals["shared_prefix_units_sum"],
"cross_bucket_shared_prefix_units_sum": bucket_totals["cross_bucket_shared_prefix_units_sum"],
"cross_bucket_shared_prefix_unit_fraction": (
bucket_totals["cross_bucket_shared_prefix_units_sum"] / bucket_totals["shared_prefix_units_sum"]
if bucket_totals["shared_prefix_units_sum"]
else 0.0
),
"bucket_total_reused_blocks": total_bucket_reused_blocks,
"reduced_reused_blocks_ratio": (
bucket_totals["cross_bucket_shared_prefix_units_sum"] / total_bucket_reused_blocks
if total_bucket_reused_blocks
else 0.0
),
}
)
session_bucket_boundary_miss_path = details_dir / "session_bucket_boundary_miss.csv"
write_csv(session_bucket_boundary_miss_path, session_bucket_boundary_rows)
details_summary_path = details_dir / DETAILS_SUMMARY_FILENAME
details_summary = {
"schema_version": 3,
"request_count": total_requests,
"figure_count": 13,
"cache_analysis_mode": "release_hash_ids",
"release_path": str(release_path),
"bucket_definition": {"buckets": _bucket_definition_rows(bucket_defs)},
"global_prompt_blocks": total_prompt_blocks,
"global_reused_blocks": total_global_reused_blocks,
"global_reuse_ratio": (total_global_reused_blocks / total_prompt_blocks) if total_prompt_blocks else 0.0,
"alive_block_timeline_summary": alive_block_timeline_summary,
"detail_files": [
"request_metrics.csv",
"theoretical_block_reuse_gaps.csv",
"theoretical_block_lifetimes.csv",
"theoretical_alive_block_timeline.csv",
"session_bucket_boundary_miss.csv",
DETAILS_SUMMARY_FILENAME,
"progress.json",
],
}
details_summary_path.write_text(json.dumps(details_summary, ensure_ascii=False, indent=2), encoding="utf-8")
current_rss_mb = process.memory_info().rss / (1024 * 1024)
peak_rss_mb = max(peak_rss_mb, current_rss_mb)
est_peak_mb = _estimate_peak_rss_mb(current_rss_mb, peak_rss_mb, 1.0)
_write_progress_state(
progress_state_path,
total_requests=total_requests,
processed_requests=processed_requests,
started_at=started_at,
current_rss_mb=current_rss_mb,
peak_rss_mb=peak_rss_mb,
est_peak_mb=est_peak_mb,
source_path=f"{source_path} + {release_path}",
features_path=features_path,
last_request_id=last_request_id,
block_state_count=len(store),
bucket_state_count=sum(len(bucket_store) for bucket_store in bucket_stores.values()),
)
return {
"details_dir": details_dir,
"progress": progress_state_path,
"request_metrics": request_metrics_path,
"theoretical_block_reuse_gaps": theoretical_block_reuse_gaps_path,
"theoretical_block_lifetimes": theoretical_block_lifetimes_path,
"theoretical_alive_block_timeline": theoretical_alive_block_timeline_path,
"session_bucket_boundary_miss": session_bucket_boundary_miss_path,
"details_summary": details_summary_path,
}

3264
trace_analyzer/study.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,12 @@
"""Public formatter entrypoints for raw trace inputs."""
SCHEMA_VERSION = "2026.04.21"
def main(argv=None):
from .cli import main as cli_main
return cli_main(argv)
__all__ = ["main"]

View File

@@ -0,0 +1,5 @@
from .cli import main
if __name__ == "__main__":
raise SystemExit(main())

200
trace_formatter/cli.py Normal file
View File

@@ -0,0 +1,200 @@
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from .formatting import derive_output_label, discover_source_files, export_release_ready_trace, format_and_sort_trace
from .time_windows import infer_time_window
def _default_output_root(input_path: str | Path) -> Path:
resolved = Path(input_path)
if resolved.is_dir():
return resolved.parent / f"{resolved.name}-formatted"
if resolved.parent.name.startswith("trace-"):
return resolved.parent.parent / f"{resolved.parent.name}-formatted"
return resolved.parent / f"{resolved.stem}-formatted"
def _resolve_raw_output_path(args: argparse.Namespace) -> Path:
if args.output:
explicit = Path(args.output)
return explicit if explicit.stem.endswith("-raw") else explicit.with_name(f"{explicit.stem}-raw.jsonl")
output_root = Path(args.output_root) if args.output_root else _default_output_root(args.input)
source_files = discover_source_files(args.input)
time_window = infer_time_window(
source_files,
start_time=None if args.no_truncate_to_window else args.start_time,
end_time=None if args.no_truncate_to_window else args.end_time,
) if (args.start_time and args.end_time) or (not args.no_truncate_to_window) else None
label = derive_output_label(args.input, time_window=time_window)
return output_root / f"{label}-raw.jsonl"
def _resolve_release_output_path(args: argparse.Namespace) -> Path:
if args.output:
explicit = Path(args.output)
return explicit if not explicit.stem.endswith("-raw") else explicit.with_name(f"{explicit.stem[:-4]}.jsonl")
input_path = Path(args.input)
if input_path.stem.endswith("-raw"):
return input_path.with_name(f"{input_path.stem[:-4]}.jsonl")
return input_path.with_suffix(".jsonl")
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Format raw trace shards into one time-sorted trace jsonl.")
subparsers = parser.add_subparsers(dest="command", required=True)
format_parser = subparsers.add_parser(
"format",
help="Format a raw trace directory or one .jsonl/.jsonl.zst file into one unified *-raw jsonl.",
)
format_parser.add_argument("input", help="Raw trace directory or one .jsonl/.jsonl.zst file.")
format_parser.add_argument(
"--output",
default=None,
help="Explicit raw output jsonl path. Defaults to a sibling trace-*-formatted/<label>-raw.jsonl path.",
)
format_parser.add_argument(
"--output-root",
default=None,
help="Base directory used when --output is omitted. Defaults to a sibling trace-*-formatted directory.",
)
format_parser.add_argument("--tmp-dir", default=None)
format_parser.add_argument("--chunk-bytes", type=int, default=128 * 1024 * 1024)
format_parser.add_argument(
"--log-file",
default=None,
help="Optional log file path. When set, progress bars are mirrored to this file.",
)
format_parser.add_argument(
"--no-progress",
action="store_true",
help="Disable progress bars during formatting.",
)
format_parser.add_argument(
"--start-time",
default=None,
help="Explicit UTC+8 start time for ready-time truncation, e.g. '2026-04-17 15:00:00.000'.",
)
format_parser.add_argument(
"--end-time",
default=None,
help="Explicit UTC+8 end time for ready-time truncation, e.g. '2026-04-17 17:00:00.000'.",
)
format_parser.add_argument(
"--no-truncate-to-window",
action="store_true",
help="Disable ready-time window truncation inferred from shard names or --start-time/--end-time.",
)
format_parser.add_argument(
"--build-release",
action="store_true",
help="Also build the open-source-ready release jsonl after formatting the raw output.",
)
format_parser.add_argument(
"--release-output",
default=None,
help="Explicit release jsonl path used only with --build-release.",
)
format_parser.add_argument(
"--release-jobs",
type=int,
default=min(os.cpu_count() or 1, 16),
help="Worker processes used by release building when --build-release is enabled.",
)
release_parser = subparsers.add_parser(
"build-release",
help="Build the open-source-ready release jsonl from one formatted *-raw trace.",
)
release_parser.add_argument("input", help="Path to the formatted *-raw jsonl.")
release_parser.add_argument(
"--output",
default=None,
help="Explicit release output jsonl path. Defaults to the sibling path without the -raw suffix.",
)
release_parser.add_argument("--tmp-dir", default=None)
release_parser.add_argument("--block-size", type=int, default=512)
release_parser.add_argument(
"--jobs",
type=int,
default=min(os.cpu_count() or 1, 16),
help="Worker processes used for release tokenization shards.",
)
release_parser.add_argument(
"--log-file",
default=None,
help="Optional log file path. When set, progress bars are mirrored to this file.",
)
release_parser.add_argument(
"--no-progress",
action="store_true",
help="Disable progress bars during release building.",
)
return parser
def main(argv: list[str] | None = None) -> int:
args = build_parser().parse_args(argv)
if args.command == "format":
raw_output_path = _resolve_raw_output_path(args)
result = format_and_sort_trace(
input_dir=args.input,
output_path=raw_output_path,
tmp_dir=args.tmp_dir,
chunk_bytes=args.chunk_bytes,
start_time=args.start_time,
end_time=args.end_time,
truncate_to_window=not args.no_truncate_to_window,
show_progress=not args.no_progress,
log_file=args.log_file,
)
payload = {
"input_path": str(Path(args.input)),
"formatted_name": raw_output_path.stem,
**result,
}
if args.build_release:
release_output_path = Path(args.release_output) if args.release_output else raw_output_path.with_name(
f"{raw_output_path.stem[:-4]}.jsonl"
)
release_result = export_release_ready_trace(
raw_input_path=raw_output_path,
release_output_path=release_output_path,
tmp_dir=args.tmp_dir,
jobs=args.release_jobs,
show_progress=not args.no_progress,
log_file=args.log_file,
)
payload.update(release_result)
print(json.dumps(payload, ensure_ascii=False, indent=2))
return 0
if args.command == "build-release":
release_output_path = _resolve_release_output_path(args)
result = export_release_ready_trace(
raw_input_path=args.input,
release_output_path=release_output_path,
tmp_dir=args.tmp_dir,
block_size=args.block_size,
jobs=args.jobs,
show_progress=not args.no_progress,
log_file=args.log_file,
)
payload = {
"input_path": str(Path(args.input)),
"formatted_name": release_output_path.stem,
**result,
}
print(json.dumps(payload, ensure_ascii=False, indent=2))
return 0
raise ValueError(f"Unsupported command: {args.command}")
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,849 @@
from __future__ import annotations
import hashlib
import heapq
import json
import os
import shutil
import sys
import subprocess
import tempfile
from array import array
from concurrent.futures import ProcessPoolExecutor, as_completed
from contextlib import contextmanager, nullcontext
from dataclasses import asdict
from pathlib import Path
from typing import Iterator, TextIO
from trace_analyzer.helpers import parse_jsonish, safe_int
from tokenizers import Tokenizer
from tqdm.auto import tqdm
from trace_model_meta import infer_model_family_from_request_model, resolve_tokenizer_path
from . import SCHEMA_VERSION
from .raw_parser import get_raw_adapter
from .sessionization import (
LogicalSessionizer,
build_message_fingerprints,
build_sequence_hashes,
decode_prefix_hashes,
decode_roles,
encode_prefix_hashes,
encode_roles,
extract_user_id,
)
from .time_windows import infer_time_offset_ms, infer_time_window, parse_time_to_ms
def _is_supported_trace_file(path: Path) -> bool:
return path.name.endswith(".jsonl") or path.name.endswith(".jsonl.zst")
def derive_trace_name(input_path: str | Path) -> str:
resolved = Path(input_path)
if resolved.is_dir():
return resolved.name
name = resolved.name
if name.endswith(".jsonl.zst"):
return name[: -len(".jsonl.zst")]
if name.endswith(".jsonl"):
return name[: -len(".jsonl")]
return resolved.stem
def default_formatted_name(input_path: str | Path) -> str:
base_name = derive_trace_name(input_path)
return base_name if base_name.endswith("-formatted") else f"{base_name}-formatted"
def derive_output_label(input_path: str | Path, *, time_window=None) -> str:
if time_window is not None and getattr(time_window, "label", None):
return str(time_window.label)
return derive_trace_name(input_path)
def discover_source_files(input_dir: str | Path) -> list[Path]:
root = Path(input_dir)
if not root.exists():
raise FileNotFoundError(f"Input path does not exist: {root}")
if root.is_file():
if not _is_supported_trace_file(root):
raise FileNotFoundError(f"Input file must be .jsonl or .jsonl.zst: {root}")
return [root]
preferred: dict[str, Path] = {}
for path in sorted(root.iterdir()):
if not path.is_file():
continue
if path.name.endswith(".jsonl.zst"):
stem = path.name[: -len(".jsonl.zst")]
preferred[stem] = path
elif path.name.endswith(".jsonl"):
preferred.setdefault(path.stem, path)
files = [preferred[key] for key in sorted(preferred)]
if not files:
raise FileNotFoundError(f"No .jsonl or .jsonl.zst files found under {root}")
return files
@contextmanager
def open_trace_text(path: str | Path) -> Iterator[TextIO]:
resolved = Path(path)
if resolved.suffix == ".zst":
proc = subprocess.Popen(
["zstdcat", str(resolved)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding="utf-8",
)
if proc.stdout is None:
raise RuntimeError(f"Failed to stream {resolved}")
try:
yield proc.stdout
finally:
stdout = proc.stdout
stdout.close()
stderr = proc.stderr.read() if proc.stderr else ""
return_code = proc.wait()
if return_code != 0:
raise RuntimeError(f"zstdcat failed for {resolved}: {stderr.strip()}")
return
with resolved.open("r", encoding="utf-8") as handle:
yield handle
def _normalize_time_ms(*, raw_time_ms: int, wall_clock_ms: int, time_offset_ms: int) -> int:
if not raw_time_ms:
return wall_clock_ms - time_offset_ms if wall_clock_ms and time_offset_ms else wall_clock_ms
if not wall_clock_ms or not time_offset_ms:
return raw_time_ms
delta_ms = wall_clock_ms - raw_time_ms
tolerance_ms = 10 * 60 * 1000
if abs(delta_ms - time_offset_ms) <= tolerance_ms:
return raw_time_ms
if abs(delta_ms) <= tolerance_ms:
return raw_time_ms - time_offset_ms
return raw_time_ms
def _extract_sort_time_ms(raw: dict, attributes: dict, *, time_offset_ms: int = 0) -> int:
wall_clock_ms = parse_time_to_ms(str(raw.get("time", ""))) if raw.get("time") else 0
ready_ms = safe_int(attributes.get("x-dashscope-inner-requestreadytime"))
if ready_ms:
return _normalize_time_ms(raw_time_ms=ready_ms, wall_clock_ms=wall_clock_ms, time_offset_ms=time_offset_ms)
if wall_clock_ms:
return _normalize_time_ms(raw_time_ms=0, wall_clock_ms=wall_clock_ms, time_offset_ms=time_offset_ms)
raw_epoch_seconds = safe_int(raw.get("__time__"))
if raw_epoch_seconds:
return raw_epoch_seconds * 1000
return 0
def _extract_response_message(response_payload: dict) -> dict:
output = response_payload.get("output", {}) if isinstance(response_payload, dict) else {}
if not isinstance(output, dict):
return {}
choices = output.get("choices", [])
if not isinstance(choices, list) or not choices:
return {}
message = choices[0].get("message", {})
return message if isinstance(message, dict) else {}
def _extract_usage(response_payload: dict) -> dict:
usage_payload = response_payload.get("usage", {}) if isinstance(response_payload, dict) else {}
output_payload = response_payload.get("output", {}) if isinstance(response_payload, dict) else {}
if (not isinstance(usage_payload, dict) or not usage_payload) and isinstance(output_payload, dict):
usage_payload = output_payload.get("usage", {})
output_details = parse_jsonish(usage_payload.get("output_tokens_details", {}))
prompt_details = parse_jsonish(usage_payload.get("prompt_tokens_details", {}))
return {
"input_tokens": safe_int(usage_payload.get("input_tokens", usage_payload.get("prompt_tokens"))),
"output_tokens": safe_int(usage_payload.get("output_tokens", usage_payload.get("completion_tokens"))),
"total_tokens": safe_int(usage_payload.get("total_tokens")),
"reasoning_tokens": safe_int(
output_details.get("reasoning_tokens") if isinstance(output_details, dict) else 0
),
"cached_tokens": safe_int(
prompt_details.get("cached_tokens") if isinstance(prompt_details, dict) else 0
),
}
def _extract_request_components(raw: dict) -> tuple[dict, dict | None, dict, dict, list]:
request_params = parse_jsonish(raw.get("request_params", {}))
response_params = parse_jsonish(raw.get("response_params", {}))
request_header = request_params.get("header", {}) if isinstance(request_params, dict) else {}
request_attributes = request_header.get("attributes", {}) if isinstance(request_header, dict) else {}
request_payload = request_params.get("payload", {}) if isinstance(request_params, dict) else {}
request_input = request_payload.get("input", {}) if isinstance(request_payload, dict) else {}
messages = request_input.get("messages", [])
return request_params, response_params, request_payload, request_attributes, messages
def _build_unified_row_from_components(
raw: dict,
*,
request_params: dict,
response_params: dict | None,
request_payload: dict,
request_attributes: dict,
messages: list,
source_file: str,
source_line: int,
time_offset_ms: int = 0,
) -> dict:
adapter = get_raw_adapter(raw)
request_parameters = request_payload.get("parameters", {}) if isinstance(request_payload, dict) else {}
response_payload = response_params.get("payload", {}) if isinstance(response_params, dict) else {}
response_header = response_params.get("header", {}) if isinstance(response_params, dict) else {}
response_attributes = response_header.get("attributes", {}) if isinstance(response_header, dict) else {}
sort_time_ms = _extract_sort_time_ms(raw, request_attributes, time_offset_ms=time_offset_ms)
total_cost_time_ms = safe_int(raw.get("total_cost_time"))
request_end_time_ms = sort_time_ms + total_cost_time_ms if sort_time_ms else total_cost_time_ms
declared_tools = request_parameters.get("tools", [])
canonical_prompt = adapter.build_canonical_prompt(request_payload)
usage = _extract_usage(response_payload)
message_events = [asdict(adapter.parse_message(message)) for message in messages if isinstance(message, dict)]
tool_specs = [asdict(adapter.parse_tool(tool)) for tool in declared_tools if isinstance(tool, dict)]
role_sequence = [event["role"] for event in message_events]
user_id = extract_user_id(request_params)
model_family = infer_model_family_from_request_model(raw.get("request_model")) or "glm5"
raw_messages = [message for message in messages if isinstance(message, dict)]
backend_first_request_time_ms = safe_int(response_attributes.get("x-ds-backend-first-request-time"))
backend_first_response_time_ms = safe_int(response_attributes.get("x-ds-backend-first-response-time"))
return {
"schema_version": SCHEMA_VERSION,
"sort_time_ms": sort_time_ms,
"meta": {
"model_family": model_family,
"request_id": str(raw.get("request_id", "")),
"session_id": "",
"raw_session_id": str(raw.get("session_id", "")),
"user_id": user_id,
"parent_request_id": "",
"parent_chat_id": -1,
"chat_id": -1,
"turn": 0,
"request_model": str(raw.get("request_model", "")),
"time": str(raw.get("time", "")),
"status_code": str(raw.get("status_code", "")),
"status_name": str(raw.get("status_name", "")),
"request_ready_time_ms": sort_time_ms,
"request_end_time_ms": request_end_time_ms,
"total_cost_time_ms": total_cost_time_ms,
"backend_first_request_time_ms": backend_first_request_time_ms,
"backend_first_response_time_ms": backend_first_response_time_ms,
},
"role_sequence": role_sequence,
"message_events": message_events,
"declared_tools": tool_specs,
"usage": usage,
"canonical_prompt": canonical_prompt,
"response_message": _extract_response_message(response_payload),
"raw_messages": raw_messages,
}
def _has_empty_response_params(raw: dict, response_params) -> bool:
raw_value = raw.get("response_params")
if raw_value is None:
return True
if isinstance(raw_value, str) and raw_value.strip().lower() in {"", "none", "null"}:
return True
return response_params is None or (isinstance(response_params, dict) and not response_params)
def build_unified_row(raw: dict, *, source_file: str, source_line: int, time_offset_ms: int = 0) -> dict:
request_params, response_params, request_payload, request_attributes, messages = _extract_request_components(raw)
return _build_unified_row_from_components(
raw,
request_params=request_params,
response_params=response_params,
request_payload=request_payload,
request_attributes=request_attributes,
messages=messages,
source_file=source_file,
source_line=source_line,
time_offset_ms=time_offset_ms,
)
def _build_unified_row_payload(
raw: dict,
*,
source_file: str,
source_line: int,
time_offset_ms: int = 0,
) -> tuple[int, str, str, str, str, str]:
request_params, response_params, request_payload, request_attributes, messages = _extract_request_components(raw)
return _build_unified_row_payload_from_components(
raw,
request_params=request_params,
response_params=response_params,
request_payload=request_payload,
request_attributes=request_attributes,
messages=messages,
source_file=source_file,
source_line=source_line,
time_offset_ms=time_offset_ms,
)
def _build_unified_row_payload_from_components(
raw: dict,
*,
request_params: dict,
response_params: dict | None,
request_payload: dict,
request_attributes: dict,
messages: list,
source_file: str,
source_line: int,
time_offset_ms: int = 0,
) -> tuple[int, str, str, str, str, str]:
normalized_messages = [message for message in messages if isinstance(message, dict)]
row = _build_unified_row_from_components(
raw,
request_params=request_params,
response_params=response_params,
request_payload=request_payload,
request_attributes=request_attributes,
messages=messages,
source_file=source_file,
source_line=source_line,
time_offset_ms=time_offset_ms,
)
message_fingerprints = build_message_fingerprints(normalized_messages)
return (
safe_int(row.get("sort_time_ms")),
str(row["meta"].get("user_id", "")),
str(row["meta"].get("request_id", "")),
encode_prefix_hashes(build_sequence_hashes(message_fingerprints)),
encode_roles([str(message.get("role", "")) for message in normalized_messages]),
json.dumps(row, ensure_ascii=False, separators=(",", ":")),
)
def _write_chunk(rows: list[tuple[int, int, str, str, str, str, str]], tmp_dir: Path, chunk_index: int) -> Path:
rows.sort(key=lambda item: (item[0], item[1]))
path = tmp_dir / f"chunk_{chunk_index:05d}.jsonl"
with path.open("w", encoding="utf-8") as handle:
for sort_time_ms, seq, user_id, request_id, sequence_hashes, roles, row_json in rows:
handle.write(f"{sort_time_ms}\t{seq}\t{user_id}\t{request_id}\t{sequence_hashes}\t{roles}\t{row_json}\n")
return path
def _iter_chunk_rows(path: Path) -> Iterator[tuple[int, int, str, str, str, str, str]]:
with path.open("r", encoding="utf-8") as handle:
for line in handle:
sort_text, seq_text, user_id, request_id, sequence_hashes, roles, row_json = line.rstrip("\n").split("\t", 6)
yield int(sort_text), int(seq_text), user_id, request_id, sequence_hashes, roles, row_json
def _replace_json_field_once(row_json: str, *, key: str, value_text: str) -> str:
target = f'"{key}":""'
if target in row_json:
return row_json.replace(target, f'"{key}":{value_text}', 1)
numeric_target = None
if key in {"parent_chat_id", "chat_id"}:
numeric_target = f'"{key}":-1'
elif key == "turn":
numeric_target = f'"{key}":0'
if numeric_target and numeric_target in row_json:
return row_json.replace(numeric_target, f'"{key}":{value_text}', 1)
raise ValueError(f"Unable to patch {key} in formatted row json")
def _apply_session_assignment_to_row_json(row_json: str, assignment) -> str:
patched = _replace_json_field_once(
row_json,
key="session_id",
value_text=json.dumps(assignment.session_id, ensure_ascii=False),
)
patched = _replace_json_field_once(
patched,
key="parent_request_id",
value_text=json.dumps(assignment.parent_request_id, ensure_ascii=False),
)
patched = _replace_json_field_once(
patched,
key="parent_chat_id",
value_text=str(assignment.parent_chat_id),
)
patched = _replace_json_field_once(
patched,
key="chat_id",
value_text=str(assignment.chat_id),
)
return _replace_json_field_once(
patched,
key="turn",
value_text=str(assignment.turn),
)
class _TeeStream:
def __init__(self, *streams):
self._streams = [stream for stream in streams if stream is not None]
def write(self, data):
for stream in self._streams:
stream.write(data)
return len(data)
def flush(self):
for stream in self._streams:
stream.flush()
@contextmanager
def _open_progress_stream(log_file: str | Path | None):
if log_file is None:
yield sys.stderr
return
path = Path(log_file)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as handle:
yield _TeeStream(sys.stderr, handle)
def _block_digest(block: list[int]) -> bytes:
digest = hashlib.blake2b(digest_size=16)
digest.update(len(block).to_bytes(4, "little", signed=False))
digest.update(array("I", block).tobytes())
return digest.digest()
def _load_release_tokenizer(model_family: str) -> Tokenizer:
resolved = Path(resolve_tokenizer_path(model_family=model_family))
tokenizer_file = resolved / "tokenizer.json" if resolved.is_dir() else resolved
return Tokenizer.from_file(str(tokenizer_file))
def _infer_window_start_ms_from_raw_rows(raw_path: Path) -> int:
with raw_path.open("r", encoding="utf-8") as handle:
for line in handle:
stripped = line.strip()
if not stripped:
continue
row = json.loads(stripped)
meta = row.get("meta", {}) if isinstance(row.get("meta", {}), dict) else {}
return safe_int(meta.get("request_ready_time_ms", row.get("sort_time_ms")))
return 0
def _compute_release_row_core(row: dict, *, base_ms: int) -> dict:
meta = row.get("meta", {}) if isinstance(row.get("meta", {}), dict) else {}
ready_ms = safe_int(meta.get("request_ready_time_ms", row.get("sort_time_ms")))
timestamp_ms = ready_ms - base_ms if ready_ms and base_ms else 0
return {
"chat_id": safe_int(meta.get("chat_id")),
"parent_chat_id": safe_int(meta.get("parent_chat_id", -1), default=-1),
"timestamp": round(timestamp_ms / 1000.0, 3),
"input_length": safe_int(row.get("usage", {}).get("input_tokens")),
"output_length": safe_int(row.get("usage", {}).get("output_tokens")),
"type": "coder",
"turn": safe_int(meta.get("turn")),
}
def _compute_release_segments(path: Path, jobs: int) -> list[tuple[int, int, int]]:
total_size = path.stat().st_size
if total_size <= 0:
return [(0, 0, 0)]
shard_count = max(1, min(jobs, total_size))
boundaries = [0]
with path.open("rb") as handle:
for index in range(1, shard_count):
target = total_size * index // shard_count
handle.seek(target)
handle.readline()
boundary = handle.tell()
if boundary > boundaries[-1]:
boundaries.append(boundary)
if boundaries[-1] != total_size:
boundaries.append(total_size)
segments: list[tuple[int, int, int]] = []
for index, (start, end) in enumerate(zip(boundaries, boundaries[1:])):
if end > start:
segments.append((index, start, end))
return segments or [(0, 0, total_size)]
def _build_release_shard(
*,
raw_input_path: str,
shard_index: int,
start_offset: int,
end_offset: int,
shard_output_path: str,
block_size: int,
base_ms: int,
) -> dict:
input_path = Path(raw_input_path)
output_path = Path(shard_output_path)
tokenizer_cache: dict[str, Tokenizer] = {}
row_count = 0
with input_path.open("rb") as source, output_path.open("w", encoding="utf-8") as destination:
source.seek(start_offset)
while source.tell() < end_offset:
line_bytes = source.readline()
if not line_bytes:
break
stripped = line_bytes.strip()
if not stripped:
continue
row = json.loads(stripped)
meta = row.get("meta", {}) if isinstance(row.get("meta", {}), dict) else {}
model_family = str(meta.get("model_family", "") or "glm5")
tokenizer = tokenizer_cache.get(model_family)
if tokenizer is None:
tokenizer = _load_release_tokenizer(model_family)
tokenizer_cache[model_family] = tokenizer
token_ids = tokenizer.encode(str(row.get("canonical_prompt", ""))).ids
digest_hexes = []
for index in range(0, len(token_ids), block_size):
block = token_ids[index:index + block_size]
digest_hexes.append(_block_digest(block).hex())
core = _compute_release_row_core(row, base_ms=base_ms)
destination.write(json.dumps(core, ensure_ascii=False, separators=(",", ":")))
destination.write("\t")
destination.write(",".join(digest_hexes))
destination.write("\n")
row_count += 1
return {
"shard_index": shard_index,
"shard_output_path": str(output_path),
"row_count": row_count,
"size_bytes": output_path.stat().st_size if output_path.exists() else 0,
}
def export_release_ready_trace(
*,
raw_input_path: str | Path,
release_output_path: str | Path,
window_start_ms: int | None = None,
block_size: int = 512,
jobs: int | None = None,
tmp_dir: str | Path | None = None,
show_progress: bool = False,
progress_stream=None,
log_file: str | Path | None = None,
) -> dict:
if progress_stream is None:
with _open_progress_stream(log_file) as owned_progress_stream:
return export_release_ready_trace(
raw_input_path=raw_input_path,
release_output_path=release_output_path,
window_start_ms=window_start_ms,
block_size=block_size,
jobs=jobs,
tmp_dir=tmp_dir,
show_progress=show_progress,
progress_stream=owned_progress_stream,
)
input_path = Path(raw_input_path)
release_destination = Path(release_output_path)
release_destination.parent.mkdir(parents=True, exist_ok=True)
requested_jobs = jobs if jobs is not None else min(os.cpu_count() or 1, 16)
shard_jobs = max(1, requested_jobs)
base_ms = window_start_ms or _infer_window_start_ms_from_raw_rows(input_path)
segments = _compute_release_segments(input_path, shard_jobs)
next_block_id = 0
block_ids_by_digest: dict[str, int] = {}
row_count = 0
with tempfile.TemporaryDirectory(dir=tmp_dir) as temp_root:
shard_root = Path(temp_root) / "release-shards"
shard_root.mkdir(parents=True, exist_ok=True)
shard_specs = [
{
"raw_input_path": str(input_path),
"shard_index": shard_index,
"start_offset": start_offset,
"end_offset": end_offset,
"shard_output_path": str(shard_root / f"shard_{shard_index:05d}.jsonl"),
"block_size": block_size,
"base_ms": base_ms,
}
for shard_index, start_offset, end_offset in segments
]
shard_progress = tqdm(
total=len(shard_specs),
desc="Build release shards",
unit="shard",
dynamic_ncols=True,
file=progress_stream or sys.stderr,
disable=not show_progress,
)
shard_results: list[dict] = []
try:
if len(shard_specs) == 1:
shard_results.append(_build_release_shard(**shard_specs[0]))
if show_progress:
shard_progress.update(1)
else:
with ProcessPoolExecutor(max_workers=len(shard_specs)) as executor:
futures = [executor.submit(_build_release_shard, **spec) for spec in shard_specs]
for future in as_completed(futures):
shard_results.append(future.result())
if show_progress:
shard_progress.update(1)
finally:
if show_progress:
shard_progress.close()
shard_results.sort(key=lambda item: item["shard_index"])
shard_paths = [Path(item["shard_output_path"]) for item in shard_results]
finalize_progress = tqdm(
total=sum(item["size_bytes"] for item in shard_results),
desc="Finalize release trace",
unit="B",
unit_scale=True,
dynamic_ncols=True,
file=progress_stream or sys.stderr,
disable=not show_progress,
)
try:
with release_destination.open("w", encoding="utf-8") as destination:
for shard_path in shard_paths:
with shard_path.open("r", encoding="utf-8") as source:
for line in source:
stripped = line.rstrip("\n")
if not stripped:
if show_progress:
finalize_progress.update(len(line.encode("utf-8")))
continue
core_json, _, digest_text = stripped.partition("\t")
release_row = json.loads(core_json)
hash_ids = []
if digest_text:
for digest_hex in digest_text.split(","):
if not digest_hex:
continue
block_id = block_ids_by_digest.get(digest_hex)
if block_id is None:
block_id = next_block_id
next_block_id += 1
block_ids_by_digest[digest_hex] = block_id
hash_ids.append(block_id)
release_row["hash_ids"] = hash_ids
destination.write(json.dumps(release_row, ensure_ascii=False))
destination.write("\n")
row_count += 1
if show_progress:
finalize_progress.update(len(line.encode("utf-8")))
finalize_progress.set_postfix(rows=row_count, unique_blocks=next_block_id)
finally:
if show_progress:
finalize_progress.close()
return {
"release_output_path": str(release_destination),
"release_row_count": row_count,
"release_unique_block_count": next_block_id,
"release_shard_count": len(segments),
}
def format_and_sort_trace(
*,
input_dir: str | Path,
output_path: str | Path,
tmp_dir: str | Path | None = None,
chunk_bytes: int = 128 * 1024 * 1024,
start_time: str | None = None,
end_time: str | None = None,
truncate_to_window: bool = True,
show_progress: bool = False,
log_file: str | Path | None = None,
) -> dict:
source_files = discover_source_files(input_dir)
destination = Path(output_path)
destination.parent.mkdir(parents=True, exist_ok=True)
time_offset_ms = infer_time_offset_ms(source_files[0]) if source_files else 0
time_window = infer_time_window(source_files, start_time=start_time, end_time=end_time) if truncate_to_window else None
total_input_bytes = sum(path.stat().st_size for path in source_files if path.suffix != ".zst")
has_zst = any(path.suffix == ".zst" for path in source_files)
with _open_progress_stream(log_file) as progress_stream, tempfile.TemporaryDirectory(dir=tmp_dir) as temp_root:
temp_raw_destination = Path(temp_root) / "formatted-raw.tmp.jsonl"
chunk_root = Path(temp_root)
chunk_paths: list[Path] = []
chunk_rows: list[tuple[int, int, str, str, str, str, str]] = []
chunk_size_bytes = 0
total_rows = 0
global_seq = 0
min_sort_time_ms: int | None = None
max_sort_time_ms: int | None = None
user_scoped_rows = 0
truncated_rows = 0
filtered_rows = 0
filtered_empty_messages_rows = 0
filtered_empty_response_rows = 0
scan_progress = tqdm(
total=None if has_zst else total_input_bytes,
desc="Scan raw trace",
unit="B" if not has_zst else "line",
unit_scale=not has_zst,
dynamic_ncols=True,
file=progress_stream,
disable=not show_progress,
)
try:
for source_file in source_files:
with open_trace_text(source_file) as handle:
for source_line, line in enumerate(handle, start=1):
stripped = line.strip()
if not stripped:
if show_progress:
scan_progress.update(1 if has_zst else len(line.encode("utf-8")))
continue
raw = json.loads(stripped)
request_params, response_params, request_payload, attributes, messages = _extract_request_components(raw)
empty_messages = not isinstance(messages, list) or len(messages) == 0
empty_response_params = _has_empty_response_params(raw, response_params)
if empty_messages or empty_response_params:
filtered_rows += 1
if empty_messages:
filtered_empty_messages_rows += 1
if empty_response_params:
filtered_empty_response_rows += 1
if show_progress:
scan_progress.update(1 if has_zst else len(line.encode("utf-8")))
scan_progress.set_postfix(
kept=total_rows,
filtered=filtered_rows,
truncated=truncated_rows,
chunks=len(chunk_paths),
)
continue
if time_window is not None:
ready_time_ms = _extract_sort_time_ms(raw, attributes, time_offset_ms=time_offset_ms)
if ready_time_ms and (ready_time_ms < time_window.start_ms or ready_time_ms >= time_window.end_ms):
truncated_rows += 1
if show_progress:
scan_progress.update(1 if has_zst else len(line.encode("utf-8")))
scan_progress.set_postfix(
kept=total_rows,
filtered=filtered_rows,
truncated=truncated_rows,
chunks=len(chunk_paths),
)
continue
sort_time_ms, user_id, request_id, sequence_hashes, roles, row_json = _build_unified_row_payload_from_components(
raw,
request_params=request_params,
response_params=response_params,
request_payload=request_payload,
request_attributes=attributes,
messages=messages,
source_file=source_file.name,
source_line=source_line,
time_offset_ms=time_offset_ms,
)
chunk_rows.append((sort_time_ms, global_seq, user_id, request_id, sequence_hashes, roles, row_json))
chunk_size_bytes += (
len(row_json.encode("utf-8"))
+ len(user_id.encode("utf-8"))
+ len(request_id.encode("utf-8"))
+ len(sequence_hashes.encode("utf-8"))
+ len(roles.encode("utf-8"))
+ 64
)
total_rows += 1
global_seq += 1
min_sort_time_ms = sort_time_ms if min_sort_time_ms is None else min(min_sort_time_ms, sort_time_ms)
max_sort_time_ms = sort_time_ms if max_sort_time_ms is None else max(max_sort_time_ms, sort_time_ms)
if user_id:
user_scoped_rows += 1
if chunk_size_bytes >= chunk_bytes:
chunk_paths.append(_write_chunk(chunk_rows, chunk_root, len(chunk_paths)))
chunk_rows = []
chunk_size_bytes = 0
if show_progress:
scan_progress.update(1 if has_zst else len(line.encode("utf-8")))
scan_progress.set_postfix(
kept=total_rows,
filtered=filtered_rows,
truncated=truncated_rows,
chunks=len(chunk_paths),
)
finally:
if show_progress:
scan_progress.close()
if chunk_rows:
chunk_paths.append(_write_chunk(chunk_rows, chunk_root, len(chunk_paths)))
iterators = [_iter_chunk_rows(path) for path in chunk_paths]
sessionizer = LogicalSessionizer()
merge_progress = tqdm(
total=total_rows,
desc="Merge formatted trace",
unit="row",
dynamic_ncols=True,
file=progress_stream,
disable=not show_progress,
)
try:
with temp_raw_destination.open("w", encoding="utf-8") as output_handle:
for _, _, user_id, request_id, sequence_hashes, roles, row_json in heapq.merge(
*iterators, key=lambda item: (item[0], item[1])
):
assignment = sessionizer.assign_precomputed(
user_id=user_id,
request_id=request_id,
sequence_hashes=decode_prefix_hashes(sequence_hashes),
roles=decode_roles(roles),
)
output_handle.write(_apply_session_assignment_to_row_json(row_json, assignment))
output_handle.write("\n")
if show_progress:
merge_progress.update(1)
merge_progress.set_postfix(rows=total_rows, chunks=len(chunk_paths))
finally:
if show_progress:
merge_progress.close()
shutil.move(str(temp_raw_destination), str(destination))
return {
"output_path": str(destination),
"row_count": total_rows,
"source_file_count": len(source_files),
"chunk_count": len(chunk_paths),
"min_sort_time_ms": min_sort_time_ms or 0,
"max_sort_time_ms": max_sort_time_ms or 0,
"rows_with_user_id": user_scoped_rows,
"truncated_row_count": truncated_rows,
"filtered_row_count": filtered_rows,
"filtered_empty_messages_row_count": filtered_empty_messages_rows,
"filtered_empty_response_params_row_count": filtered_empty_response_rows,
"window_start_ms": time_window.start_ms if time_window is not None else None,
"window_end_ms": time_window.end_ms if time_window is not None else None,
}

View File

@@ -0,0 +1,370 @@
from __future__ import annotations
import json
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any
from jinja2 import Environment
import re
from trace_analyzer.helpers import compact_json, parse_jsonish
from trace_model_meta import infer_model_family_from_request_model, resolve_chat_template_path
@dataclass
class ParsedMessageEvent:
role: str
content_type: str
text_len: int
has_cache_control: bool = False
item_count: int = 0
@dataclass
class ParsedToolSpec:
name: str
tool_type: str
@dataclass(frozen=True)
class Qwen3CoderToolParserSpec:
tool_call_start_token: str = "<tool_call>"
tool_call_end_token: str = "</tool_call>"
tool_call_prefix: str = "<function="
function_end_token: str = "</function>"
parameter_prefix: str = "<parameter="
parameter_end_token: str = "</parameter>"
def _template_tojson(value, ensure_ascii=True):
return json.dumps(value, ensure_ascii=ensure_ascii)
@lru_cache(maxsize=4)
def _load_chat_template(path: str):
environment = Environment()
environment.filters["tojson"] = _template_tojson
return environment.from_string(Path(path).read_text(encoding="utf-8"))
def _load_glm5_chat_template():
return _load_chat_template(str(resolve_chat_template_path("glm5")))
def _load_qwen3_coder_chat_template():
return _load_chat_template(str(resolve_chat_template_path("qwen3-coder")))
@lru_cache(maxsize=1)
def _load_qwen3_coder_tool_parser_spec() -> Qwen3CoderToolParserSpec:
defaults = Qwen3CoderToolParserSpec()
parser_path = resolve_chat_template_path("qwen3-coder").with_name("qwen3coder_tool_parser.py")
if not parser_path.exists():
return defaults
text = parser_path.read_text(encoding="utf-8")
def _extract(name: str, fallback: str) -> str:
pattern = rf'self\.{re.escape(name)}:\s*str\s*=\s*"([^"]+)"'
match = re.search(pattern, text)
return match.group(1) if match else fallback
return Qwen3CoderToolParserSpec(
tool_call_start_token=_extract("tool_call_start_token", defaults.tool_call_start_token),
tool_call_end_token=_extract("tool_call_end_token", defaults.tool_call_end_token),
tool_call_prefix=_extract("tool_call_prefix", defaults.tool_call_prefix),
function_end_token=_extract("function_end_token", defaults.function_end_token),
parameter_prefix=_extract("parameter_prefix", defaults.parameter_prefix),
parameter_end_token=_extract("parameter_end_token", defaults.parameter_end_token),
)
def _stringify_message_content_for_template(content):
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict) and isinstance(item.get("text"), str):
parts.append(item["text"])
else:
parts.append(compact_json(item))
return "".join(parts)
if isinstance(content, dict):
if isinstance(content.get("text"), str):
return content["text"]
return compact_json(content)
if content is None:
return ""
return str(content)
def _normalize_message_content_for_template(content, role=""):
if role == "tool" and isinstance(content, (list, dict)):
has_named_tool_refs = False
if isinstance(content, list):
has_named_tool_refs = any(isinstance(item, dict) and item.get("name") for item in content)
elif isinstance(content, dict):
has_named_tool_refs = bool(content.get("name"))
if not has_named_tool_refs:
return _stringify_message_content_for_template(content)
if isinstance(content, list):
normalized_items = []
for item in content:
if isinstance(item, dict) and "text" in item and "type" not in item:
normalized_item = dict(item)
normalized_item["type"] = "text"
normalized_items.append(normalized_item)
else:
normalized_items.append(item)
return normalized_items
if isinstance(content, dict) and "text" in content:
normalized_item = dict(content)
normalized_item.setdefault("type", "text")
return [normalized_item]
return content
def _normalize_tool_call_for_template(tool_call):
if not isinstance(tool_call, dict):
return tool_call
normalized = dict(tool_call)
function = normalized.get("function")
normalized_function = dict(function) if isinstance(function, dict) else None
if normalized_function is None and ("name" in normalized or "arguments" in normalized):
normalized_function = {}
if normalized_function is not None:
if "name" not in normalized_function and normalized.get("name"):
normalized_function["name"] = normalized["name"]
if "arguments" not in normalized_function and "arguments" in normalized:
normalized_function["arguments"] = normalized["arguments"]
arguments = parse_jsonish(normalized_function.get("arguments", {}))
normalized_function["arguments"] = arguments if isinstance(arguments, dict) else {"__raw_arguments__": arguments}
normalized["function"] = normalized_function
return normalized
def _normalize_tool_spec_for_template(tool):
if not isinstance(tool, dict):
return tool
normalized = dict(tool)
function = normalized.get("function")
if isinstance(function, dict):
normalized_function = dict(function)
parameters = parse_jsonish(normalized_function.get("parameters", {}))
if isinstance(parameters, dict):
normalized_function["parameters"] = parameters
normalized["function"] = normalized_function
return normalized
def _normalize_qwen_message_for_template(message):
if not isinstance(message, dict):
return message
normalized_message = dict(message)
normalized_message["content"] = _stringify_message_content_for_template(message.get("content"))
normalized_tool_calls = []
for tool_call in message.get("tool_calls", []):
normalized_tool_call = _normalize_tool_call_for_template(tool_call)
if isinstance(normalized_tool_call, dict):
function = normalized_tool_call.get("function")
if isinstance(function, dict):
arguments = function.get("arguments", {})
if not isinstance(arguments, dict):
function["arguments"] = {"__raw_arguments__": arguments}
normalized_tool_calls.append(normalized_tool_call)
normalized_message["tool_calls"] = normalized_tool_calls
return normalized_message
def _render_qwen3_coder_tool_call(tool_call, spec: Qwen3CoderToolParserSpec) -> str:
normalized_tool_call = _normalize_tool_call_for_template(tool_call)
if not isinstance(normalized_tool_call, dict):
return ""
function = normalized_tool_call.get("function", {})
if not isinstance(function, dict):
return ""
function_name = str(function.get("name", "")).strip()
if not function_name:
return ""
parts = [
spec.tool_call_start_token,
f"{spec.tool_call_prefix}{function_name}>",
]
arguments = function.get("arguments", {})
if isinstance(arguments, dict):
for arg_name, arg_value in arguments.items():
rendered_value = (
json.dumps(arg_value, ensure_ascii=False)
if isinstance(arg_value, (dict, list))
else str(arg_value)
)
parts.extend(
[
f"{spec.parameter_prefix}{arg_name}>",
rendered_value,
spec.parameter_end_token,
]
)
parts.extend([spec.function_end_token, spec.tool_call_end_token])
return "\n".join(parts)
def _normalize_qwen_message_with_tool_parser(message, spec: Qwen3CoderToolParserSpec):
normalized_message = _normalize_qwen_message_for_template(message)
if not isinstance(normalized_message, dict):
return normalized_message
if normalized_message.get("role") != "assistant":
return normalized_message
tool_calls = normalized_message.get("tool_calls", [])
if not isinstance(tool_calls, list) or not tool_calls:
return normalized_message
rendered_tool_calls = [chunk for chunk in (_render_qwen3_coder_tool_call(tool_call, spec) for tool_call in tool_calls) if chunk]
if not rendered_tool_calls:
return normalized_message
prefix = str(normalized_message.get("content", "") or "").strip()
combined_parts = [prefix] if prefix else []
combined_parts.extend(rendered_tool_calls)
normalized_message["content"] = "\n".join(combined_parts)
normalized_message["tool_calls"] = []
return normalized_message
def build_glm5_canonical_prompt(payload):
input_payload = payload.get("input", {}) if isinstance(payload, dict) else {}
parameters = payload.get("parameters", {}) if isinstance(payload, dict) else {}
messages = []
for message in input_payload.get("messages", []):
if not isinstance(message, dict):
messages.append(message)
continue
normalized_message = dict(message)
normalized_message["content"] = _normalize_message_content_for_template(
message.get("content"),
role=str(message.get("role", "")),
)
normalized_message["tool_calls"] = [
_normalize_tool_call_for_template(tool_call) for tool_call in message.get("tool_calls", [])
]
messages.append(normalized_message)
tools = [_normalize_tool_spec_for_template(tool) for tool in parameters.get("tools", []) if isinstance(tool, dict)]
return _load_glm5_chat_template().render(
messages=messages,
tools=tools,
add_generation_prompt=True,
)
def build_qwen3_coder_canonical_prompt(payload):
input_payload = payload.get("input", {}) if isinstance(payload, dict) else {}
parameters = payload.get("parameters", {}) if isinstance(payload, dict) else {}
tool_parser_spec = _load_qwen3_coder_tool_parser_spec()
messages = [
_normalize_qwen_message_with_tool_parser(message, tool_parser_spec)
for message in input_payload.get("messages", [])
]
if not messages:
messages = [{"role": "system", "content": ""}]
tools = [_normalize_tool_spec_for_template(tool) for tool in parameters.get("tools", []) if isinstance(tool, dict)]
return _load_qwen3_coder_chat_template().render(
messages=messages,
tools=tools,
add_generation_prompt=True,
)
class GLMTraceAdapter:
name = "glm"
def detect(self, raw):
return "request_params" in raw and "response_params" in raw
def parse_message(self, message):
role = str(message.get("role", "unknown"))
content = message.get("content")
if isinstance(content, str):
return ParsedMessageEvent(
role=role,
content_type="text",
text_len=len(content),
has_cache_control=False,
item_count=1,
)
if isinstance(content, list):
text_len = 0
has_cache_control = False
item_types = []
for item in content:
if isinstance(item, str):
text_len += len(item)
item_types.append("text")
continue
if not isinstance(item, dict):
text_len += len(compact_json(item))
item_types.append(type(item).__name__)
continue
if "text" in item and isinstance(item["text"], str):
text_len += len(item["text"])
else:
text_len += len(compact_json(item))
has_cache_control = has_cache_control or "cache_control" in item
item_types.append(str(item.get("type", "text" if "text" in item else "object")))
content_type = "+".join(sorted(set(item_types))) if item_types else "list"
return ParsedMessageEvent(
role=role,
content_type=content_type,
text_len=text_len,
has_cache_control=has_cache_control,
item_count=len(content),
)
if isinstance(content, dict):
return ParsedMessageEvent(
role=role,
content_type=str(content.get("type", "object")),
text_len=len(compact_json(content)),
has_cache_control="cache_control" in content,
item_count=1,
)
return ParsedMessageEvent(
role=role,
content_type=type(content).__name__,
text_len=len(str(content)),
has_cache_control=False,
item_count=1,
)
def parse_tool(self, tool):
function = tool.get("function", {})
name = ""
if isinstance(function, dict):
name = str(function.get("name", ""))
return ParsedToolSpec(
name=name,
tool_type=str(tool.get("type", "function")),
)
def build_canonical_prompt(self, payload):
return build_glm5_canonical_prompt(payload)
class QwenTraceAdapter(GLMTraceAdapter):
name = "qwen3-coder"
def detect(self, raw):
if "request_params" not in raw:
return False
return infer_model_family_from_request_model(raw.get("request_model")) == "qwen3-coder"
def build_canonical_prompt(self, payload):
return build_qwen3_coder_canonical_prompt(payload)
def get_raw_adapter(raw: dict[str, Any]):
for adapter in [QwenTraceAdapter(), GLMTraceAdapter()]:
if adapter.detect(raw):
return adapter
raise ValueError("Unsupported raw trace format for trace_formatter.")

View File

@@ -0,0 +1,217 @@
from __future__ import annotations
import hashlib
import json
from dataclasses import dataclass
from typing import Any
def normalize_content(content: Any) -> str:
if isinstance(content, str):
return content
try:
return json.dumps(content, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
except Exception:
return str(content)
def serialize_tool_calls(tool_calls: Any) -> str:
if tool_calls is None:
return ""
if isinstance(tool_calls, dict):
tool_calls = [tool_calls]
if not isinstance(tool_calls, list):
return normalize_content(tool_calls)
return "\n".join(
json.dumps(item, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
for item in tool_calls
)
def stable_message_fingerprint(message: dict[str, Any]) -> str:
role = str(message.get("role", "unknown"))
content = normalize_content(message.get("content"))
tool_calls = serialize_tool_calls(message.get("tool_calls"))
digest = hashlib.blake2b(digest_size=16)
digest.update(role.encode("utf-8", "ignore"))
digest.update(b"\x1f")
digest.update(content.encode("utf-8", "ignore"))
digest.update(b"\x1f")
digest.update(tool_calls.encode("utf-8", "ignore"))
return digest.hexdigest()
def build_prefix_hashes(messages: list[dict[str, Any]]) -> list[str]:
digest = hashlib.blake2b(digest_size=16)
prefixes: list[str] = []
for message in messages:
digest.update(stable_message_fingerprint(message).encode("ascii"))
digest.update(b"\n")
prefixes.append(digest.hexdigest())
return prefixes
def build_message_fingerprints(messages: list[dict[str, Any]]) -> list[str]:
return [stable_message_fingerprint(message) for message in messages]
def build_sequence_hashes(message_fingerprints: list[str]) -> list[str]:
digest = hashlib.blake2b(digest_size=16)
prefixes: list[str] = []
for fingerprint in message_fingerprints:
digest.update(fingerprint.encode("ascii"))
digest.update(b"\n")
prefixes.append(digest.hexdigest())
return prefixes
def encode_prefix_hashes(prefix_hashes: list[str]) -> str:
return ",".join(prefix_hashes)
def decode_prefix_hashes(encoded: str) -> list[str]:
if not encoded:
return []
return [item for item in encoded.split(",") if item]
def encode_message_fingerprints(message_fingerprints: list[str]) -> str:
return ",".join(message_fingerprints)
def decode_message_fingerprints(encoded: str) -> list[str]:
if not encoded:
return []
return [item for item in encoded.split(",") if item]
def encode_roles(roles: list[str]) -> str:
return ",".join(roles)
def decode_roles(encoded: str) -> list[str]:
if not encoded:
return []
return [item for item in encoded.split(",") if item]
def extract_user_id(request_params: dict[str, Any]) -> str:
header = request_params.get("header", {}) if isinstance(request_params, dict) else {}
attributes = header.get("attributes", {}) if isinstance(header, dict) else {}
return str(attributes.get("user_id", "") or "")
def build_root_session_id(user_id: str, request_id: str) -> str:
digest = hashlib.blake2b(digest_size=10)
digest.update(user_id.encode("utf-8", "ignore"))
digest.update(b"\x00")
digest.update(request_id.encode("utf-8", "ignore"))
return f"ls-{digest.hexdigest()}"
@dataclass
class SessionAssignment:
session_id: str
parent_request_id: str
parent_chat_id: int
chat_id: int
turn: int
@dataclass
class _SessionNode:
request_id: str
session_id: str
chat_id: int
turn: int
message_count: int
class LogicalSessionizer:
def __init__(self) -> None:
self._index: dict[tuple[str, str], _SessionNode] = {}
self._next_chat_id = 0
def assign_precomputed(
self,
*,
user_id: str,
request_id: str,
sequence_hashes: list[str],
roles: list[str],
) -> SessionAssignment:
parent: _SessionNode | None = None
scope_user_id = user_id or f"missing-user:{request_id}"
has_user_prefix = False
user_prefix_flags: list[bool] = []
for role in roles:
if role == "user":
has_user_prefix = True
user_prefix_flags.append(has_user_prefix)
for prefix_len in range(len(sequence_hashes), 0, -1):
if prefix_len - 1 >= len(user_prefix_flags) or not user_prefix_flags[prefix_len - 1]:
continue
candidate = self._index.get((scope_user_id, sequence_hashes[prefix_len - 1]))
if candidate is None:
continue
if prefix_len < len(sequence_hashes) or candidate.message_count > prefix_len:
parent = candidate
break
if parent is None:
session_id = build_root_session_id(scope_user_id, request_id)
parent_request_id = ""
parent_chat_id = -1
turn = 1
else:
session_id = parent.session_id
parent_request_id = parent.request_id
parent_chat_id = parent.chat_id
turn = parent.turn + 1
chat_id = self._next_chat_id
self._next_chat_id += 1
if sequence_hashes:
node = _SessionNode(
request_id=request_id,
session_id=session_id,
chat_id=chat_id,
turn=turn,
message_count=len(sequence_hashes),
)
self._index[(scope_user_id, sequence_hashes[-1])] = node
trailing_non_user = 0
for role in reversed(roles):
if role == "user":
break
trailing_non_user += 1
if trailing_non_user > 2:
break
prefix_len = len(sequence_hashes) - trailing_non_user
if prefix_len > 0 and prefix_len - 1 < len(user_prefix_flags) and user_prefix_flags[prefix_len - 1]:
self._index[(scope_user_id, sequence_hashes[prefix_len - 1])] = node
return SessionAssignment(
session_id=session_id,
parent_request_id=parent_request_id,
parent_chat_id=parent_chat_id,
chat_id=chat_id,
turn=turn,
)
def assign(
self,
*,
user_id: str,
request_id: str,
message_fingerprints: list[str],
roles: list[str],
) -> SessionAssignment:
return self.assign_precomputed(
user_id=user_id,
request_id=request_id,
sequence_hashes=build_sequence_hashes(message_fingerprints),
roles=roles,
)

View File

@@ -0,0 +1,108 @@
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
from trace_analyzer.helpers import parse_jsonish, safe_int
WINDOW_RE = re.compile(r"(?P<day>\d{4})-(?P<start>\d{4})-(?P<end>\d{4})$")
UTC_PLUS_8 = timezone(timedelta(hours=8))
@dataclass(frozen=True)
class TimeWindow:
label: str
start_ms: int
end_ms: int
def parse_time_to_ms(value: str) -> int:
text = str(value or "").strip()
if not text:
return 0
for fmt in ("%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S"):
try:
dt = datetime.strptime(text, fmt).replace(tzinfo=UTC_PLUS_8)
return int(dt.timestamp() * 1000)
except ValueError:
continue
raise ValueError(f"Unsupported timestamp format: {value!r}")
def _read_first_timestamp(path: Path) -> str:
with path.open("r", encoding="utf-8") as handle:
for line in handle:
stripped = line.strip()
if not stripped:
continue
raw = json.loads(stripped)
value = str(raw.get("time", "")).strip()
if value:
return value
raise ValueError(f"Could not find time field in {path}")
def _read_first_timestamp_and_ready_ms(path: Path) -> tuple[str, int]:
with path.open("r", encoding="utf-8") as handle:
for line in handle:
stripped = line.strip()
if not stripped:
continue
raw = json.loads(stripped)
value = str(raw.get("time", "")).strip()
if not value:
continue
request_params = parse_jsonish(raw.get("request_params", {}))
header = request_params.get("header", {}) if isinstance(request_params, dict) else {}
attributes = header.get("attributes", {}) if isinstance(header, dict) else {}
ready_ms = safe_int(attributes.get("x-dashscope-inner-requestreadytime"))
return value, ready_ms
raise ValueError(f"Could not find time field in {path}")
def infer_time_offset_ms(path: Path) -> int:
first_time, first_ready_ms = _read_first_timestamp_and_ready_ms(path)
if not first_ready_ms:
return 0
wall_clock_ms = parse_time_to_ms(first_time)
hour_ms = 60 * 60 * 1000
return int(round((wall_clock_ms - first_ready_ms) / hour_ms)) * hour_ms
def infer_time_window(
source_files: list[Path],
*,
start_time: str | None = None,
end_time: str | None = None,
) -> TimeWindow | None:
if start_time and end_time:
start_ms = parse_time_to_ms(start_time)
end_ms = parse_time_to_ms(end_time)
label = (
f"{datetime.fromtimestamp(start_ms / 1000, tz=UTC_PLUS_8):%m%d%H}"
f"-{datetime.fromtimestamp(end_ms / 1000, tz=UTC_PLUS_8):%m%d%H}"
)
return TimeWindow(label=label, start_ms=start_ms, end_ms=end_ms)
if not source_files:
return None
first_match = WINDOW_RE.match(source_files[0].stem)
last_match = WINDOW_RE.match(source_files[-1].stem)
if first_match is None or last_match is None:
return None
first_time = _read_first_timestamp(source_files[0])
time_offset_ms = infer_time_offset_ms(source_files[0])
base_date = first_time.split(" ", 1)[0]
start_hhmm = first_match.group("start")
end_hhmm = last_match.group("end")
start_text = f"{base_date} {start_hhmm[:2]}:{start_hhmm[2:]}:00.000"
end_text = f"{base_date} {end_hhmm[:2]}:{end_hhmm[2:]}:00.000"
start_ms = parse_time_to_ms(start_text) - time_offset_ms
end_ms = parse_time_to_ms(end_text) - time_offset_ms
label = f"{first_match.group('day')}{start_hhmm[:2]}-{last_match.group('day')}{end_hhmm[:2]}"
return TimeWindow(label=label, start_ms=start_ms, end_ms=end_ms)

View File

@@ -0,0 +1,117 @@
{% macro render_extra_keys(json_dict, handled_keys) %}
{%- if json_dict is mapping %}
{%- for json_key in json_dict if json_key not in handled_keys %}
{%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}
{{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}
{%- else %}
{{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}
{%- endif %}
{%- endfor %}
{%- endif %}
{% endmacro %}
{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = [] %}
{%- endif %}
{%- if system_message is defined %}
{{- "<|im_start|>system\n" + system_message }}
{%- else %}
{%- if tools is iterable and tools | length > 0 %}
{{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }}
{%- endif %}
{%- endif %}
{%- if tools is iterable and tools | length > 0 %}
{{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }}
{{- "<tools>" }}
{%- for tool in tools %}
{%- if tool.function is defined %}
{%- set tool = tool.function %}
{%- endif %}
{{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
{%- if tool.description is defined %}
{{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
{%- endif %}
{{- '\n<parameters>' }}
{%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
{%- for param_name, param_fields in tool.parameters.properties|items %}
{{- '\n<parameter>' }}
{{- '\n<name>' ~ param_name ~ '</name>' }}
{%- if param_fields.type is defined %}
{{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
{%- endif %}
{%- if param_fields.description is defined %}
{{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
{%- endif %}
{%- set handled_keys = ['name', 'type', 'description'] %}
{{- render_extra_keys(param_fields, handled_keys) }}
{{- '\n</parameter>' }}
{%- endfor %}
{%- endif %}
{% set handled_keys = ['type', 'properties'] %}
{{- render_extra_keys(tool.parameters, handled_keys) }}
{{- '\n</parameters>' }}
{%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
{{- render_extra_keys(tool, handled_keys) }}
{{- '\n</function>' }}
{%- endfor %}
{{- "\n</tools>" }}
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
{%- endif %}
{%- if system_message is defined %}
{{- '<|im_end|>\n' }}
{%- else %}
{%- if tools is iterable and tools | length > 0 %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- for message in loop_messages %}
{%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}
{{- '<|im_start|>' + message.role }}
{%- if message.content is defined and message.content is string and message.content | trim | length > 0 %}
{{- '\n' + message.content | trim + '\n' }}
{%- endif %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
{%- if tool_call.arguments is defined %}
{%- for args_name, args_value in tool_call.arguments|items %}
{{- '<parameter=' + args_name + '>\n' }}
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
{{- args_value }}
{{- '\n</parameter>\n' }}
{%- endfor %}
{%- endif %}
{{- '</function>\n</tool_call>' }}
{%- endfor %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %}
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
{%- elif message.role == "tool" %}
{%- if loop.previtem and loop.previtem.role != "tool" %}
{{- '<|im_start|>user\n' }}
{%- endif %}
{{- '<tool_response>\n' }}
{{- message.content }}
{{- '\n</tool_response>\n' }}
{%- if not loop.last and loop.nextitem.role != "tool" %}
{{- '<|im_end|>\n' }}
{%- elif loop.last %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}

View File

@@ -0,0 +1,39 @@
{
"architectures": [
"Qwen3MoeForCausalLM"
],
"attention_dropout": 0.0,
"decoder_sparse_step": 1,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 6144,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 262144,
"max_window_layers": 62,
"mlp_only_layers": [],
"model_type": "qwen3_moe",
"moe_intermediate_size": 2560,
"norm_topk_prob": true,
"num_attention_heads": 96,
"num_experts": 160,
"num_experts_per_tok": 8,
"num_hidden_layers": 62,
"num_key_value_heads": 8,
"output_router_logits": false,
"qkv_bias": false,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 10000000,
"router_aux_loss_coef": 0.0,
"shared_expert_intermediate_size": 0,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0",
"use_cache": true,
"use_qk_norm": true,
"use_sliding_window": false,
"vocab_size": 151936
}

View File

@@ -0,0 +1,689 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
import uuid
from collections.abc import Sequence
from typing import Any, List, Optional, Union
import regex as re
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module("qwen3_coder")
class Qwen3CoderToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = []
# Sentinel tokens for streaming mode
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_prefix: str = "<function="
self.function_end_token: str = "</function>"
self.parameter_prefix: str = "<parameter="
self.parameter_end_token: str = "</parameter>"
self.is_tool_call_started: bool = False
self.failed_count: int = 0
# Enhanced streaming state - reset for each new message
self._reset_streaming_state()
# Regex patterns
self.tool_call_complete_regex = re.compile(
r"<tool_call>(.*?)</tool_call>", re.DOTALL)
self.tool_call_regex = re.compile(
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL)
self.tool_call_function_regex = re.compile(
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL)
self.tool_call_parameter_regex = re.compile(
r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
raise RuntimeError(
"Qwen3 XML Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
logger.info(
f"vLLM Successfully import tool parser {self.__class__.__name__} !"
)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def _reset_streaming_state(self):
"""Reset all streaming state."""
self.current_tool_index = 0
self.is_tool_call_started = False
self.header_sent = False
self.current_tool_id = None
self.current_function_name = None
self.current_param_name = None
self.current_param_value = ""
self.param_count = 0
self.in_param = False
self.in_function = False
self.accumulated_text = ""
self.json_started = False
self.json_closed = False
# Store accumulated parameters for type conversion
self.accumulated_params = {}
self.streaming_request = None
def _get_arguments_config(
self, func_name: str,
tools: Optional[list[ChatCompletionToolsParam]]) -> dict:
"""Extract argument configuration for a function."""
if tools is None:
return {}
for config in tools:
if not hasattr(config, "type") or not (hasattr(
config, "function") and hasattr(config.function, "name")):
continue
if config.type == "function" and config.function.name == func_name:
if not hasattr(config.function, "parameters"):
return {}
params = config.function.parameters
if isinstance(params, dict) and "properties" in params:
return params["properties"]
elif isinstance(params, dict):
return params
else:
return {}
logger.warning(f"Tool '{func_name}' is not defined in the tools list.")
return {}
def _convert_param_value(self, param_value: str, param_name: str,
param_config: dict, func_name: str) -> Any:
"""Convert parameter value based on its type in the schema."""
# Handle null value for any type
if param_value.lower() == "null":
return None
if param_name not in param_config:
if param_config != {}:
logger.warning(
f"Parsed parameter '{param_name}' is not defined in the tool "
f"parameters for tool '{func_name}', directly returning the string value."
)
return param_value
if isinstance(param_config[param_name],
dict) and "type" in param_config[param_name]:
param_type = str(param_config[param_name]["type"]).strip().lower()
else:
param_type = "string"
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
return param_value
elif param_type.startswith("int") or param_type.startswith(
"uint") or param_type.startswith(
"long") or param_type.startswith(
"short") or param_type.startswith("unsigned"):
try:
param_value = int(param_value)
except:
logger.warning(
f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool "
f"'{func_name}', degenerating to string.")
return param_value
elif param_type.startswith("num") or param_type.startswith("float"):
try:
float_param_value = float(param_value)
param_value = float_param_value if float_param_value - int(
float_param_value) != 0 else int(float_param_value)
except:
logger.warning(
f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool "
f"'{func_name}', degenerating to string.")
return param_value
elif param_type in ["boolean", "bool", "binary"]:
param_value = param_value.lower()
if param_value not in ["true", "false"]:
logger.warning(
f"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false."
)
return param_value == "true"
else:
if param_type in ["object", "array", "arr"
] or param_type.startswith(
"dict") or param_type.startswith("list"):
try:
param_value = json.loads(param_value)
return param_value
except:
logger.warning(
f"Parsed value '{param_value}' of parameter '{param_name}' cannot be parsed with json.loads in tool "
f"'{func_name}', will try other methods to parse it.")
try:
param_value = ast.literal_eval(param_value) # safer
except:
logger.warning(
f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `ast.literal_eval()` in tool '{func_name}', degenerating to string."
)
return param_value
def _parse_xml_function_call(
self, function_call_str: str,
tools: Optional[list[ChatCompletionToolsParam]]
) -> Optional[ToolCall]:
# Extract function name
end_index = function_call_str.index(">")
function_name = function_call_str[:end_index]
param_config = self._get_arguments_config(function_name, tools)
parameters = function_call_str[end_index + 1:]
param_dict = {}
for match_text in self.tool_call_parameter_regex.findall(parameters):
idx = match_text.index(">")
param_name = match_text[:idx]
param_value = str(match_text[idx + 1:])
# Remove prefix and trailing \n
if param_value.startswith("\n"):
param_value = param_value[1:]
if param_value.endswith("\n"):
param_value = param_value[:-1]
param_dict[param_name] = self._convert_param_value(
param_value, param_name, param_config, function_name)
return ToolCall(
type="function",
function=FunctionCall(name=function_name,
arguments=json.dumps(param_dict,
ensure_ascii=False)),
)
def _get_function_calls(self, model_output: str) -> List[str]:
# Find all tool calls
matched_ranges = self.tool_call_regex.findall(model_output)
raw_tool_calls = [
match[0] if match[0] else match[1] for match in matched_ranges
]
# Back-off strategy if no tool_call tags found
if len(raw_tool_calls) == 0:
raw_tool_calls = [model_output]
raw_function_calls = []
for tool_call in raw_tool_calls:
raw_function_calls.extend(
self.tool_call_function_regex.findall(tool_call))
function_calls = [
match[0] if match[0] else match[1] for match in raw_function_calls
]
return function_calls
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# Quick check to avoid unnecessary processing
if self.tool_call_prefix not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
function_calls = self._get_function_calls(model_output)
if len(function_calls) == 0:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
tool_calls = [
self._parse_xml_function_call(function_call_str, request.tools)
for function_call_str in function_calls
]
# Populate prev_tool_call_arr for serving layer to set finish_reason
self.prev_tool_call_arr.clear() # Clear previous calls
for tool_call in tool_calls:
if tool_call:
self.prev_tool_call_arr.append({
"name":
tool_call.function.name,
"arguments":
tool_call.function.arguments,
})
# Extract content before tool calls
content_index = model_output.find(self.tool_call_start_token)
content_index = content_index if content_index >= 0 else model_output.find(
self.tool_call_prefix)
content = model_output[:content_index] # .rstrip()
return ExtractedToolCallInformation(
tools_called=(len(tool_calls) > 0),
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# Store request for type conversion
if not previous_text:
self._reset_streaming_state()
self.streaming_request = request
# If no delta text, return None unless it's an EOS token after tool calls
if not delta_text:
# Check if this is an EOS token after all tool calls are complete
# We check for tool calls in the text even if is_tool_call_started is False
# because it might have been reset after processing all tools
if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
# Count complete tool calls
complete_calls = len(
self.tool_call_complete_regex.findall(current_text))
# If we have completed tool calls and populated prev_tool_call_arr
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token) - current_text.count(
self.tool_call_end_token)
if open_calls == 0:
# Return empty delta message to allow finish_reason processing
return DeltaMessage(content="")
elif not self.is_tool_call_started and current_text:
# This is a regular content response that's now complete
return DeltaMessage(content="")
return None
# Update accumulated text
self.accumulated_text = current_text
# Check if we need to advance to next tool
if self.json_closed and not self.in_function:
# Check if this tool call has ended
tool_ends = current_text.count(self.tool_call_end_token)
if tool_ends > self.current_tool_index:
# This tool has ended, advance to next
self.current_tool_index += 1
self.header_sent = False
self.param_count = 0
self.json_started = False
self.json_closed = False
self.accumulated_params = {}
# Check if there are more tool calls
tool_starts = current_text.count(self.tool_call_start_token)
if self.current_tool_index >= tool_starts:
# No more tool calls
self.is_tool_call_started = False
# Continue processing next tool
return None
# Handle normal content before tool calls
if not self.is_tool_call_started:
# Check if tool call is starting
if self.tool_call_start_token_id in delta_token_ids or self.tool_call_start_token in delta_text:
self.is_tool_call_started = True
# Return any content before the tool call
if self.tool_call_start_token in delta_text:
content_before = delta_text[:delta_text.index(
self.tool_call_start_token)]
if content_before:
return DeltaMessage(content=content_before)
return None
else:
# Check if we're between tool calls - skip whitespace
if current_text.rstrip().endswith(self.tool_call_end_token):
# We just ended a tool call, skip whitespace
if delta_text.strip() == "":
return None
# Normal content, no tool call
return DeltaMessage(content=delta_text)
# Check if we're between tool calls (waiting for next one)
# Count tool calls we've seen vs processed
tool_starts_count = current_text.count(self.tool_call_start_token)
if self.current_tool_index >= tool_starts_count:
# We're past all tool calls, shouldn't be here
return None
# We're in a tool call, find the current tool call portion
# Need to find the correct tool call based on current_tool_index
tool_starts = []
idx = 0
while True:
idx = current_text.find(self.tool_call_start_token, idx)
if idx == -1:
break
tool_starts.append(idx)
idx += len(self.tool_call_start_token)
if self.current_tool_index >= len(tool_starts):
# No more tool calls to process yet
return None
tool_start_idx = tool_starts[self.current_tool_index]
# Find where this tool call ends (or current position if not ended yet)
tool_end_idx = current_text.find(self.tool_call_end_token,
tool_start_idx)
if tool_end_idx == -1:
tool_text = current_text[tool_start_idx:]
else:
tool_text = current_text[tool_start_idx:tool_end_idx +
len(self.tool_call_end_token)]
# Looking for function header
if not self.header_sent:
if self.tool_call_prefix in tool_text:
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix)
func_end = tool_text.find(">", func_start)
if func_end != -1:
# Found complete function name
self.current_function_name = tool_text[func_start:func_end]
self.current_tool_id = self._generate_tool_call_id()
self.header_sent = True
self.in_function = True
# IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call
# This ensures finish_reason="tool_calls" even if parsing isn't complete
already_added = any(
tool.get("name") == self.current_function_name
for tool in self.prev_tool_call_arr)
if not already_added:
self.prev_tool_call_arr.append({
"name": self.current_function_name,
"arguments":
"{}", # Placeholder, will be updated later
})
# Send header with function info
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_id,
function=DeltaFunctionCall(
name=self.current_function_name, arguments=""),
type="function",
)
])
return None
# We've sent header, now handle function body
if self.in_function:
# Send opening brace if not sent yet
if not self.json_started and self.parameter_prefix not in delta_text:
self.json_started = True
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="{"),
)
])
# Make sure json_started is set if we're processing parameters
if not self.json_started:
self.json_started = True
# Check for function end in accumulated text
if not self.json_closed and self.function_end_token in tool_text:
# Close JSON
self.json_closed = True
# Extract the complete tool call to update prev_tool_call_arr with final arguments
# Find the function content
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix)
func_content_end = tool_text.find(self.function_end_token,
func_start)
if func_content_end != -1:
func_content = tool_text[func_start:func_content_end]
# Parse to get the complete arguments
try:
parsed_tool = self._parse_xml_function_call(
func_content, self.streaming_request.tools
if self.streaming_request else None)
if parsed_tool:
# Update existing entry in prev_tool_call_arr with complete arguments
for i, tool in enumerate(self.prev_tool_call_arr):
if tool.get(
"name") == parsed_tool.function.name:
self.prev_tool_call_arr[i][
"arguments"] = parsed_tool.function.arguments
break
except Exception:
pass # Ignore parsing errors during streaming
result = DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="}"),
)
])
# Reset state for next tool
self.in_function = False
self.json_closed = True
self.accumulated_params = {}
return result
# Look for parameters
# Find all parameter starts
param_starts = []
idx = 0
while True:
idx = tool_text.find(self.parameter_prefix, idx)
if idx == -1:
break
param_starts.append(idx)
idx += len(self.parameter_prefix)
# Check if we should start a new parameter
if not self.in_param and self.param_count < len(param_starts):
if len(param_starts) > self.param_count:
# Process the next parameter
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:]
if ">" in remaining:
# We have the complete parameter name
name_end = remaining.find(">")
self.current_param_name = remaining[:name_end]
# Find the parameter value
value_start = param_start + name_end + 1
value_text = tool_text[value_start:]
if value_text.startswith("\n"):
value_text = value_text[1:]
# Find where this parameter ends
param_end_idx = value_text.find(
self.parameter_end_token)
if param_end_idx == -1:
# No closing tag, look for next parameter or function end
next_param_idx = value_text.find(
self.parameter_prefix)
func_end_idx = value_text.find(
self.function_end_token)
if next_param_idx != -1 and (func_end_idx == -1
or next_param_idx
< func_end_idx):
param_end_idx = next_param_idx
elif func_end_idx != -1:
param_end_idx = func_end_idx
else:
# Neither found, check if tool call is complete
if self.tool_call_end_token in tool_text:
# Tool call is complete, so parameter must be complete too
# Use all remaining text before function end as value
param_end_idx = len(value_text)
else:
# Still streaming, wait for more content
return None
if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]
# Store raw value for later processing
self.accumulated_params[
self.current_param_name] = param_value
# Get parameter configuration for type conversion
param_config = self._get_arguments_config(
self.current_function_name,
self.streaming_request.tools
if self.streaming_request else None)
# Convert the parameter value to the appropriate type
converted_value = self._convert_param_value(
param_value, self.current_param_name,
param_config, self.current_function_name)
# Build JSON fragment based on the converted type
# Use json.dumps to properly serialize the value
serialized_value = json.dumps(converted_value,
ensure_ascii=False)
if self.param_count == 0:
json_fragment = f'"{self.current_param_name}": {serialized_value}'
else:
json_fragment = f', "{self.current_param_name}": {serialized_value}'
self.param_count += 1
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=json_fragment),
)
])
# Continue parameter value - Not used in the current implementation
# since we process complete parameters above
if self.in_param:
if self.parameter_end_token in delta_text:
# End of parameter
end_idx = delta_text.find(self.parameter_end_token)
value_chunk = delta_text[:end_idx]
# Skip past > if at start
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1:]
if not self.current_param_value and value_chunk.startswith(
"\n"):
value_chunk = value_chunk[1:]
# Store complete value
full_value = self.current_param_value + value_chunk
self.accumulated_params[
self.current_param_name] = full_value
# Get parameter configuration for type conversion
param_config = self._get_arguments_config(
self.current_function_name,
self.streaming_request.tools
if self.streaming_request else None)
# Convert the parameter value to the appropriate type
converted_value = self._convert_param_value(
full_value, self.current_param_name, param_config,
self.current_function_name)
# Serialize the converted value
serialized_value = json.dumps(converted_value,
ensure_ascii=False)
# Since we've been streaming the quoted version, we need to close it properly
# This is complex - for now just complete the value
self.in_param = False
self.current_param_value = ""
# Just close the current parameter string
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments='"'), # Close the string quote
)
])
else:
# Continue accumulating value
value_chunk = delta_text
# Handle first chunk after param name
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1:]
if not self.current_param_value and value_chunk.startswith(
"\n"):
value_chunk = value_chunk[1:]
if value_chunk:
# Stream the escaped delta
prev_escaped = json.dumps(
self.current_param_value, ensure_ascii=False
)[1:-1] if self.current_param_value else ""
self.current_param_value += value_chunk
full_escaped = json.dumps(self.current_param_value,
ensure_ascii=False)[1:-1]
delta_escaped = full_escaped[len(prev_escaped):]
if delta_escaped:
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped),
)
])
return None

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,86 @@
[gMASK]<sop>
{%- if tools -%}
<|system|>
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{% for tool in tools %}
{{ tool | tojson(ensure_ascii=False) }}
{% endfor %}
</tools>
For each function call, output the function name and arguments within the following XML format:
<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>{%- endif -%}
{%- macro visible_text(content) -%}
{%- if content is string -%}
{{- content }}
{%- elif content is iterable and content is not mapping -%}
{%- for item in content -%}
{%- if item is mapping and item.type == 'text' -%}
{{- item.text }}
{%- elif item is string -%}
{{- item }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{- content }}
{%- endif -%}
{%- endmacro -%}
{%- set ns = namespace(last_user_index=-1) %}
{%- for m in messages %}
{%- if m.role == 'user' %}
{%- set ns.last_user_index = loop.index0 -%}
{%- endif %}
{%- endfor %}
{%- for m in messages -%}
{%- if m.role == 'user' -%}<|user|>{{ visible_text(m.content) }}
{%- elif m.role == 'assistant' -%}
<|assistant|>
{%- set reasoning_content = '' %}
{%- set content = visible_text(m.content) %}
{%- if m.reasoning_content is string %}
{%- set reasoning_content = m.reasoning_content %}
{%- else %}
{%- if '</think>' in content %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- if ((clear_thinking is defined and not clear_thinking) or loop.index0 > ns.last_user_index) and reasoning_content -%}
{{ '<think>' + reasoning_content.strip() + '</think>'}}
{%- else -%}
{{ '</think>' }}
{%- endif -%}
{%- if content.strip() -%}
{{ content.strip() }}
{%- endif -%}
{% if m.tool_calls %}
{% for tc in m.tool_calls %}
{%- if tc.function %}
{%- set tc = tc.function %}
{%- endif %}
{{- '<tool_call>' + tc.name -}}
{% set _args = tc.arguments %}{% for k, v in _args.items() %}<arg_key>{{ k }}</arg_key><arg_value>{{ v | tojson(ensure_ascii=False) if v is not string else v }}</arg_value>{% endfor %}</tool_call>{% endfor %}
{% endif %}
{%- elif m.role == 'tool' -%}
{%- if m.content is string -%}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|observation|>' }}
{%- endif %}
{{- '<tool_response>' }}
{{- m.content }}
{{- '</tool_response>' }}
{%- else -%}
<|observation|>{% for tr in m.content %}
<tool_response>{{ tr.output if tr.output is defined else tr }}</tool_response>{% endfor -%}
{% endif -%}
{%- elif m.role == 'system' -%}
<|system|>{{ visible_text(m.content) }}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
<|assistant|>{{- '</think>' if (enable_thinking is defined and not enable_thinking) else '<think>' -}}
{%- endif -%}

View File

@@ -0,0 +1,782 @@
{
"architectures": [
"GlmMoeDsaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"dtype": "bfloat16",
"eos_token_id": [
154820,
154827,
154829
],
"ep_size": 1,
"first_k_dense_replace": 3,
"hidden_act": "silu",
"head_dim": 64,
"hidden_size": 6144,
"index_head_dim": 128,
"index_n_heads": 32,
"index_topk": 2048,
"indexer_rope_interleave": true,
"initializer_range": 0.02,
"intermediate_size": 12288,
"kv_lora_rank": 512,
"max_position_embeddings": 202752,
"moe_intermediate_size": 2048,
"moe_layer_freq": 1,
"model_type": "glm_moe_dsa",
"n_group": 1,
"n_routed_experts": 256,
"n_shared_experts": 1,
"norm_topk_prob": true,
"num_attention_heads": 64,
"num_experts_per_tok": 8,
"num_hidden_layers": 78,
"num_key_value_heads": 64,
"num_nextn_predict_layers": 1,
"pad_token_id": 154820,
"pretraining_tp": 1,
"q_lora_rank": 2048,
"qk_head_dim": 256,
"qk_nope_head_dim": 192,
"qk_rope_head_dim": 64,
"rms_norm_eps": 1e-05,
"rope_interleave": true,
"rope_parameters": {
"rope_theta": 1000000,
"rope_type": "default"
},
"routed_scaling_factor": 2.5,
"scoring_func": "sigmoid",
"tie_word_embeddings": false,
"topk_group": 1,
"topk_method": "noaux_tc",
"transformers_version": "5.0.2.dev0",
"use_cache": true,
"v_head_dim": 256,
"vocab_size": 154880,
"quantization_config": {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [
128,
128
],
"modules_to_not_convert": [
"lm_head",
"model.embed_tokens",
"model.layers.0.input_layernorm",
"model.layers.0.post_attention_layernorm",
"model.layers.0.self_attn.indexer.k_norm",
"model.layers.0.self_attn.indexer.k_norm.bias",
"model.layers.0.self_attn.indexers_proj",
"model.layers.0.self_attn.kv_a_layernorm",
"model.layers.0.self_attn.q_a_layernorm",
"model.layers.1.input_layernorm",
"model.layers.1.post_attention_layernorm",
"model.layers.1.self_attn.indexer.k_norm",
"model.layers.1.self_attn.indexer.k_norm.bias",
"model.layers.1.self_attn.indexers_proj",
"model.layers.1.self_attn.kv_a_layernorm",
"model.layers.1.self_attn.q_a_layernorm",
"model.layers.2.input_layernorm",
"model.layers.2.post_attention_layernorm",
"model.layers.2.self_attn.indexer.k_norm",
"model.layers.2.self_attn.indexer.k_norm.bias",
"model.layers.2.self_attn.indexers_proj",
"model.layers.2.self_attn.kv_a_layernorm",
"model.layers.2.self_attn.q_a_layernorm",
"model.layers.3.input_layernorm",
"model.layers.3.mlp.gate",
"model.layers.3.mlp.gate.e_score_correction_bias",
"model.layers.3.post_attention_layernorm",
"model.layers.3.self_attn.indexer.k_norm",
"model.layers.3.self_attn.indexer.k_norm.bias",
"model.layers.3.self_attn.indexers_proj",
"model.layers.3.self_attn.kv_a_layernorm",
"model.layers.3.self_attn.q_a_layernorm",
"model.layers.4.input_layernorm",
"model.layers.4.mlp.gate",
"model.layers.4.mlp.gate.e_score_correction_bias",
"model.layers.4.post_attention_layernorm",
"model.layers.4.self_attn.indexer.k_norm",
"model.layers.4.self_attn.indexer.k_norm.bias",
"model.layers.4.self_attn.indexers_proj",
"model.layers.4.self_attn.kv_a_layernorm",
"model.layers.4.self_attn.q_a_layernorm",
"model.layers.5.input_layernorm",
"model.layers.5.mlp.gate",
"model.layers.5.mlp.gate.e_score_correction_bias",
"model.layers.5.post_attention_layernorm",
"model.layers.5.self_attn.indexer.k_norm",
"model.layers.5.self_attn.indexer.k_norm.bias",
"model.layers.5.self_attn.indexers_proj",
"model.layers.5.self_attn.kv_a_layernorm",
"model.layers.5.self_attn.q_a_layernorm",
"model.layers.6.input_layernorm",
"model.layers.6.mlp.gate",
"model.layers.6.mlp.gate.e_score_correction_bias",
"model.layers.6.post_attention_layernorm",
"model.layers.6.self_attn.indexer.k_norm",
"model.layers.6.self_attn.indexer.k_norm.bias",
"model.layers.6.self_attn.indexers_proj",
"model.layers.6.self_attn.kv_a_layernorm",
"model.layers.6.self_attn.q_a_layernorm",
"model.layers.7.input_layernorm",
"model.layers.7.mlp.gate",
"model.layers.7.mlp.gate.e_score_correction_bias",
"model.layers.7.post_attention_layernorm",
"model.layers.7.self_attn.indexer.k_norm",
"model.layers.7.self_attn.indexer.k_norm.bias",
"model.layers.7.self_attn.indexers_proj",
"model.layers.7.self_attn.kv_a_layernorm",
"model.layers.7.self_attn.q_a_layernorm",
"model.layers.8.input_layernorm",
"model.layers.8.mlp.gate",
"model.layers.8.mlp.gate.e_score_correction_bias",
"model.layers.8.post_attention_layernorm",
"model.layers.8.self_attn.indexer.k_norm",
"model.layers.8.self_attn.indexer.k_norm.bias",
"model.layers.8.self_attn.indexers_proj",
"model.layers.8.self_attn.kv_a_layernorm",
"model.layers.8.self_attn.q_a_layernorm",
"model.layers.9.input_layernorm",
"model.layers.9.mlp.gate",
"model.layers.9.mlp.gate.e_score_correction_bias",
"model.layers.9.post_attention_layernorm",
"model.layers.9.self_attn.indexer.k_norm",
"model.layers.9.self_attn.indexer.k_norm.bias",
"model.layers.9.self_attn.indexers_proj",
"model.layers.9.self_attn.kv_a_layernorm",
"model.layers.9.self_attn.q_a_layernorm",
"model.layers.10.input_layernorm",
"model.layers.10.mlp.gate",
"model.layers.10.mlp.gate.e_score_correction_bias",
"model.layers.10.post_attention_layernorm",
"model.layers.10.self_attn.indexer.k_norm",
"model.layers.10.self_attn.indexer.k_norm.bias",
"model.layers.10.self_attn.indexers_proj",
"model.layers.10.self_attn.kv_a_layernorm",
"model.layers.10.self_attn.q_a_layernorm",
"model.layers.11.input_layernorm",
"model.layers.11.mlp.gate",
"model.layers.11.mlp.gate.e_score_correction_bias",
"model.layers.11.post_attention_layernorm",
"model.layers.11.self_attn.indexer.k_norm",
"model.layers.11.self_attn.indexer.k_norm.bias",
"model.layers.11.self_attn.indexers_proj",
"model.layers.11.self_attn.kv_a_layernorm",
"model.layers.11.self_attn.q_a_layernorm",
"model.layers.12.input_layernorm",
"model.layers.12.mlp.gate",
"model.layers.12.mlp.gate.e_score_correction_bias",
"model.layers.12.post_attention_layernorm",
"model.layers.12.self_attn.indexer.k_norm",
"model.layers.12.self_attn.indexer.k_norm.bias",
"model.layers.12.self_attn.indexers_proj",
"model.layers.12.self_attn.kv_a_layernorm",
"model.layers.12.self_attn.q_a_layernorm",
"model.layers.13.input_layernorm",
"model.layers.13.mlp.gate",
"model.layers.13.mlp.gate.e_score_correction_bias",
"model.layers.13.post_attention_layernorm",
"model.layers.13.self_attn.indexer.k_norm",
"model.layers.13.self_attn.indexer.k_norm.bias",
"model.layers.13.self_attn.indexers_proj",
"model.layers.13.self_attn.kv_a_layernorm",
"model.layers.13.self_attn.q_a_layernorm",
"model.layers.14.input_layernorm",
"model.layers.14.mlp.gate",
"model.layers.14.mlp.gate.e_score_correction_bias",
"model.layers.14.post_attention_layernorm",
"model.layers.14.self_attn.indexer.k_norm",
"model.layers.14.self_attn.indexer.k_norm.bias",
"model.layers.14.self_attn.indexers_proj",
"model.layers.14.self_attn.kv_a_layernorm",
"model.layers.14.self_attn.q_a_layernorm",
"model.layers.15.input_layernorm",
"model.layers.15.mlp.gate",
"model.layers.15.mlp.gate.e_score_correction_bias",
"model.layers.15.post_attention_layernorm",
"model.layers.15.self_attn.indexer.k_norm",
"model.layers.15.self_attn.indexer.k_norm.bias",
"model.layers.15.self_attn.indexers_proj",
"model.layers.15.self_attn.kv_a_layernorm",
"model.layers.15.self_attn.q_a_layernorm",
"model.layers.16.input_layernorm",
"model.layers.16.mlp.gate",
"model.layers.16.mlp.gate.e_score_correction_bias",
"model.layers.16.post_attention_layernorm",
"model.layers.16.self_attn.indexer.k_norm",
"model.layers.16.self_attn.indexer.k_norm.bias",
"model.layers.16.self_attn.indexers_proj",
"model.layers.16.self_attn.kv_a_layernorm",
"model.layers.16.self_attn.q_a_layernorm",
"model.layers.17.input_layernorm",
"model.layers.17.mlp.gate",
"model.layers.17.mlp.gate.e_score_correction_bias",
"model.layers.17.post_attention_layernorm",
"model.layers.17.self_attn.indexer.k_norm",
"model.layers.17.self_attn.indexer.k_norm.bias",
"model.layers.17.self_attn.indexers_proj",
"model.layers.17.self_attn.kv_a_layernorm",
"model.layers.17.self_attn.q_a_layernorm",
"model.layers.18.input_layernorm",
"model.layers.18.mlp.gate",
"model.layers.18.mlp.gate.e_score_correction_bias",
"model.layers.18.post_attention_layernorm",
"model.layers.18.self_attn.indexer.k_norm",
"model.layers.18.self_attn.indexer.k_norm.bias",
"model.layers.18.self_attn.indexers_proj",
"model.layers.18.self_attn.kv_a_layernorm",
"model.layers.18.self_attn.q_a_layernorm",
"model.layers.19.input_layernorm",
"model.layers.19.mlp.gate",
"model.layers.19.mlp.gate.e_score_correction_bias",
"model.layers.19.post_attention_layernorm",
"model.layers.19.self_attn.indexer.k_norm",
"model.layers.19.self_attn.indexer.k_norm.bias",
"model.layers.19.self_attn.indexers_proj",
"model.layers.19.self_attn.kv_a_layernorm",
"model.layers.19.self_attn.q_a_layernorm",
"model.layers.20.input_layernorm",
"model.layers.20.mlp.gate",
"model.layers.20.mlp.gate.e_score_correction_bias",
"model.layers.20.post_attention_layernorm",
"model.layers.20.self_attn.indexer.k_norm",
"model.layers.20.self_attn.indexer.k_norm.bias",
"model.layers.20.self_attn.indexers_proj",
"model.layers.20.self_attn.kv_a_layernorm",
"model.layers.20.self_attn.q_a_layernorm",
"model.layers.21.input_layernorm",
"model.layers.21.mlp.gate",
"model.layers.21.mlp.gate.e_score_correction_bias",
"model.layers.21.post_attention_layernorm",
"model.layers.21.self_attn.indexer.k_norm",
"model.layers.21.self_attn.indexer.k_norm.bias",
"model.layers.21.self_attn.indexers_proj",
"model.layers.21.self_attn.kv_a_layernorm",
"model.layers.21.self_attn.q_a_layernorm",
"model.layers.22.input_layernorm",
"model.layers.22.mlp.gate",
"model.layers.22.mlp.gate.e_score_correction_bias",
"model.layers.22.post_attention_layernorm",
"model.layers.22.self_attn.indexer.k_norm",
"model.layers.22.self_attn.indexer.k_norm.bias",
"model.layers.22.self_attn.indexers_proj",
"model.layers.22.self_attn.kv_a_layernorm",
"model.layers.22.self_attn.q_a_layernorm",
"model.layers.23.input_layernorm",
"model.layers.23.mlp.gate",
"model.layers.23.mlp.gate.e_score_correction_bias",
"model.layers.23.post_attention_layernorm",
"model.layers.23.self_attn.indexer.k_norm",
"model.layers.23.self_attn.indexer.k_norm.bias",
"model.layers.23.self_attn.indexers_proj",
"model.layers.23.self_attn.kv_a_layernorm",
"model.layers.23.self_attn.q_a_layernorm",
"model.layers.24.input_layernorm",
"model.layers.24.mlp.gate",
"model.layers.24.mlp.gate.e_score_correction_bias",
"model.layers.24.post_attention_layernorm",
"model.layers.24.self_attn.indexer.k_norm",
"model.layers.24.self_attn.indexer.k_norm.bias",
"model.layers.24.self_attn.indexers_proj",
"model.layers.24.self_attn.kv_a_layernorm",
"model.layers.24.self_attn.q_a_layernorm",
"model.layers.25.input_layernorm",
"model.layers.25.mlp.gate",
"model.layers.25.mlp.gate.e_score_correction_bias",
"model.layers.25.post_attention_layernorm",
"model.layers.25.self_attn.indexer.k_norm",
"model.layers.25.self_attn.indexer.k_norm.bias",
"model.layers.25.self_attn.indexers_proj",
"model.layers.25.self_attn.kv_a_layernorm",
"model.layers.25.self_attn.q_a_layernorm",
"model.layers.26.input_layernorm",
"model.layers.26.mlp.gate",
"model.layers.26.mlp.gate.e_score_correction_bias",
"model.layers.26.post_attention_layernorm",
"model.layers.26.self_attn.indexer.k_norm",
"model.layers.26.self_attn.indexer.k_norm.bias",
"model.layers.26.self_attn.indexers_proj",
"model.layers.26.self_attn.kv_a_layernorm",
"model.layers.26.self_attn.q_a_layernorm",
"model.layers.27.input_layernorm",
"model.layers.27.mlp.gate",
"model.layers.27.mlp.gate.e_score_correction_bias",
"model.layers.27.post_attention_layernorm",
"model.layers.27.self_attn.indexer.k_norm",
"model.layers.27.self_attn.indexer.k_norm.bias",
"model.layers.27.self_attn.indexers_proj",
"model.layers.27.self_attn.kv_a_layernorm",
"model.layers.27.self_attn.q_a_layernorm",
"model.layers.28.input_layernorm",
"model.layers.28.mlp.gate",
"model.layers.28.mlp.gate.e_score_correction_bias",
"model.layers.28.post_attention_layernorm",
"model.layers.28.self_attn.indexer.k_norm",
"model.layers.28.self_attn.indexer.k_norm.bias",
"model.layers.28.self_attn.indexers_proj",
"model.layers.28.self_attn.kv_a_layernorm",
"model.layers.28.self_attn.q_a_layernorm",
"model.layers.29.input_layernorm",
"model.layers.29.mlp.gate",
"model.layers.29.mlp.gate.e_score_correction_bias",
"model.layers.29.post_attention_layernorm",
"model.layers.29.self_attn.indexer.k_norm",
"model.layers.29.self_attn.indexer.k_norm.bias",
"model.layers.29.self_attn.indexers_proj",
"model.layers.29.self_attn.kv_a_layernorm",
"model.layers.29.self_attn.q_a_layernorm",
"model.layers.30.input_layernorm",
"model.layers.30.mlp.gate",
"model.layers.30.mlp.gate.e_score_correction_bias",
"model.layers.30.post_attention_layernorm",
"model.layers.30.self_attn.indexer.k_norm",
"model.layers.30.self_attn.indexer.k_norm.bias",
"model.layers.30.self_attn.indexers_proj",
"model.layers.30.self_attn.kv_a_layernorm",
"model.layers.30.self_attn.q_a_layernorm",
"model.layers.31.input_layernorm",
"model.layers.31.mlp.gate",
"model.layers.31.mlp.gate.e_score_correction_bias",
"model.layers.31.post_attention_layernorm",
"model.layers.31.self_attn.indexer.k_norm",
"model.layers.31.self_attn.indexer.k_norm.bias",
"model.layers.31.self_attn.indexers_proj",
"model.layers.31.self_attn.kv_a_layernorm",
"model.layers.31.self_attn.q_a_layernorm",
"model.layers.32.input_layernorm",
"model.layers.32.mlp.gate",
"model.layers.32.mlp.gate.e_score_correction_bias",
"model.layers.32.post_attention_layernorm",
"model.layers.32.self_attn.indexer.k_norm",
"model.layers.32.self_attn.indexer.k_norm.bias",
"model.layers.32.self_attn.indexers_proj",
"model.layers.32.self_attn.kv_a_layernorm",
"model.layers.32.self_attn.q_a_layernorm",
"model.layers.33.input_layernorm",
"model.layers.33.mlp.gate",
"model.layers.33.mlp.gate.e_score_correction_bias",
"model.layers.33.post_attention_layernorm",
"model.layers.33.self_attn.indexer.k_norm",
"model.layers.33.self_attn.indexer.k_norm.bias",
"model.layers.33.self_attn.indexers_proj",
"model.layers.33.self_attn.kv_a_layernorm",
"model.layers.33.self_attn.q_a_layernorm",
"model.layers.34.input_layernorm",
"model.layers.34.mlp.gate",
"model.layers.34.mlp.gate.e_score_correction_bias",
"model.layers.34.post_attention_layernorm",
"model.layers.34.self_attn.indexer.k_norm",
"model.layers.34.self_attn.indexer.k_norm.bias",
"model.layers.34.self_attn.indexers_proj",
"model.layers.34.self_attn.kv_a_layernorm",
"model.layers.34.self_attn.q_a_layernorm",
"model.layers.35.input_layernorm",
"model.layers.35.mlp.gate",
"model.layers.35.mlp.gate.e_score_correction_bias",
"model.layers.35.post_attention_layernorm",
"model.layers.35.self_attn.indexer.k_norm",
"model.layers.35.self_attn.indexer.k_norm.bias",
"model.layers.35.self_attn.indexers_proj",
"model.layers.35.self_attn.kv_a_layernorm",
"model.layers.35.self_attn.q_a_layernorm",
"model.layers.36.input_layernorm",
"model.layers.36.mlp.gate",
"model.layers.36.mlp.gate.e_score_correction_bias",
"model.layers.36.post_attention_layernorm",
"model.layers.36.self_attn.indexer.k_norm",
"model.layers.36.self_attn.indexer.k_norm.bias",
"model.layers.36.self_attn.indexers_proj",
"model.layers.36.self_attn.kv_a_layernorm",
"model.layers.36.self_attn.q_a_layernorm",
"model.layers.37.input_layernorm",
"model.layers.37.mlp.gate",
"model.layers.37.mlp.gate.e_score_correction_bias",
"model.layers.37.post_attention_layernorm",
"model.layers.37.self_attn.indexer.k_norm",
"model.layers.37.self_attn.indexer.k_norm.bias",
"model.layers.37.self_attn.indexers_proj",
"model.layers.37.self_attn.kv_a_layernorm",
"model.layers.37.self_attn.q_a_layernorm",
"model.layers.38.input_layernorm",
"model.layers.38.mlp.gate",
"model.layers.38.mlp.gate.e_score_correction_bias",
"model.layers.38.post_attention_layernorm",
"model.layers.38.self_attn.indexer.k_norm",
"model.layers.38.self_attn.indexer.k_norm.bias",
"model.layers.38.self_attn.indexers_proj",
"model.layers.38.self_attn.kv_a_layernorm",
"model.layers.38.self_attn.q_a_layernorm",
"model.layers.39.input_layernorm",
"model.layers.39.mlp.gate",
"model.layers.39.mlp.gate.e_score_correction_bias",
"model.layers.39.post_attention_layernorm",
"model.layers.39.self_attn.indexer.k_norm",
"model.layers.39.self_attn.indexer.k_norm.bias",
"model.layers.39.self_attn.indexers_proj",
"model.layers.39.self_attn.kv_a_layernorm",
"model.layers.39.self_attn.q_a_layernorm",
"model.layers.40.input_layernorm",
"model.layers.40.mlp.gate",
"model.layers.40.mlp.gate.e_score_correction_bias",
"model.layers.40.post_attention_layernorm",
"model.layers.40.self_attn.indexer.k_norm",
"model.layers.40.self_attn.indexer.k_norm.bias",
"model.layers.40.self_attn.indexers_proj",
"model.layers.40.self_attn.kv_a_layernorm",
"model.layers.40.self_attn.q_a_layernorm",
"model.layers.41.input_layernorm",
"model.layers.41.mlp.gate",
"model.layers.41.mlp.gate.e_score_correction_bias",
"model.layers.41.post_attention_layernorm",
"model.layers.41.self_attn.indexer.k_norm",
"model.layers.41.self_attn.indexer.k_norm.bias",
"model.layers.41.self_attn.indexers_proj",
"model.layers.41.self_attn.kv_a_layernorm",
"model.layers.41.self_attn.q_a_layernorm",
"model.layers.42.input_layernorm",
"model.layers.42.mlp.gate",
"model.layers.42.mlp.gate.e_score_correction_bias",
"model.layers.42.post_attention_layernorm",
"model.layers.42.self_attn.indexer.k_norm",
"model.layers.42.self_attn.indexer.k_norm.bias",
"model.layers.42.self_attn.indexers_proj",
"model.layers.42.self_attn.kv_a_layernorm",
"model.layers.42.self_attn.q_a_layernorm",
"model.layers.43.input_layernorm",
"model.layers.43.mlp.gate",
"model.layers.43.mlp.gate.e_score_correction_bias",
"model.layers.43.post_attention_layernorm",
"model.layers.43.self_attn.indexer.k_norm",
"model.layers.43.self_attn.indexer.k_norm.bias",
"model.layers.43.self_attn.indexers_proj",
"model.layers.43.self_attn.kv_a_layernorm",
"model.layers.43.self_attn.q_a_layernorm",
"model.layers.44.input_layernorm",
"model.layers.44.mlp.gate",
"model.layers.44.mlp.gate.e_score_correction_bias",
"model.layers.44.post_attention_layernorm",
"model.layers.44.self_attn.indexer.k_norm",
"model.layers.44.self_attn.indexer.k_norm.bias",
"model.layers.44.self_attn.indexers_proj",
"model.layers.44.self_attn.kv_a_layernorm",
"model.layers.44.self_attn.q_a_layernorm",
"model.layers.45.input_layernorm",
"model.layers.45.mlp.gate",
"model.layers.45.mlp.gate.e_score_correction_bias",
"model.layers.45.post_attention_layernorm",
"model.layers.45.self_attn.indexer.k_norm",
"model.layers.45.self_attn.indexer.k_norm.bias",
"model.layers.45.self_attn.indexers_proj",
"model.layers.45.self_attn.kv_a_layernorm",
"model.layers.45.self_attn.q_a_layernorm",
"model.layers.46.input_layernorm",
"model.layers.46.mlp.gate",
"model.layers.46.mlp.gate.e_score_correction_bias",
"model.layers.46.post_attention_layernorm",
"model.layers.46.self_attn.indexer.k_norm",
"model.layers.46.self_attn.indexer.k_norm.bias",
"model.layers.46.self_attn.indexers_proj",
"model.layers.46.self_attn.kv_a_layernorm",
"model.layers.46.self_attn.q_a_layernorm",
"model.layers.47.input_layernorm",
"model.layers.47.mlp.gate",
"model.layers.47.mlp.gate.e_score_correction_bias",
"model.layers.47.post_attention_layernorm",
"model.layers.47.self_attn.indexer.k_norm",
"model.layers.47.self_attn.indexer.k_norm.bias",
"model.layers.47.self_attn.indexers_proj",
"model.layers.47.self_attn.kv_a_layernorm",
"model.layers.47.self_attn.q_a_layernorm",
"model.layers.48.input_layernorm",
"model.layers.48.mlp.gate",
"model.layers.48.mlp.gate.e_score_correction_bias",
"model.layers.48.post_attention_layernorm",
"model.layers.48.self_attn.indexer.k_norm",
"model.layers.48.self_attn.indexer.k_norm.bias",
"model.layers.48.self_attn.indexers_proj",
"model.layers.48.self_attn.kv_a_layernorm",
"model.layers.48.self_attn.q_a_layernorm",
"model.layers.49.input_layernorm",
"model.layers.49.mlp.gate",
"model.layers.49.mlp.gate.e_score_correction_bias",
"model.layers.49.post_attention_layernorm",
"model.layers.49.self_attn.indexer.k_norm",
"model.layers.49.self_attn.indexer.k_norm.bias",
"model.layers.49.self_attn.indexers_proj",
"model.layers.49.self_attn.kv_a_layernorm",
"model.layers.49.self_attn.q_a_layernorm",
"model.layers.50.input_layernorm",
"model.layers.50.mlp.gate",
"model.layers.50.mlp.gate.e_score_correction_bias",
"model.layers.50.post_attention_layernorm",
"model.layers.50.self_attn.indexer.k_norm",
"model.layers.50.self_attn.indexer.k_norm.bias",
"model.layers.50.self_attn.indexers_proj",
"model.layers.50.self_attn.kv_a_layernorm",
"model.layers.50.self_attn.q_a_layernorm",
"model.layers.51.input_layernorm",
"model.layers.51.mlp.gate",
"model.layers.51.mlp.gate.e_score_correction_bias",
"model.layers.51.post_attention_layernorm",
"model.layers.51.self_attn.indexer.k_norm",
"model.layers.51.self_attn.indexer.k_norm.bias",
"model.layers.51.self_attn.indexers_proj",
"model.layers.51.self_attn.kv_a_layernorm",
"model.layers.51.self_attn.q_a_layernorm",
"model.layers.52.input_layernorm",
"model.layers.52.mlp.gate",
"model.layers.52.mlp.gate.e_score_correction_bias",
"model.layers.52.post_attention_layernorm",
"model.layers.52.self_attn.indexer.k_norm",
"model.layers.52.self_attn.indexer.k_norm.bias",
"model.layers.52.self_attn.indexers_proj",
"model.layers.52.self_attn.kv_a_layernorm",
"model.layers.52.self_attn.q_a_layernorm",
"model.layers.53.input_layernorm",
"model.layers.53.mlp.gate",
"model.layers.53.mlp.gate.e_score_correction_bias",
"model.layers.53.post_attention_layernorm",
"model.layers.53.self_attn.indexer.k_norm",
"model.layers.53.self_attn.indexer.k_norm.bias",
"model.layers.53.self_attn.indexers_proj",
"model.layers.53.self_attn.kv_a_layernorm",
"model.layers.53.self_attn.q_a_layernorm",
"model.layers.54.input_layernorm",
"model.layers.54.mlp.gate",
"model.layers.54.mlp.gate.e_score_correction_bias",
"model.layers.54.post_attention_layernorm",
"model.layers.54.self_attn.indexer.k_norm",
"model.layers.54.self_attn.indexer.k_norm.bias",
"model.layers.54.self_attn.indexers_proj",
"model.layers.54.self_attn.kv_a_layernorm",
"model.layers.54.self_attn.q_a_layernorm",
"model.layers.55.input_layernorm",
"model.layers.55.mlp.gate",
"model.layers.55.mlp.gate.e_score_correction_bias",
"model.layers.55.post_attention_layernorm",
"model.layers.55.self_attn.indexer.k_norm",
"model.layers.55.self_attn.indexer.k_norm.bias",
"model.layers.55.self_attn.indexers_proj",
"model.layers.55.self_attn.kv_a_layernorm",
"model.layers.55.self_attn.q_a_layernorm",
"model.layers.56.input_layernorm",
"model.layers.56.mlp.gate",
"model.layers.56.mlp.gate.e_score_correction_bias",
"model.layers.56.post_attention_layernorm",
"model.layers.56.self_attn.indexer.k_norm",
"model.layers.56.self_attn.indexer.k_norm.bias",
"model.layers.56.self_attn.indexers_proj",
"model.layers.56.self_attn.kv_a_layernorm",
"model.layers.56.self_attn.q_a_layernorm",
"model.layers.57.input_layernorm",
"model.layers.57.mlp.gate",
"model.layers.57.mlp.gate.e_score_correction_bias",
"model.layers.57.post_attention_layernorm",
"model.layers.57.self_attn.indexer.k_norm",
"model.layers.57.self_attn.indexer.k_norm.bias",
"model.layers.57.self_attn.indexers_proj",
"model.layers.57.self_attn.kv_a_layernorm",
"model.layers.57.self_attn.q_a_layernorm",
"model.layers.58.input_layernorm",
"model.layers.58.mlp.gate",
"model.layers.58.mlp.gate.e_score_correction_bias",
"model.layers.58.post_attention_layernorm",
"model.layers.58.self_attn.indexer.k_norm",
"model.layers.58.self_attn.indexer.k_norm.bias",
"model.layers.58.self_attn.indexers_proj",
"model.layers.58.self_attn.kv_a_layernorm",
"model.layers.58.self_attn.q_a_layernorm",
"model.layers.59.input_layernorm",
"model.layers.59.mlp.gate",
"model.layers.59.mlp.gate.e_score_correction_bias",
"model.layers.59.post_attention_layernorm",
"model.layers.59.self_attn.indexer.k_norm",
"model.layers.59.self_attn.indexer.k_norm.bias",
"model.layers.59.self_attn.indexers_proj",
"model.layers.59.self_attn.kv_a_layernorm",
"model.layers.59.self_attn.q_a_layernorm",
"model.layers.60.input_layernorm",
"model.layers.60.mlp.gate",
"model.layers.60.mlp.gate.e_score_correction_bias",
"model.layers.60.post_attention_layernorm",
"model.layers.60.self_attn.indexer.k_norm",
"model.layers.60.self_attn.indexer.k_norm.bias",
"model.layers.60.self_attn.indexers_proj",
"model.layers.60.self_attn.kv_a_layernorm",
"model.layers.60.self_attn.q_a_layernorm",
"model.layers.61.input_layernorm",
"model.layers.61.mlp.gate",
"model.layers.61.mlp.gate.e_score_correction_bias",
"model.layers.61.post_attention_layernorm",
"model.layers.61.self_attn.indexer.k_norm",
"model.layers.61.self_attn.indexer.k_norm.bias",
"model.layers.61.self_attn.indexers_proj",
"model.layers.61.self_attn.kv_a_layernorm",
"model.layers.61.self_attn.q_a_layernorm",
"model.layers.62.input_layernorm",
"model.layers.62.mlp.gate",
"model.layers.62.mlp.gate.e_score_correction_bias",
"model.layers.62.post_attention_layernorm",
"model.layers.62.self_attn.indexer.k_norm",
"model.layers.62.self_attn.indexer.k_norm.bias",
"model.layers.62.self_attn.indexers_proj",
"model.layers.62.self_attn.kv_a_layernorm",
"model.layers.62.self_attn.q_a_layernorm",
"model.layers.63.input_layernorm",
"model.layers.63.mlp.gate",
"model.layers.63.mlp.gate.e_score_correction_bias",
"model.layers.63.post_attention_layernorm",
"model.layers.63.self_attn.indexer.k_norm",
"model.layers.63.self_attn.indexer.k_norm.bias",
"model.layers.63.self_attn.indexers_proj",
"model.layers.63.self_attn.kv_a_layernorm",
"model.layers.63.self_attn.q_a_layernorm",
"model.layers.64.input_layernorm",
"model.layers.64.mlp.gate",
"model.layers.64.mlp.gate.e_score_correction_bias",
"model.layers.64.post_attention_layernorm",
"model.layers.64.self_attn.indexer.k_norm",
"model.layers.64.self_attn.indexer.k_norm.bias",
"model.layers.64.self_attn.indexers_proj",
"model.layers.64.self_attn.kv_a_layernorm",
"model.layers.64.self_attn.q_a_layernorm",
"model.layers.65.input_layernorm",
"model.layers.65.mlp.gate",
"model.layers.65.mlp.gate.e_score_correction_bias",
"model.layers.65.post_attention_layernorm",
"model.layers.65.self_attn.indexer.k_norm",
"model.layers.65.self_attn.indexer.k_norm.bias",
"model.layers.65.self_attn.indexers_proj",
"model.layers.65.self_attn.kv_a_layernorm",
"model.layers.65.self_attn.q_a_layernorm",
"model.layers.66.input_layernorm",
"model.layers.66.mlp.gate",
"model.layers.66.mlp.gate.e_score_correction_bias",
"model.layers.66.post_attention_layernorm",
"model.layers.66.self_attn.indexer.k_norm",
"model.layers.66.self_attn.indexer.k_norm.bias",
"model.layers.66.self_attn.indexers_proj",
"model.layers.66.self_attn.kv_a_layernorm",
"model.layers.66.self_attn.q_a_layernorm",
"model.layers.67.input_layernorm",
"model.layers.67.mlp.gate",
"model.layers.67.mlp.gate.e_score_correction_bias",
"model.layers.67.post_attention_layernorm",
"model.layers.67.self_attn.indexer.k_norm",
"model.layers.67.self_attn.indexer.k_norm.bias",
"model.layers.67.self_attn.indexers_proj",
"model.layers.67.self_attn.kv_a_layernorm",
"model.layers.67.self_attn.q_a_layernorm",
"model.layers.68.input_layernorm",
"model.layers.68.mlp.gate",
"model.layers.68.mlp.gate.e_score_correction_bias",
"model.layers.68.post_attention_layernorm",
"model.layers.68.self_attn.indexer.k_norm",
"model.layers.68.self_attn.indexer.k_norm.bias",
"model.layers.68.self_attn.indexers_proj",
"model.layers.68.self_attn.kv_a_layernorm",
"model.layers.68.self_attn.q_a_layernorm",
"model.layers.69.input_layernorm",
"model.layers.69.mlp.gate",
"model.layers.69.mlp.gate.e_score_correction_bias",
"model.layers.69.post_attention_layernorm",
"model.layers.69.self_attn.indexer.k_norm",
"model.layers.69.self_attn.indexer.k_norm.bias",
"model.layers.69.self_attn.indexers_proj",
"model.layers.69.self_attn.kv_a_layernorm",
"model.layers.69.self_attn.q_a_layernorm",
"model.layers.70.input_layernorm",
"model.layers.70.mlp.gate",
"model.layers.70.mlp.gate.e_score_correction_bias",
"model.layers.70.post_attention_layernorm",
"model.layers.70.self_attn.indexer.k_norm",
"model.layers.70.self_attn.indexer.k_norm.bias",
"model.layers.70.self_attn.indexers_proj",
"model.layers.70.self_attn.kv_a_layernorm",
"model.layers.70.self_attn.q_a_layernorm",
"model.layers.71.input_layernorm",
"model.layers.71.mlp.gate",
"model.layers.71.mlp.gate.e_score_correction_bias",
"model.layers.71.post_attention_layernorm",
"model.layers.71.self_attn.indexer.k_norm",
"model.layers.71.self_attn.indexer.k_norm.bias",
"model.layers.71.self_attn.indexers_proj",
"model.layers.71.self_attn.kv_a_layernorm",
"model.layers.71.self_attn.q_a_layernorm",
"model.layers.72.input_layernorm",
"model.layers.72.mlp.gate",
"model.layers.72.mlp.gate.e_score_correction_bias",
"model.layers.72.post_attention_layernorm",
"model.layers.72.self_attn.indexer.k_norm",
"model.layers.72.self_attn.indexer.k_norm.bias",
"model.layers.72.self_attn.indexers_proj",
"model.layers.72.self_attn.kv_a_layernorm",
"model.layers.72.self_attn.q_a_layernorm",
"model.layers.73.input_layernorm",
"model.layers.73.mlp.gate",
"model.layers.73.mlp.gate.e_score_correction_bias",
"model.layers.73.post_attention_layernorm",
"model.layers.73.self_attn.indexer.k_norm",
"model.layers.73.self_attn.indexer.k_norm.bias",
"model.layers.73.self_attn.indexers_proj",
"model.layers.73.self_attn.kv_a_layernorm",
"model.layers.73.self_attn.q_a_layernorm",
"model.layers.74.input_layernorm",
"model.layers.74.mlp.gate",
"model.layers.74.mlp.gate.e_score_correction_bias",
"model.layers.74.post_attention_layernorm",
"model.layers.74.self_attn.indexer.k_norm",
"model.layers.74.self_attn.indexer.k_norm.bias",
"model.layers.74.self_attn.indexers_proj",
"model.layers.74.self_attn.kv_a_layernorm",
"model.layers.74.self_attn.q_a_layernorm",
"model.layers.75.input_layernorm",
"model.layers.75.mlp.gate",
"model.layers.75.mlp.gate.e_score_correction_bias",
"model.layers.75.post_attention_layernorm",
"model.layers.75.self_attn.indexer.k_norm",
"model.layers.75.self_attn.indexer.k_norm.bias",
"model.layers.75.self_attn.indexers_proj",
"model.layers.75.self_attn.kv_a_layernorm",
"model.layers.75.self_attn.q_a_layernorm",
"model.layers.76.input_layernorm",
"model.layers.76.mlp.gate",
"model.layers.76.mlp.gate.e_score_correction_bias",
"model.layers.76.post_attention_layernorm",
"model.layers.76.self_attn.indexer.k_norm",
"model.layers.76.self_attn.indexer.k_norm.bias",
"model.layers.76.self_attn.indexers_proj",
"model.layers.76.self_attn.kv_a_layernorm",
"model.layers.76.self_attn.q_a_layernorm",
"model.layers.77.input_layernorm",
"model.layers.77.mlp.gate",
"model.layers.77.mlp.gate.e_score_correction_bias",
"model.layers.77.post_attention_layernorm",
"model.layers.77.self_attn.indexer.k_norm",
"model.layers.77.self_attn.indexer.k_norm.bias",
"model.layers.77.self_attn.indexers_proj",
"model.layers.77.self_attn.kv_a_layernorm",
"model.layers.77.self_attn.q_a_layernorm",
"model.layers.78.eh_proj",
"model.layers.78.enorm",
"model.layers.78.hnorm",
"model.layers.78.input_layernorm",
"model.layers.78.mlp.gate",
"model.layers.78.mlp.gate.e_score_correction_bias",
"model.layers.78.post_attention_layernorm",
"model.layers.78.self_attn.indexer.k_norm",
"model.layers.78.self_attn.indexer.k_norm.bias",
"model.layers.78.self_attn.indexers_proj",
"model.layers.78.self_attn.kv_a_layernorm",
"model.layers.78.self_attn.q_a_layernorm",
"model.layers.78.shared_head.norm",
"model.norm"
]
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,33 @@
{
"backend": "tokenizers",
"clean_up_tokenization_spaces": false,
"do_lower_case": false,
"eos_token": "<|endoftext|>",
"extra_special_tokens": [
"<|endoftext|>",
"[MASK]",
"[gMASK]",
"[sMASK]",
"<sop>",
"<eop>",
"<|system|>",
"<|user|>",
"<|assistant|>",
"<|observation|>",
"<|begin_of_image|>",
"<|end_of_image|>",
"<|begin_of_video|>",
"<|end_of_video|>",
"<|begin_of_audio|>",
"<|end_of_audio|>",
"<|begin_of_transcription|>",
"<|end_of_transcription|>"
],
"is_local": true,
"model_max_length": 202752,
"model_specific_special_tokens": {},
"pad_token": "<|endoftext|>",
"padding_side": "left",
"remove_space": false,
"tokenizer_class": "TokenizersBackend"
}

View File

@@ -0,0 +1,23 @@
from .registry import (
ModelMeta,
detect_model_family_from_features,
detect_model_family_from_records,
detect_model_family_from_trace_file,
get_model_meta,
infer_model_family_from_request_model,
resolve_chat_template_path,
resolve_model_family,
resolve_tokenizer_path,
)
__all__ = [
"ModelMeta",
"detect_model_family_from_features",
"detect_model_family_from_records",
"detect_model_family_from_trace_file",
"get_model_meta",
"infer_model_family_from_request_model",
"resolve_chat_template_path",
"resolve_model_family",
"resolve_tokenizer_path",
]

View File

@@ -0,0 +1,201 @@
from __future__ import annotations
import csv
import json
from dataclasses import dataclass
from pathlib import Path
MODEL_META_ROOT = Path(__file__).resolve().parent
@dataclass(frozen=True)
class ModelMeta:
family: str
provider: str
model_name: str
request_model_hints: tuple[str, ...]
@property
def model_dir(self) -> Path:
return MODEL_META_ROOT / self.provider / self.model_name
@property
def tokenizer_path(self) -> Path:
return self.model_dir / "tokenizer.json"
@property
def chat_template_path(self) -> Path:
return self.model_dir / "chat_template.jinja"
MODEL_REGISTRY = {
"glm5": ModelMeta(
family="glm5",
provider="ZhipuAI",
model_name="GLM-5-FP8",
request_model_hints=("glm", "zhipu"),
),
"qwen3-coder": ModelMeta(
family="qwen3-coder",
provider="Qwen",
model_name="Qwen3-Coder-480B-A35B-Instruct",
request_model_hints=("qwen3-coder", "qwen3 coder", "qwen3_coder"),
),
}
MODEL_ALIASES = {
"glm": "glm5",
"glm5": "glm5",
"zhipu-glm5": "glm5",
"zhipuai-glm5": "glm5",
"qwen": "qwen3-coder",
"qwen3": "qwen3-coder",
"qwen3-coder": "qwen3-coder",
"qwen3_coder": "qwen3-coder",
"coder": "qwen3-coder",
}
def infer_model_family_from_request_model(request_model: str | None) -> str | None:
text = str(request_model or "").strip().lower()
if not text:
return None
for family, meta in MODEL_REGISTRY.items():
if any(hint in text for hint in meta.request_model_hints):
return family
return None
def _infer_model_family_from_path(input_path: str | Path | None) -> str | None:
text = str(input_path or "").strip().lower()
if not text:
return None
if "qwen3-coder" in text or "qwen3_coder" in text:
return "qwen3-coder"
if "glm5" in text or "trace-glm" in text:
return "glm5"
return None
def detect_model_family_from_trace_file(path: str | Path) -> str | None:
resolved = Path(path)
with resolved.open("r", encoding="utf-8") as handle:
for line in handle:
stripped = line.strip()
if not stripped:
continue
raw = json.loads(stripped)
if isinstance(raw.get("meta"), dict):
meta = raw["meta"]
family = str(meta.get("model_family", "")).strip()
if family:
return resolve_model_family(family)
inferred = infer_model_family_from_request_model(meta.get("request_model"))
if inferred:
return inferred
inferred = infer_model_family_from_request_model(raw.get("request_model"))
if inferred:
return inferred
break
return _infer_model_family_from_path(path)
def detect_model_family_from_features(path: str | Path) -> str | None:
resolved = Path(path)
with resolved.open("r", encoding="utf-8") as handle:
reader = csv.DictReader(handle)
for row in reader:
inferred = infer_model_family_from_request_model(row.get("model"))
if inferred:
return inferred
break
return _infer_model_family_from_path(path)
def detect_model_family_from_records(records) -> str | None:
for record in records:
inferred = infer_model_family_from_request_model(record.meta.request_model)
if inferred:
return inferred
break
return None
def resolve_model_family(
model_family: str | None = None,
*,
request_model: str | None = None,
input_path: str | Path | None = None,
features_path: str | Path | None = None,
records=None,
) -> str:
candidate = str(model_family or "auto").strip().lower()
if candidate and candidate != "auto":
if candidate in MODEL_ALIASES:
return MODEL_ALIASES[candidate]
raise ValueError(f"Unsupported model family: {model_family}")
inferred = infer_model_family_from_request_model(request_model)
if inferred:
return inferred
if records is not None:
inferred = detect_model_family_from_records(records)
if inferred:
return inferred
if features_path is not None:
inferred = detect_model_family_from_features(features_path)
if inferred:
return inferred
if input_path is not None:
inferred = detect_model_family_from_trace_file(input_path)
if inferred:
return inferred
return "glm5"
def get_model_meta(model_family: str | None = None, *, model_meta_dir: str | Path | None = None, **kwargs) -> ModelMeta:
family = resolve_model_family(model_family, **kwargs)
base_meta = MODEL_REGISTRY[family]
if model_meta_dir is None:
return base_meta
custom_root = Path(model_meta_dir)
custom_model_dir = custom_root / base_meta.provider / base_meta.model_name
if not custom_model_dir.exists():
raise FileNotFoundError(f"Model meta directory not found for {family}: {custom_model_dir}")
return ModelMeta(
family=base_meta.family,
provider=base_meta.provider,
model_name=base_meta.model_name,
request_model_hints=base_meta.request_model_hints,
)
def resolve_chat_template_path(
model_family: str | None = None,
*,
model_meta_dir: str | Path | None = None,
**kwargs,
) -> Path:
family = resolve_model_family(model_family, **kwargs)
meta = MODEL_REGISTRY[family]
model_dir = Path(model_meta_dir) / meta.provider / meta.model_name if model_meta_dir else meta.model_dir
return model_dir / "chat_template.jinja"
def resolve_tokenizer_path(
tokenizer_path: str | Path | None = None,
*,
model_family: str | None = None,
model_meta_dir: str | Path | None = None,
**kwargs,
) -> str:
if tokenizer_path:
path = Path(tokenizer_path)
return str(path.parent if path.is_file() else path)
family = resolve_model_family(model_family, **kwargs)
meta = MODEL_REGISTRY[family]
model_dir = Path(model_meta_dir) / meta.provider / meta.model_name if model_meta_dir else meta.model_dir
return str(model_dir)

1180
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff