Add raw trace window preparation script
This commit is contained in:
312
scripts/prepare_trace_windows.py
Normal file
312
scripts/prepare_trace_windows.py
Normal file
@@ -0,0 +1,312 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
DEFAULT_LEGACY_SOURCE = Path(
|
||||
"/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260311-260317-formatted"
|
||||
)
|
||||
DEFAULT_THINKING_SOURCE = Path(
|
||||
"/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260321-260327-formatted"
|
||||
)
|
||||
DEFAULT_OUTPUT_ROOT = REPO_ROOT / "trace_windows"
|
||||
LEGACY_TARGET_DATES = ["2026-03-11", "2026-03-12", "2026-03-13", "2026-03-16", "2026-03-17"]
|
||||
THINKING_WINDOWS = [
|
||||
("2026-03-21", "peak_weekend"),
|
||||
("2026-03-22", "peak_weekend"),
|
||||
("2026-03-23", "peak"),
|
||||
("2026-03-24", "peak"),
|
||||
("2026-03-25", "peak"),
|
||||
("2026-03-26", "peak"),
|
||||
("2026-03-27", "peak"),
|
||||
]
|
||||
WINDOW_SPECS = {
|
||||
"peak": {
|
||||
"start_hour": 9,
|
||||
"window_start": 3600.0,
|
||||
"window_end": 4200.0,
|
||||
"slot_label": "10:00-10:10",
|
||||
"slot_token": "1000",
|
||||
},
|
||||
"peak_weekend": {
|
||||
"start_hour": 9,
|
||||
"window_start": 3600.0,
|
||||
"window_end": 4200.0,
|
||||
"slot_label": "10:00-10:10",
|
||||
"slot_token": "1000",
|
||||
},
|
||||
"valley": {
|
||||
"start_hour": 21,
|
||||
"window_start": 3600.0,
|
||||
"window_end": 4200.0,
|
||||
"slot_label": "22:00-22:10",
|
||||
"slot_token": "2200",
|
||||
},
|
||||
}
|
||||
TRACE_SPECS = {
|
||||
"chat": {
|
||||
"block_size": 64,
|
||||
"source_key": "legacy",
|
||||
"target_dates": LEGACY_TARGET_DATES,
|
||||
"day_parts": ("peak", "valley"),
|
||||
},
|
||||
"coder": {
|
||||
"block_size": 512,
|
||||
"source_key": "legacy",
|
||||
"target_dates": LEGACY_TARGET_DATES,
|
||||
"day_parts": ("peak", "valley"),
|
||||
},
|
||||
"thinking": {
|
||||
"block_size": 64,
|
||||
"source_key": "thinking",
|
||||
"date_day_parts": THINKING_WINDOWS,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Prepare canonical trace windows under the current aituner repo by reading the "
|
||||
"raw formatted Bailian trace directories and merging prompt payloads."
|
||||
)
|
||||
)
|
||||
parser.add_argument("--legacy-source", type=Path, default=DEFAULT_LEGACY_SOURCE)
|
||||
parser.add_argument("--thinking-source", type=Path, default=DEFAULT_THINKING_SOURCE)
|
||||
parser.add_argument("--output-root", type=Path, default=DEFAULT_OUTPUT_ROOT)
|
||||
parser.add_argument(
|
||||
"--workloads",
|
||||
default="chat,coder,thinking",
|
||||
help="Comma-separated workloads to materialize. Defaults to chat,coder,thinking.",
|
||||
)
|
||||
parser.add_argument("--sample-seed", type=int, default=20260325)
|
||||
parser.add_argument("--overwrite", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def stable_uniform(*, seed: int, window_id: str, index: int, row: dict[str, Any]) -> float:
|
||||
payload = json.dumps(
|
||||
{
|
||||
"seed": seed,
|
||||
"window_id": window_id,
|
||||
"index": index,
|
||||
"timestamp": row.get("timestamp"),
|
||||
"input_length": row.get("input_length"),
|
||||
"output_length": row.get("output_length"),
|
||||
"chat_id": row.get("chat_id"),
|
||||
"turn": row.get("turn"),
|
||||
},
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8")
|
||||
digest = hashlib.blake2b(payload, digest_size=8).digest()
|
||||
return int.from_bytes(digest, byteorder="big", signed=False) / float(1 << 64)
|
||||
|
||||
|
||||
def build_source_filename(*, trace_type: str, date_text: str, day_part: str) -> str:
|
||||
month, day = date_text[5:7], date_text[8:10]
|
||||
start_hour = int(WINDOW_SPECS[day_part]["start_hour"])
|
||||
return (
|
||||
f"qwen_{trace_type}_blksz_{TRACE_SPECS[trace_type]['block_size']}_"
|
||||
f"{month}{day}{start_hour:02d}-{month}{day}{start_hour + 2:02d}.jsonl"
|
||||
)
|
||||
|
||||
|
||||
def parse_workloads(text: str) -> tuple[str, ...]:
|
||||
workloads = tuple(part.strip() for part in str(text).split(",") if part.strip())
|
||||
if not workloads:
|
||||
raise ValueError("--workloads must contain at least one workload")
|
||||
unknown = [item for item in workloads if item not in TRACE_SPECS]
|
||||
if unknown:
|
||||
raise ValueError(f"Unknown workloads: {unknown}")
|
||||
return workloads
|
||||
|
||||
|
||||
def build_windows(
|
||||
source_dirs: dict[str, Path], *, workloads: tuple[str, ...]
|
||||
) -> list[dict[str, Any]]:
|
||||
windows: list[dict[str, Any]] = []
|
||||
for trace_type in workloads:
|
||||
spec = TRACE_SPECS[trace_type]
|
||||
source_dir = source_dirs[str(spec["source_key"])]
|
||||
if "date_day_parts" in spec:
|
||||
date_day_parts = list(spec["date_day_parts"])
|
||||
else:
|
||||
date_day_parts = [
|
||||
(date_text, day_part)
|
||||
for day_part in spec["day_parts"]
|
||||
for date_text in spec["target_dates"]
|
||||
]
|
||||
sample_order_by_group: dict[str, int] = {}
|
||||
for date_text, day_part in date_day_parts:
|
||||
window_spec = WINDOW_SPECS[day_part]
|
||||
date_token = date_text.replace("-", "")
|
||||
sample_order = sample_order_by_group.setdefault(day_part, 0)
|
||||
sample_order_by_group[day_part] += 1
|
||||
window_id = f"{trace_type}_w{date_token}_{day_part}_{window_spec['slot_token']}"
|
||||
source_trace_path = source_dir / build_source_filename(
|
||||
trace_type=trace_type, date_text=date_text, day_part=day_part
|
||||
)
|
||||
source_prompt_path = Path(str(source_trace_path).replace(".jsonl", "_prompt.jsonl"))
|
||||
windows.append(
|
||||
{
|
||||
"window_id": window_id,
|
||||
"trace_type": trace_type,
|
||||
"block_size": TRACE_SPECS[trace_type]["block_size"],
|
||||
"window_start": float(window_spec["window_start"]),
|
||||
"window_end": float(window_spec["window_end"]),
|
||||
"window_index": 6,
|
||||
"sample_order": sample_order,
|
||||
"date": date_text,
|
||||
"day_part": day_part,
|
||||
"slot_label": window_spec["slot_label"],
|
||||
"source_trace_path": str(source_trace_path),
|
||||
"source_prompt_path": str(source_prompt_path),
|
||||
}
|
||||
)
|
||||
return windows
|
||||
|
||||
|
||||
def _merge_trace_and_prompt(trace_row: dict[str, Any], prompt_row: dict[str, Any]) -> dict[str, Any]:
|
||||
if trace_row.get("chat_id") != prompt_row.get("chat_id") or trace_row.get("turn") != prompt_row.get("turn"):
|
||||
raise ValueError(
|
||||
"trace/prompt rows are misaligned: "
|
||||
f"{trace_row.get('chat_id')}/{trace_row.get('turn')} vs "
|
||||
f"{prompt_row.get('chat_id')}/{prompt_row.get('turn')}"
|
||||
)
|
||||
merged = dict(trace_row)
|
||||
prompt = prompt_row.get("prompt")
|
||||
if isinstance(prompt, str) and prompt:
|
||||
merged["prompt"] = prompt
|
||||
return merged
|
||||
|
||||
|
||||
def extract_windows(windows: list[dict[str, Any]], *, sample_seed: int) -> dict[str, list[dict[str, Any]]]:
|
||||
grouped: dict[tuple[Path, Path], list[dict[str, Any]]] = {}
|
||||
for window in windows:
|
||||
trace_path = Path(window["source_trace_path"])
|
||||
prompt_path = Path(window["source_prompt_path"])
|
||||
grouped.setdefault((trace_path, prompt_path), []).append(window)
|
||||
|
||||
extracted: dict[str, list[dict[str, Any]]] = {str(window["window_id"]): [] for window in windows}
|
||||
for (trace_path, prompt_path), bucket in grouped.items():
|
||||
bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"])))
|
||||
with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle:
|
||||
for index, (trace_raw, prompt_raw) in enumerate(zip(trace_handle, prompt_handle)):
|
||||
trace_raw = trace_raw.strip()
|
||||
prompt_raw = prompt_raw.strip()
|
||||
if not trace_raw or not prompt_raw:
|
||||
continue
|
||||
trace_row = json.loads(trace_raw)
|
||||
prompt_row = json.loads(prompt_raw)
|
||||
merged = _merge_trace_and_prompt(trace_row, prompt_row)
|
||||
timestamp = float(merged.get("timestamp") or 0.0)
|
||||
for window in bucket:
|
||||
start = float(window["window_start"])
|
||||
end = float(window["window_end"])
|
||||
if start <= timestamp < end:
|
||||
out = dict(merged)
|
||||
out["source_timestamp"] = timestamp
|
||||
out["timestamp"] = timestamp - start
|
||||
out["sampling_u"] = stable_uniform(
|
||||
seed=sample_seed,
|
||||
window_id=str(window["window_id"]),
|
||||
index=len(extracted[str(window["window_id"])]),
|
||||
row=merged,
|
||||
)
|
||||
extracted[str(window["window_id"])].append(out)
|
||||
break
|
||||
return extracted
|
||||
|
||||
|
||||
def build_output_window(
|
||||
window: dict[str, Any],
|
||||
rows: list[dict[str, Any]],
|
||||
trace_relpath: str,
|
||||
*,
|
||||
sample_seed: int,
|
||||
) -> dict[str, Any]:
|
||||
output = dict(window)
|
||||
output["trace_file"] = trace_relpath
|
||||
output["window_start"] = 0.0
|
||||
output["window_end"] = float(window["window_end"]) - float(window["window_start"])
|
||||
output["num_requests"] = len(rows)
|
||||
output["sum_input_length"] = int(sum(int(row.get("input_length") or 0) for row in rows))
|
||||
output["max_input_length"] = int(
|
||||
max((int(row.get("input_length") or 0) for row in rows), default=0)
|
||||
)
|
||||
output["num_excluded_too_long"] = 0
|
||||
output["sampling_u_field"] = "sampling_u"
|
||||
output["sampling_seed"] = int(sample_seed)
|
||||
output["sampling_strategy"] = "fixed_uniform_score"
|
||||
if rows:
|
||||
output["first_request_ts"] = float(rows[0]["timestamp"])
|
||||
output["last_request_ts"] = float(rows[-1]["timestamp"])
|
||||
output["first_request_index"] = 0
|
||||
output["last_request_index"] = len(rows) - 1
|
||||
else:
|
||||
output["first_request_ts"] = None
|
||||
output["last_request_ts"] = None
|
||||
output["first_request_index"] = None
|
||||
output["last_request_index"] = None
|
||||
return output
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
workloads = parse_workloads(args.workloads)
|
||||
output_root = args.output_root.resolve()
|
||||
if output_root.exists():
|
||||
if not args.overwrite:
|
||||
raise SystemExit(f"Output root exists: {output_root} (pass --overwrite to replace)")
|
||||
output_root.mkdir(parents=True, exist_ok=True)
|
||||
traces_dir = output_root / "traces"
|
||||
traces_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
source_dirs = {
|
||||
"legacy": args.legacy_source.resolve(),
|
||||
"thinking": args.thinking_source.resolve(),
|
||||
}
|
||||
windows = build_windows(source_dirs, workloads=workloads)
|
||||
extracted = extract_windows(windows, sample_seed=args.sample_seed)
|
||||
|
||||
rendered_windows: list[dict[str, Any]] = []
|
||||
for window in windows:
|
||||
rows = extracted[str(window["window_id"])]
|
||||
trace_filename = f"{window['window_id']}.jsonl"
|
||||
trace_path = traces_dir / trace_filename
|
||||
with trace_path.open("w", encoding="utf-8") as handle:
|
||||
for row in rows:
|
||||
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
rendered_windows.append(
|
||||
build_output_window(
|
||||
window,
|
||||
rows,
|
||||
trace_relpath=f"traces/{trace_filename}",
|
||||
sample_seed=args.sample_seed,
|
||||
)
|
||||
)
|
||||
|
||||
windows_payload = {
|
||||
"sample_seed": args.sample_seed,
|
||||
"source_trace_dirs": {key: str(path) for key, path in source_dirs.items()},
|
||||
"u_field": "sampling_u",
|
||||
"window_duration_seconds": 600.0,
|
||||
"windows": rendered_windows,
|
||||
}
|
||||
(output_root / "windows.json").write_text(
|
||||
json.dumps(windows_payload, ensure_ascii=False, indent=2) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
print(output_root)
|
||||
print(f"windows={len(rendered_windows)}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user