Files
aituner/scripts/prepare_trace_windows.py

348 lines
13 KiB
Python

from __future__ import annotations
import argparse
import hashlib
import json
from pathlib import Path
from typing import Any
from dataclasses import dataclass
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_LEGACY_SOURCE = Path(
"/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260311-260317-formatted"
)
DEFAULT_THINKING_SOURCE = Path(
"/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260321-260327-formatted"
)
DEFAULT_OUTPUT_ROOT = REPO_ROOT / "trace_windows"
LEGACY_TARGET_DATES = ["2026-03-11", "2026-03-12", "2026-03-13", "2026-03-16", "2026-03-17"]
THINKING_WINDOWS = [
("2026-03-21", "1000"),
("2026-03-22", "1000"),
("2026-03-23", "1000"),
("2026-03-24", "1000"),
("2026-03-25", "1000"),
("2026-03-26", "1000"),
("2026-03-27", "1000"),
]
WINDOW_SPECS = {
"1000": {
"start_hour": 9,
"window_start": 3600.0,
"window_end": 4200.0,
"slot_label": "10:00-10:10",
"slot_token": "1000",
},
"2200": {
"start_hour": 21,
"window_start": 3600.0,
"window_end": 4200.0,
"slot_label": "22:00-22:10",
"slot_token": "2200",
},
}
TRACE_SPECS = {
"chat": {
"block_size": 64,
"source_key": "legacy",
"target_dates": LEGACY_TARGET_DATES,
"slot_tokens": ("1000", "2200"),
},
"coder": {
"block_size": 512,
"source_key": "legacy",
"target_dates": LEGACY_TARGET_DATES,
"slot_tokens": ("1000", "2200"),
},
"thinking": {
"block_size": 64,
"source_key": "thinking",
"date_slot_tokens": THINKING_WINDOWS,
},
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Prepare canonical trace windows under the current aituner repo by reading the "
"raw formatted Bailian trace directories and merging prompt payloads."
)
)
parser.add_argument("--legacy-source", type=Path, default=DEFAULT_LEGACY_SOURCE)
parser.add_argument("--thinking-source", type=Path, default=DEFAULT_THINKING_SOURCE)
parser.add_argument("--output-root", type=Path, default=DEFAULT_OUTPUT_ROOT)
parser.add_argument(
"--workloads",
default="chat,coder,thinking",
help="Comma-separated workloads to materialize. Defaults to chat,coder,thinking.",
)
parser.add_argument("--sample-seed", type=int, default=20260325)
parser.add_argument("--overwrite", action="store_true")
return parser.parse_args()
def stable_uniform(*, seed: int, window_id: str, index: int, row: dict[str, Any]) -> float:
payload = json.dumps(
{
"seed": seed,
"window_id": window_id,
"index": index,
"timestamp": row.get("timestamp"),
"input_length": row.get("input_length"),
"output_length": row.get("output_length"),
"chat_id": row.get("chat_id"),
"turn": row.get("turn"),
},
sort_keys=True,
separators=(",", ":"),
).encode("utf-8")
digest = hashlib.blake2b(payload, digest_size=8).digest()
return int.from_bytes(digest, byteorder="big", signed=False) / float(1 << 64)
def build_source_filename(*, trace_type: str, date_text: str, slot_token: str) -> str:
month, day = date_text[5:7], date_text[8:10]
start_hour = int(WINDOW_SPECS[slot_token]["start_hour"])
return (
f"qwen_{trace_type}_blksz_{TRACE_SPECS[trace_type]['block_size']}_"
f"{month}{day}{start_hour:02d}-{month}{day}{start_hour + 2:02d}.jsonl"
)
def parse_workloads(text: str) -> tuple[str, ...]:
workloads = tuple(part.strip() for part in str(text).split(",") if part.strip())
if not workloads:
raise ValueError("--workloads must contain at least one workload")
unknown = [item for item in workloads if item not in TRACE_SPECS]
if unknown:
raise ValueError(f"Unknown workloads: {unknown}")
return workloads
def build_windows(
source_dirs: dict[str, Path], *, workloads: tuple[str, ...]
) -> list[dict[str, Any]]:
windows: list[dict[str, Any]] = []
for trace_type in workloads:
spec = TRACE_SPECS[trace_type]
source_dir = source_dirs[str(spec["source_key"])]
if "date_slot_tokens" in spec:
date_slot_tokens = list(spec["date_slot_tokens"])
else:
date_slot_tokens = [
(date_text, slot_token)
for slot_token in spec["slot_tokens"]
for date_text in spec["target_dates"]
]
sample_order_by_group: dict[str, int] = {}
for date_text, slot_token in date_slot_tokens:
window_spec = WINDOW_SPECS[slot_token]
date_token = date_text.replace("-", "")
sample_order = sample_order_by_group.setdefault(slot_token, 0)
sample_order_by_group[slot_token] += 1
window_id = f"{trace_type}_w{date_token}_{window_spec['slot_token']}"
source_trace_path = source_dir / build_source_filename(
trace_type=trace_type, date_text=date_text, slot_token=slot_token
)
source_prompt_path = Path(str(source_trace_path).replace(".jsonl", "_prompt.jsonl"))
windows.append(
{
"window_id": window_id,
"trace_type": trace_type,
"block_size": TRACE_SPECS[trace_type]["block_size"],
"window_start": float(window_spec["window_start"]),
"window_end": float(window_spec["window_end"]),
"window_index": 6,
"sample_order": sample_order,
"date": date_text,
"slot_token": window_spec["slot_token"],
"slot_label": window_spec["slot_label"],
"source_trace_path": str(source_trace_path),
"source_prompt_path": str(source_prompt_path),
}
)
return windows
def _merge_trace_and_prompt(trace_row: dict[str, Any], prompt_row: dict[str, Any]) -> dict[str, Any]:
if trace_row.get("chat_id") != prompt_row.get("chat_id") or trace_row.get("turn") != prompt_row.get("turn"):
raise ValueError(
"trace/prompt rows are misaligned: "
f"{trace_row.get('chat_id')}/{trace_row.get('turn')} vs "
f"{prompt_row.get('chat_id')}/{prompt_row.get('turn')}"
)
merged = dict(trace_row)
prompt = prompt_row.get("prompt")
if isinstance(prompt, str) and prompt:
merged["prompt"] = prompt
return merged
@dataclass
class WindowStats:
num_requests: int = 0
sum_input_length: int = 0
max_input_length: int = 0
first_request_ts: float | None = None
last_request_ts: float | None = None
first_request_index: int | None = None
last_request_index: int | None = None
def record(self, row: dict[str, Any]) -> None:
input_length = int(row.get("input_length") or 0)
timestamp = float(row["timestamp"])
if self.num_requests == 0:
self.first_request_ts = timestamp
self.first_request_index = 0
self.last_request_ts = timestamp
self.last_request_index = self.num_requests
self.num_requests += 1
self.sum_input_length += input_length
self.max_input_length = max(self.max_input_length, input_length)
def materialize_windows(
windows: list[dict[str, Any]], *, sample_seed: int, traces_dir: Path
) -> dict[str, WindowStats]:
grouped: dict[tuple[Path, Path], list[dict[str, Any]]] = {}
for window in windows:
trace_path = Path(window["source_trace_path"])
prompt_path = Path(window["source_prompt_path"])
grouped.setdefault((trace_path, prompt_path), []).append(window)
stats_by_window = {str(window["window_id"]): WindowStats() for window in windows}
handles: dict[str, Any] = {}
try:
for window in windows:
window_id = str(window["window_id"])
handles[window_id] = (traces_dir / f"{window_id}.jsonl").open("w", encoding="utf-8")
for trace_path, prompt_path in sorted(grouped.keys()):
bucket = grouped[(trace_path, prompt_path)]
bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"])))
matched_rows = 0
with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle:
for trace_raw, prompt_raw in zip(trace_handle, prompt_handle):
trace_raw = trace_raw.strip()
if not trace_raw:
continue
trace_row = json.loads(trace_raw)
timestamp = float(trace_row.get("timestamp") or 0.0)
matched_window: dict[str, Any] | None = None
for window in bucket:
start = float(window["window_start"])
end = float(window["window_end"])
if start <= timestamp < end:
matched_window = window
break
if matched_window is None:
continue
prompt_raw = prompt_raw.strip()
if not prompt_raw:
continue
prompt_row = json.loads(prompt_raw)
merged = _merge_trace_and_prompt(trace_row, prompt_row)
window_id = str(matched_window["window_id"])
out = dict(merged)
start = float(matched_window["window_start"])
out["source_timestamp"] = timestamp
out["timestamp"] = timestamp - start
out["sampling_u"] = stable_uniform(
seed=sample_seed,
window_id=window_id,
index=stats_by_window[window_id].num_requests,
row=merged,
)
handles[window_id].write(json.dumps(out, ensure_ascii=False) + "\n")
stats_by_window[window_id].record(out)
matched_rows += 1
print(
f"materialized {trace_path.name} -> matched_rows={matched_rows}",
flush=True,
)
finally:
for handle in handles.values():
handle.close()
return stats_by_window
def build_output_window(
window: dict[str, Any],
stats: WindowStats,
trace_relpath: str,
*,
sample_seed: int,
) -> dict[str, Any]:
output = dict(window)
output["trace_file"] = trace_relpath
output["window_start"] = 0.0
output["window_end"] = float(window["window_end"]) - float(window["window_start"])
output["num_requests"] = stats.num_requests
output["sum_input_length"] = stats.sum_input_length
output["max_input_length"] = stats.max_input_length
output["num_excluded_too_long"] = 0
output["sampling_u_field"] = "sampling_u"
output["sampling_seed"] = int(sample_seed)
output["sampling_strategy"] = "fixed_uniform_score"
output["first_request_ts"] = stats.first_request_ts
output["last_request_ts"] = stats.last_request_ts
output["first_request_index"] = stats.first_request_index
output["last_request_index"] = stats.last_request_index
return output
def main() -> int:
args = parse_args()
workloads = parse_workloads(args.workloads)
output_root = args.output_root.resolve()
if output_root.exists():
if not args.overwrite:
raise SystemExit(f"Output root exists: {output_root} (pass --overwrite to replace)")
output_root.mkdir(parents=True, exist_ok=True)
traces_dir = output_root / "traces"
traces_dir.mkdir(parents=True, exist_ok=True)
source_dirs = {
"legacy": args.legacy_source.resolve(),
"thinking": args.thinking_source.resolve(),
}
windows = build_windows(source_dirs, workloads=workloads)
stats_by_window = materialize_windows(
windows,
sample_seed=args.sample_seed,
traces_dir=traces_dir,
)
rendered_windows: list[dict[str, Any]] = []
for window in windows:
trace_filename = f"{window['window_id']}.jsonl"
rendered_windows.append(
build_output_window(
window,
stats_by_window[str(window["window_id"])],
trace_relpath=f"traces/{trace_filename}",
sample_seed=args.sample_seed,
)
)
windows_payload = {
"sample_seed": args.sample_seed,
"source_trace_dirs": {key: str(path) for key, path in source_dirs.items()},
"u_field": "sampling_u",
"window_duration_seconds": 600.0,
"windows": rendered_windows,
}
(output_root / "windows.json").write_text(
json.dumps(windows_payload, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
print(output_root)
print(f"windows={len(rendered_windows)}")
return 0
if __name__ == "__main__":
raise SystemExit(main())