381 lines
14 KiB
Python
381 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import hashlib
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from dataclasses import dataclass
|
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
DEFAULT_LEGACY_SOURCE = Path(
|
|
"/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260311-260317-formatted"
|
|
)
|
|
DEFAULT_THINKING_SOURCE = Path(
|
|
"/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260321-260327-formatted"
|
|
)
|
|
DEFAULT_OUTPUT_ROOT = REPO_ROOT / "trace_windows"
|
|
LEGACY_TARGET_DATES = [
|
|
"2026-03-11",
|
|
"2026-03-12",
|
|
"2026-03-13",
|
|
"2026-03-14",
|
|
"2026-03-15",
|
|
"2026-03-16",
|
|
"2026-03-17",
|
|
]
|
|
THINKING_WINDOWS = [
|
|
("2026-03-21", "1000"),
|
|
("2026-03-22", "1000"),
|
|
("2026-03-23", "1000"),
|
|
("2026-03-24", "1000"),
|
|
("2026-03-25", "1000"),
|
|
("2026-03-26", "1000"),
|
|
("2026-03-27", "1000"),
|
|
]
|
|
WINDOW_SPECS = {
|
|
"1000": {
|
|
"start_hour": 9,
|
|
"window_start": 3600.0,
|
|
"window_end": 4200.0,
|
|
"slot_label": "10:00-10:10",
|
|
"slot_token": "1000",
|
|
},
|
|
"2200": {
|
|
"start_hour": 21,
|
|
"window_start": 3600.0,
|
|
"window_end": 4200.0,
|
|
"slot_label": "22:00-22:10",
|
|
"slot_token": "2200",
|
|
},
|
|
}
|
|
TRACE_SPECS = {
|
|
"chat": {
|
|
"block_size": 64,
|
|
"source_key": "legacy",
|
|
"target_dates": LEGACY_TARGET_DATES,
|
|
"slot_tokens": ("1000", "2200"),
|
|
},
|
|
"coder": {
|
|
"block_size": 512,
|
|
"source_key": "legacy",
|
|
"target_dates": LEGACY_TARGET_DATES,
|
|
"slot_tokens": ("1000", "2200"),
|
|
},
|
|
"thinking": {
|
|
"block_size": 64,
|
|
"source_key": "thinking",
|
|
"date_slot_tokens": THINKING_WINDOWS,
|
|
},
|
|
}
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description=(
|
|
"Prepare canonical trace windows under the current aituner repo by reading the "
|
|
"raw formatted Bailian trace directories and merging prompt payloads."
|
|
)
|
|
)
|
|
parser.add_argument("--legacy-source", type=Path, default=DEFAULT_LEGACY_SOURCE)
|
|
parser.add_argument("--thinking-source", type=Path, default=DEFAULT_THINKING_SOURCE)
|
|
parser.add_argument("--output-root", type=Path, default=DEFAULT_OUTPUT_ROOT)
|
|
parser.add_argument(
|
|
"--workloads",
|
|
default="chat,coder,thinking",
|
|
help="Comma-separated workloads to materialize. Defaults to chat,coder,thinking.",
|
|
)
|
|
parser.add_argument("--sample-seed", type=int, default=20260325)
|
|
parser.add_argument("--overwrite", action="store_true")
|
|
return parser.parse_args()
|
|
|
|
|
|
def stable_uniform(*, seed: int, window_id: str, index: int, row: dict[str, Any]) -> float:
|
|
payload = json.dumps(
|
|
{
|
|
"seed": seed,
|
|
"window_id": window_id,
|
|
"index": index,
|
|
"timestamp": row.get("timestamp"),
|
|
"input_length": row.get("input_length"),
|
|
"output_length": row.get("output_length"),
|
|
"chat_id": row.get("chat_id"),
|
|
"turn": row.get("turn"),
|
|
},
|
|
sort_keys=True,
|
|
separators=(",", ":"),
|
|
).encode("utf-8")
|
|
digest = hashlib.blake2b(payload, digest_size=8).digest()
|
|
return int.from_bytes(digest, byteorder="big", signed=False) / float(1 << 64)
|
|
|
|
|
|
def build_source_filename(*, trace_type: str, date_text: str, slot_token: str) -> str:
|
|
month, day = date_text[5:7], date_text[8:10]
|
|
start_hour = int(WINDOW_SPECS[slot_token]["start_hour"])
|
|
return (
|
|
f"qwen_{trace_type}_blksz_{TRACE_SPECS[trace_type]['block_size']}_"
|
|
f"{month}{day}{start_hour:02d}-{month}{day}{start_hour + 2:02d}.jsonl"
|
|
)
|
|
|
|
|
|
def parse_workloads(text: str) -> tuple[str, ...]:
|
|
workloads = tuple(part.strip() for part in str(text).split(",") if part.strip())
|
|
if not workloads:
|
|
raise ValueError("--workloads must contain at least one workload")
|
|
unknown = [item for item in workloads if item not in TRACE_SPECS]
|
|
if unknown:
|
|
raise ValueError(f"Unknown workloads: {unknown}")
|
|
return workloads
|
|
|
|
|
|
def build_windows(
|
|
source_dirs: dict[str, Path], *, workloads: tuple[str, ...]
|
|
) -> list[dict[str, Any]]:
|
|
windows: list[dict[str, Any]] = []
|
|
for trace_type in workloads:
|
|
spec = TRACE_SPECS[trace_type]
|
|
source_dir = source_dirs[str(spec["source_key"])]
|
|
if "date_slot_tokens" in spec:
|
|
date_slot_tokens = list(spec["date_slot_tokens"])
|
|
else:
|
|
date_slot_tokens = [
|
|
(date_text, slot_token)
|
|
for slot_token in spec["slot_tokens"]
|
|
for date_text in spec["target_dates"]
|
|
]
|
|
sample_order_by_group: dict[str, int] = {}
|
|
for date_text, slot_token in date_slot_tokens:
|
|
window_spec = WINDOW_SPECS[slot_token]
|
|
date_token = date_text.replace("-", "")
|
|
sample_order = sample_order_by_group.setdefault(slot_token, 0)
|
|
sample_order_by_group[slot_token] += 1
|
|
window_id = f"{trace_type}_w{date_token}_{window_spec['slot_token']}"
|
|
source_trace_path = source_dir / build_source_filename(
|
|
trace_type=trace_type, date_text=date_text, slot_token=slot_token
|
|
)
|
|
source_prompt_path = Path(str(source_trace_path).replace(".jsonl", "_prompt.jsonl"))
|
|
windows.append(
|
|
{
|
|
"window_id": window_id,
|
|
"trace_type": trace_type,
|
|
"block_size": TRACE_SPECS[trace_type]["block_size"],
|
|
"window_start": float(window_spec["window_start"]),
|
|
"window_end": float(window_spec["window_end"]),
|
|
"window_index": 6,
|
|
"sample_order": sample_order,
|
|
"date": date_text,
|
|
"slot_token": window_spec["slot_token"],
|
|
"slot_label": window_spec["slot_label"],
|
|
"source_trace_path": str(source_trace_path),
|
|
"source_prompt_path": str(source_prompt_path),
|
|
}
|
|
)
|
|
return windows
|
|
|
|
|
|
def _merge_trace_and_prompt(trace_row: dict[str, Any], prompt_row: dict[str, Any]) -> dict[str, Any]:
|
|
if trace_row.get("chat_id") != prompt_row.get("chat_id") or trace_row.get("turn") != prompt_row.get("turn"):
|
|
raise ValueError(
|
|
"trace/prompt rows are misaligned: "
|
|
f"{trace_row.get('chat_id')}/{trace_row.get('turn')} vs "
|
|
f"{prompt_row.get('chat_id')}/{prompt_row.get('turn')}"
|
|
)
|
|
merged = dict(trace_row)
|
|
prompt = prompt_row.get("prompt")
|
|
if isinstance(prompt, str) and prompt:
|
|
merged["prompt"] = prompt
|
|
return merged
|
|
|
|
|
|
@dataclass
|
|
class WindowStats:
|
|
num_requests: int = 0
|
|
sum_input_length: int = 0
|
|
max_input_length: int = 0
|
|
first_request_ts: float | None = None
|
|
last_request_ts: float | None = None
|
|
first_request_index: int | None = None
|
|
last_request_index: int | None = None
|
|
|
|
def record(self, row: dict[str, Any]) -> None:
|
|
input_length = int(row.get("input_length") or 0)
|
|
timestamp = float(row["timestamp"])
|
|
if self.num_requests == 0:
|
|
self.first_request_ts = timestamp
|
|
self.first_request_index = 0
|
|
self.last_request_ts = timestamp
|
|
self.last_request_index = self.num_requests
|
|
self.num_requests += 1
|
|
self.sum_input_length += input_length
|
|
self.max_input_length = max(self.max_input_length, input_length)
|
|
|
|
|
|
def materialize_windows(
|
|
windows: list[dict[str, Any]], *, sample_seed: int, traces_dir: Path
|
|
) -> dict[str, WindowStats]:
|
|
grouped: dict[tuple[Path, Path], list[dict[str, Any]]] = {}
|
|
for window in windows:
|
|
trace_path = Path(window["source_trace_path"])
|
|
prompt_path = Path(window["source_prompt_path"])
|
|
grouped.setdefault((trace_path, prompt_path), []).append(window)
|
|
|
|
stats_by_window = {str(window["window_id"]): WindowStats() for window in windows}
|
|
handles: dict[str, Any] = {}
|
|
final_paths: dict[str, Path] = {}
|
|
temp_paths: dict[str, Path] = {}
|
|
completed = False
|
|
try:
|
|
for window in windows:
|
|
window_id = str(window["window_id"])
|
|
final_path = traces_dir / f"{window_id}.jsonl"
|
|
temp_path = traces_dir / f".{window_id}.jsonl.tmp.{os.getpid()}"
|
|
if temp_path.exists():
|
|
temp_path.unlink()
|
|
final_paths[window_id] = final_path
|
|
temp_paths[window_id] = temp_path
|
|
handles[window_id] = temp_path.open("w", encoding="utf-8")
|
|
|
|
for trace_path, prompt_path in sorted(grouped.keys()):
|
|
bucket = grouped[(trace_path, prompt_path)]
|
|
bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"])))
|
|
matched_rows = 0
|
|
with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle:
|
|
for trace_raw, prompt_raw in zip(trace_handle, prompt_handle):
|
|
trace_raw = trace_raw.strip()
|
|
if not trace_raw:
|
|
continue
|
|
trace_row = json.loads(trace_raw)
|
|
timestamp = float(trace_row.get("timestamp") or 0.0)
|
|
matched_window: dict[str, Any] | None = None
|
|
for window in bucket:
|
|
start = float(window["window_start"])
|
|
end = float(window["window_end"])
|
|
if start <= timestamp < end:
|
|
matched_window = window
|
|
break
|
|
if matched_window is None:
|
|
continue
|
|
prompt_raw = prompt_raw.strip()
|
|
if not prompt_raw:
|
|
continue
|
|
prompt_row = json.loads(prompt_raw)
|
|
merged = _merge_trace_and_prompt(trace_row, prompt_row)
|
|
window_id = str(matched_window["window_id"])
|
|
out = dict(merged)
|
|
start = float(matched_window["window_start"])
|
|
out["source_timestamp"] = timestamp
|
|
out["timestamp"] = timestamp - start
|
|
out["sampling_u"] = stable_uniform(
|
|
seed=sample_seed,
|
|
window_id=window_id,
|
|
index=stats_by_window[window_id].num_requests,
|
|
row=merged,
|
|
)
|
|
handles[window_id].write(json.dumps(out, ensure_ascii=False) + "\n")
|
|
stats_by_window[window_id].record(out)
|
|
matched_rows += 1
|
|
print(
|
|
f"materialized {trace_path.name} -> matched_rows={matched_rows}",
|
|
flush=True,
|
|
)
|
|
completed = True
|
|
finally:
|
|
for handle in handles.values():
|
|
handle.close()
|
|
if completed:
|
|
for window_id, temp_path in temp_paths.items():
|
|
os.replace(temp_path, final_paths[window_id])
|
|
else:
|
|
for temp_path in temp_paths.values():
|
|
if temp_path.exists():
|
|
temp_path.unlink()
|
|
return stats_by_window
|
|
|
|
|
|
def build_output_window(
|
|
window: dict[str, Any],
|
|
stats: WindowStats,
|
|
trace_relpath: str,
|
|
*,
|
|
sample_seed: int,
|
|
) -> dict[str, Any]:
|
|
output = dict(window)
|
|
output["trace_file"] = trace_relpath
|
|
output["window_start"] = 0.0
|
|
output["window_end"] = float(window["window_end"]) - float(window["window_start"])
|
|
output["num_requests"] = stats.num_requests
|
|
output["sum_input_length"] = stats.sum_input_length
|
|
output["max_input_length"] = stats.max_input_length
|
|
output["num_excluded_too_long"] = 0
|
|
output["sampling_u_field"] = "sampling_u"
|
|
output["sampling_seed"] = int(sample_seed)
|
|
output["sampling_strategy"] = "fixed_uniform_score"
|
|
output["first_request_ts"] = stats.first_request_ts
|
|
output["last_request_ts"] = stats.last_request_ts
|
|
output["first_request_index"] = stats.first_request_index
|
|
output["last_request_index"] = stats.last_request_index
|
|
return output
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
workloads = parse_workloads(args.workloads)
|
|
output_root = args.output_root.resolve()
|
|
if output_root.exists():
|
|
if not args.overwrite:
|
|
raise SystemExit(f"Output root exists: {output_root} (pass --overwrite to replace)")
|
|
output_root.mkdir(parents=True, exist_ok=True)
|
|
traces_dir = output_root / "traces"
|
|
traces_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
source_dirs = {
|
|
"legacy": args.legacy_source.resolve(),
|
|
"thinking": args.thinking_source.resolve(),
|
|
}
|
|
windows = build_windows(source_dirs, workloads=workloads)
|
|
stats_by_window = materialize_windows(
|
|
windows,
|
|
sample_seed=args.sample_seed,
|
|
traces_dir=traces_dir,
|
|
)
|
|
|
|
rendered_windows: list[dict[str, Any]] = []
|
|
for window in windows:
|
|
trace_filename = f"{window['window_id']}.jsonl"
|
|
rendered_windows.append(
|
|
build_output_window(
|
|
window,
|
|
stats_by_window[str(window["window_id"])],
|
|
trace_relpath=f"traces/{trace_filename}",
|
|
sample_seed=args.sample_seed,
|
|
)
|
|
)
|
|
|
|
windows_payload = {
|
|
"sample_seed": args.sample_seed,
|
|
"source_trace_dirs": {key: str(path) for key, path in source_dirs.items()},
|
|
"u_field": "sampling_u",
|
|
"window_duration_seconds": 600.0,
|
|
"windows": rendered_windows,
|
|
}
|
|
windows_path = output_root / "windows.json"
|
|
windows_tmp_path = output_root / f".windows.json.tmp.{os.getpid()}"
|
|
try:
|
|
windows_tmp_path.write_text(
|
|
json.dumps(windows_payload, ensure_ascii=False, indent=2) + "\n",
|
|
encoding="utf-8",
|
|
)
|
|
os.replace(windows_tmp_path, windows_path)
|
|
finally:
|
|
if windows_tmp_path.exists():
|
|
windows_tmp_path.unlink()
|
|
print(output_root)
|
|
print(f"windows={len(rendered_windows)}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|