diff --git a/scripts/prepare_trace_windows.py b/scripts/prepare_trace_windows.py new file mode 100644 index 0000000..beca987 --- /dev/null +++ b/scripts/prepare_trace_windows.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import argparse +import hashlib +import json +from pathlib import Path +from typing import Any + + +REPO_ROOT = Path(__file__).resolve().parents[1] +DEFAULT_LEGACY_SOURCE = Path( + "/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260311-260317-formatted" +) +DEFAULT_THINKING_SOURCE = Path( + "/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260321-260327-formatted" +) +DEFAULT_OUTPUT_ROOT = REPO_ROOT / "trace_windows" +LEGACY_TARGET_DATES = ["2026-03-11", "2026-03-12", "2026-03-13", "2026-03-16", "2026-03-17"] +THINKING_WINDOWS = [ + ("2026-03-21", "peak_weekend"), + ("2026-03-22", "peak_weekend"), + ("2026-03-23", "peak"), + ("2026-03-24", "peak"), + ("2026-03-25", "peak"), + ("2026-03-26", "peak"), + ("2026-03-27", "peak"), +] +WINDOW_SPECS = { + "peak": { + "start_hour": 9, + "window_start": 3600.0, + "window_end": 4200.0, + "slot_label": "10:00-10:10", + "slot_token": "1000", + }, + "peak_weekend": { + "start_hour": 9, + "window_start": 3600.0, + "window_end": 4200.0, + "slot_label": "10:00-10:10", + "slot_token": "1000", + }, + "valley": { + "start_hour": 21, + "window_start": 3600.0, + "window_end": 4200.0, + "slot_label": "22:00-22:10", + "slot_token": "2200", + }, +} +TRACE_SPECS = { + "chat": { + "block_size": 64, + "source_key": "legacy", + "target_dates": LEGACY_TARGET_DATES, + "day_parts": ("peak", "valley"), + }, + "coder": { + "block_size": 512, + "source_key": "legacy", + "target_dates": LEGACY_TARGET_DATES, + "day_parts": ("peak", "valley"), + }, + "thinking": { + "block_size": 64, + "source_key": "thinking", + "date_day_parts": THINKING_WINDOWS, + }, +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Prepare canonical trace windows under the current aituner repo by reading the " + "raw formatted Bailian trace directories and merging prompt payloads." + ) + ) + parser.add_argument("--legacy-source", type=Path, default=DEFAULT_LEGACY_SOURCE) + parser.add_argument("--thinking-source", type=Path, default=DEFAULT_THINKING_SOURCE) + parser.add_argument("--output-root", type=Path, default=DEFAULT_OUTPUT_ROOT) + parser.add_argument( + "--workloads", + default="chat,coder,thinking", + help="Comma-separated workloads to materialize. Defaults to chat,coder,thinking.", + ) + parser.add_argument("--sample-seed", type=int, default=20260325) + parser.add_argument("--overwrite", action="store_true") + return parser.parse_args() + + +def stable_uniform(*, seed: int, window_id: str, index: int, row: dict[str, Any]) -> float: + payload = json.dumps( + { + "seed": seed, + "window_id": window_id, + "index": index, + "timestamp": row.get("timestamp"), + "input_length": row.get("input_length"), + "output_length": row.get("output_length"), + "chat_id": row.get("chat_id"), + "turn": row.get("turn"), + }, + sort_keys=True, + separators=(",", ":"), + ).encode("utf-8") + digest = hashlib.blake2b(payload, digest_size=8).digest() + return int.from_bytes(digest, byteorder="big", signed=False) / float(1 << 64) + + +def build_source_filename(*, trace_type: str, date_text: str, day_part: str) -> str: + month, day = date_text[5:7], date_text[8:10] + start_hour = int(WINDOW_SPECS[day_part]["start_hour"]) + return ( + f"qwen_{trace_type}_blksz_{TRACE_SPECS[trace_type]['block_size']}_" + f"{month}{day}{start_hour:02d}-{month}{day}{start_hour + 2:02d}.jsonl" + ) + + +def parse_workloads(text: str) -> tuple[str, ...]: + workloads = tuple(part.strip() for part in str(text).split(",") if part.strip()) + if not workloads: + raise ValueError("--workloads must contain at least one workload") + unknown = [item for item in workloads if item not in TRACE_SPECS] + if unknown: + raise ValueError(f"Unknown workloads: {unknown}") + return workloads + + +def build_windows( + source_dirs: dict[str, Path], *, workloads: tuple[str, ...] +) -> list[dict[str, Any]]: + windows: list[dict[str, Any]] = [] + for trace_type in workloads: + spec = TRACE_SPECS[trace_type] + source_dir = source_dirs[str(spec["source_key"])] + if "date_day_parts" in spec: + date_day_parts = list(spec["date_day_parts"]) + else: + date_day_parts = [ + (date_text, day_part) + for day_part in spec["day_parts"] + for date_text in spec["target_dates"] + ] + sample_order_by_group: dict[str, int] = {} + for date_text, day_part in date_day_parts: + window_spec = WINDOW_SPECS[day_part] + date_token = date_text.replace("-", "") + sample_order = sample_order_by_group.setdefault(day_part, 0) + sample_order_by_group[day_part] += 1 + window_id = f"{trace_type}_w{date_token}_{day_part}_{window_spec['slot_token']}" + source_trace_path = source_dir / build_source_filename( + trace_type=trace_type, date_text=date_text, day_part=day_part + ) + source_prompt_path = Path(str(source_trace_path).replace(".jsonl", "_prompt.jsonl")) + windows.append( + { + "window_id": window_id, + "trace_type": trace_type, + "block_size": TRACE_SPECS[trace_type]["block_size"], + "window_start": float(window_spec["window_start"]), + "window_end": float(window_spec["window_end"]), + "window_index": 6, + "sample_order": sample_order, + "date": date_text, + "day_part": day_part, + "slot_label": window_spec["slot_label"], + "source_trace_path": str(source_trace_path), + "source_prompt_path": str(source_prompt_path), + } + ) + return windows + + +def _merge_trace_and_prompt(trace_row: dict[str, Any], prompt_row: dict[str, Any]) -> dict[str, Any]: + if trace_row.get("chat_id") != prompt_row.get("chat_id") or trace_row.get("turn") != prompt_row.get("turn"): + raise ValueError( + "trace/prompt rows are misaligned: " + f"{trace_row.get('chat_id')}/{trace_row.get('turn')} vs " + f"{prompt_row.get('chat_id')}/{prompt_row.get('turn')}" + ) + merged = dict(trace_row) + prompt = prompt_row.get("prompt") + if isinstance(prompt, str) and prompt: + merged["prompt"] = prompt + return merged + + +def extract_windows(windows: list[dict[str, Any]], *, sample_seed: int) -> dict[str, list[dict[str, Any]]]: + grouped: dict[tuple[Path, Path], list[dict[str, Any]]] = {} + for window in windows: + trace_path = Path(window["source_trace_path"]) + prompt_path = Path(window["source_prompt_path"]) + grouped.setdefault((trace_path, prompt_path), []).append(window) + + extracted: dict[str, list[dict[str, Any]]] = {str(window["window_id"]): [] for window in windows} + for (trace_path, prompt_path), bucket in grouped.items(): + bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"]))) + with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle: + for index, (trace_raw, prompt_raw) in enumerate(zip(trace_handle, prompt_handle)): + trace_raw = trace_raw.strip() + prompt_raw = prompt_raw.strip() + if not trace_raw or not prompt_raw: + continue + trace_row = json.loads(trace_raw) + prompt_row = json.loads(prompt_raw) + merged = _merge_trace_and_prompt(trace_row, prompt_row) + timestamp = float(merged.get("timestamp") or 0.0) + for window in bucket: + start = float(window["window_start"]) + end = float(window["window_end"]) + if start <= timestamp < end: + out = dict(merged) + out["source_timestamp"] = timestamp + out["timestamp"] = timestamp - start + out["sampling_u"] = stable_uniform( + seed=sample_seed, + window_id=str(window["window_id"]), + index=len(extracted[str(window["window_id"])]), + row=merged, + ) + extracted[str(window["window_id"])].append(out) + break + return extracted + + +def build_output_window( + window: dict[str, Any], + rows: list[dict[str, Any]], + trace_relpath: str, + *, + sample_seed: int, +) -> dict[str, Any]: + output = dict(window) + output["trace_file"] = trace_relpath + output["window_start"] = 0.0 + output["window_end"] = float(window["window_end"]) - float(window["window_start"]) + output["num_requests"] = len(rows) + output["sum_input_length"] = int(sum(int(row.get("input_length") or 0) for row in rows)) + output["max_input_length"] = int( + max((int(row.get("input_length") or 0) for row in rows), default=0) + ) + output["num_excluded_too_long"] = 0 + output["sampling_u_field"] = "sampling_u" + output["sampling_seed"] = int(sample_seed) + output["sampling_strategy"] = "fixed_uniform_score" + if rows: + output["first_request_ts"] = float(rows[0]["timestamp"]) + output["last_request_ts"] = float(rows[-1]["timestamp"]) + output["first_request_index"] = 0 + output["last_request_index"] = len(rows) - 1 + else: + output["first_request_ts"] = None + output["last_request_ts"] = None + output["first_request_index"] = None + output["last_request_index"] = None + return output + + +def main() -> int: + args = parse_args() + workloads = parse_workloads(args.workloads) + output_root = args.output_root.resolve() + if output_root.exists(): + if not args.overwrite: + raise SystemExit(f"Output root exists: {output_root} (pass --overwrite to replace)") + output_root.mkdir(parents=True, exist_ok=True) + traces_dir = output_root / "traces" + traces_dir.mkdir(parents=True, exist_ok=True) + + source_dirs = { + "legacy": args.legacy_source.resolve(), + "thinking": args.thinking_source.resolve(), + } + windows = build_windows(source_dirs, workloads=workloads) + extracted = extract_windows(windows, sample_seed=args.sample_seed) + + rendered_windows: list[dict[str, Any]] = [] + for window in windows: + rows = extracted[str(window["window_id"])] + trace_filename = f"{window['window_id']}.jsonl" + trace_path = traces_dir / trace_filename + with trace_path.open("w", encoding="utf-8") as handle: + for row in rows: + handle.write(json.dumps(row, ensure_ascii=False) + "\n") + rendered_windows.append( + build_output_window( + window, + rows, + trace_relpath=f"traces/{trace_filename}", + sample_seed=args.sample_seed, + ) + ) + + windows_payload = { + "sample_seed": args.sample_seed, + "source_trace_dirs": {key: str(path) for key, path in source_dirs.items()}, + "u_field": "sampling_u", + "window_duration_seconds": 600.0, + "windows": rendered_windows, + } + (output_root / "windows.json").write_text( + json.dumps(windows_payload, ensure_ascii=False, indent=2) + "\n", + encoding="utf-8", + ) + print(output_root) + print(f"windows={len(rendered_windows)}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())