406 lines
16 KiB
Python
406 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""Replay a ReplayServe fixture on vLLM with synthetic prompt token blocks."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import csv
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import random
|
|
import statistics
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
def positive_int(value: str) -> int:
|
|
parsed = int(value)
|
|
if parsed <= 0:
|
|
raise argparse.ArgumentTypeError("must be positive")
|
|
return parsed
|
|
|
|
|
|
def positive_float(value: str) -> float:
|
|
parsed = float(value)
|
|
if parsed <= 0:
|
|
raise argparse.ArgumentTypeError("must be positive")
|
|
return parsed
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description=(
|
|
"Run an online vLLM smoke/replay using synthetic prompt_token_ids "
|
|
"derived from ReplayServe block hashes."
|
|
)
|
|
)
|
|
parser.add_argument("--fixture-dir", required=True, type=Path)
|
|
parser.add_argument("--model", required=True, type=str)
|
|
parser.add_argument("--output-dir", required=True, type=Path)
|
|
parser.add_argument("--tensor-parallel-size", type=positive_int, default=1)
|
|
parser.add_argument("--limit", type=positive_int)
|
|
parser.add_argument("--block-size", type=positive_int, default=16)
|
|
parser.add_argument("--max-model-len", type=positive_int, default=32768)
|
|
parser.add_argument("--max-num-seqs", type=positive_int, default=128)
|
|
parser.add_argument("--max-num-batched-tokens", type=positive_int, default=32768)
|
|
parser.add_argument("--gpu-memory-utilization", type=positive_float, default=0.9)
|
|
parser.add_argument("--time-scale", type=positive_float, default=1.0)
|
|
parser.add_argument(
|
|
"--max-output-tokens",
|
|
type=positive_int,
|
|
help="Cap each row's output_length for smoke tests.",
|
|
)
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
parser.add_argument("--dtype", default="auto")
|
|
parser.add_argument("--enforce-eager", action="store_true")
|
|
parser.add_argument("--trust-remote-code", action=argparse.BooleanOptionalAction, default=True)
|
|
parser.add_argument("--enable-prefix-caching", action=argparse.BooleanOptionalAction, default=True)
|
|
parser.add_argument("--enable-chunked-prefill", action=argparse.BooleanOptionalAction, default=True)
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_jsonl(path: Path) -> list[dict[str, Any]]:
|
|
rows: list[dict[str, Any]] = []
|
|
with path.open("r", encoding="utf-8") as handle:
|
|
for line_number, line in enumerate(handle, start=1):
|
|
stripped = line.strip()
|
|
if not stripped:
|
|
continue
|
|
row = json.loads(stripped)
|
|
if not isinstance(row, dict):
|
|
raise ValueError(f"{path}: line {line_number}: expected object")
|
|
rows.append(row)
|
|
return rows
|
|
|
|
|
|
def percentile(values: list[float], pct: float) -> float | None:
|
|
if not values:
|
|
return None
|
|
ordered = sorted(values)
|
|
index = min(len(ordered) - 1, max(0, int((len(ordered) - 1) * pct)))
|
|
return ordered[index]
|
|
|
|
|
|
def block_seed(hash_id: int, seed: int) -> int:
|
|
digest = hashlib.blake2b(
|
|
f"{seed}:{hash_id}".encode("utf-8"), digest_size=8
|
|
).digest()
|
|
return int.from_bytes(digest, "big")
|
|
|
|
|
|
def block_tokens(
|
|
hash_id: int,
|
|
*,
|
|
seed: int,
|
|
block_size: int,
|
|
vocab_size: int,
|
|
special_ids: set[int],
|
|
) -> list[int]:
|
|
rng = random.Random(block_seed(hash_id, seed))
|
|
low = 1000
|
|
high = max(low + 1, vocab_size - 1000)
|
|
tokens: list[int] = []
|
|
while len(tokens) < block_size:
|
|
token_id = rng.randrange(low, high)
|
|
if token_id not in special_ids:
|
|
tokens.append(token_id)
|
|
return tokens
|
|
|
|
|
|
def make_prompt_token_ids(
|
|
row: dict[str, Any],
|
|
*,
|
|
seed: int,
|
|
block_size: int,
|
|
vocab_size: int,
|
|
special_ids: set[int],
|
|
) -> list[int]:
|
|
hash_ids = [int(value) for value in row["hash_ids"]]
|
|
counts = [int(value) for value in row["block_token_counts"]]
|
|
if len(hash_ids) != len(counts):
|
|
raise ValueError(f"request {row.get('request_id')}: hash/count length mismatch")
|
|
|
|
token_ids: list[int] = []
|
|
for hash_id, count in zip(hash_ids, counts):
|
|
token_ids.extend(
|
|
block_tokens(
|
|
hash_id,
|
|
seed=seed,
|
|
block_size=block_size,
|
|
vocab_size=vocab_size,
|
|
special_ids=special_ids,
|
|
)[:count]
|
|
)
|
|
expected = int(row["input_length"])
|
|
if len(token_ids) != expected:
|
|
raise ValueError(
|
|
f"request {row.get('request_id')}: synthetic prompt length "
|
|
f"{len(token_ids)} != input_length {expected}"
|
|
)
|
|
return token_ids
|
|
|
|
|
|
def estimate_prefix_reuse(rows: list[dict[str, Any]]) -> dict[int, dict[str, int | float]]:
|
|
trie: dict[int, dict[Any, Any]] = {}
|
|
estimates: dict[int, dict[str, int | float]] = {}
|
|
for row in rows:
|
|
request_id = int(row["request_id"])
|
|
hash_ids = [int(value) for value in row["hash_ids"]]
|
|
counts = [int(value) for value in row["block_token_counts"]]
|
|
|
|
node = trie
|
|
hit_blocks = 0
|
|
for hash_id in hash_ids:
|
|
if hash_id not in node:
|
|
break
|
|
hit_blocks += 1
|
|
node = node[hash_id]
|
|
|
|
node = trie
|
|
for hash_id in hash_ids:
|
|
node = node.setdefault(hash_id, {})
|
|
|
|
query_tokens = int(row["input_length"])
|
|
hit_tokens = sum(counts[:hit_blocks])
|
|
estimates[request_id] = {
|
|
"query_blocks": len(hash_ids),
|
|
"hit_blocks": hit_blocks,
|
|
"query_tokens": query_tokens,
|
|
"hit_tokens": hit_tokens,
|
|
"block_hit_ratio": hit_blocks / len(hash_ids) if hash_ids else 0.0,
|
|
"token_hit_ratio": hit_tokens / query_tokens if query_tokens else 0.0,
|
|
}
|
|
return estimates
|
|
|
|
|
|
async def run_replay(args: argparse.Namespace) -> dict[str, Any]:
|
|
try:
|
|
from transformers import AutoTokenizer
|
|
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
|
|
from vllm.inputs import TokensPrompt
|
|
except Exception as exc: # pragma: no cover - exercised on GPU host.
|
|
raise RuntimeError(f"failed to import vLLM runtime dependencies: {exc}") from exc
|
|
|
|
sidecar_path = args.fixture_dir / "sidecar.jsonl"
|
|
rows = load_jsonl(sidecar_path)
|
|
if args.limit is not None:
|
|
rows = rows[: args.limit]
|
|
if not rows:
|
|
raise ValueError("no rows selected")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
|
|
special_ids = {int(value) for value in tokenizer.all_special_ids}
|
|
vocab_size = len(tokenizer)
|
|
synthetic_prompts = {
|
|
int(row["request_id"]): make_prompt_token_ids(
|
|
row,
|
|
seed=args.seed,
|
|
block_size=args.block_size,
|
|
vocab_size=vocab_size,
|
|
special_ids=special_ids,
|
|
)
|
|
for row in rows
|
|
}
|
|
prefix_reuse = estimate_prefix_reuse(rows)
|
|
|
|
engine_args = AsyncEngineArgs(
|
|
model=args.model,
|
|
tokenizer=args.model,
|
|
trust_remote_code=args.trust_remote_code,
|
|
tensor_parallel_size=args.tensor_parallel_size,
|
|
dtype=args.dtype,
|
|
max_model_len=args.max_model_len,
|
|
block_size=args.block_size,
|
|
enable_prefix_caching=args.enable_prefix_caching,
|
|
enable_chunked_prefill=args.enable_chunked_prefill,
|
|
max_num_seqs=args.max_num_seqs,
|
|
max_num_batched_tokens=args.max_num_batched_tokens,
|
|
gpu_memory_utilization=args.gpu_memory_utilization,
|
|
enforce_eager=args.enforce_eager,
|
|
disable_log_stats=True,
|
|
)
|
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
output_rows: list[dict[str, Any]] = []
|
|
first_timestamp = float(rows[0]["timestamp"])
|
|
replay_start = time.perf_counter()
|
|
|
|
async def run_one(row: dict[str, Any]) -> None:
|
|
request_id = int(row["request_id"])
|
|
scheduled_arrival_s = (float(row["timestamp"]) - first_timestamp) * args.time_scale
|
|
await asyncio.sleep(max(0.0, replay_start + scheduled_arrival_s - time.perf_counter()))
|
|
|
|
prompt_token_ids = synthetic_prompts[request_id]
|
|
requested_output_tokens = int(row["output_length"])
|
|
effective_output_tokens = requested_output_tokens
|
|
if args.max_output_tokens is not None:
|
|
effective_output_tokens = min(effective_output_tokens, args.max_output_tokens)
|
|
|
|
sampling_params = SamplingParams(
|
|
temperature=0.0,
|
|
max_tokens=effective_output_tokens,
|
|
min_tokens=effective_output_tokens,
|
|
ignore_eos=True,
|
|
detokenize=False,
|
|
seed=args.seed + request_id,
|
|
)
|
|
arrival_wall = time.perf_counter()
|
|
first_token_wall: float | None = None
|
|
last_output_tokens = 0
|
|
final_output: Any = None
|
|
generator = engine.generate(
|
|
TokensPrompt(prompt_token_ids=prompt_token_ids),
|
|
sampling_params,
|
|
request_id=str(request_id),
|
|
)
|
|
async for output in generator:
|
|
final_output = output
|
|
if output.outputs:
|
|
token_count = len(output.outputs[0].token_ids)
|
|
if token_count > 0 and first_token_wall is None:
|
|
first_token_wall = time.perf_counter()
|
|
last_output_tokens = token_count
|
|
done_wall = time.perf_counter()
|
|
|
|
finish_reason = ""
|
|
if final_output is not None and final_output.outputs:
|
|
finish_reason = str(final_output.outputs[0].finish_reason)
|
|
ttft_s = None if first_token_wall is None else first_token_wall - arrival_wall
|
|
e2e_s = done_wall - arrival_wall
|
|
tpot_s = None
|
|
if first_token_wall is not None and last_output_tokens > 1:
|
|
tpot_s = (done_wall - first_token_wall) / (last_output_tokens - 1)
|
|
reuse = prefix_reuse[request_id]
|
|
output_rows.append(
|
|
{
|
|
"request_id": request_id,
|
|
"scheduled_arrival_s": scheduled_arrival_s,
|
|
"arrival_delay_s": arrival_wall - replay_start - scheduled_arrival_s,
|
|
"input_length": int(row["input_length"]),
|
|
"requested_output_length": requested_output_tokens,
|
|
"effective_output_length": effective_output_tokens,
|
|
"generated_output_tokens": last_output_tokens,
|
|
"ttft_s": ttft_s,
|
|
"tpot_s": tpot_s,
|
|
"e2e_s": e2e_s,
|
|
"finish_reason": finish_reason,
|
|
"prefix_query_blocks_est": reuse["query_blocks"],
|
|
"prefix_hit_blocks_est": reuse["hit_blocks"],
|
|
"prefix_query_tokens_est": reuse["query_tokens"],
|
|
"prefix_hit_tokens_est": reuse["hit_tokens"],
|
|
"prefix_block_hit_ratio_est": reuse["block_hit_ratio"],
|
|
"prefix_token_hit_ratio_est": reuse["token_hit_ratio"],
|
|
}
|
|
)
|
|
|
|
try:
|
|
await asyncio.gather(*(run_one(row) for row in rows))
|
|
finally:
|
|
engine.shutdown()
|
|
|
|
replay_end = time.perf_counter()
|
|
output_rows.sort(key=lambda item: int(item["request_id"]))
|
|
args.output_dir.mkdir(parents=True, exist_ok=True)
|
|
request_metrics_path = args.output_dir / "request_metrics.csv"
|
|
fieldnames = list(output_rows[0].keys())
|
|
with request_metrics_path.open("w", encoding="utf-8", newline="") as handle:
|
|
writer = csv.DictWriter(handle, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
writer.writerows(output_rows)
|
|
|
|
ttft_values = [float(row["ttft_s"]) for row in output_rows if row["ttft_s"] is not None]
|
|
tpot_values = [float(row["tpot_s"]) for row in output_rows if row["tpot_s"] is not None]
|
|
e2e_values = [float(row["e2e_s"]) for row in output_rows]
|
|
generated_tokens = sum(int(row["generated_output_tokens"]) for row in output_rows)
|
|
prompt_tokens = sum(int(row["input_length"]) for row in output_rows)
|
|
wall_s = replay_end - replay_start
|
|
summary = {
|
|
"status": "pass",
|
|
"fixture_dir": str(args.fixture_dir),
|
|
"model": args.model,
|
|
"tensor_parallel_size": args.tensor_parallel_size,
|
|
"cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES", ""),
|
|
"rows": len(output_rows),
|
|
"block_size": args.block_size,
|
|
"max_model_len": args.max_model_len,
|
|
"max_num_seqs": args.max_num_seqs,
|
|
"max_num_batched_tokens": args.max_num_batched_tokens,
|
|
"gpu_memory_utilization": args.gpu_memory_utilization,
|
|
"enable_prefix_caching": args.enable_prefix_caching,
|
|
"enable_chunked_prefill": args.enable_chunked_prefill,
|
|
"time_scale": args.time_scale,
|
|
"max_output_tokens": args.max_output_tokens,
|
|
"synthetic_replay": {
|
|
"semantics": (
|
|
"Each trace block hash is deterministically mapped to a stable "
|
|
"block of prompt token ids; equal hashes reuse equal token blocks. "
|
|
"This preserves arrival, length, and block-prefix sharing patterns, "
|
|
"but it is not original text/token recovery."
|
|
),
|
|
"seed": args.seed,
|
|
"vocab_size": vocab_size,
|
|
"special_token_ids_excluded": sorted(special_ids),
|
|
},
|
|
"wall_time_s": wall_s,
|
|
"requests_per_second": len(output_rows) / wall_s if wall_s else 0.0,
|
|
"prompt_tokens_per_second": prompt_tokens / wall_s if wall_s else 0.0,
|
|
"generated_tokens_per_second": generated_tokens / wall_s if wall_s else 0.0,
|
|
"total_prompt_tokens": prompt_tokens,
|
|
"total_generated_tokens": generated_tokens,
|
|
"ttft_s": {
|
|
"mean": statistics.fmean(ttft_values) if ttft_values else None,
|
|
"p50": percentile(ttft_values, 0.50),
|
|
"p95": percentile(ttft_values, 0.95),
|
|
},
|
|
"tpot_s": {
|
|
"mean": statistics.fmean(tpot_values) if tpot_values else None,
|
|
"p50": percentile(tpot_values, 0.50),
|
|
"p95": percentile(tpot_values, 0.95),
|
|
},
|
|
"e2e_s": {
|
|
"mean": statistics.fmean(e2e_values) if e2e_values else None,
|
|
"p50": percentile(e2e_values, 0.50),
|
|
"p95": percentile(e2e_values, 0.95),
|
|
},
|
|
"estimated_prefix_reuse": {
|
|
"query_blocks": sum(int(row["prefix_query_blocks_est"]) for row in output_rows),
|
|
"hit_blocks": sum(int(row["prefix_hit_blocks_est"]) for row in output_rows),
|
|
"query_tokens": sum(int(row["prefix_query_tokens_est"]) for row in output_rows),
|
|
"hit_tokens": sum(int(row["prefix_hit_tokens_est"]) for row in output_rows),
|
|
},
|
|
"request_metrics_csv": str(request_metrics_path),
|
|
}
|
|
reuse = summary["estimated_prefix_reuse"]
|
|
summary["estimated_prefix_reuse"]["block_hit_ratio"] = (
|
|
reuse["hit_blocks"] / reuse["query_blocks"] if reuse["query_blocks"] else 0.0
|
|
)
|
|
summary["estimated_prefix_reuse"]["token_hit_ratio"] = (
|
|
reuse["hit_tokens"] / reuse["query_tokens"] if reuse["query_tokens"] else 0.0
|
|
)
|
|
with (args.output_dir / "summary.json").open("w", encoding="utf-8") as handle:
|
|
json.dump(summary, handle, indent=2, sort_keys=True)
|
|
handle.write("\n")
|
|
return summary
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
try:
|
|
summary = asyncio.run(run_replay(args))
|
|
except Exception as exc:
|
|
args.output_dir.mkdir(parents=True, exist_ok=True)
|
|
with (args.output_dir / "summary.json").open("w", encoding="utf-8") as handle:
|
|
json.dump({"status": "fail", "error": str(exc)}, handle, indent=2)
|
|
handle.write("\n")
|
|
print(f"vllm_synthetic_replay.py: error: {exc}", file=sys.stderr)
|
|
return 1
|
|
print(json.dumps(summary, indent=2, sort_keys=True))
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|