Files
replaysim/tools/analyze_vllm_prefix_log.py

99 lines
3.4 KiB
Python

#!/usr/bin/env python3
"""Summarize vLLM scheduler prefix-cache `computed:` log lines."""
from __future__ import annotations
import argparse
import json
import re
from pathlib import Path
from typing import Any
START_RE = re.compile(r"Request (\d+) started running, prompt: (\d+), computed: (\d+)")
PREEMPT_RE = re.compile(r"Request (\d+) preempted")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Parse vLLM scheduler logs and report observed computed-token "
"prefix-cache behavior. Repeated starts indicate preemption or "
"re-admission, so all-start sums are not equivalent to per-request "
"prefix hits."
)
)
parser.add_argument("stdout_log", type=Path)
parser.add_argument("--summary-json", type=Path)
return parser.parse_args()
def load_estimated_hit_tokens(path: Path | None) -> int | None:
if path is None:
return None
summary = json.loads(path.read_text(encoding="utf-8"))
reuse = summary.get("estimated_prefix_reuse", {})
hit_tokens = reuse.get("hit_tokens")
return int(hit_tokens) if hit_tokens is not None else None
def main() -> int:
args = parse_args()
text = args.stdout_log.read_text(encoding="utf-8", errors="replace")
by_request: dict[int, list[dict[str, int]]] = {}
for match in START_RE.finditer(text):
request_id = int(match.group(1))
by_request.setdefault(request_id, []).append(
{
"prompt_tokens": int(match.group(2)),
"computed_tokens": int(match.group(3)),
}
)
preempted_request_ids = [int(match.group(1)) for match in PREEMPT_RE.finditer(text)]
repeated = {
str(request_id): starts
for request_id, starts in sorted(by_request.items())
if len(starts) > 1
}
all_computed = sum(
start["computed_tokens"]
for starts in by_request.values()
for start in starts
)
first_computed = sum(starts[0]["computed_tokens"] for starts in by_request.values())
last_computed = sum(starts[-1]["computed_tokens"] for starts in by_request.values())
max_computed = sum(max(start["computed_tokens"] for start in starts) for starts in by_request.values())
estimated_hit_tokens = load_estimated_hit_tokens(args.summary_json)
result: dict[str, Any] = {
"stdout_log": str(args.stdout_log),
"starts_total": sum(len(starts) for starts in by_request.values()),
"unique_requests": len(by_request),
"preemptions": len(preempted_request_ids),
"preempted_request_ids": preempted_request_ids,
"repeated_request_ids": sorted(int(request_id) for request_id in repeated),
"computed_tokens": {
"all_starts": all_computed,
"first_start_per_request": first_computed,
"last_start_per_request": last_computed,
"max_per_request": max_computed,
},
"repeated_starts": repeated,
}
if estimated_hit_tokens is not None:
result["estimated_prefix_hit_tokens"] = estimated_hit_tokens
result["matches_estimate"] = {
name: value == estimated_hit_tokens
for name, value in result["computed_tokens"].items()
}
print(json.dumps(result, indent=2, sort_keys=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())