from __future__ import annotations import argparse import hashlib import json from pathlib import Path from typing import Any from dataclasses import dataclass REPO_ROOT = Path(__file__).resolve().parents[1] DEFAULT_LEGACY_SOURCE = Path( "/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260311-260317-formatted" ) DEFAULT_THINKING_SOURCE = Path( "/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260321-260327-formatted" ) DEFAULT_OUTPUT_ROOT = REPO_ROOT / "trace_windows" LEGACY_TARGET_DATES = ["2026-03-11", "2026-03-12", "2026-03-13", "2026-03-16", "2026-03-17"] THINKING_WINDOWS = [ ("2026-03-21", "1000"), ("2026-03-22", "1000"), ("2026-03-23", "1000"), ("2026-03-24", "1000"), ("2026-03-25", "1000"), ("2026-03-26", "1000"), ("2026-03-27", "1000"), ] WINDOW_SPECS = { "1000": { "start_hour": 9, "window_start": 3600.0, "window_end": 4200.0, "slot_label": "10:00-10:10", "slot_token": "1000", }, "2200": { "start_hour": 21, "window_start": 3600.0, "window_end": 4200.0, "slot_label": "22:00-22:10", "slot_token": "2200", }, } TRACE_SPECS = { "chat": { "block_size": 64, "source_key": "legacy", "target_dates": LEGACY_TARGET_DATES, "slot_tokens": ("1000", "2200"), }, "coder": { "block_size": 512, "source_key": "legacy", "target_dates": LEGACY_TARGET_DATES, "slot_tokens": ("1000", "2200"), }, "thinking": { "block_size": 64, "source_key": "thinking", "date_slot_tokens": THINKING_WINDOWS, }, } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description=( "Prepare canonical trace windows under the current aituner repo by reading the " "raw formatted Bailian trace directories and merging prompt payloads." ) ) parser.add_argument("--legacy-source", type=Path, default=DEFAULT_LEGACY_SOURCE) parser.add_argument("--thinking-source", type=Path, default=DEFAULT_THINKING_SOURCE) parser.add_argument("--output-root", type=Path, default=DEFAULT_OUTPUT_ROOT) parser.add_argument( "--workloads", default="chat,coder,thinking", help="Comma-separated workloads to materialize. Defaults to chat,coder,thinking.", ) parser.add_argument("--sample-seed", type=int, default=20260325) parser.add_argument("--overwrite", action="store_true") return parser.parse_args() def stable_uniform(*, seed: int, window_id: str, index: int, row: dict[str, Any]) -> float: payload = json.dumps( { "seed": seed, "window_id": window_id, "index": index, "timestamp": row.get("timestamp"), "input_length": row.get("input_length"), "output_length": row.get("output_length"), "chat_id": row.get("chat_id"), "turn": row.get("turn"), }, sort_keys=True, separators=(",", ":"), ).encode("utf-8") digest = hashlib.blake2b(payload, digest_size=8).digest() return int.from_bytes(digest, byteorder="big", signed=False) / float(1 << 64) def build_source_filename(*, trace_type: str, date_text: str, slot_token: str) -> str: month, day = date_text[5:7], date_text[8:10] start_hour = int(WINDOW_SPECS[slot_token]["start_hour"]) return ( f"qwen_{trace_type}_blksz_{TRACE_SPECS[trace_type]['block_size']}_" f"{month}{day}{start_hour:02d}-{month}{day}{start_hour + 2:02d}.jsonl" ) def parse_workloads(text: str) -> tuple[str, ...]: workloads = tuple(part.strip() for part in str(text).split(",") if part.strip()) if not workloads: raise ValueError("--workloads must contain at least one workload") unknown = [item for item in workloads if item not in TRACE_SPECS] if unknown: raise ValueError(f"Unknown workloads: {unknown}") return workloads def build_windows( source_dirs: dict[str, Path], *, workloads: tuple[str, ...] ) -> list[dict[str, Any]]: windows: list[dict[str, Any]] = [] for trace_type in workloads: spec = TRACE_SPECS[trace_type] source_dir = source_dirs[str(spec["source_key"])] if "date_slot_tokens" in spec: date_slot_tokens = list(spec["date_slot_tokens"]) else: date_slot_tokens = [ (date_text, slot_token) for slot_token in spec["slot_tokens"] for date_text in spec["target_dates"] ] sample_order_by_group: dict[str, int] = {} for date_text, slot_token in date_slot_tokens: window_spec = WINDOW_SPECS[slot_token] date_token = date_text.replace("-", "") sample_order = sample_order_by_group.setdefault(slot_token, 0) sample_order_by_group[slot_token] += 1 window_id = f"{trace_type}_w{date_token}_{window_spec['slot_token']}" source_trace_path = source_dir / build_source_filename( trace_type=trace_type, date_text=date_text, slot_token=slot_token ) source_prompt_path = Path(str(source_trace_path).replace(".jsonl", "_prompt.jsonl")) windows.append( { "window_id": window_id, "trace_type": trace_type, "block_size": TRACE_SPECS[trace_type]["block_size"], "window_start": float(window_spec["window_start"]), "window_end": float(window_spec["window_end"]), "window_index": 6, "sample_order": sample_order, "date": date_text, "slot_token": window_spec["slot_token"], "slot_label": window_spec["slot_label"], "source_trace_path": str(source_trace_path), "source_prompt_path": str(source_prompt_path), } ) return windows def _merge_trace_and_prompt(trace_row: dict[str, Any], prompt_row: dict[str, Any]) -> dict[str, Any]: if trace_row.get("chat_id") != prompt_row.get("chat_id") or trace_row.get("turn") != prompt_row.get("turn"): raise ValueError( "trace/prompt rows are misaligned: " f"{trace_row.get('chat_id')}/{trace_row.get('turn')} vs " f"{prompt_row.get('chat_id')}/{prompt_row.get('turn')}" ) merged = dict(trace_row) prompt = prompt_row.get("prompt") if isinstance(prompt, str) and prompt: merged["prompt"] = prompt return merged @dataclass class WindowStats: num_requests: int = 0 sum_input_length: int = 0 max_input_length: int = 0 first_request_ts: float | None = None last_request_ts: float | None = None first_request_index: int | None = None last_request_index: int | None = None def record(self, row: dict[str, Any]) -> None: input_length = int(row.get("input_length") or 0) timestamp = float(row["timestamp"]) if self.num_requests == 0: self.first_request_ts = timestamp self.first_request_index = 0 self.last_request_ts = timestamp self.last_request_index = self.num_requests self.num_requests += 1 self.sum_input_length += input_length self.max_input_length = max(self.max_input_length, input_length) def materialize_windows( windows: list[dict[str, Any]], *, sample_seed: int, traces_dir: Path ) -> dict[str, WindowStats]: grouped: dict[tuple[Path, Path], list[dict[str, Any]]] = {} for window in windows: trace_path = Path(window["source_trace_path"]) prompt_path = Path(window["source_prompt_path"]) grouped.setdefault((trace_path, prompt_path), []).append(window) stats_by_window = {str(window["window_id"]): WindowStats() for window in windows} handles: dict[str, Any] = {} try: for window in windows: window_id = str(window["window_id"]) handles[window_id] = (traces_dir / f"{window_id}.jsonl").open("w", encoding="utf-8") for trace_path, prompt_path in sorted(grouped.keys()): bucket = grouped[(trace_path, prompt_path)] bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"]))) matched_rows = 0 with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle: for trace_raw, prompt_raw in zip(trace_handle, prompt_handle): trace_raw = trace_raw.strip() if not trace_raw: continue trace_row = json.loads(trace_raw) timestamp = float(trace_row.get("timestamp") or 0.0) matched_window: dict[str, Any] | None = None for window in bucket: start = float(window["window_start"]) end = float(window["window_end"]) if start <= timestamp < end: matched_window = window break if matched_window is None: continue prompt_raw = prompt_raw.strip() if not prompt_raw: continue prompt_row = json.loads(prompt_raw) merged = _merge_trace_and_prompt(trace_row, prompt_row) window_id = str(matched_window["window_id"]) out = dict(merged) start = float(matched_window["window_start"]) out["source_timestamp"] = timestamp out["timestamp"] = timestamp - start out["sampling_u"] = stable_uniform( seed=sample_seed, window_id=window_id, index=stats_by_window[window_id].num_requests, row=merged, ) handles[window_id].write(json.dumps(out, ensure_ascii=False) + "\n") stats_by_window[window_id].record(out) matched_rows += 1 print( f"materialized {trace_path.name} -> matched_rows={matched_rows}", flush=True, ) finally: for handle in handles.values(): handle.close() return stats_by_window def build_output_window( window: dict[str, Any], stats: WindowStats, trace_relpath: str, *, sample_seed: int, ) -> dict[str, Any]: output = dict(window) output["trace_file"] = trace_relpath output["window_start"] = 0.0 output["window_end"] = float(window["window_end"]) - float(window["window_start"]) output["num_requests"] = stats.num_requests output["sum_input_length"] = stats.sum_input_length output["max_input_length"] = stats.max_input_length output["num_excluded_too_long"] = 0 output["sampling_u_field"] = "sampling_u" output["sampling_seed"] = int(sample_seed) output["sampling_strategy"] = "fixed_uniform_score" output["first_request_ts"] = stats.first_request_ts output["last_request_ts"] = stats.last_request_ts output["first_request_index"] = stats.first_request_index output["last_request_index"] = stats.last_request_index return output def main() -> int: args = parse_args() workloads = parse_workloads(args.workloads) output_root = args.output_root.resolve() if output_root.exists(): if not args.overwrite: raise SystemExit(f"Output root exists: {output_root} (pass --overwrite to replace)") output_root.mkdir(parents=True, exist_ok=True) traces_dir = output_root / "traces" traces_dir.mkdir(parents=True, exist_ok=True) source_dirs = { "legacy": args.legacy_source.resolve(), "thinking": args.thinking_source.resolve(), } windows = build_windows(source_dirs, workloads=workloads) stats_by_window = materialize_windows( windows, sample_seed=args.sample_seed, traces_dir=traces_dir, ) rendered_windows: list[dict[str, Any]] = [] for window in windows: trace_filename = f"{window['window_id']}.jsonl" rendered_windows.append( build_output_window( window, stats_by_window[str(window["window_id"])], trace_relpath=f"traces/{trace_filename}", sample_seed=args.sample_seed, ) ) windows_payload = { "sample_seed": args.sample_seed, "source_trace_dirs": {key: str(path) for key, path in source_dirs.items()}, "u_field": "sampling_u", "window_duration_seconds": 600.0, "windows": rendered_windows, } (output_root / "windows.json").write_text( json.dumps(windows_payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8", ) print(output_root) print(f"windows={len(rendered_windows)}") return 0 if __name__ == "__main__": raise SystemExit(main())