Files
aituner/scripts/prepare_trace_windows.py
Gahow Wang 0f15bbc3f1 Make the offered-load axis session-coherent
Phase 1 of the two-stop work. Subsampling the trace by per-request uniform score
broke multi-turn sessions (a kept turn-2 could lose its turn-1), which lowered the
realized KV-cache hit rate as offered load dropped — so the feasibility boundary
was measured on a workload with a different C than production, contradicting the
paper's scale-stationary L-C-A premise.

prepare_trace_windows now resolves each row's session root via the parent_chat_id
chain in a single streaming pass and assigns sampling_u per session, so thresholding
keeps or drops whole sessions and preserves intra-session prefix reuse. Rows whose
parent fell outside the span fall back to grouping under the parent id.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 14:16:06 +08:00

407 lines
15 KiB
Python

from __future__ import annotations
import argparse
import hashlib
import json
import os
from pathlib import Path
from typing import Any
from dataclasses import dataclass
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_LEGACY_SOURCE = Path(
"/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260311-260317-formatted"
)
DEFAULT_THINKING_SOURCE = Path(
"/home/admin/cpfs/wjh/bailian-trace/qwen-trace-260321-260327-formatted"
)
DEFAULT_OUTPUT_ROOT = REPO_ROOT / "trace_windows"
LEGACY_TARGET_DATES = [
"2026-03-11",
"2026-03-12",
"2026-03-13",
"2026-03-14",
"2026-03-15",
"2026-03-16",
"2026-03-17",
]
THINKING_WINDOWS = [
("2026-03-21", "1000"),
("2026-03-22", "1000"),
("2026-03-23", "1000"),
("2026-03-24", "1000"),
("2026-03-25", "1000"),
("2026-03-26", "1000"),
("2026-03-27", "1000"),
]
WINDOW_SPECS = {
"1000": {
"start_hour": 9,
"window_start": 3600.0,
"window_end": 4200.0,
"slot_label": "10:00-10:10",
"slot_token": "1000",
},
"2200": {
"start_hour": 21,
"window_start": 3600.0,
"window_end": 4200.0,
"slot_label": "22:00-22:10",
"slot_token": "2200",
},
}
TRACE_SPECS = {
"chat": {
"block_size": 64,
"source_key": "legacy",
"target_dates": LEGACY_TARGET_DATES,
"slot_tokens": ("1000", "2200"),
},
"coder": {
"block_size": 512,
"source_key": "legacy",
"target_dates": LEGACY_TARGET_DATES,
"slot_tokens": ("1000", "2200"),
},
"thinking": {
"block_size": 64,
"source_key": "thinking",
"date_slot_tokens": THINKING_WINDOWS,
},
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Prepare canonical trace windows under the current aituner repo by reading the "
"raw formatted Bailian trace directories and merging prompt payloads."
)
)
parser.add_argument("--legacy-source", type=Path, default=DEFAULT_LEGACY_SOURCE)
parser.add_argument("--thinking-source", type=Path, default=DEFAULT_THINKING_SOURCE)
parser.add_argument("--output-root", type=Path, default=DEFAULT_OUTPUT_ROOT)
parser.add_argument(
"--workloads",
default="chat,coder,thinking",
help="Comma-separated workloads to materialize. Defaults to chat,coder,thinking.",
)
parser.add_argument("--sample-seed", type=int, default=20260325)
parser.add_argument("--overwrite", action="store_true")
return parser.parse_args()
def resolve_session_root(row: dict[str, Any], root_of: dict[Any, Any]) -> Any:
"""Resolve the session root chat_id for a trace row.
Sessions are multi-turn chains linked via parent_chat_id (turn>1 points to the
parent turn's chat_id, the root turn has parent_chat_id=-1). Because parent
turns precede their children in time, a single streaming pass that records
chat_id -> root resolves the full chain. Rows whose parent is not yet known
(e.g. it fell outside the materialized span) fall back to the parent id so
siblings still group together.
"""
chat_id = row.get("chat_id")
parent = row.get("parent_chat_id")
parent_is_root = (
parent is None
or (isinstance(parent, (int, float)) and not isinstance(parent, bool) and int(parent) < 0)
)
root = chat_id if parent_is_root else root_of.get(parent, parent)
if chat_id is not None:
root_of[chat_id] = root
return root
def session_uniform(*, seed: int, window_id: str, session_root: Any) -> float:
"""Deterministic per-session uniform score in [0, 1).
All turns of a session share one score, so thresholding sampling_u keeps or
drops whole sessions and preserves intra-session prefix (KV-cache) reuse.
"""
payload = json.dumps(
{
"seed": seed,
"window_id": window_id,
"session_root": session_root,
},
sort_keys=True,
separators=(",", ":"),
).encode("utf-8")
digest = hashlib.blake2b(payload, digest_size=8).digest()
return int.from_bytes(digest, byteorder="big", signed=False) / float(1 << 64)
def build_source_filename(*, trace_type: str, date_text: str, slot_token: str) -> str:
month, day = date_text[5:7], date_text[8:10]
start_hour = int(WINDOW_SPECS[slot_token]["start_hour"])
return (
f"qwen_{trace_type}_blksz_{TRACE_SPECS[trace_type]['block_size']}_"
f"{month}{day}{start_hour:02d}-{month}{day}{start_hour + 2:02d}.jsonl"
)
def parse_workloads(text: str) -> tuple[str, ...]:
workloads = tuple(part.strip() for part in str(text).split(",") if part.strip())
if not workloads:
raise ValueError("--workloads must contain at least one workload")
unknown = [item for item in workloads if item not in TRACE_SPECS]
if unknown:
raise ValueError(f"Unknown workloads: {unknown}")
return workloads
def build_windows(
source_dirs: dict[str, Path], *, workloads: tuple[str, ...]
) -> list[dict[str, Any]]:
windows: list[dict[str, Any]] = []
for trace_type in workloads:
spec = TRACE_SPECS[trace_type]
source_dir = source_dirs[str(spec["source_key"])]
if "date_slot_tokens" in spec:
date_slot_tokens = list(spec["date_slot_tokens"])
else:
date_slot_tokens = [
(date_text, slot_token)
for slot_token in spec["slot_tokens"]
for date_text in spec["target_dates"]
]
sample_order_by_group: dict[str, int] = {}
for date_text, slot_token in date_slot_tokens:
window_spec = WINDOW_SPECS[slot_token]
date_token = date_text.replace("-", "")
sample_order = sample_order_by_group.setdefault(slot_token, 0)
sample_order_by_group[slot_token] += 1
window_id = f"{trace_type}_w{date_token}_{window_spec['slot_token']}"
source_trace_path = source_dir / build_source_filename(
trace_type=trace_type, date_text=date_text, slot_token=slot_token
)
source_prompt_path = Path(str(source_trace_path).replace(".jsonl", "_prompt.jsonl"))
windows.append(
{
"window_id": window_id,
"trace_type": trace_type,
"block_size": TRACE_SPECS[trace_type]["block_size"],
"window_start": float(window_spec["window_start"]),
"window_end": float(window_spec["window_end"]),
"window_index": 6,
"sample_order": sample_order,
"date": date_text,
"slot_token": window_spec["slot_token"],
"slot_label": window_spec["slot_label"],
"source_trace_path": str(source_trace_path),
"source_prompt_path": str(source_prompt_path),
}
)
return windows
def _merge_trace_and_prompt(trace_row: dict[str, Any], prompt_row: dict[str, Any]) -> dict[str, Any]:
if trace_row.get("chat_id") != prompt_row.get("chat_id") or trace_row.get("turn") != prompt_row.get("turn"):
raise ValueError(
"trace/prompt rows are misaligned: "
f"{trace_row.get('chat_id')}/{trace_row.get('turn')} vs "
f"{prompt_row.get('chat_id')}/{prompt_row.get('turn')}"
)
merged = dict(trace_row)
prompt = prompt_row.get("prompt")
if isinstance(prompt, str) and prompt:
merged["prompt"] = prompt
return merged
@dataclass
class WindowStats:
num_requests: int = 0
sum_input_length: int = 0
max_input_length: int = 0
first_request_ts: float | None = None
last_request_ts: float | None = None
first_request_index: int | None = None
last_request_index: int | None = None
def record(self, row: dict[str, Any]) -> None:
input_length = int(row.get("input_length") or 0)
timestamp = float(row["timestamp"])
if self.num_requests == 0:
self.first_request_ts = timestamp
self.first_request_index = 0
self.last_request_ts = timestamp
self.last_request_index = self.num_requests
self.num_requests += 1
self.sum_input_length += input_length
self.max_input_length = max(self.max_input_length, input_length)
def materialize_windows(
windows: list[dict[str, Any]], *, sample_seed: int, traces_dir: Path
) -> dict[str, WindowStats]:
grouped: dict[tuple[Path, Path], list[dict[str, Any]]] = {}
for window in windows:
trace_path = Path(window["source_trace_path"])
prompt_path = Path(window["source_prompt_path"])
grouped.setdefault((trace_path, prompt_path), []).append(window)
stats_by_window = {str(window["window_id"]): WindowStats() for window in windows}
handles: dict[str, Any] = {}
final_paths: dict[str, Path] = {}
temp_paths: dict[str, Path] = {}
completed = False
try:
for window in windows:
window_id = str(window["window_id"])
final_path = traces_dir / f"{window_id}.jsonl"
temp_path = traces_dir / f".{window_id}.jsonl.tmp.{os.getpid()}"
if temp_path.exists():
temp_path.unlink()
final_paths[window_id] = final_path
temp_paths[window_id] = temp_path
handles[window_id] = temp_path.open("w", encoding="utf-8")
for trace_path, prompt_path in sorted(grouped.keys()):
bucket = grouped[(trace_path, prompt_path)]
bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"])))
matched_rows = 0
root_of: dict[Any, Any] = {}
with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle:
for trace_raw, prompt_raw in zip(trace_handle, prompt_handle):
trace_raw = trace_raw.strip()
if not trace_raw:
continue
trace_row = json.loads(trace_raw)
# Resolve session linkage for every row (even unmatched ones)
# so multi-turn chains crossing the window edge still group.
session_root = resolve_session_root(trace_row, root_of)
timestamp = float(trace_row.get("timestamp") or 0.0)
matched_window: dict[str, Any] | None = None
for window in bucket:
start = float(window["window_start"])
end = float(window["window_end"])
if start <= timestamp < end:
matched_window = window
break
if matched_window is None:
continue
prompt_raw = prompt_raw.strip()
if not prompt_raw:
continue
prompt_row = json.loads(prompt_raw)
merged = _merge_trace_and_prompt(trace_row, prompt_row)
window_id = str(matched_window["window_id"])
out = dict(merged)
start = float(matched_window["window_start"])
out["source_timestamp"] = timestamp
out["timestamp"] = timestamp - start
out["session_root"] = session_root
out["sampling_u"] = session_uniform(
seed=sample_seed,
window_id=window_id,
session_root=session_root,
)
handles[window_id].write(json.dumps(out, ensure_ascii=False) + "\n")
stats_by_window[window_id].record(out)
matched_rows += 1
print(
f"materialized {trace_path.name} -> matched_rows={matched_rows}",
flush=True,
)
completed = True
finally:
for handle in handles.values():
handle.close()
if completed:
for window_id, temp_path in temp_paths.items():
os.replace(temp_path, final_paths[window_id])
else:
for temp_path in temp_paths.values():
if temp_path.exists():
temp_path.unlink()
return stats_by_window
def build_output_window(
window: dict[str, Any],
stats: WindowStats,
trace_relpath: str,
*,
sample_seed: int,
) -> dict[str, Any]:
output = dict(window)
output["trace_file"] = trace_relpath
output["window_start"] = 0.0
output["window_end"] = float(window["window_end"]) - float(window["window_start"])
output["num_requests"] = stats.num_requests
output["sum_input_length"] = stats.sum_input_length
output["max_input_length"] = stats.max_input_length
output["num_excluded_too_long"] = 0
output["sampling_u_field"] = "sampling_u"
output["sampling_seed"] = int(sample_seed)
output["sampling_strategy"] = "session_coherent_uniform_score"
output["first_request_ts"] = stats.first_request_ts
output["last_request_ts"] = stats.last_request_ts
output["first_request_index"] = stats.first_request_index
output["last_request_index"] = stats.last_request_index
return output
def main() -> int:
args = parse_args()
workloads = parse_workloads(args.workloads)
output_root = args.output_root.resolve()
if output_root.exists():
if not args.overwrite:
raise SystemExit(f"Output root exists: {output_root} (pass --overwrite to replace)")
output_root.mkdir(parents=True, exist_ok=True)
traces_dir = output_root / "traces"
traces_dir.mkdir(parents=True, exist_ok=True)
source_dirs = {
"legacy": args.legacy_source.resolve(),
"thinking": args.thinking_source.resolve(),
}
windows = build_windows(source_dirs, workloads=workloads)
stats_by_window = materialize_windows(
windows,
sample_seed=args.sample_seed,
traces_dir=traces_dir,
)
rendered_windows: list[dict[str, Any]] = []
for window in windows:
trace_filename = f"{window['window_id']}.jsonl"
rendered_windows.append(
build_output_window(
window,
stats_by_window[str(window["window_id"])],
trace_relpath=f"traces/{trace_filename}",
sample_seed=args.sample_seed,
)
)
windows_payload = {
"sample_seed": args.sample_seed,
"source_trace_dirs": {key: str(path) for key, path in source_dirs.items()},
"u_field": "sampling_u",
"window_duration_seconds": 600.0,
"windows": rendered_windows,
}
windows_path = output_root / "windows.json"
windows_tmp_path = output_root / f".windows.json.tmp.{os.getpid()}"
try:
windows_tmp_path.write_text(
json.dumps(windows_payload, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
os.replace(windows_tmp_path, windows_path)
finally:
if windows_tmp_path.exists():
windows_tmp_path.unlink()
print(output_root)
print(f"windows={len(rendered_windows)}")
return 0
if __name__ == "__main__":
raise SystemExit(main())