trace: make window materialization atomic

This commit is contained in:
2026-04-12 23:09:30 +08:00
parent 631a076498
commit 4625fba487
2 changed files with 88 additions and 5 deletions

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import argparse import argparse
import hashlib import hashlib
import json import json
import os
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
@@ -222,10 +223,19 @@ def materialize_windows(
stats_by_window = {str(window["window_id"]): WindowStats() for window in windows} stats_by_window = {str(window["window_id"]): WindowStats() for window in windows}
handles: dict[str, Any] = {} handles: dict[str, Any] = {}
final_paths: dict[str, Path] = {}
temp_paths: dict[str, Path] = {}
completed = False
try: try:
for window in windows: for window in windows:
window_id = str(window["window_id"]) window_id = str(window["window_id"])
handles[window_id] = (traces_dir / f"{window_id}.jsonl").open("w", encoding="utf-8") final_path = traces_dir / f"{window_id}.jsonl"
temp_path = traces_dir / f".{window_id}.jsonl.tmp.{os.getpid()}"
if temp_path.exists():
temp_path.unlink()
final_paths[window_id] = final_path
temp_paths[window_id] = temp_path
handles[window_id] = temp_path.open("w", encoding="utf-8")
for trace_path, prompt_path in sorted(grouped.keys()): for trace_path, prompt_path in sorted(grouped.keys()):
bucket = grouped[(trace_path, prompt_path)] bucket = grouped[(trace_path, prompt_path)]
@@ -270,9 +280,17 @@ def materialize_windows(
f"materialized {trace_path.name} -> matched_rows={matched_rows}", f"materialized {trace_path.name} -> matched_rows={matched_rows}",
flush=True, flush=True,
) )
completed = True
finally: finally:
for handle in handles.values(): for handle in handles.values():
handle.close() handle.close()
if completed:
for window_id, temp_path in temp_paths.items():
os.replace(temp_path, final_paths[window_id])
else:
for temp_path in temp_paths.values():
if temp_path.exists():
temp_path.unlink()
return stats_by_window return stats_by_window
@@ -342,10 +360,17 @@ def main() -> int:
"window_duration_seconds": 600.0, "window_duration_seconds": 600.0,
"windows": rendered_windows, "windows": rendered_windows,
} }
(output_root / "windows.json").write_text( windows_path = output_root / "windows.json"
windows_tmp_path = output_root / f".windows.json.tmp.{os.getpid()}"
try:
windows_tmp_path.write_text(
json.dumps(windows_payload, ensure_ascii=False, indent=2) + "\n", json.dumps(windows_payload, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8", encoding="utf-8",
) )
os.replace(windows_tmp_path, windows_path)
finally:
if windows_tmp_path.exists():
windows_tmp_path.unlink()
print(output_root) print(output_root)
print(f"windows={len(rendered_windows)}") print(f"windows={len(rendered_windows)}")
return 0 return 0

View File

@@ -853,6 +853,64 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(rows[0]["output_length"], 7) self.assertEqual(rows[0]["output_length"], 7)
self.assertIsInstance(rows[0]["sampling_u"], float) self.assertIsInstance(rows[0]["sampling_u"], float)
def test_prepare_trace_windows_preserves_existing_files_on_failure(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
legacy_source = tmp_path / "legacy"
thinking_source = tmp_path / "thinking"
output_root = tmp_path / "trace_windows"
traces_dir = output_root / "traces"
legacy_source.mkdir()
thinking_source.mkdir()
traces_dir.mkdir(parents=True)
for filename in [
"qwen_chat_blksz_64_031109-031111",
"qwen_chat_blksz_64_031121-031123",
]:
for suffix in [".jsonl", "_prompt.jsonl"]:
path = legacy_source / f"{filename}{suffix}"
path.write_text(
json.dumps(
{
"chat_id": "c1",
"turn": 1,
"timestamp": 3605.0,
"input_length": 20,
"output_length": 7,
"prompt": "prompt",
}
)
+ "\n",
encoding="utf-8",
)
sentinel = traces_dir / "chat_w20260311_1000.jsonl"
sentinel.write_text("sentinel\n", encoding="utf-8")
proc = subprocess.run(
[
"python3",
"scripts/prepare_trace_windows.py",
"--legacy-source",
str(legacy_source),
"--thinking-source",
str(thinking_source),
"--output-root",
str(output_root),
"--workloads",
"chat",
"--overwrite",
],
cwd="/home/gahow/phd/aituner",
capture_output=True,
text=True,
)
self.assertNotEqual(proc.returncode, 0)
self.assertEqual(sentinel.read_text(encoding="utf-8"), "sentinel\n")
self.assertEqual(sorted(path.name for path in traces_dir.glob("*.tmp.*")), [])
def test_binary_search_max_feasible(self) -> None: def test_binary_search_max_feasible(self) -> None:
result = binary_search_max_feasible( result = binary_search_max_feasible(
low=0.0, low=0.0,