Files
replaysim/tools/qwen_to_frontier.py

392 lines
14 KiB
Python
Executable File

#!/usr/bin/env python3
"""Convert Qwen JSONL traces to Frontier trace-replay CSV fixtures."""
from __future__ import annotations
import argparse
import csv
import json
import math
import os
import sys
from pathlib import Path
from typing import Any
CSV_FIELDS = [
"arrived_at",
"num_prefill_tokens",
"num_decode_tokens",
"session_id",
"block_hash_ids",
]
SIDECAR_FIELDS = [
"request_id",
"chat_id",
"parent_chat_id",
"turn",
"type",
"timestamp",
"input_length",
"output_length",
"hash_ids",
"block_token_counts",
]
def positive_int(value: str) -> int:
parsed = int(value)
if parsed <= 0:
raise argparse.ArgumentTypeError("must be positive")
return parsed
def positive_float(value: str) -> float:
parsed = float(value)
if parsed <= 0:
raise argparse.ArgumentTypeError("must be positive")
return parsed
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Convert Qwen JSONL to Frontier CSV plus ReplayServe sidecar."
)
parser.add_argument("--input", required=True, type=Path, help="Qwen JSONL path.")
parser.add_argument(
"--frontier-csv", required=True, type=Path, help="Output Frontier CSV path."
)
parser.add_argument(
"--sidecar-jsonl",
required=True,
type=Path,
help="Output ReplayServe sidecar JSONL path.",
)
parser.add_argument(
"--source-jsonl",
type=Path,
help="Optional path for the original source JSONL slice.",
)
parser.add_argument(
"--manifest-json", type=Path, help="Optional path for fixture manifest JSON."
)
parser.add_argument(
"--fixture-name", help="Optional fixture name stored in the manifest."
)
parser.add_argument(
"--limit", type=positive_int, help="Maximum number of rows to convert."
)
parser.add_argument("--max-tokens", type=positive_int, default=32768)
parser.add_argument("--block-size", type=positive_int, default=16)
parser.add_argument(
"--timestamp-scale",
type=positive_float,
default=1.0,
help="Multiply each source timestamp before writing fixture files.",
)
parser.add_argument(
"--fail-on-overflow",
action="store_true",
help="Hard fail if input_length + output_length exceeds --max-tokens.",
)
return parser.parse_args()
def require_int(row: dict[str, Any], key: str, line_number: int) -> int:
try:
value = row[key]
except KeyError as exc:
raise ValueError(f"line {line_number}: missing field {key!r}") from exc
if isinstance(value, bool) or not isinstance(value, int):
raise ValueError(f"line {line_number}: field {key!r} must be an int")
return value
def require_number(row: dict[str, Any], key: str, line_number: int) -> int | float:
try:
value = row[key]
except KeyError as exc:
raise ValueError(f"line {line_number}: missing field {key!r}") from exc
if isinstance(value, bool) or not isinstance(value, (int, float)):
raise ValueError(f"line {line_number}: field {key!r} must be numeric")
return value
def require_hash_ids(row: dict[str, Any], line_number: int) -> list[int]:
try:
value = row["hash_ids"]
except KeyError as exc:
raise ValueError(f"line {line_number}: missing field 'hash_ids'") from exc
if not isinstance(value, list):
raise ValueError(f"line {line_number}: field 'hash_ids' must be a list")
hash_ids: list[int] = []
for index, item in enumerate(value):
if isinstance(item, bool) or not isinstance(item, int):
raise ValueError(
f"line {line_number}: hash_ids[{index}] must be an int"
)
hash_ids.append(item)
return hash_ids
def block_token_counts(input_length: int, hash_count: int, block_size: int) -> list[int]:
if hash_count == 0:
return []
last_count = input_length % block_size
if last_count == 0:
last_count = block_size
return [block_size] * (hash_count - 1) + [last_count]
def convert_row(
row: dict[str, Any],
request_id: int,
line_number: int,
block_size: int,
max_tokens: int,
fail_on_overflow: bool,
timestamp_scale: float,
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
chat_id = require_int(row, "chat_id", line_number)
parent_chat_id = require_int(row, "parent_chat_id", line_number)
timestamp = float(require_number(row, "timestamp", line_number)) * timestamp_scale
input_length = require_int(row, "input_length", line_number)
output_length = require_int(row, "output_length", line_number)
turn = require_int(row, "turn", line_number)
request_type = row.get("type")
hash_ids = require_hash_ids(row, line_number)
if input_length <= 0:
raise ValueError(f"line {line_number}: input_length must be positive")
if output_length <= 0:
raise ValueError(f"line {line_number}: output_length must be positive")
expected_hash_count = math.ceil(input_length / block_size)
if len(hash_ids) != expected_hash_count:
raise ValueError(
f"line {line_number}: len(hash_ids)={len(hash_ids)} does not match "
f"ceil(input_length / block_size)={expected_hash_count}"
)
total_tokens = input_length + output_length
overflow = total_tokens > max_tokens
if overflow and fail_on_overflow:
raise ValueError(
f"line {line_number}: total_tokens={total_tokens} exceeds "
f"max_tokens={max_tokens}"
)
counts = block_token_counts(input_length, len(hash_ids), block_size)
frontier_row = {
"arrived_at": timestamp,
"num_prefill_tokens": input_length,
"num_decode_tokens": output_length,
"session_id": chat_id,
"block_hash_ids": "|".join(str(item) for item in hash_ids),
}
sidecar_row = {
"request_id": request_id,
"chat_id": chat_id,
"parent_chat_id": parent_chat_id,
"turn": turn,
"type": request_type,
"timestamp": timestamp,
"input_length": input_length,
"output_length": output_length,
"hash_ids": hash_ids,
"block_token_counts": counts,
}
stats = {
"total_tokens": total_tokens,
"input_length": input_length,
"output_length": output_length,
"timestamp": timestamp,
"partial_final_block": input_length % block_size != 0,
"overflow": overflow,
}
return frontier_row, sidecar_row, stats
def tmp_path(path: Path) -> Path:
return path.with_name(f".{path.name}.tmp")
def ensure_parent(path: Path | None) -> None:
if path is not None:
path.parent.mkdir(parents=True, exist_ok=True)
def publish_tmp_files(paths: list[tuple[Path, Path]]) -> None:
for temporary, final in paths:
os.replace(temporary, final)
def cleanup_tmp_files(paths: list[tuple[Path, Path]]) -> None:
for temporary, _ in paths:
try:
temporary.unlink()
except FileNotFoundError:
pass
def main() -> int:
args = parse_args()
for output_path in (
args.frontier_csv,
args.sidecar_jsonl,
args.source_jsonl,
args.manifest_json,
):
ensure_parent(output_path)
temporary_paths: list[tuple[Path, Path]] = [
(tmp_path(args.frontier_csv), args.frontier_csv),
(tmp_path(args.sidecar_jsonl), args.sidecar_jsonl),
]
if args.source_jsonl is not None:
temporary_paths.append((tmp_path(args.source_jsonl), args.source_jsonl))
if args.manifest_json is not None:
temporary_paths.append((tmp_path(args.manifest_json), args.manifest_json))
row_count = 0
overflow_count = 0
max_total_tokens = 0
max_input_length = 0
max_output_length = 0
first_timestamp: float | None = None
last_timestamp: float | None = None
timestamp_monotonic = True
partial_final_block_rows = 0
try:
with (
args.input.open("r", encoding="utf-8") as input_file,
tmp_path(args.frontier_csv).open("w", encoding="utf-8", newline="") as csv_file,
tmp_path(args.sidecar_jsonl).open("w", encoding="utf-8") as sidecar_file,
):
csv_writer = csv.DictWriter(
csv_file, fieldnames=CSV_FIELDS, lineterminator="\n"
)
csv_writer.writeheader()
source_file = None
if args.source_jsonl is not None:
source_file = tmp_path(args.source_jsonl).open("w", encoding="utf-8")
try:
for line_number, raw_line in enumerate(input_file, start=1):
if args.limit is not None and row_count >= args.limit:
break
stripped = raw_line.strip()
if not stripped:
continue
row = json.loads(stripped)
frontier_row, sidecar_row, stats = convert_row(
row=row,
request_id=line_number - 1,
line_number=line_number,
block_size=args.block_size,
max_tokens=args.max_tokens,
fail_on_overflow=args.fail_on_overflow,
timestamp_scale=args.timestamp_scale,
)
csv_writer.writerow(frontier_row)
sidecar_file.write(
json.dumps(sidecar_row, sort_keys=True, separators=(",", ":"))
+ "\n"
)
if source_file is not None:
if args.timestamp_scale == 1.0:
source_file.write(
raw_line if raw_line.endswith("\n") else raw_line + "\n"
)
else:
source_row = dict(row)
source_row["timestamp"] = stats["timestamp"]
source_file.write(
json.dumps(
source_row, sort_keys=True, separators=(",", ":")
)
+ "\n"
)
row_count += 1
overflow_count += int(stats["overflow"])
max_total_tokens = max(max_total_tokens, int(stats["total_tokens"]))
max_input_length = max(max_input_length, int(stats["input_length"]))
max_output_length = max(max_output_length, int(stats["output_length"]))
partial_final_block_rows += int(stats["partial_final_block"])
timestamp = float(stats["timestamp"])
if first_timestamp is None:
first_timestamp = timestamp
if last_timestamp is not None and timestamp < last_timestamp:
timestamp_monotonic = False
last_timestamp = timestamp
finally:
if source_file is not None:
source_file.close()
if args.manifest_json is not None:
manifest = {
"fixture_name": args.fixture_name,
"generated_by": "tools/qwen_to_frontier.py",
"input_jsonl": str(args.input),
"source_jsonl": str(args.source_jsonl) if args.source_jsonl else None,
"frontier_csv": str(args.frontier_csv),
"sidecar_jsonl": str(args.sidecar_jsonl),
"csv_fields": CSV_FIELDS,
"sidecar_fields": SIDECAR_FIELDS,
"limit": args.limit,
"row_count": row_count,
"block_size": args.block_size,
"max_tokens": args.max_tokens,
"fail_on_overflow": args.fail_on_overflow,
"timestamp_scale": args.timestamp_scale,
"overflow_count": overflow_count,
"max_total_tokens": max_total_tokens,
"max_input_length": max_input_length,
"max_output_length": max_output_length,
"first_timestamp": first_timestamp,
"last_timestamp": last_timestamp,
"timestamp_monotonic": timestamp_monotonic,
"partial_final_block_rows": partial_final_block_rows,
"adapter_semantics": {
"timestamp": "arrived_at",
"input_length": "num_prefill_tokens",
"output_length": "num_decode_tokens",
"chat_id": "session_id",
"hash_ids": "block_hash_ids joined by |",
"block_token_counts": (
"full blocks use block_size tokens; final partial block "
"uses input_length % block_size, or block_size when zero"
),
},
}
with tmp_path(args.manifest_json).open("w", encoding="utf-8") as manifest_file:
json.dump(manifest, manifest_file, indent=2, sort_keys=True)
manifest_file.write("\n")
publish_tmp_files(temporary_paths)
except Exception as exc:
cleanup_tmp_files(temporary_paths)
print(f"qwen_to_frontier.py: error: {exc}", file=sys.stderr)
return 1
if overflow_count and not args.fail_on_overflow:
print(
f"qwen_to_frontier.py: warning: {overflow_count} rows exceed "
f"max_tokens={args.max_tokens}; no clipping was applied",
file=sys.stderr,
)
print(
f"converted rows={row_count} max_total_tokens={max_total_tokens} "
f"overflows={overflow_count}",
file=sys.stderr,
)
return 0
if __name__ == "__main__":
raise SystemExit(main())