Add ReplayServe Frontier vLLM alignment report
This commit is contained in:
405
tools/vllm_synthetic_replay.py
Normal file
405
tools/vllm_synthetic_replay.py
Normal file
@@ -0,0 +1,405 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Replay a ReplayServe fixture on vLLM with synthetic prompt token blocks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import csv
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def positive_int(value: str) -> int:
|
||||
parsed = int(value)
|
||||
if parsed <= 0:
|
||||
raise argparse.ArgumentTypeError("must be positive")
|
||||
return parsed
|
||||
|
||||
|
||||
def positive_float(value: str) -> float:
|
||||
parsed = float(value)
|
||||
if parsed <= 0:
|
||||
raise argparse.ArgumentTypeError("must be positive")
|
||||
return parsed
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Run an online vLLM smoke/replay using synthetic prompt_token_ids "
|
||||
"derived from ReplayServe block hashes."
|
||||
)
|
||||
)
|
||||
parser.add_argument("--fixture-dir", required=True, type=Path)
|
||||
parser.add_argument("--model", required=True, type=str)
|
||||
parser.add_argument("--output-dir", required=True, type=Path)
|
||||
parser.add_argument("--tensor-parallel-size", type=positive_int, default=1)
|
||||
parser.add_argument("--limit", type=positive_int)
|
||||
parser.add_argument("--block-size", type=positive_int, default=16)
|
||||
parser.add_argument("--max-model-len", type=positive_int, default=32768)
|
||||
parser.add_argument("--max-num-seqs", type=positive_int, default=128)
|
||||
parser.add_argument("--max-num-batched-tokens", type=positive_int, default=32768)
|
||||
parser.add_argument("--gpu-memory-utilization", type=positive_float, default=0.9)
|
||||
parser.add_argument("--time-scale", type=positive_float, default=1.0)
|
||||
parser.add_argument(
|
||||
"--max-output-tokens",
|
||||
type=positive_int,
|
||||
help="Cap each row's output_length for smoke tests.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--dtype", default="auto")
|
||||
parser.add_argument("--enforce-eager", action="store_true")
|
||||
parser.add_argument("--trust-remote-code", action=argparse.BooleanOptionalAction, default=True)
|
||||
parser.add_argument("--enable-prefix-caching", action=argparse.BooleanOptionalAction, default=True)
|
||||
parser.add_argument("--enable-chunked-prefill", action=argparse.BooleanOptionalAction, default=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_jsonl(path: Path) -> list[dict[str, Any]]:
|
||||
rows: list[dict[str, Any]] = []
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line_number, line in enumerate(handle, start=1):
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
row = json.loads(stripped)
|
||||
if not isinstance(row, dict):
|
||||
raise ValueError(f"{path}: line {line_number}: expected object")
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
|
||||
def percentile(values: list[float], pct: float) -> float | None:
|
||||
if not values:
|
||||
return None
|
||||
ordered = sorted(values)
|
||||
index = min(len(ordered) - 1, max(0, int((len(ordered) - 1) * pct)))
|
||||
return ordered[index]
|
||||
|
||||
|
||||
def block_seed(hash_id: int, seed: int) -> int:
|
||||
digest = hashlib.blake2b(
|
||||
f"{seed}:{hash_id}".encode("utf-8"), digest_size=8
|
||||
).digest()
|
||||
return int.from_bytes(digest, "big")
|
||||
|
||||
|
||||
def block_tokens(
|
||||
hash_id: int,
|
||||
*,
|
||||
seed: int,
|
||||
block_size: int,
|
||||
vocab_size: int,
|
||||
special_ids: set[int],
|
||||
) -> list[int]:
|
||||
rng = random.Random(block_seed(hash_id, seed))
|
||||
low = 1000
|
||||
high = max(low + 1, vocab_size - 1000)
|
||||
tokens: list[int] = []
|
||||
while len(tokens) < block_size:
|
||||
token_id = rng.randrange(low, high)
|
||||
if token_id not in special_ids:
|
||||
tokens.append(token_id)
|
||||
return tokens
|
||||
|
||||
|
||||
def make_prompt_token_ids(
|
||||
row: dict[str, Any],
|
||||
*,
|
||||
seed: int,
|
||||
block_size: int,
|
||||
vocab_size: int,
|
||||
special_ids: set[int],
|
||||
) -> list[int]:
|
||||
hash_ids = [int(value) for value in row["hash_ids"]]
|
||||
counts = [int(value) for value in row["block_token_counts"]]
|
||||
if len(hash_ids) != len(counts):
|
||||
raise ValueError(f"request {row.get('request_id')}: hash/count length mismatch")
|
||||
|
||||
token_ids: list[int] = []
|
||||
for hash_id, count in zip(hash_ids, counts):
|
||||
token_ids.extend(
|
||||
block_tokens(
|
||||
hash_id,
|
||||
seed=seed,
|
||||
block_size=block_size,
|
||||
vocab_size=vocab_size,
|
||||
special_ids=special_ids,
|
||||
)[:count]
|
||||
)
|
||||
expected = int(row["input_length"])
|
||||
if len(token_ids) != expected:
|
||||
raise ValueError(
|
||||
f"request {row.get('request_id')}: synthetic prompt length "
|
||||
f"{len(token_ids)} != input_length {expected}"
|
||||
)
|
||||
return token_ids
|
||||
|
||||
|
||||
def estimate_prefix_reuse(rows: list[dict[str, Any]]) -> dict[int, dict[str, int | float]]:
|
||||
trie: dict[int, dict[Any, Any]] = {}
|
||||
estimates: dict[int, dict[str, int | float]] = {}
|
||||
for row in rows:
|
||||
request_id = int(row["request_id"])
|
||||
hash_ids = [int(value) for value in row["hash_ids"]]
|
||||
counts = [int(value) for value in row["block_token_counts"]]
|
||||
|
||||
node = trie
|
||||
hit_blocks = 0
|
||||
for hash_id in hash_ids:
|
||||
if hash_id not in node:
|
||||
break
|
||||
hit_blocks += 1
|
||||
node = node[hash_id]
|
||||
|
||||
node = trie
|
||||
for hash_id in hash_ids:
|
||||
node = node.setdefault(hash_id, {})
|
||||
|
||||
query_tokens = int(row["input_length"])
|
||||
hit_tokens = sum(counts[:hit_blocks])
|
||||
estimates[request_id] = {
|
||||
"query_blocks": len(hash_ids),
|
||||
"hit_blocks": hit_blocks,
|
||||
"query_tokens": query_tokens,
|
||||
"hit_tokens": hit_tokens,
|
||||
"block_hit_ratio": hit_blocks / len(hash_ids) if hash_ids else 0.0,
|
||||
"token_hit_ratio": hit_tokens / query_tokens if query_tokens else 0.0,
|
||||
}
|
||||
return estimates
|
||||
|
||||
|
||||
async def run_replay(args: argparse.Namespace) -> dict[str, Any]:
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
|
||||
from vllm.inputs import TokensPrompt
|
||||
except Exception as exc: # pragma: no cover - exercised on GPU host.
|
||||
raise RuntimeError(f"failed to import vLLM runtime dependencies: {exc}") from exc
|
||||
|
||||
sidecar_path = args.fixture_dir / "sidecar.jsonl"
|
||||
rows = load_jsonl(sidecar_path)
|
||||
if args.limit is not None:
|
||||
rows = rows[: args.limit]
|
||||
if not rows:
|
||||
raise ValueError("no rows selected")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
|
||||
special_ids = {int(value) for value in tokenizer.all_special_ids}
|
||||
vocab_size = len(tokenizer)
|
||||
synthetic_prompts = {
|
||||
int(row["request_id"]): make_prompt_token_ids(
|
||||
row,
|
||||
seed=args.seed,
|
||||
block_size=args.block_size,
|
||||
vocab_size=vocab_size,
|
||||
special_ids=special_ids,
|
||||
)
|
||||
for row in rows
|
||||
}
|
||||
prefix_reuse = estimate_prefix_reuse(rows)
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=args.model,
|
||||
tokenizer=args.model,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
dtype=args.dtype,
|
||||
max_model_len=args.max_model_len,
|
||||
block_size=args.block_size,
|
||||
enable_prefix_caching=args.enable_prefix_caching,
|
||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||
max_num_seqs=args.max_num_seqs,
|
||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
enforce_eager=args.enforce_eager,
|
||||
disable_log_stats=True,
|
||||
)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
output_rows: list[dict[str, Any]] = []
|
||||
first_timestamp = float(rows[0]["timestamp"])
|
||||
replay_start = time.perf_counter()
|
||||
|
||||
async def run_one(row: dict[str, Any]) -> None:
|
||||
request_id = int(row["request_id"])
|
||||
scheduled_arrival_s = (float(row["timestamp"]) - first_timestamp) * args.time_scale
|
||||
await asyncio.sleep(max(0.0, replay_start + scheduled_arrival_s - time.perf_counter()))
|
||||
|
||||
prompt_token_ids = synthetic_prompts[request_id]
|
||||
requested_output_tokens = int(row["output_length"])
|
||||
effective_output_tokens = requested_output_tokens
|
||||
if args.max_output_tokens is not None:
|
||||
effective_output_tokens = min(effective_output_tokens, args.max_output_tokens)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=effective_output_tokens,
|
||||
min_tokens=effective_output_tokens,
|
||||
ignore_eos=True,
|
||||
detokenize=False,
|
||||
seed=args.seed + request_id,
|
||||
)
|
||||
arrival_wall = time.perf_counter()
|
||||
first_token_wall: float | None = None
|
||||
last_output_tokens = 0
|
||||
final_output: Any = None
|
||||
generator = engine.generate(
|
||||
TokensPrompt(prompt_token_ids=prompt_token_ids),
|
||||
sampling_params,
|
||||
request_id=str(request_id),
|
||||
)
|
||||
async for output in generator:
|
||||
final_output = output
|
||||
if output.outputs:
|
||||
token_count = len(output.outputs[0].token_ids)
|
||||
if token_count > 0 and first_token_wall is None:
|
||||
first_token_wall = time.perf_counter()
|
||||
last_output_tokens = token_count
|
||||
done_wall = time.perf_counter()
|
||||
|
||||
finish_reason = ""
|
||||
if final_output is not None and final_output.outputs:
|
||||
finish_reason = str(final_output.outputs[0].finish_reason)
|
||||
ttft_s = None if first_token_wall is None else first_token_wall - arrival_wall
|
||||
e2e_s = done_wall - arrival_wall
|
||||
tpot_s = None
|
||||
if first_token_wall is not None and last_output_tokens > 1:
|
||||
tpot_s = (done_wall - first_token_wall) / (last_output_tokens - 1)
|
||||
reuse = prefix_reuse[request_id]
|
||||
output_rows.append(
|
||||
{
|
||||
"request_id": request_id,
|
||||
"scheduled_arrival_s": scheduled_arrival_s,
|
||||
"arrival_delay_s": arrival_wall - replay_start - scheduled_arrival_s,
|
||||
"input_length": int(row["input_length"]),
|
||||
"requested_output_length": requested_output_tokens,
|
||||
"effective_output_length": effective_output_tokens,
|
||||
"generated_output_tokens": last_output_tokens,
|
||||
"ttft_s": ttft_s,
|
||||
"tpot_s": tpot_s,
|
||||
"e2e_s": e2e_s,
|
||||
"finish_reason": finish_reason,
|
||||
"prefix_query_blocks_est": reuse["query_blocks"],
|
||||
"prefix_hit_blocks_est": reuse["hit_blocks"],
|
||||
"prefix_query_tokens_est": reuse["query_tokens"],
|
||||
"prefix_hit_tokens_est": reuse["hit_tokens"],
|
||||
"prefix_block_hit_ratio_est": reuse["block_hit_ratio"],
|
||||
"prefix_token_hit_ratio_est": reuse["token_hit_ratio"],
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.gather(*(run_one(row) for row in rows))
|
||||
finally:
|
||||
engine.shutdown()
|
||||
|
||||
replay_end = time.perf_counter()
|
||||
output_rows.sort(key=lambda item: int(item["request_id"]))
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
request_metrics_path = args.output_dir / "request_metrics.csv"
|
||||
fieldnames = list(output_rows[0].keys())
|
||||
with request_metrics_path.open("w", encoding="utf-8", newline="") as handle:
|
||||
writer = csv.DictWriter(handle, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(output_rows)
|
||||
|
||||
ttft_values = [float(row["ttft_s"]) for row in output_rows if row["ttft_s"] is not None]
|
||||
tpot_values = [float(row["tpot_s"]) for row in output_rows if row["tpot_s"] is not None]
|
||||
e2e_values = [float(row["e2e_s"]) for row in output_rows]
|
||||
generated_tokens = sum(int(row["generated_output_tokens"]) for row in output_rows)
|
||||
prompt_tokens = sum(int(row["input_length"]) for row in output_rows)
|
||||
wall_s = replay_end - replay_start
|
||||
summary = {
|
||||
"status": "pass",
|
||||
"fixture_dir": str(args.fixture_dir),
|
||||
"model": args.model,
|
||||
"tensor_parallel_size": args.tensor_parallel_size,
|
||||
"cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES", ""),
|
||||
"rows": len(output_rows),
|
||||
"block_size": args.block_size,
|
||||
"max_model_len": args.max_model_len,
|
||||
"max_num_seqs": args.max_num_seqs,
|
||||
"max_num_batched_tokens": args.max_num_batched_tokens,
|
||||
"gpu_memory_utilization": args.gpu_memory_utilization,
|
||||
"enable_prefix_caching": args.enable_prefix_caching,
|
||||
"enable_chunked_prefill": args.enable_chunked_prefill,
|
||||
"time_scale": args.time_scale,
|
||||
"max_output_tokens": args.max_output_tokens,
|
||||
"synthetic_replay": {
|
||||
"semantics": (
|
||||
"Each trace block hash is deterministically mapped to a stable "
|
||||
"block of prompt token ids; equal hashes reuse equal token blocks. "
|
||||
"This preserves arrival, length, and block-prefix sharing patterns, "
|
||||
"but it is not original text/token recovery."
|
||||
),
|
||||
"seed": args.seed,
|
||||
"vocab_size": vocab_size,
|
||||
"special_token_ids_excluded": sorted(special_ids),
|
||||
},
|
||||
"wall_time_s": wall_s,
|
||||
"requests_per_second": len(output_rows) / wall_s if wall_s else 0.0,
|
||||
"prompt_tokens_per_second": prompt_tokens / wall_s if wall_s else 0.0,
|
||||
"generated_tokens_per_second": generated_tokens / wall_s if wall_s else 0.0,
|
||||
"total_prompt_tokens": prompt_tokens,
|
||||
"total_generated_tokens": generated_tokens,
|
||||
"ttft_s": {
|
||||
"mean": statistics.fmean(ttft_values) if ttft_values else None,
|
||||
"p50": percentile(ttft_values, 0.50),
|
||||
"p95": percentile(ttft_values, 0.95),
|
||||
},
|
||||
"tpot_s": {
|
||||
"mean": statistics.fmean(tpot_values) if tpot_values else None,
|
||||
"p50": percentile(tpot_values, 0.50),
|
||||
"p95": percentile(tpot_values, 0.95),
|
||||
},
|
||||
"e2e_s": {
|
||||
"mean": statistics.fmean(e2e_values) if e2e_values else None,
|
||||
"p50": percentile(e2e_values, 0.50),
|
||||
"p95": percentile(e2e_values, 0.95),
|
||||
},
|
||||
"estimated_prefix_reuse": {
|
||||
"query_blocks": sum(int(row["prefix_query_blocks_est"]) for row in output_rows),
|
||||
"hit_blocks": sum(int(row["prefix_hit_blocks_est"]) for row in output_rows),
|
||||
"query_tokens": sum(int(row["prefix_query_tokens_est"]) for row in output_rows),
|
||||
"hit_tokens": sum(int(row["prefix_hit_tokens_est"]) for row in output_rows),
|
||||
},
|
||||
"request_metrics_csv": str(request_metrics_path),
|
||||
}
|
||||
reuse = summary["estimated_prefix_reuse"]
|
||||
summary["estimated_prefix_reuse"]["block_hit_ratio"] = (
|
||||
reuse["hit_blocks"] / reuse["query_blocks"] if reuse["query_blocks"] else 0.0
|
||||
)
|
||||
summary["estimated_prefix_reuse"]["token_hit_ratio"] = (
|
||||
reuse["hit_tokens"] / reuse["query_tokens"] if reuse["query_tokens"] else 0.0
|
||||
)
|
||||
with (args.output_dir / "summary.json").open("w", encoding="utf-8") as handle:
|
||||
json.dump(summary, handle, indent=2, sort_keys=True)
|
||||
handle.write("\n")
|
||||
return summary
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
try:
|
||||
summary = asyncio.run(run_replay(args))
|
||||
except Exception as exc:
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
with (args.output_dir / "summary.json").open("w", encoding="utf-8") as handle:
|
||||
json.dump({"status": "fail", "error": str(exc)}, handle, indent=2)
|
||||
handle.write("\n")
|
||||
print(f"vllm_synthetic_replay.py: error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
print(json.dumps(summary, indent=2, sort_keys=True))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user