Add raw trace window preparation script
This commit is contained in:
312
scripts/prepare_trace_windows.py
Normal file
312
scripts/prepare_trace_windows.py
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
DEFAULT_LEGACY_SOURCE = Path(
|
||||||
|
"/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260311-260317-formatted"
|
||||||
|
)
|
||||||
|
DEFAULT_THINKING_SOURCE = Path(
|
||||||
|
"/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260321-260327-formatted"
|
||||||
|
)
|
||||||
|
DEFAULT_OUTPUT_ROOT = REPO_ROOT / "trace_windows"
|
||||||
|
LEGACY_TARGET_DATES = ["2026-03-11", "2026-03-12", "2026-03-13", "2026-03-16", "2026-03-17"]
|
||||||
|
THINKING_WINDOWS = [
|
||||||
|
("2026-03-21", "peak_weekend"),
|
||||||
|
("2026-03-22", "peak_weekend"),
|
||||||
|
("2026-03-23", "peak"),
|
||||||
|
("2026-03-24", "peak"),
|
||||||
|
("2026-03-25", "peak"),
|
||||||
|
("2026-03-26", "peak"),
|
||||||
|
("2026-03-27", "peak"),
|
||||||
|
]
|
||||||
|
WINDOW_SPECS = {
|
||||||
|
"peak": {
|
||||||
|
"start_hour": 9,
|
||||||
|
"window_start": 3600.0,
|
||||||
|
"window_end": 4200.0,
|
||||||
|
"slot_label": "10:00-10:10",
|
||||||
|
"slot_token": "1000",
|
||||||
|
},
|
||||||
|
"peak_weekend": {
|
||||||
|
"start_hour": 9,
|
||||||
|
"window_start": 3600.0,
|
||||||
|
"window_end": 4200.0,
|
||||||
|
"slot_label": "10:00-10:10",
|
||||||
|
"slot_token": "1000",
|
||||||
|
},
|
||||||
|
"valley": {
|
||||||
|
"start_hour": 21,
|
||||||
|
"window_start": 3600.0,
|
||||||
|
"window_end": 4200.0,
|
||||||
|
"slot_label": "22:00-22:10",
|
||||||
|
"slot_token": "2200",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
TRACE_SPECS = {
|
||||||
|
"chat": {
|
||||||
|
"block_size": 64,
|
||||||
|
"source_key": "legacy",
|
||||||
|
"target_dates": LEGACY_TARGET_DATES,
|
||||||
|
"day_parts": ("peak", "valley"),
|
||||||
|
},
|
||||||
|
"coder": {
|
||||||
|
"block_size": 512,
|
||||||
|
"source_key": "legacy",
|
||||||
|
"target_dates": LEGACY_TARGET_DATES,
|
||||||
|
"day_parts": ("peak", "valley"),
|
||||||
|
},
|
||||||
|
"thinking": {
|
||||||
|
"block_size": 64,
|
||||||
|
"source_key": "thinking",
|
||||||
|
"date_day_parts": THINKING_WINDOWS,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=(
|
||||||
|
"Prepare canonical trace windows under the current aituner repo by reading the "
|
||||||
|
"raw formatted Bailian trace directories and merging prompt payloads."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
parser.add_argument("--legacy-source", type=Path, default=DEFAULT_LEGACY_SOURCE)
|
||||||
|
parser.add_argument("--thinking-source", type=Path, default=DEFAULT_THINKING_SOURCE)
|
||||||
|
parser.add_argument("--output-root", type=Path, default=DEFAULT_OUTPUT_ROOT)
|
||||||
|
parser.add_argument(
|
||||||
|
"--workloads",
|
||||||
|
default="chat,coder,thinking",
|
||||||
|
help="Comma-separated workloads to materialize. Defaults to chat,coder,thinking.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--sample-seed", type=int, default=20260325)
|
||||||
|
parser.add_argument("--overwrite", action="store_true")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def stable_uniform(*, seed: int, window_id: str, index: int, row: dict[str, Any]) -> float:
|
||||||
|
payload = json.dumps(
|
||||||
|
{
|
||||||
|
"seed": seed,
|
||||||
|
"window_id": window_id,
|
||||||
|
"index": index,
|
||||||
|
"timestamp": row.get("timestamp"),
|
||||||
|
"input_length": row.get("input_length"),
|
||||||
|
"output_length": row.get("output_length"),
|
||||||
|
"chat_id": row.get("chat_id"),
|
||||||
|
"turn": row.get("turn"),
|
||||||
|
},
|
||||||
|
sort_keys=True,
|
||||||
|
separators=(",", ":"),
|
||||||
|
).encode("utf-8")
|
||||||
|
digest = hashlib.blake2b(payload, digest_size=8).digest()
|
||||||
|
return int.from_bytes(digest, byteorder="big", signed=False) / float(1 << 64)
|
||||||
|
|
||||||
|
|
||||||
|
def build_source_filename(*, trace_type: str, date_text: str, day_part: str) -> str:
|
||||||
|
month, day = date_text[5:7], date_text[8:10]
|
||||||
|
start_hour = int(WINDOW_SPECS[day_part]["start_hour"])
|
||||||
|
return (
|
||||||
|
f"qwen_{trace_type}_blksz_{TRACE_SPECS[trace_type]['block_size']}_"
|
||||||
|
f"{month}{day}{start_hour:02d}-{month}{day}{start_hour + 2:02d}.jsonl"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_workloads(text: str) -> tuple[str, ...]:
|
||||||
|
workloads = tuple(part.strip() for part in str(text).split(",") if part.strip())
|
||||||
|
if not workloads:
|
||||||
|
raise ValueError("--workloads must contain at least one workload")
|
||||||
|
unknown = [item for item in workloads if item not in TRACE_SPECS]
|
||||||
|
if unknown:
|
||||||
|
raise ValueError(f"Unknown workloads: {unknown}")
|
||||||
|
return workloads
|
||||||
|
|
||||||
|
|
||||||
|
def build_windows(
|
||||||
|
source_dirs: dict[str, Path], *, workloads: tuple[str, ...]
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
windows: list[dict[str, Any]] = []
|
||||||
|
for trace_type in workloads:
|
||||||
|
spec = TRACE_SPECS[trace_type]
|
||||||
|
source_dir = source_dirs[str(spec["source_key"])]
|
||||||
|
if "date_day_parts" in spec:
|
||||||
|
date_day_parts = list(spec["date_day_parts"])
|
||||||
|
else:
|
||||||
|
date_day_parts = [
|
||||||
|
(date_text, day_part)
|
||||||
|
for day_part in spec["day_parts"]
|
||||||
|
for date_text in spec["target_dates"]
|
||||||
|
]
|
||||||
|
sample_order_by_group: dict[str, int] = {}
|
||||||
|
for date_text, day_part in date_day_parts:
|
||||||
|
window_spec = WINDOW_SPECS[day_part]
|
||||||
|
date_token = date_text.replace("-", "")
|
||||||
|
sample_order = sample_order_by_group.setdefault(day_part, 0)
|
||||||
|
sample_order_by_group[day_part] += 1
|
||||||
|
window_id = f"{trace_type}_w{date_token}_{day_part}_{window_spec['slot_token']}"
|
||||||
|
source_trace_path = source_dir / build_source_filename(
|
||||||
|
trace_type=trace_type, date_text=date_text, day_part=day_part
|
||||||
|
)
|
||||||
|
source_prompt_path = Path(str(source_trace_path).replace(".jsonl", "_prompt.jsonl"))
|
||||||
|
windows.append(
|
||||||
|
{
|
||||||
|
"window_id": window_id,
|
||||||
|
"trace_type": trace_type,
|
||||||
|
"block_size": TRACE_SPECS[trace_type]["block_size"],
|
||||||
|
"window_start": float(window_spec["window_start"]),
|
||||||
|
"window_end": float(window_spec["window_end"]),
|
||||||
|
"window_index": 6,
|
||||||
|
"sample_order": sample_order,
|
||||||
|
"date": date_text,
|
||||||
|
"day_part": day_part,
|
||||||
|
"slot_label": window_spec["slot_label"],
|
||||||
|
"source_trace_path": str(source_trace_path),
|
||||||
|
"source_prompt_path": str(source_prompt_path),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_trace_and_prompt(trace_row: dict[str, Any], prompt_row: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
if trace_row.get("chat_id") != prompt_row.get("chat_id") or trace_row.get("turn") != prompt_row.get("turn"):
|
||||||
|
raise ValueError(
|
||||||
|
"trace/prompt rows are misaligned: "
|
||||||
|
f"{trace_row.get('chat_id')}/{trace_row.get('turn')} vs "
|
||||||
|
f"{prompt_row.get('chat_id')}/{prompt_row.get('turn')}"
|
||||||
|
)
|
||||||
|
merged = dict(trace_row)
|
||||||
|
prompt = prompt_row.get("prompt")
|
||||||
|
if isinstance(prompt, str) and prompt:
|
||||||
|
merged["prompt"] = prompt
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def extract_windows(windows: list[dict[str, Any]], *, sample_seed: int) -> dict[str, list[dict[str, Any]]]:
|
||||||
|
grouped: dict[tuple[Path, Path], list[dict[str, Any]]] = {}
|
||||||
|
for window in windows:
|
||||||
|
trace_path = Path(window["source_trace_path"])
|
||||||
|
prompt_path = Path(window["source_prompt_path"])
|
||||||
|
grouped.setdefault((trace_path, prompt_path), []).append(window)
|
||||||
|
|
||||||
|
extracted: dict[str, list[dict[str, Any]]] = {str(window["window_id"]): [] for window in windows}
|
||||||
|
for (trace_path, prompt_path), bucket in grouped.items():
|
||||||
|
bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"])))
|
||||||
|
with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle:
|
||||||
|
for index, (trace_raw, prompt_raw) in enumerate(zip(trace_handle, prompt_handle)):
|
||||||
|
trace_raw = trace_raw.strip()
|
||||||
|
prompt_raw = prompt_raw.strip()
|
||||||
|
if not trace_raw or not prompt_raw:
|
||||||
|
continue
|
||||||
|
trace_row = json.loads(trace_raw)
|
||||||
|
prompt_row = json.loads(prompt_raw)
|
||||||
|
merged = _merge_trace_and_prompt(trace_row, prompt_row)
|
||||||
|
timestamp = float(merged.get("timestamp") or 0.0)
|
||||||
|
for window in bucket:
|
||||||
|
start = float(window["window_start"])
|
||||||
|
end = float(window["window_end"])
|
||||||
|
if start <= timestamp < end:
|
||||||
|
out = dict(merged)
|
||||||
|
out["source_timestamp"] = timestamp
|
||||||
|
out["timestamp"] = timestamp - start
|
||||||
|
out["sampling_u"] = stable_uniform(
|
||||||
|
seed=sample_seed,
|
||||||
|
window_id=str(window["window_id"]),
|
||||||
|
index=len(extracted[str(window["window_id"])]),
|
||||||
|
row=merged,
|
||||||
|
)
|
||||||
|
extracted[str(window["window_id"])].append(out)
|
||||||
|
break
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
|
||||||
|
def build_output_window(
|
||||||
|
window: dict[str, Any],
|
||||||
|
rows: list[dict[str, Any]],
|
||||||
|
trace_relpath: str,
|
||||||
|
*,
|
||||||
|
sample_seed: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
output = dict(window)
|
||||||
|
output["trace_file"] = trace_relpath
|
||||||
|
output["window_start"] = 0.0
|
||||||
|
output["window_end"] = float(window["window_end"]) - float(window["window_start"])
|
||||||
|
output["num_requests"] = len(rows)
|
||||||
|
output["sum_input_length"] = int(sum(int(row.get("input_length") or 0) for row in rows))
|
||||||
|
output["max_input_length"] = int(
|
||||||
|
max((int(row.get("input_length") or 0) for row in rows), default=0)
|
||||||
|
)
|
||||||
|
output["num_excluded_too_long"] = 0
|
||||||
|
output["sampling_u_field"] = "sampling_u"
|
||||||
|
output["sampling_seed"] = int(sample_seed)
|
||||||
|
output["sampling_strategy"] = "fixed_uniform_score"
|
||||||
|
if rows:
|
||||||
|
output["first_request_ts"] = float(rows[0]["timestamp"])
|
||||||
|
output["last_request_ts"] = float(rows[-1]["timestamp"])
|
||||||
|
output["first_request_index"] = 0
|
||||||
|
output["last_request_index"] = len(rows) - 1
|
||||||
|
else:
|
||||||
|
output["first_request_ts"] = None
|
||||||
|
output["last_request_ts"] = None
|
||||||
|
output["first_request_index"] = None
|
||||||
|
output["last_request_index"] = None
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
args = parse_args()
|
||||||
|
workloads = parse_workloads(args.workloads)
|
||||||
|
output_root = args.output_root.resolve()
|
||||||
|
if output_root.exists():
|
||||||
|
if not args.overwrite:
|
||||||
|
raise SystemExit(f"Output root exists: {output_root} (pass --overwrite to replace)")
|
||||||
|
output_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
traces_dir = output_root / "traces"
|
||||||
|
traces_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
source_dirs = {
|
||||||
|
"legacy": args.legacy_source.resolve(),
|
||||||
|
"thinking": args.thinking_source.resolve(),
|
||||||
|
}
|
||||||
|
windows = build_windows(source_dirs, workloads=workloads)
|
||||||
|
extracted = extract_windows(windows, sample_seed=args.sample_seed)
|
||||||
|
|
||||||
|
rendered_windows: list[dict[str, Any]] = []
|
||||||
|
for window in windows:
|
||||||
|
rows = extracted[str(window["window_id"])]
|
||||||
|
trace_filename = f"{window['window_id']}.jsonl"
|
||||||
|
trace_path = traces_dir / trace_filename
|
||||||
|
with trace_path.open("w", encoding="utf-8") as handle:
|
||||||
|
for row in rows:
|
||||||
|
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||||
|
rendered_windows.append(
|
||||||
|
build_output_window(
|
||||||
|
window,
|
||||||
|
rows,
|
||||||
|
trace_relpath=f"traces/{trace_filename}",
|
||||||
|
sample_seed=args.sample_seed,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
windows_payload = {
|
||||||
|
"sample_seed": args.sample_seed,
|
||||||
|
"source_trace_dirs": {key: str(path) for key, path in source_dirs.items()},
|
||||||
|
"u_field": "sampling_u",
|
||||||
|
"window_duration_seconds": 600.0,
|
||||||
|
"windows": rendered_windows,
|
||||||
|
}
|
||||||
|
(output_root / "windows.json").write_text(
|
||||||
|
json.dumps(windows_payload, ensure_ascii=False, indent=2) + "\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
print(output_root)
|
||||||
|
print(f"windows={len(rendered_windows)}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
Reference in New Issue
Block a user