From 4e1401f50c67c02049799a090df89a6a7e12fd1f Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Sat, 4 Apr 2026 21:49:03 +0800 Subject: [PATCH] Stream trace window materialization --- scripts/prepare_trace_windows.py | 153 +++++++++++++++++++------------ tests/test_core_flow.py | 97 +++++++++++++++++++- 2 files changed, 191 insertions(+), 59 deletions(-) diff --git a/scripts/prepare_trace_windows.py b/scripts/prepare_trace_windows.py index 09d2d48..ec49109 100644 --- a/scripts/prepare_trace_windows.py +++ b/scripts/prepare_trace_windows.py @@ -5,6 +5,7 @@ import hashlib import json from pathlib import Path from typing import Any +from dataclasses import dataclass REPO_ROOT = Path(__file__).resolve().parents[1] @@ -186,52 +187,97 @@ def _merge_trace_and_prompt(trace_row: dict[str, Any], prompt_row: dict[str, Any return merged -def extract_windows(windows: list[dict[str, Any]], *, sample_seed: int) -> dict[str, list[dict[str, Any]]]: +@dataclass +class WindowStats: + num_requests: int = 0 + sum_input_length: int = 0 + max_input_length: int = 0 + first_request_ts: float | None = None + last_request_ts: float | None = None + first_request_index: int | None = None + last_request_index: int | None = None + + def record(self, row: dict[str, Any]) -> None: + input_length = int(row.get("input_length") or 0) + timestamp = float(row["timestamp"]) + if self.num_requests == 0: + self.first_request_ts = timestamp + self.first_request_index = 0 + self.last_request_ts = timestamp + self.last_request_index = self.num_requests + self.num_requests += 1 + self.sum_input_length += input_length + self.max_input_length = max(self.max_input_length, input_length) + + +def materialize_windows( + windows: list[dict[str, Any]], *, sample_seed: int, traces_dir: Path +) -> dict[str, WindowStats]: grouped: dict[tuple[Path, Path], list[dict[str, Any]]] = {} for window in windows: trace_path = Path(window["source_trace_path"]) prompt_path = Path(window["source_prompt_path"]) grouped.setdefault((trace_path, prompt_path), []).append(window) - extracted: dict[str, list[dict[str, Any]]] = {str(window["window_id"]): [] for window in windows} - for (trace_path, prompt_path), bucket in grouped.items(): - bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"]))) - with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle: - for index, (trace_raw, prompt_raw) in enumerate(zip(trace_handle, prompt_handle)): - trace_raw = trace_raw.strip() - prompt_raw = prompt_raw.strip() - if not trace_raw or not prompt_raw: - continue - trace_row = json.loads(trace_raw) - timestamp = float(trace_row.get("timestamp") or 0.0) - matched_window: dict[str, Any] | None = None - for window in bucket: - start = float(window["window_start"]) - end = float(window["window_end"]) - if start <= timestamp < end: - matched_window = window - break - if matched_window is None: - continue - prompt_row = json.loads(prompt_raw) - merged = _merge_trace_and_prompt(trace_row, prompt_row) - out = dict(merged) - start = float(matched_window["window_start"]) - out["source_timestamp"] = timestamp - out["timestamp"] = timestamp - start - out["sampling_u"] = stable_uniform( - seed=sample_seed, - window_id=str(matched_window["window_id"]), - index=len(extracted[str(matched_window["window_id"])]), - row=merged, - ) - extracted[str(matched_window["window_id"])].append(out) - return extracted + stats_by_window = {str(window["window_id"]): WindowStats() for window in windows} + handles: dict[str, Any] = {} + try: + for window in windows: + window_id = str(window["window_id"]) + handles[window_id] = (traces_dir / f"{window_id}.jsonl").open("w", encoding="utf-8") + + for trace_path, prompt_path in sorted(grouped.keys()): + bucket = grouped[(trace_path, prompt_path)] + bucket.sort(key=lambda item: (float(item["window_start"]), str(item["window_id"]))) + matched_rows = 0 + with trace_path.open() as trace_handle, prompt_path.open() as prompt_handle: + for trace_raw, prompt_raw in zip(trace_handle, prompt_handle): + trace_raw = trace_raw.strip() + if not trace_raw: + continue + trace_row = json.loads(trace_raw) + timestamp = float(trace_row.get("timestamp") or 0.0) + matched_window: dict[str, Any] | None = None + for window in bucket: + start = float(window["window_start"]) + end = float(window["window_end"]) + if start <= timestamp < end: + matched_window = window + break + if matched_window is None: + continue + prompt_raw = prompt_raw.strip() + if not prompt_raw: + continue + prompt_row = json.loads(prompt_raw) + merged = _merge_trace_and_prompt(trace_row, prompt_row) + window_id = str(matched_window["window_id"]) + out = dict(merged) + start = float(matched_window["window_start"]) + out["source_timestamp"] = timestamp + out["timestamp"] = timestamp - start + out["sampling_u"] = stable_uniform( + seed=sample_seed, + window_id=window_id, + index=stats_by_window[window_id].num_requests, + row=merged, + ) + handles[window_id].write(json.dumps(out, ensure_ascii=False) + "\n") + stats_by_window[window_id].record(out) + matched_rows += 1 + print( + f"materialized {trace_path.name} -> matched_rows={matched_rows}", + flush=True, + ) + finally: + for handle in handles.values(): + handle.close() + return stats_by_window def build_output_window( window: dict[str, Any], - rows: list[dict[str, Any]], + stats: WindowStats, trace_relpath: str, *, sample_seed: int, @@ -240,25 +286,17 @@ def build_output_window( output["trace_file"] = trace_relpath output["window_start"] = 0.0 output["window_end"] = float(window["window_end"]) - float(window["window_start"]) - output["num_requests"] = len(rows) - output["sum_input_length"] = int(sum(int(row.get("input_length") or 0) for row in rows)) - output["max_input_length"] = int( - max((int(row.get("input_length") or 0) for row in rows), default=0) - ) + output["num_requests"] = stats.num_requests + output["sum_input_length"] = stats.sum_input_length + output["max_input_length"] = stats.max_input_length output["num_excluded_too_long"] = 0 output["sampling_u_field"] = "sampling_u" output["sampling_seed"] = int(sample_seed) output["sampling_strategy"] = "fixed_uniform_score" - if rows: - output["first_request_ts"] = float(rows[0]["timestamp"]) - output["last_request_ts"] = float(rows[-1]["timestamp"]) - output["first_request_index"] = 0 - output["last_request_index"] = len(rows) - 1 - else: - output["first_request_ts"] = None - output["last_request_ts"] = None - output["first_request_index"] = None - output["last_request_index"] = None + output["first_request_ts"] = stats.first_request_ts + output["last_request_ts"] = stats.last_request_ts + output["first_request_index"] = stats.first_request_index + output["last_request_index"] = stats.last_request_index return output @@ -278,20 +316,19 @@ def main() -> int: "thinking": args.thinking_source.resolve(), } windows = build_windows(source_dirs, workloads=workloads) - extracted = extract_windows(windows, sample_seed=args.sample_seed) + stats_by_window = materialize_windows( + windows, + sample_seed=args.sample_seed, + traces_dir=traces_dir, + ) rendered_windows: list[dict[str, Any]] = [] for window in windows: - rows = extracted[str(window["window_id"])] trace_filename = f"{window['window_id']}.jsonl" - trace_path = traces_dir / trace_filename - with trace_path.open("w", encoding="utf-8") as handle: - for row in rows: - handle.write(json.dumps(row, ensure_ascii=False) + "\n") rendered_windows.append( build_output_window( window, - rows, + stats_by_window[str(window["window_id"])], trace_relpath=f"traces/{trace_filename}", sample_seed=args.sample_seed, ) diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index 4706c37..e66ca84 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import subprocess import tempfile import unittest from pathlib import Path @@ -258,7 +259,101 @@ class CoreFlowTests(unittest.TestCase): self.assertTrue(evaluations[0].passed) self.assertFalse(evaluations[1].passed) self.assertEqual(summary["slo_pass_rate"], 0.5) - self.assertFalse(summary["feasible"]) + + def test_prepare_trace_windows_materializes_repo_local_assets(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + legacy_source = tmp_path / "legacy" + thinking_source = tmp_path / "thinking" + legacy_source.mkdir() + thinking_source.mkdir() + + for filename in [ + "qwen_chat_blksz_64_031109-031111", + "qwen_chat_blksz_64_031121-031123", + "qwen_chat_blksz_64_031209-031211", + "qwen_chat_blksz_64_031221-031223", + "qwen_chat_blksz_64_031309-031311", + "qwen_chat_blksz_64_031321-031323", + "qwen_chat_blksz_64_031609-031611", + "qwen_chat_blksz_64_031621-031623", + "qwen_chat_blksz_64_031709-031711", + "qwen_chat_blksz_64_031721-031723", + ]: + for suffix in [".jsonl", "_prompt.jsonl"]: + path = legacy_source / f"{filename}{suffix}" + path.write_text("", encoding="utf-8") + + peak_trace = legacy_source / "qwen_chat_blksz_64_031109-031111.jsonl" + peak_prompt = legacy_source / "qwen_chat_blksz_64_031109-031111_prompt.jsonl" + peak_trace.write_text( + "\n".join( + [ + json.dumps( + { + "chat_id": "c1", + "turn": 1, + "timestamp": 3599.0, + "input_length": 10, + "output_length": 3, + } + ), + json.dumps( + { + "chat_id": "c2", + "turn": 2, + "timestamp": 3605.0, + "input_length": 20, + "output_length": 7, + } + ), + ] + ) + + "\n", + encoding="utf-8", + ) + peak_prompt.write_text( + "\n".join( + [ + json.dumps({"chat_id": "c1", "turn": 1, "prompt": "ignore me"}), + json.dumps({"chat_id": "c2", "turn": 2, "prompt": "real prompt"}), + ] + ) + + "\n", + encoding="utf-8", + ) + + output_root = tmp_path / "trace_windows" + subprocess.run( + [ + "python3", + "scripts/prepare_trace_windows.py", + "--legacy-source", + str(legacy_source), + "--thinking-source", + str(thinking_source), + "--output-root", + str(output_root), + "--workloads", + "chat", + "--overwrite", + ], + check=True, + cwd="/home/gahow/phd/aituner", + ) + + windows_payload = json.loads((output_root / "windows.json").read_text(encoding="utf-8")) + windows = {item["window_id"]: item for item in windows_payload["windows"]} + self.assertIn("chat_w20260311_peak_1000", windows) + self.assertEqual(windows["chat_w20260311_peak_1000"]["num_requests"], 1) + + trace_path = output_root / windows["chat_w20260311_peak_1000"]["trace_file"] + rows = [json.loads(line) for line in trace_path.read_text(encoding="utf-8").splitlines()] + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0]["prompt"], "real prompt") + self.assertEqual(rows[0]["timestamp"], 5.0) + self.assertEqual(rows[0]["output_length"], 7) + self.assertIsInstance(rows[0]["sampling_u"], float) def test_binary_search_max_feasible(self) -> None: result = binary_search_max_feasible(