241 lines
9.4 KiB
Python
Executable File
241 lines
9.4 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Validate ReplayServe fixture directories."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
import math
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
def positive_int(value: str) -> int:
|
|
parsed = int(value)
|
|
if parsed <= 0:
|
|
raise argparse.ArgumentTypeError("must be positive")
|
|
return parsed
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Validate ReplayServe fixtures.")
|
|
parser.add_argument("fixture_dirs", nargs="+", type=Path)
|
|
parser.add_argument("--max-tokens", type=positive_int, default=32768)
|
|
parser.add_argument("--block-size", type=positive_int, default=16)
|
|
return parser.parse_args()
|
|
|
|
|
|
def parse_block_hash_ids(value: str) -> list[int]:
|
|
stripped = value.strip()
|
|
if not stripped:
|
|
return []
|
|
return [int(part) for part in stripped.split("|") if part]
|
|
|
|
|
|
def expected_block_counts(input_length: int, block_size: int) -> list[int]:
|
|
hash_count = math.ceil(input_length / block_size)
|
|
if hash_count == 0:
|
|
return []
|
|
last_count = input_length % block_size
|
|
if last_count == 0:
|
|
last_count = block_size
|
|
return [block_size] * (hash_count - 1) + [last_count]
|
|
|
|
|
|
def load_jsonl(path: Path) -> list[dict[str, Any]]:
|
|
rows: list[dict[str, Any]] = []
|
|
with path.open("r", encoding="utf-8") as handle:
|
|
for line_number, line in enumerate(handle, start=1):
|
|
stripped = line.strip()
|
|
if not stripped:
|
|
continue
|
|
try:
|
|
row = json.loads(stripped)
|
|
except json.JSONDecodeError as exc:
|
|
raise ValueError(f"{path}: line {line_number}: invalid JSON") from exc
|
|
if not isinstance(row, dict):
|
|
raise ValueError(f"{path}: line {line_number}: JSON value must be object")
|
|
rows.append(row)
|
|
return rows
|
|
|
|
|
|
def load_csv(path: Path) -> list[dict[str, str]]:
|
|
with path.open("r", encoding="utf-8", newline="") as handle:
|
|
reader = csv.DictReader(handle)
|
|
required = {
|
|
"arrived_at",
|
|
"num_prefill_tokens",
|
|
"num_decode_tokens",
|
|
"session_id",
|
|
"block_hash_ids",
|
|
}
|
|
missing = required - set(reader.fieldnames or [])
|
|
if missing:
|
|
raise ValueError(f"{path}: missing CSV columns: {sorted(missing)}")
|
|
return list(reader)
|
|
|
|
|
|
def require_paths(fixture_dir: Path) -> tuple[Path, Path, Path, Path]:
|
|
source_path = fixture_dir / "source.jsonl"
|
|
csv_path = fixture_dir / "frontier.csv"
|
|
sidecar_path = fixture_dir / "sidecar.jsonl"
|
|
manifest_path = fixture_dir / "manifest.json"
|
|
for path in (source_path, csv_path, sidecar_path, manifest_path):
|
|
if not path.exists():
|
|
raise ValueError(f"{fixture_dir}: missing {path.name}")
|
|
return source_path, csv_path, sidecar_path, manifest_path
|
|
|
|
|
|
def validate_fixture(fixture_dir: Path, block_size: int, max_tokens: int) -> str:
|
|
source_path, csv_path, sidecar_path, manifest_path = require_paths(fixture_dir)
|
|
source_rows = load_jsonl(source_path)
|
|
csv_rows = load_csv(csv_path)
|
|
sidecar_rows = load_jsonl(sidecar_path)
|
|
with manifest_path.open("r", encoding="utf-8") as handle:
|
|
manifest = json.load(handle)
|
|
|
|
row_count = len(csv_rows)
|
|
if len(source_rows) != row_count or len(sidecar_rows) != row_count:
|
|
raise ValueError(
|
|
f"{fixture_dir}: row count mismatch source={len(source_rows)} "
|
|
f"csv={row_count} sidecar={len(sidecar_rows)}"
|
|
)
|
|
if manifest.get("row_count") != row_count:
|
|
raise ValueError(
|
|
f"{fixture_dir}: manifest row_count={manifest.get('row_count')} "
|
|
f"does not match csv rows={row_count}"
|
|
)
|
|
if manifest.get("block_size") != block_size:
|
|
raise ValueError(
|
|
f"{fixture_dir}: manifest block_size={manifest.get('block_size')} "
|
|
f"does not match expected {block_size}"
|
|
)
|
|
if manifest.get("max_tokens") != max_tokens:
|
|
raise ValueError(
|
|
f"{fixture_dir}: manifest max_tokens={manifest.get('max_tokens')} "
|
|
f"does not match expected {max_tokens}"
|
|
)
|
|
|
|
previous_timestamp: float | None = None
|
|
max_total_tokens = 0
|
|
partial_final_block_rows = 0
|
|
for index, (source, csv_row, sidecar) in enumerate(
|
|
zip(source_rows, csv_rows, sidecar_rows)
|
|
):
|
|
prefix = f"{fixture_dir}: row {index}"
|
|
input_length = int(csv_row["num_prefill_tokens"])
|
|
output_length = int(csv_row["num_decode_tokens"])
|
|
total_tokens = input_length + output_length
|
|
if total_tokens > max_tokens:
|
|
raise ValueError(
|
|
f"{prefix}: total_tokens={total_tokens} exceeds max_tokens={max_tokens}"
|
|
)
|
|
max_total_tokens = max(max_total_tokens, total_tokens)
|
|
|
|
timestamp = float(csv_row["arrived_at"])
|
|
if previous_timestamp is not None and timestamp < previous_timestamp:
|
|
raise ValueError(f"{prefix}: timestamp is not monotonic")
|
|
previous_timestamp = timestamp
|
|
|
|
hash_ids = parse_block_hash_ids(csv_row["block_hash_ids"])
|
|
expected_hash_count = math.ceil(input_length / block_size)
|
|
if len(hash_ids) != expected_hash_count:
|
|
raise ValueError(
|
|
f"{prefix}: hash count {len(hash_ids)} != {expected_hash_count}"
|
|
)
|
|
counts = expected_block_counts(input_length, block_size)
|
|
if sum(counts) != input_length:
|
|
raise ValueError(f"{prefix}: expected block counts do not sum to input")
|
|
partial_final_block_rows += int(input_length % block_size != 0)
|
|
|
|
if int(csv_row["session_id"]) != int(source["chat_id"]):
|
|
raise ValueError(f"{prefix}: session_id does not match source chat_id")
|
|
if timestamp != float(source["timestamp"]):
|
|
raise ValueError(f"{prefix}: arrived_at does not match source timestamp")
|
|
if input_length != int(source["input_length"]):
|
|
raise ValueError(f"{prefix}: num_prefill_tokens does not match source")
|
|
if output_length != int(source["output_length"]):
|
|
raise ValueError(f"{prefix}: num_decode_tokens does not match source")
|
|
if hash_ids != source["hash_ids"]:
|
|
raise ValueError(f"{prefix}: block_hash_ids do not match source hash_ids")
|
|
|
|
required_sidecar_keys = {
|
|
"request_id",
|
|
"chat_id",
|
|
"parent_chat_id",
|
|
"turn",
|
|
"type",
|
|
"timestamp",
|
|
"input_length",
|
|
"output_length",
|
|
"hash_ids",
|
|
"block_token_counts",
|
|
}
|
|
missing = required_sidecar_keys - set(sidecar)
|
|
if missing:
|
|
raise ValueError(f"{prefix}: missing sidecar keys {sorted(missing)}")
|
|
if int(sidecar["request_id"]) != index:
|
|
raise ValueError(f"{prefix}: sidecar request_id mismatch")
|
|
if int(sidecar["chat_id"]) != int(source["chat_id"]):
|
|
raise ValueError(f"{prefix}: sidecar chat_id mismatch")
|
|
if int(sidecar["parent_chat_id"]) != int(source["parent_chat_id"]):
|
|
raise ValueError(f"{prefix}: sidecar parent_chat_id mismatch")
|
|
if int(sidecar["turn"]) != int(source["turn"]):
|
|
raise ValueError(f"{prefix}: sidecar turn mismatch")
|
|
if sidecar["type"] != source["type"]:
|
|
raise ValueError(f"{prefix}: sidecar type mismatch")
|
|
if float(sidecar["timestamp"]) != float(source["timestamp"]):
|
|
raise ValueError(f"{prefix}: sidecar timestamp mismatch")
|
|
if int(sidecar["input_length"]) != input_length:
|
|
raise ValueError(f"{prefix}: sidecar input_length mismatch")
|
|
if int(sidecar["output_length"]) != output_length:
|
|
raise ValueError(f"{prefix}: sidecar output_length mismatch")
|
|
if sidecar["hash_ids"] != hash_ids:
|
|
raise ValueError(f"{prefix}: sidecar hash_ids mismatch")
|
|
if sidecar["block_token_counts"] != counts:
|
|
raise ValueError(f"{prefix}: sidecar block_token_counts mismatch")
|
|
|
|
if manifest.get("max_total_tokens") != max_total_tokens:
|
|
raise ValueError(
|
|
f"{fixture_dir}: manifest max_total_tokens="
|
|
f"{manifest.get('max_total_tokens')} does not match {max_total_tokens}"
|
|
)
|
|
if manifest.get("partial_final_block_rows") != partial_final_block_rows:
|
|
raise ValueError(
|
|
f"{fixture_dir}: manifest partial_final_block_rows="
|
|
f"{manifest.get('partial_final_block_rows')} does not match "
|
|
f"{partial_final_block_rows}"
|
|
)
|
|
if manifest.get("overflow_count") != 0:
|
|
raise ValueError(f"{fixture_dir}: manifest overflow_count is not zero")
|
|
if manifest.get("timestamp_monotonic") is not True:
|
|
raise ValueError(f"{fixture_dir}: manifest timestamp_monotonic is not true")
|
|
|
|
return (
|
|
f"{fixture_dir.name}: rows={row_count} max_total_tokens={max_total_tokens} "
|
|
f"partial_final_block_rows={partial_final_block_rows}"
|
|
)
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
try:
|
|
for fixture_dir in args.fixture_dirs:
|
|
print(
|
|
validate_fixture(
|
|
fixture_dir=fixture_dir,
|
|
block_size=args.block_size,
|
|
max_tokens=args.max_tokens,
|
|
)
|
|
)
|
|
except Exception as exc:
|
|
print(f"validate_fixtures.py: error: {exc}", file=sys.stderr)
|
|
return 1
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|