Files
replaysim/tools/validate_fixtures.py

241 lines
9.4 KiB
Python
Executable File

#!/usr/bin/env python3
"""Validate ReplayServe fixture directories."""
from __future__ import annotations
import argparse
import csv
import json
import math
import sys
from pathlib import Path
from typing import Any
def positive_int(value: str) -> int:
parsed = int(value)
if parsed <= 0:
raise argparse.ArgumentTypeError("must be positive")
return parsed
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Validate ReplayServe fixtures.")
parser.add_argument("fixture_dirs", nargs="+", type=Path)
parser.add_argument("--max-tokens", type=positive_int, default=32768)
parser.add_argument("--block-size", type=positive_int, default=16)
return parser.parse_args()
def parse_block_hash_ids(value: str) -> list[int]:
stripped = value.strip()
if not stripped:
return []
return [int(part) for part in stripped.split("|") if part]
def expected_block_counts(input_length: int, block_size: int) -> list[int]:
hash_count = math.ceil(input_length / block_size)
if hash_count == 0:
return []
last_count = input_length % block_size
if last_count == 0:
last_count = block_size
return [block_size] * (hash_count - 1) + [last_count]
def load_jsonl(path: Path) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
with path.open("r", encoding="utf-8") as handle:
for line_number, line in enumerate(handle, start=1):
stripped = line.strip()
if not stripped:
continue
try:
row = json.loads(stripped)
except json.JSONDecodeError as exc:
raise ValueError(f"{path}: line {line_number}: invalid JSON") from exc
if not isinstance(row, dict):
raise ValueError(f"{path}: line {line_number}: JSON value must be object")
rows.append(row)
return rows
def load_csv(path: Path) -> list[dict[str, str]]:
with path.open("r", encoding="utf-8", newline="") as handle:
reader = csv.DictReader(handle)
required = {
"arrived_at",
"num_prefill_tokens",
"num_decode_tokens",
"session_id",
"block_hash_ids",
}
missing = required - set(reader.fieldnames or [])
if missing:
raise ValueError(f"{path}: missing CSV columns: {sorted(missing)}")
return list(reader)
def require_paths(fixture_dir: Path) -> tuple[Path, Path, Path, Path]:
source_path = fixture_dir / "source.jsonl"
csv_path = fixture_dir / "frontier.csv"
sidecar_path = fixture_dir / "sidecar.jsonl"
manifest_path = fixture_dir / "manifest.json"
for path in (source_path, csv_path, sidecar_path, manifest_path):
if not path.exists():
raise ValueError(f"{fixture_dir}: missing {path.name}")
return source_path, csv_path, sidecar_path, manifest_path
def validate_fixture(fixture_dir: Path, block_size: int, max_tokens: int) -> str:
source_path, csv_path, sidecar_path, manifest_path = require_paths(fixture_dir)
source_rows = load_jsonl(source_path)
csv_rows = load_csv(csv_path)
sidecar_rows = load_jsonl(sidecar_path)
with manifest_path.open("r", encoding="utf-8") as handle:
manifest = json.load(handle)
row_count = len(csv_rows)
if len(source_rows) != row_count or len(sidecar_rows) != row_count:
raise ValueError(
f"{fixture_dir}: row count mismatch source={len(source_rows)} "
f"csv={row_count} sidecar={len(sidecar_rows)}"
)
if manifest.get("row_count") != row_count:
raise ValueError(
f"{fixture_dir}: manifest row_count={manifest.get('row_count')} "
f"does not match csv rows={row_count}"
)
if manifest.get("block_size") != block_size:
raise ValueError(
f"{fixture_dir}: manifest block_size={manifest.get('block_size')} "
f"does not match expected {block_size}"
)
if manifest.get("max_tokens") != max_tokens:
raise ValueError(
f"{fixture_dir}: manifest max_tokens={manifest.get('max_tokens')} "
f"does not match expected {max_tokens}"
)
previous_timestamp: float | None = None
max_total_tokens = 0
partial_final_block_rows = 0
for index, (source, csv_row, sidecar) in enumerate(
zip(source_rows, csv_rows, sidecar_rows)
):
prefix = f"{fixture_dir}: row {index}"
input_length = int(csv_row["num_prefill_tokens"])
output_length = int(csv_row["num_decode_tokens"])
total_tokens = input_length + output_length
if total_tokens > max_tokens:
raise ValueError(
f"{prefix}: total_tokens={total_tokens} exceeds max_tokens={max_tokens}"
)
max_total_tokens = max(max_total_tokens, total_tokens)
timestamp = float(csv_row["arrived_at"])
if previous_timestamp is not None and timestamp < previous_timestamp:
raise ValueError(f"{prefix}: timestamp is not monotonic")
previous_timestamp = timestamp
hash_ids = parse_block_hash_ids(csv_row["block_hash_ids"])
expected_hash_count = math.ceil(input_length / block_size)
if len(hash_ids) != expected_hash_count:
raise ValueError(
f"{prefix}: hash count {len(hash_ids)} != {expected_hash_count}"
)
counts = expected_block_counts(input_length, block_size)
if sum(counts) != input_length:
raise ValueError(f"{prefix}: expected block counts do not sum to input")
partial_final_block_rows += int(input_length % block_size != 0)
if int(csv_row["session_id"]) != int(source["chat_id"]):
raise ValueError(f"{prefix}: session_id does not match source chat_id")
if timestamp != float(source["timestamp"]):
raise ValueError(f"{prefix}: arrived_at does not match source timestamp")
if input_length != int(source["input_length"]):
raise ValueError(f"{prefix}: num_prefill_tokens does not match source")
if output_length != int(source["output_length"]):
raise ValueError(f"{prefix}: num_decode_tokens does not match source")
if hash_ids != source["hash_ids"]:
raise ValueError(f"{prefix}: block_hash_ids do not match source hash_ids")
required_sidecar_keys = {
"request_id",
"chat_id",
"parent_chat_id",
"turn",
"type",
"timestamp",
"input_length",
"output_length",
"hash_ids",
"block_token_counts",
}
missing = required_sidecar_keys - set(sidecar)
if missing:
raise ValueError(f"{prefix}: missing sidecar keys {sorted(missing)}")
if int(sidecar["request_id"]) != index:
raise ValueError(f"{prefix}: sidecar request_id mismatch")
if int(sidecar["chat_id"]) != int(source["chat_id"]):
raise ValueError(f"{prefix}: sidecar chat_id mismatch")
if int(sidecar["parent_chat_id"]) != int(source["parent_chat_id"]):
raise ValueError(f"{prefix}: sidecar parent_chat_id mismatch")
if int(sidecar["turn"]) != int(source["turn"]):
raise ValueError(f"{prefix}: sidecar turn mismatch")
if sidecar["type"] != source["type"]:
raise ValueError(f"{prefix}: sidecar type mismatch")
if float(sidecar["timestamp"]) != float(source["timestamp"]):
raise ValueError(f"{prefix}: sidecar timestamp mismatch")
if int(sidecar["input_length"]) != input_length:
raise ValueError(f"{prefix}: sidecar input_length mismatch")
if int(sidecar["output_length"]) != output_length:
raise ValueError(f"{prefix}: sidecar output_length mismatch")
if sidecar["hash_ids"] != hash_ids:
raise ValueError(f"{prefix}: sidecar hash_ids mismatch")
if sidecar["block_token_counts"] != counts:
raise ValueError(f"{prefix}: sidecar block_token_counts mismatch")
if manifest.get("max_total_tokens") != max_total_tokens:
raise ValueError(
f"{fixture_dir}: manifest max_total_tokens="
f"{manifest.get('max_total_tokens')} does not match {max_total_tokens}"
)
if manifest.get("partial_final_block_rows") != partial_final_block_rows:
raise ValueError(
f"{fixture_dir}: manifest partial_final_block_rows="
f"{manifest.get('partial_final_block_rows')} does not match "
f"{partial_final_block_rows}"
)
if manifest.get("overflow_count") != 0:
raise ValueError(f"{fixture_dir}: manifest overflow_count is not zero")
if manifest.get("timestamp_monotonic") is not True:
raise ValueError(f"{fixture_dir}: manifest timestamp_monotonic is not true")
return (
f"{fixture_dir.name}: rows={row_count} max_total_tokens={max_total_tokens} "
f"partial_final_block_rows={partial_final_block_rows}"
)
def main() -> int:
args = parse_args()
try:
for fixture_dir in args.fixture_dirs:
print(
validate_fixture(
fixture_dir=fixture_dir,
block_size=args.block_size,
max_tokens=args.max_tokens,
)
)
except Exception as exc:
print(f"validate_fixtures.py: error: {exc}", file=sys.stderr)
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())