Stream trace window materialization

This commit is contained in:
2026-04-04 21:49:03 +08:00
parent 69f666593e
commit 4e1401f50c
2 changed files with 191 additions and 59 deletions

View File

@@ -5,6 +5,7 @@ import hashlib
import json import json
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from dataclasses import dataclass
REPO_ROOT = Path(__file__).resolve().parents[1] REPO_ROOT = Path(__file__).resolve().parents[1]
@@ -186,21 +187,53 @@ def _merge_trace_and_prompt(trace_row: dict[str, Any], prompt_row: dict[str, Any
return merged return merged
def extract_windows(windows: list[dict[str, Any]], *, sample_seed: int) -> dict[str, list[dict[str, Any]]]: @dataclass
class WindowStats:
num_requests: int = 0
sum_input_length: int = 0
max_input_length: int = 0
first_request_ts: float | None = None
last_request_ts: float | None = None
first_request_index: int | None = None
last_request_index: int | None = None
def record(self, row: dict[str, Any]) -> None:
input_length = int(row.get("input_length") or 0)
timestamp = float(row["timestamp"])
if self.num_requests == 0:
self.first_request_ts = timestamp
self.first_request_index = 0
self.last_request_ts = timestamp
self.last_request_index = self.num_requests
self.num_requests += 1
self.sum_input_length += input_length
self.max_input_length = max(self.max_input_length, input_length)
def materialize_windows(
windows: list[dict[str, Any]], *, sample_seed: int, traces_dir: Path
) -> dict[str, WindowStats]:
grouped: dict[tuple[Path, Path], list[dict[str, Any]]] = {} grouped: dict[tuple[Path, Path], list[dict[str, Any]]] = {}
for window in windows: for window in windows:
trace_path = Path(window["source_trace_path"]) trace_path = Path(window["source_trace_path"])
prompt_path = Path(window["source_prompt_path"]) prompt_path = Path(window["source_prompt_path"])
grouped.setdefault((trace_path, prompt_path), []).append(window) grouped.setdefault((trace_path, prompt_path), []).append(window)
extracted: dict[str, list[dict[str, Any]]] = {str(window["window_id"]): [] for window in windows} stats_by_window = {str(window["window_id"]): WindowStats() for window in windows}
for (trace_path, prompt_path), bucket in grouped.items(): handles: dict[str, Any] = {}
try:
for window in windows:
window_id = str(window["window_id"])
handles[window_id] = (traces_dir / f"{window_id}.jsonl").open("w", encoding="utf-8")
for trace_path, prompt_path in sorted(grouped.keys()):
bucket = grouped[(trace_path, prompt_path)]
bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"]))) bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"])))
matched_rows = 0
with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle: with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle:
for index, (trace_raw, prompt_raw) in enumerate(zip(trace_handle, prompt_handle)): for trace_raw, prompt_raw in zip(trace_handle, prompt_handle):
trace_raw = trace_raw.strip() trace_raw = trace_raw.strip()
prompt_raw = prompt_raw.strip() if not trace_raw:
if not trace_raw or not prompt_raw:
continue continue
trace_row = json.loads(trace_raw) trace_row = json.loads(trace_raw)
timestamp = float(trace_row.get("timestamp") or 0.0) timestamp = float(trace_row.get("timestamp") or 0.0)
@@ -213,25 +246,38 @@ def extract_windows(windows: list[dict[str, Any]], *, sample_seed: int) -> dict[
break break
if matched_window is None: if matched_window is None:
continue continue
prompt_raw = prompt_raw.strip()
if not prompt_raw:
continue
prompt_row = json.loads(prompt_raw) prompt_row = json.loads(prompt_raw)
merged = _merge_trace_and_prompt(trace_row, prompt_row) merged = _merge_trace_and_prompt(trace_row, prompt_row)
window_id = str(matched_window["window_id"])
out = dict(merged) out = dict(merged)
start = float(matched_window["window_start"]) start = float(matched_window["window_start"])
out["source_timestamp"] = timestamp out["source_timestamp"] = timestamp
out["timestamp"] = timestamp - start out["timestamp"] = timestamp - start
out["sampling_u"] = stable_uniform( out["sampling_u"] = stable_uniform(
seed=sample_seed, seed=sample_seed,
window_id=str(matched_window["window_id"]), window_id=window_id,
index=len(extracted[str(matched_window["window_id"])]), index=stats_by_window[window_id].num_requests,
row=merged, row=merged,
) )
extracted[str(matched_window["window_id"])].append(out) handles[window_id].write(json.dumps(out, ensure_ascii=False) + "\n")
return extracted stats_by_window[window_id].record(out)
matched_rows += 1
print(
f"materialized {trace_path.name} -> matched_rows={matched_rows}",
flush=True,
)
finally:
for handle in handles.values():
handle.close()
return stats_by_window
def build_output_window( def build_output_window(
window: dict[str, Any], window: dict[str, Any],
rows: list[dict[str, Any]], stats: WindowStats,
trace_relpath: str, trace_relpath: str,
*, *,
sample_seed: int, sample_seed: int,
@@ -240,25 +286,17 @@ def build_output_window(
output["trace_file"] = trace_relpath output["trace_file"] = trace_relpath
output["window_start"] = 0.0 output["window_start"] = 0.0
output["window_end"] = float(window["window_end"]) - float(window["window_start"]) output["window_end"] = float(window["window_end"]) - float(window["window_start"])
output["num_requests"] = len(rows) output["num_requests"] = stats.num_requests
output["sum_input_length"] = int(sum(int(row.get("input_length") or 0) for row in rows)) output["sum_input_length"] = stats.sum_input_length
output["max_input_length"] = int( output["max_input_length"] = stats.max_input_length
max((int(row.get("input_length") or 0) for row in rows), default=0)
)
output["num_excluded_too_long"] = 0 output["num_excluded_too_long"] = 0
output["sampling_u_field"] = "sampling_u" output["sampling_u_field"] = "sampling_u"
output["sampling_seed"] = int(sample_seed) output["sampling_seed"] = int(sample_seed)
output["sampling_strategy"] = "fixed_uniform_score" output["sampling_strategy"] = "fixed_uniform_score"
if rows: output["first_request_ts"] = stats.first_request_ts
output["first_request_ts"] = float(rows[0]["timestamp"]) output["last_request_ts"] = stats.last_request_ts
output["last_request_ts"] = float(rows[-1]["timestamp"]) output["first_request_index"] = stats.first_request_index
output["first_request_index"] = 0 output["last_request_index"] = stats.last_request_index
output["last_request_index"] = len(rows) - 1
else:
output["first_request_ts"] = None
output["last_request_ts"] = None
output["first_request_index"] = None
output["last_request_index"] = None
return output return output
@@ -278,20 +316,19 @@ def main() -> int:
"thinking": args.thinking_source.resolve(), "thinking": args.thinking_source.resolve(),
} }
windows = build_windows(source_dirs, workloads=workloads) windows = build_windows(source_dirs, workloads=workloads)
extracted = extract_windows(windows, sample_seed=args.sample_seed) stats_by_window = materialize_windows(
windows,
sample_seed=args.sample_seed,
traces_dir=traces_dir,
)
rendered_windows: list[dict[str, Any]] = [] rendered_windows: list[dict[str, Any]] = []
for window in windows: for window in windows:
rows = extracted[str(window["window_id"])]
trace_filename = f"{window['window_id']}.jsonl" trace_filename = f"{window['window_id']}.jsonl"
trace_path = traces_dir / trace_filename
with trace_path.open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row, ensure_ascii=False) + "\n")
rendered_windows.append( rendered_windows.append(
build_output_window( build_output_window(
window, window,
rows, stats_by_window[str(window["window_id"])],
trace_relpath=f"traces/{trace_filename}", trace_relpath=f"traces/{trace_filename}",
sample_seed=args.sample_seed, sample_seed=args.sample_seed,
) )

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
import subprocess
import tempfile import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
@@ -258,7 +259,101 @@ class CoreFlowTests(unittest.TestCase):
self.assertTrue(evaluations[0].passed) self.assertTrue(evaluations[0].passed)
self.assertFalse(evaluations[1].passed) self.assertFalse(evaluations[1].passed)
self.assertEqual(summary["slo_pass_rate"], 0.5) self.assertEqual(summary["slo_pass_rate"], 0.5)
self.assertFalse(summary["feasible"])
def test_prepare_trace_windows_materializes_repo_local_assets(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
legacy_source = tmp_path / "legacy"
thinking_source = tmp_path / "thinking"
legacy_source.mkdir()
thinking_source.mkdir()
for filename in [
"qwen_chat_blksz_64_031109-031111",
"qwen_chat_blksz_64_031121-031123",
"qwen_chat_blksz_64_031209-031211",
"qwen_chat_blksz_64_031221-031223",
"qwen_chat_blksz_64_031309-031311",
"qwen_chat_blksz_64_031321-031323",
"qwen_chat_blksz_64_031609-031611",
"qwen_chat_blksz_64_031621-031623",
"qwen_chat_blksz_64_031709-031711",
"qwen_chat_blksz_64_031721-031723",
]:
for suffix in [".jsonl", "_prompt.jsonl"]:
path = legacy_source / f"{filename}{suffix}"
path.write_text("", encoding="utf-8")
peak_trace = legacy_source / "qwen_chat_blksz_64_031109-031111.jsonl"
peak_prompt = legacy_source / "qwen_chat_blksz_64_031109-031111_prompt.jsonl"
peak_trace.write_text(
"\n".join(
[
json.dumps(
{
"chat_id": "c1",
"turn": 1,
"timestamp": 3599.0,
"input_length": 10,
"output_length": 3,
}
),
json.dumps(
{
"chat_id": "c2",
"turn": 2,
"timestamp": 3605.0,
"input_length": 20,
"output_length": 7,
}
),
]
)
+ "\n",
encoding="utf-8",
)
peak_prompt.write_text(
"\n".join(
[
json.dumps({"chat_id": "c1", "turn": 1, "prompt": "ignore me"}),
json.dumps({"chat_id": "c2", "turn": 2, "prompt": "real prompt"}),
]
)
+ "\n",
encoding="utf-8",
)
output_root = tmp_path / "trace_windows"
subprocess.run(
[
"python3",
"scripts/prepare_trace_windows.py",
"--legacy-source",
str(legacy_source),
"--thinking-source",
str(thinking_source),
"--output-root",
str(output_root),
"--workloads",
"chat",
"--overwrite",
],
check=True,
cwd="/home/gahow/phd/aituner",
)
windows_payload = json.loads((output_root / "windows.json").read_text(encoding="utf-8"))
windows = {item["window_id"]: item for item in windows_payload["windows"]}
self.assertIn("chat_w20260311_peak_1000", windows)
self.assertEqual(windows["chat_w20260311_peak_1000"]["num_requests"], 1)
trace_path = output_root / windows["chat_w20260311_peak_1000"]["trace_file"]
rows = [json.loads(line) for line in trace_path.read_text(encoding="utf-8").splitlines()]
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0]["prompt"], "real prompt")
self.assertEqual(rows[0]["timestamp"], 5.0)
self.assertEqual(rows[0]["output_length"], 7)
self.assertIsInstance(rows[0]["sampling_u"], float)
def test_binary_search_max_feasible(self) -> None: def test_binary_search_max_feasible(self) -> None:
result = binary_search_max_feasible( result = binary_search_max_feasible(