392 lines
14 KiB
Python
Executable File
392 lines
14 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Convert Qwen JSONL traces to Frontier trace-replay CSV fixtures."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
import math
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
CSV_FIELDS = [
|
|
"arrived_at",
|
|
"num_prefill_tokens",
|
|
"num_decode_tokens",
|
|
"session_id",
|
|
"block_hash_ids",
|
|
]
|
|
|
|
SIDECAR_FIELDS = [
|
|
"request_id",
|
|
"chat_id",
|
|
"parent_chat_id",
|
|
"turn",
|
|
"type",
|
|
"timestamp",
|
|
"input_length",
|
|
"output_length",
|
|
"hash_ids",
|
|
"block_token_counts",
|
|
]
|
|
|
|
|
|
def positive_int(value: str) -> int:
|
|
parsed = int(value)
|
|
if parsed <= 0:
|
|
raise argparse.ArgumentTypeError("must be positive")
|
|
return parsed
|
|
|
|
|
|
def positive_float(value: str) -> float:
|
|
parsed = float(value)
|
|
if parsed <= 0:
|
|
raise argparse.ArgumentTypeError("must be positive")
|
|
return parsed
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Convert Qwen JSONL to Frontier CSV plus ReplayServe sidecar."
|
|
)
|
|
parser.add_argument("--input", required=True, type=Path, help="Qwen JSONL path.")
|
|
parser.add_argument(
|
|
"--frontier-csv", required=True, type=Path, help="Output Frontier CSV path."
|
|
)
|
|
parser.add_argument(
|
|
"--sidecar-jsonl",
|
|
required=True,
|
|
type=Path,
|
|
help="Output ReplayServe sidecar JSONL path.",
|
|
)
|
|
parser.add_argument(
|
|
"--source-jsonl",
|
|
type=Path,
|
|
help="Optional path for the original source JSONL slice.",
|
|
)
|
|
parser.add_argument(
|
|
"--manifest-json", type=Path, help="Optional path for fixture manifest JSON."
|
|
)
|
|
parser.add_argument(
|
|
"--fixture-name", help="Optional fixture name stored in the manifest."
|
|
)
|
|
parser.add_argument(
|
|
"--limit", type=positive_int, help="Maximum number of rows to convert."
|
|
)
|
|
parser.add_argument("--max-tokens", type=positive_int, default=32768)
|
|
parser.add_argument("--block-size", type=positive_int, default=16)
|
|
parser.add_argument(
|
|
"--timestamp-scale",
|
|
type=positive_float,
|
|
default=1.0,
|
|
help="Multiply each source timestamp before writing fixture files.",
|
|
)
|
|
parser.add_argument(
|
|
"--fail-on-overflow",
|
|
action="store_true",
|
|
help="Hard fail if input_length + output_length exceeds --max-tokens.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def require_int(row: dict[str, Any], key: str, line_number: int) -> int:
|
|
try:
|
|
value = row[key]
|
|
except KeyError as exc:
|
|
raise ValueError(f"line {line_number}: missing field {key!r}") from exc
|
|
if isinstance(value, bool) or not isinstance(value, int):
|
|
raise ValueError(f"line {line_number}: field {key!r} must be an int")
|
|
return value
|
|
|
|
|
|
def require_number(row: dict[str, Any], key: str, line_number: int) -> int | float:
|
|
try:
|
|
value = row[key]
|
|
except KeyError as exc:
|
|
raise ValueError(f"line {line_number}: missing field {key!r}") from exc
|
|
if isinstance(value, bool) or not isinstance(value, (int, float)):
|
|
raise ValueError(f"line {line_number}: field {key!r} must be numeric")
|
|
return value
|
|
|
|
|
|
def require_hash_ids(row: dict[str, Any], line_number: int) -> list[int]:
|
|
try:
|
|
value = row["hash_ids"]
|
|
except KeyError as exc:
|
|
raise ValueError(f"line {line_number}: missing field 'hash_ids'") from exc
|
|
if not isinstance(value, list):
|
|
raise ValueError(f"line {line_number}: field 'hash_ids' must be a list")
|
|
hash_ids: list[int] = []
|
|
for index, item in enumerate(value):
|
|
if isinstance(item, bool) or not isinstance(item, int):
|
|
raise ValueError(
|
|
f"line {line_number}: hash_ids[{index}] must be an int"
|
|
)
|
|
hash_ids.append(item)
|
|
return hash_ids
|
|
|
|
|
|
def block_token_counts(input_length: int, hash_count: int, block_size: int) -> list[int]:
|
|
if hash_count == 0:
|
|
return []
|
|
last_count = input_length % block_size
|
|
if last_count == 0:
|
|
last_count = block_size
|
|
return [block_size] * (hash_count - 1) + [last_count]
|
|
|
|
|
|
def convert_row(
|
|
row: dict[str, Any],
|
|
request_id: int,
|
|
line_number: int,
|
|
block_size: int,
|
|
max_tokens: int,
|
|
fail_on_overflow: bool,
|
|
timestamp_scale: float,
|
|
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
|
chat_id = require_int(row, "chat_id", line_number)
|
|
parent_chat_id = require_int(row, "parent_chat_id", line_number)
|
|
timestamp = float(require_number(row, "timestamp", line_number)) * timestamp_scale
|
|
input_length = require_int(row, "input_length", line_number)
|
|
output_length = require_int(row, "output_length", line_number)
|
|
turn = require_int(row, "turn", line_number)
|
|
request_type = row.get("type")
|
|
hash_ids = require_hash_ids(row, line_number)
|
|
|
|
if input_length <= 0:
|
|
raise ValueError(f"line {line_number}: input_length must be positive")
|
|
if output_length <= 0:
|
|
raise ValueError(f"line {line_number}: output_length must be positive")
|
|
|
|
expected_hash_count = math.ceil(input_length / block_size)
|
|
if len(hash_ids) != expected_hash_count:
|
|
raise ValueError(
|
|
f"line {line_number}: len(hash_ids)={len(hash_ids)} does not match "
|
|
f"ceil(input_length / block_size)={expected_hash_count}"
|
|
)
|
|
|
|
total_tokens = input_length + output_length
|
|
overflow = total_tokens > max_tokens
|
|
if overflow and fail_on_overflow:
|
|
raise ValueError(
|
|
f"line {line_number}: total_tokens={total_tokens} exceeds "
|
|
f"max_tokens={max_tokens}"
|
|
)
|
|
|
|
counts = block_token_counts(input_length, len(hash_ids), block_size)
|
|
frontier_row = {
|
|
"arrived_at": timestamp,
|
|
"num_prefill_tokens": input_length,
|
|
"num_decode_tokens": output_length,
|
|
"session_id": chat_id,
|
|
"block_hash_ids": "|".join(str(item) for item in hash_ids),
|
|
}
|
|
sidecar_row = {
|
|
"request_id": request_id,
|
|
"chat_id": chat_id,
|
|
"parent_chat_id": parent_chat_id,
|
|
"turn": turn,
|
|
"type": request_type,
|
|
"timestamp": timestamp,
|
|
"input_length": input_length,
|
|
"output_length": output_length,
|
|
"hash_ids": hash_ids,
|
|
"block_token_counts": counts,
|
|
}
|
|
stats = {
|
|
"total_tokens": total_tokens,
|
|
"input_length": input_length,
|
|
"output_length": output_length,
|
|
"timestamp": timestamp,
|
|
"partial_final_block": input_length % block_size != 0,
|
|
"overflow": overflow,
|
|
}
|
|
return frontier_row, sidecar_row, stats
|
|
|
|
|
|
def tmp_path(path: Path) -> Path:
|
|
return path.with_name(f".{path.name}.tmp")
|
|
|
|
|
|
def ensure_parent(path: Path | None) -> None:
|
|
if path is not None:
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def publish_tmp_files(paths: list[tuple[Path, Path]]) -> None:
|
|
for temporary, final in paths:
|
|
os.replace(temporary, final)
|
|
|
|
|
|
def cleanup_tmp_files(paths: list[tuple[Path, Path]]) -> None:
|
|
for temporary, _ in paths:
|
|
try:
|
|
temporary.unlink()
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
for output_path in (
|
|
args.frontier_csv,
|
|
args.sidecar_jsonl,
|
|
args.source_jsonl,
|
|
args.manifest_json,
|
|
):
|
|
ensure_parent(output_path)
|
|
|
|
temporary_paths: list[tuple[Path, Path]] = [
|
|
(tmp_path(args.frontier_csv), args.frontier_csv),
|
|
(tmp_path(args.sidecar_jsonl), args.sidecar_jsonl),
|
|
]
|
|
if args.source_jsonl is not None:
|
|
temporary_paths.append((tmp_path(args.source_jsonl), args.source_jsonl))
|
|
if args.manifest_json is not None:
|
|
temporary_paths.append((tmp_path(args.manifest_json), args.manifest_json))
|
|
|
|
row_count = 0
|
|
overflow_count = 0
|
|
max_total_tokens = 0
|
|
max_input_length = 0
|
|
max_output_length = 0
|
|
first_timestamp: float | None = None
|
|
last_timestamp: float | None = None
|
|
timestamp_monotonic = True
|
|
partial_final_block_rows = 0
|
|
|
|
try:
|
|
with (
|
|
args.input.open("r", encoding="utf-8") as input_file,
|
|
tmp_path(args.frontier_csv).open("w", encoding="utf-8", newline="") as csv_file,
|
|
tmp_path(args.sidecar_jsonl).open("w", encoding="utf-8") as sidecar_file,
|
|
):
|
|
csv_writer = csv.DictWriter(
|
|
csv_file, fieldnames=CSV_FIELDS, lineterminator="\n"
|
|
)
|
|
csv_writer.writeheader()
|
|
|
|
source_file = None
|
|
if args.source_jsonl is not None:
|
|
source_file = tmp_path(args.source_jsonl).open("w", encoding="utf-8")
|
|
|
|
try:
|
|
for line_number, raw_line in enumerate(input_file, start=1):
|
|
if args.limit is not None and row_count >= args.limit:
|
|
break
|
|
stripped = raw_line.strip()
|
|
if not stripped:
|
|
continue
|
|
row = json.loads(stripped)
|
|
frontier_row, sidecar_row, stats = convert_row(
|
|
row=row,
|
|
request_id=line_number - 1,
|
|
line_number=line_number,
|
|
block_size=args.block_size,
|
|
max_tokens=args.max_tokens,
|
|
fail_on_overflow=args.fail_on_overflow,
|
|
timestamp_scale=args.timestamp_scale,
|
|
)
|
|
csv_writer.writerow(frontier_row)
|
|
sidecar_file.write(
|
|
json.dumps(sidecar_row, sort_keys=True, separators=(",", ":"))
|
|
+ "\n"
|
|
)
|
|
if source_file is not None:
|
|
if args.timestamp_scale == 1.0:
|
|
source_file.write(
|
|
raw_line if raw_line.endswith("\n") else raw_line + "\n"
|
|
)
|
|
else:
|
|
source_row = dict(row)
|
|
source_row["timestamp"] = stats["timestamp"]
|
|
source_file.write(
|
|
json.dumps(
|
|
source_row, sort_keys=True, separators=(",", ":")
|
|
)
|
|
+ "\n"
|
|
)
|
|
|
|
row_count += 1
|
|
overflow_count += int(stats["overflow"])
|
|
max_total_tokens = max(max_total_tokens, int(stats["total_tokens"]))
|
|
max_input_length = max(max_input_length, int(stats["input_length"]))
|
|
max_output_length = max(max_output_length, int(stats["output_length"]))
|
|
partial_final_block_rows += int(stats["partial_final_block"])
|
|
timestamp = float(stats["timestamp"])
|
|
if first_timestamp is None:
|
|
first_timestamp = timestamp
|
|
if last_timestamp is not None and timestamp < last_timestamp:
|
|
timestamp_monotonic = False
|
|
last_timestamp = timestamp
|
|
finally:
|
|
if source_file is not None:
|
|
source_file.close()
|
|
|
|
if args.manifest_json is not None:
|
|
manifest = {
|
|
"fixture_name": args.fixture_name,
|
|
"generated_by": "tools/qwen_to_frontier.py",
|
|
"input_jsonl": str(args.input),
|
|
"source_jsonl": str(args.source_jsonl) if args.source_jsonl else None,
|
|
"frontier_csv": str(args.frontier_csv),
|
|
"sidecar_jsonl": str(args.sidecar_jsonl),
|
|
"csv_fields": CSV_FIELDS,
|
|
"sidecar_fields": SIDECAR_FIELDS,
|
|
"limit": args.limit,
|
|
"row_count": row_count,
|
|
"block_size": args.block_size,
|
|
"max_tokens": args.max_tokens,
|
|
"fail_on_overflow": args.fail_on_overflow,
|
|
"timestamp_scale": args.timestamp_scale,
|
|
"overflow_count": overflow_count,
|
|
"max_total_tokens": max_total_tokens,
|
|
"max_input_length": max_input_length,
|
|
"max_output_length": max_output_length,
|
|
"first_timestamp": first_timestamp,
|
|
"last_timestamp": last_timestamp,
|
|
"timestamp_monotonic": timestamp_monotonic,
|
|
"partial_final_block_rows": partial_final_block_rows,
|
|
"adapter_semantics": {
|
|
"timestamp": "arrived_at",
|
|
"input_length": "num_prefill_tokens",
|
|
"output_length": "num_decode_tokens",
|
|
"chat_id": "session_id",
|
|
"hash_ids": "block_hash_ids joined by |",
|
|
"block_token_counts": (
|
|
"full blocks use block_size tokens; final partial block "
|
|
"uses input_length % block_size, or block_size when zero"
|
|
),
|
|
},
|
|
}
|
|
with tmp_path(args.manifest_json).open("w", encoding="utf-8") as manifest_file:
|
|
json.dump(manifest, manifest_file, indent=2, sort_keys=True)
|
|
manifest_file.write("\n")
|
|
|
|
publish_tmp_files(temporary_paths)
|
|
except Exception as exc:
|
|
cleanup_tmp_files(temporary_paths)
|
|
print(f"qwen_to_frontier.py: error: {exc}", file=sys.stderr)
|
|
return 1
|
|
|
|
if overflow_count and not args.fail_on_overflow:
|
|
print(
|
|
f"qwen_to_frontier.py: warning: {overflow_count} rows exceed "
|
|
f"max_tokens={args.max_tokens}; no clipping was applied",
|
|
file=sys.stderr,
|
|
)
|
|
print(
|
|
f"converted rows={row_count} max_total_tokens={max_total_tokens} "
|
|
f"overflows={overflow_count}",
|
|
file=sys.stderr,
|
|
)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|