Files
replaysim/tools/analyze_trace_window.py

189 lines
7.7 KiB
Python
Executable File

#!/usr/bin/env python3
"""Analyze Qwen/ReplayServe sidecar rows around a request id."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Analyze sidecar prefix overlap.")
parser.add_argument("--fixture-dir", required=True, type=Path)
parser.add_argument("--request-id", required=True, type=int)
parser.add_argument("--window", type=int, default=10)
parser.add_argument("--top-k", type=int, default=15)
parser.add_argument("--output-dir", required=True, type=Path)
return parser.parse_args()
def load_jsonl(path: Path) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
with path.open("r", encoding="utf-8") as handle:
for line_number, line in enumerate(handle, start=1):
stripped = line.strip()
if not stripped:
continue
row = json.loads(stripped)
if not isinstance(row, dict):
raise ValueError(f"{path}: line {line_number}: expected object")
rows.append(row)
return rows
def common_prefix_len(left: list[int], right: list[int]) -> int:
count = 0
for left_item, right_item in zip(left, right):
if left_item != right_item:
break
count += 1
return count
def summarize_row(row: dict[str, Any], block_size: int = 16) -> dict[str, Any]:
input_length = int(row["input_length"])
output_length = int(row["output_length"])
hash_ids = [int(value) for value in row["hash_ids"]]
block_token_counts = [int(value) for value in row["block_token_counts"]]
return {
"request_id": int(row["request_id"]),
"chat_id": int(row["chat_id"]),
"parent_chat_id": int(row["parent_chat_id"]),
"turn": int(row["turn"]),
"type": row["type"],
"timestamp": float(row["timestamp"]),
"input_length": input_length,
"output_length": output_length,
"total_tokens": input_length + output_length,
"hash_count": len(hash_ids),
"first_hash_ids": hash_ids[:12],
"last_hash_id": hash_ids[-1] if hash_ids else None,
"partial_final_block": input_length % block_size != 0,
"final_block_token_count": block_token_counts[-1] if block_token_counts else 0,
}
def main() -> int:
args = parse_args()
sidecar_path = args.fixture_dir / "sidecar.jsonl"
rows = load_jsonl(sidecar_path)
by_id = {int(row["request_id"]): row for row in rows}
if args.request_id not in by_id:
raise SystemExit(f"request_id {args.request_id} not found in {sidecar_path}")
target = by_id[args.request_id]
target_hashes = [int(value) for value in target["hash_ids"]]
target_counts = [int(value) for value in target["block_token_counts"]]
overlaps: list[dict[str, Any]] = []
for row in rows:
request_id = int(row["request_id"])
if request_id >= args.request_id:
continue
lcp_blocks = common_prefix_len(target_hashes, [int(value) for value in row["hash_ids"]])
if lcp_blocks <= 0:
continue
overlaps.append(
{
**summarize_row(row),
"common_prefix_blocks_with_target": lcp_blocks,
"common_prefix_tokens_with_target": sum(target_counts[:lcp_blocks]),
"target_prefix_fraction_blocks": (
lcp_blocks / len(target_hashes) if target_hashes else 0.0
),
"target_prefix_fraction_tokens": (
sum(target_counts[:lcp_blocks]) / int(target["input_length"])
if int(target["input_length"]) > 0
else 0.0
),
}
)
overlaps.sort(
key=lambda item: (
item["common_prefix_blocks_with_target"],
item["request_id"],
),
reverse=True,
)
start = max(0, args.request_id - args.window)
end = min(len(rows), args.request_id + args.window + 1)
local_window = [summarize_row(row) for row in rows[start:end]]
parent_chat_id = int(target["parent_chat_id"])
parent_rows = [
summarize_row(row)
for row in rows
if int(row["chat_id"]) == parent_chat_id or int(row["request_id"]) == parent_chat_id
]
result = {
"fixture_dir": str(args.fixture_dir),
"sidecar": str(sidecar_path),
"request_id": args.request_id,
"target": summarize_row(target),
"local_window": local_window,
"top_prior_prefix_overlaps": overlaps[: args.top_k],
"prior_overlap_count": len(overlaps),
"parent_candidates": parent_rows,
"interpretation": {
"prefix_overlap_semantics": (
"Frontier prefix cache matches consecutive block_hash_ids from "
"the start of the prompt. common_prefix_tokens_with_target uses "
"the target sidecar block_token_counts, preserving partial final "
"block token counts."
),
"partial_final_block_related": bool(int(target["input_length"]) % 16 != 0),
},
}
args.output_dir.mkdir(parents=True, exist_ok=True)
json_path = args.output_dir / f"request_{args.request_id}_analysis.json"
md_path = args.output_dir / f"request_{args.request_id}_analysis.md"
with json_path.open("w", encoding="utf-8") as handle:
json.dump(result, handle, indent=2, sort_keys=True)
handle.write("\n")
with md_path.open("w", encoding="utf-8") as handle:
target_summary = result["target"]
handle.write(f"# Request {args.request_id} Trace Analysis\n\n")
handle.write(f"- Fixture: `{args.fixture_dir}`\n")
handle.write(f"- Timestamp: `{target_summary['timestamp']}`\n")
handle.write(f"- Chat: `{target_summary['chat_id']}` parent `{target_summary['parent_chat_id']}` turn `{target_summary['turn']}`\n")
handle.write(f"- Input/output/total tokens: `{target_summary['input_length']}` / `{target_summary['output_length']}` / `{target_summary['total_tokens']}`\n")
handle.write(f"- Hash blocks: `{target_summary['hash_count']}`\n")
handle.write(f"- Partial final block: `{target_summary['partial_final_block']}` final count `{target_summary['final_block_token_count']}`\n")
handle.write("\n## Top Prior Prefix Overlaps\n\n")
if not overlaps:
handle.write("No prior request shares a first block with the target.\n")
else:
handle.write("| prior request | timestamp | input | output | lcp blocks | lcp tokens | partial final |\n")
handle.write("|---:|---:|---:|---:|---:|---:|---|\n")
for item in overlaps[: args.top_k]:
handle.write(
f"| {item['request_id']} | {item['timestamp']} | "
f"{item['input_length']} | {item['output_length']} | "
f"{item['common_prefix_blocks_with_target']} | "
f"{item['common_prefix_tokens_with_target']} | "
f"{item['partial_final_block']} |\n"
)
handle.write("\n## Local Window\n\n")
handle.write("| request | timestamp | input | output | blocks | partial final | first hashes |\n")
handle.write("|---:|---:|---:|---:|---:|---|---|\n")
for item in local_window:
handle.write(
f"| {item['request_id']} | {item['timestamp']} | "
f"{item['input_length']} | {item['output_length']} | "
f"{item['hash_count']} | {item['partial_final_block']} | "
f"`{item['first_hash_ids']}` |\n"
)
print(json_path)
print(md_path)
return 0
if __name__ == "__main__":
raise SystemExit(main())