Files
aituner/scripts/prepare_trace_windows.py

313 lines
12 KiB
Python

from __future__ import annotations
import argparse
import hashlib
import json
from pathlib import Path
from typing import Any
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_LEGACY_SOURCE = Path(
"/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260311-260317-formatted"
)
DEFAULT_THINKING_SOURCE = Path(
"/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260321-260327-formatted"
)
DEFAULT_OUTPUT_ROOT = REPO_ROOT / "trace_windows"
LEGACY_TARGET_DATES = ["2026-03-11", "2026-03-12", "2026-03-13", "2026-03-16", "2026-03-17"]
THINKING_WINDOWS = [
("2026-03-21", "peak_weekend"),
("2026-03-22", "peak_weekend"),
("2026-03-23", "peak"),
("2026-03-24", "peak"),
("2026-03-25", "peak"),
("2026-03-26", "peak"),
("2026-03-27", "peak"),
]
WINDOW_SPECS = {
"peak": {
"start_hour": 9,
"window_start": 3600.0,
"window_end": 4200.0,
"slot_label": "10:00-10:10",
"slot_token": "1000",
},
"peak_weekend": {
"start_hour": 9,
"window_start": 3600.0,
"window_end": 4200.0,
"slot_label": "10:00-10:10",
"slot_token": "1000",
},
"valley": {
"start_hour": 21,
"window_start": 3600.0,
"window_end": 4200.0,
"slot_label": "22:00-22:10",
"slot_token": "2200",
},
}
TRACE_SPECS = {
"chat": {
"block_size": 64,
"source_key": "legacy",
"target_dates": LEGACY_TARGET_DATES,
"day_parts": ("peak", "valley"),
},
"coder": {
"block_size": 512,
"source_key": "legacy",
"target_dates": LEGACY_TARGET_DATES,
"day_parts": ("peak", "valley"),
},
"thinking": {
"block_size": 64,
"source_key": "thinking",
"date_day_parts": THINKING_WINDOWS,
},
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Prepare canonical trace windows under the current aituner repo by reading the "
"raw formatted Bailian trace directories and merging prompt payloads."
)
)
parser.add_argument("--legacy-source", type=Path, default=DEFAULT_LEGACY_SOURCE)
parser.add_argument("--thinking-source", type=Path, default=DEFAULT_THINKING_SOURCE)
parser.add_argument("--output-root", type=Path, default=DEFAULT_OUTPUT_ROOT)
parser.add_argument(
"--workloads",
default="chat,coder,thinking",
help="Comma-separated workloads to materialize. Defaults to chat,coder,thinking.",
)
parser.add_argument("--sample-seed", type=int, default=20260325)
parser.add_argument("--overwrite", action="store_true")
return parser.parse_args()
def stable_uniform(*, seed: int, window_id: str, index: int, row: dict[str, Any]) -> float:
payload = json.dumps(
{
"seed": seed,
"window_id": window_id,
"index": index,
"timestamp": row.get("timestamp"),
"input_length": row.get("input_length"),
"output_length": row.get("output_length"),
"chat_id": row.get("chat_id"),
"turn": row.get("turn"),
},
sort_keys=True,
separators=(",", ":"),
).encode("utf-8")
digest = hashlib.blake2b(payload, digest_size=8).digest()
return int.from_bytes(digest, byteorder="big", signed=False) / float(1 << 64)
def build_source_filename(*, trace_type: str, date_text: str, day_part: str) -> str:
month, day = date_text[5:7], date_text[8:10]
start_hour = int(WINDOW_SPECS[day_part]["start_hour"])
return (
f"qwen_{trace_type}_blksz_{TRACE_SPECS[trace_type]['block_size']}_"
f"{month}{day}{start_hour:02d}-{month}{day}{start_hour + 2:02d}.jsonl"
)
def parse_workloads(text: str) -> tuple[str, ...]:
workloads = tuple(part.strip() for part in str(text).split(",") if part.strip())
if not workloads:
raise ValueError("--workloads must contain at least one workload")
unknown = [item for item in workloads if item not in TRACE_SPECS]
if unknown:
raise ValueError(f"Unknown workloads: {unknown}")
return workloads
def build_windows(
source_dirs: dict[str, Path], *, workloads: tuple[str, ...]
) -> list[dict[str, Any]]:
windows: list[dict[str, Any]] = []
for trace_type in workloads:
spec = TRACE_SPECS[trace_type]
source_dir = source_dirs[str(spec["source_key"])]
if "date_day_parts" in spec:
date_day_parts = list(spec["date_day_parts"])
else:
date_day_parts = [
(date_text, day_part)
for day_part in spec["day_parts"]
for date_text in spec["target_dates"]
]
sample_order_by_group: dict[str, int] = {}
for date_text, day_part in date_day_parts:
window_spec = WINDOW_SPECS[day_part]
date_token = date_text.replace("-", "")
sample_order = sample_order_by_group.setdefault(day_part, 0)
sample_order_by_group[day_part] += 1
window_id = f"{trace_type}_w{date_token}_{day_part}_{window_spec['slot_token']}"
source_trace_path = source_dir / build_source_filename(
trace_type=trace_type, date_text=date_text, day_part=day_part
)
source_prompt_path = Path(str(source_trace_path).replace(".jsonl", "_prompt.jsonl"))
windows.append(
{
"window_id": window_id,
"trace_type": trace_type,
"block_size": TRACE_SPECS[trace_type]["block_size"],
"window_start": float(window_spec["window_start"]),
"window_end": float(window_spec["window_end"]),
"window_index": 6,
"sample_order": sample_order,
"date": date_text,
"day_part": day_part,
"slot_label": window_spec["slot_label"],
"source_trace_path": str(source_trace_path),
"source_prompt_path": str(source_prompt_path),
}
)
return windows
def _merge_trace_and_prompt(trace_row: dict[str, Any], prompt_row: dict[str, Any]) -> dict[str, Any]:
if trace_row.get("chat_id") != prompt_row.get("chat_id") or trace_row.get("turn") != prompt_row.get("turn"):
raise ValueError(
"trace/prompt rows are misaligned: "
f"{trace_row.get('chat_id')}/{trace_row.get('turn')} vs "
f"{prompt_row.get('chat_id')}/{prompt_row.get('turn')}"
)
merged = dict(trace_row)
prompt = prompt_row.get("prompt")
if isinstance(prompt, str) and prompt:
merged["prompt"] = prompt
return merged
def extract_windows(windows: list[dict[str, Any]], *, sample_seed: int) -> dict[str, list[dict[str, Any]]]:
grouped: dict[tuple[Path, Path], list[dict[str, Any]]] = {}
for window in windows:
trace_path = Path(window["source_trace_path"])
prompt_path = Path(window["source_prompt_path"])
grouped.setdefault((trace_path, prompt_path), []).append(window)
extracted: dict[str, list[dict[str, Any]]] = {str(window["window_id"]): [] for window in windows}
for (trace_path, prompt_path), bucket in grouped.items():
bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"])))
with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle:
for index, (trace_raw, prompt_raw) in enumerate(zip(trace_handle, prompt_handle)):
trace_raw = trace_raw.strip()
prompt_raw = prompt_raw.strip()
if not trace_raw or not prompt_raw:
continue
trace_row = json.loads(trace_raw)
prompt_row = json.loads(prompt_raw)
merged = _merge_trace_and_prompt(trace_row, prompt_row)
timestamp = float(merged.get("timestamp") or 0.0)
for window in bucket:
start = float(window["window_start"])
end = float(window["window_end"])
if start <= timestamp < end:
out = dict(merged)
out["source_timestamp"] = timestamp
out["timestamp"] = timestamp - start
out["sampling_u"] = stable_uniform(
seed=sample_seed,
window_id=str(window["window_id"]),
index=len(extracted[str(window["window_id"])]),
row=merged,
)
extracted[str(window["window_id"])].append(out)
break
return extracted
def build_output_window(
window: dict[str, Any],
rows: list[dict[str, Any]],
trace_relpath: str,
*,
sample_seed: int,
) -> dict[str, Any]:
output = dict(window)
output["trace_file"] = trace_relpath
output["window_start"] = 0.0
output["window_end"] = float(window["window_end"]) - float(window["window_start"])
output["num_requests"] = len(rows)
output["sum_input_length"] = int(sum(int(row.get("input_length") or 0) for row in rows))
output["max_input_length"] = int(
max((int(row.get("input_length") or 0) for row in rows), default=0)
)
output["num_excluded_too_long"] = 0
output["sampling_u_field"] = "sampling_u"
output["sampling_seed"] = int(sample_seed)
output["sampling_strategy"] = "fixed_uniform_score"
if rows:
output["first_request_ts"] = float(rows[0]["timestamp"])
output["last_request_ts"] = float(rows[-1]["timestamp"])
output["first_request_index"] = 0
output["last_request_index"] = len(rows) - 1
else:
output["first_request_ts"] = None
output["last_request_ts"] = None
output["first_request_index"] = None
output["last_request_index"] = None
return output
def main() -> int:
args = parse_args()
workloads = parse_workloads(args.workloads)
output_root = args.output_root.resolve()
if output_root.exists():
if not args.overwrite:
raise SystemExit(f"Output root exists: {output_root} (pass --overwrite to replace)")
output_root.mkdir(parents=True, exist_ok=True)
traces_dir = output_root / "traces"
traces_dir.mkdir(parents=True, exist_ok=True)
source_dirs = {
"legacy": args.legacy_source.resolve(),
"thinking": args.thinking_source.resolve(),
}
windows = build_windows(source_dirs, workloads=workloads)
extracted = extract_windows(windows, sample_seed=args.sample_seed)
rendered_windows: list[dict[str, Any]] = []
for window in windows:
rows = extracted[str(window["window_id"])]
trace_filename = f"{window['window_id']}.jsonl"
trace_path = traces_dir / trace_filename
with trace_path.open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
rendered_windows.append(
build_output_window(
window,
rows,
trace_relpath=f"traces/{trace_filename}",
sample_seed=args.sample_seed,
)
)
windows_payload = {
"sample_seed": args.sample_seed,
"source_trace_dirs": {key: str(path) for key, path in source_dirs.items()},
"u_field": "sampling_u",
"window_duration_seconds": 600.0,
"windows": rendered_windows,
}
(output_root / "windows.json").write_text(
json.dumps(windows_payload, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
print(output_root)
print(f"windows={len(rendered_windows)}")
return 0
if __name__ == "__main__":
raise SystemExit(main())