trace: make window materialization atomic
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
@@ -222,10 +223,19 @@ def materialize_windows(
|
||||
|
||||
stats_by_window = {str(window["window_id"]): WindowStats() for window in windows}
|
||||
handles: dict[str, Any] = {}
|
||||
final_paths: dict[str, Path] = {}
|
||||
temp_paths: dict[str, Path] = {}
|
||||
completed = False
|
||||
try:
|
||||
for window in windows:
|
||||
window_id = str(window["window_id"])
|
||||
handles[window_id] = (traces_dir / f"{window_id}.jsonl").open("w", encoding="utf-8")
|
||||
final_path = traces_dir / f"{window_id}.jsonl"
|
||||
temp_path = traces_dir / f".{window_id}.jsonl.tmp.{os.getpid()}"
|
||||
if temp_path.exists():
|
||||
temp_path.unlink()
|
||||
final_paths[window_id] = final_path
|
||||
temp_paths[window_id] = temp_path
|
||||
handles[window_id] = temp_path.open("w", encoding="utf-8")
|
||||
|
||||
for trace_path, prompt_path in sorted(grouped.keys()):
|
||||
bucket = grouped[(trace_path, prompt_path)]
|
||||
@@ -270,9 +280,17 @@ def materialize_windows(
|
||||
f"materialized {trace_path.name} -> matched_rows={matched_rows}",
|
||||
flush=True,
|
||||
)
|
||||
completed = True
|
||||
finally:
|
||||
for handle in handles.values():
|
||||
handle.close()
|
||||
if completed:
|
||||
for window_id, temp_path in temp_paths.items():
|
||||
os.replace(temp_path, final_paths[window_id])
|
||||
else:
|
||||
for temp_path in temp_paths.values():
|
||||
if temp_path.exists():
|
||||
temp_path.unlink()
|
||||
return stats_by_window
|
||||
|
||||
|
||||
@@ -342,10 +360,17 @@ def main() -> int:
|
||||
"window_duration_seconds": 600.0,
|
||||
"windows": rendered_windows,
|
||||
}
|
||||
(output_root / "windows.json").write_text(
|
||||
windows_path = output_root / "windows.json"
|
||||
windows_tmp_path = output_root / f".windows.json.tmp.{os.getpid()}"
|
||||
try:
|
||||
windows_tmp_path.write_text(
|
||||
json.dumps(windows_payload, ensure_ascii=False, indent=2) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
os.replace(windows_tmp_path, windows_path)
|
||||
finally:
|
||||
if windows_tmp_path.exists():
|
||||
windows_tmp_path.unlink()
|
||||
print(output_root)
|
||||
print(f"windows={len(rendered_windows)}")
|
||||
return 0
|
||||
|
||||
@@ -853,6 +853,64 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertEqual(rows[0]["output_length"], 7)
|
||||
self.assertIsInstance(rows[0]["sampling_u"], float)
|
||||
|
||||
def test_prepare_trace_windows_preserves_existing_files_on_failure(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
legacy_source = tmp_path / "legacy"
|
||||
thinking_source = tmp_path / "thinking"
|
||||
output_root = tmp_path / "trace_windows"
|
||||
traces_dir = output_root / "traces"
|
||||
legacy_source.mkdir()
|
||||
thinking_source.mkdir()
|
||||
traces_dir.mkdir(parents=True)
|
||||
|
||||
for filename in [
|
||||
"qwen_chat_blksz_64_031109-031111",
|
||||
"qwen_chat_blksz_64_031121-031123",
|
||||
]:
|
||||
for suffix in [".jsonl", "_prompt.jsonl"]:
|
||||
path = legacy_source / f"{filename}{suffix}"
|
||||
path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"chat_id": "c1",
|
||||
"turn": 1,
|
||||
"timestamp": 3605.0,
|
||||
"input_length": 20,
|
||||
"output_length": 7,
|
||||
"prompt": "prompt",
|
||||
}
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
sentinel = traces_dir / "chat_w20260311_1000.jsonl"
|
||||
sentinel.write_text("sentinel\n", encoding="utf-8")
|
||||
|
||||
proc = subprocess.run(
|
||||
[
|
||||
"python3",
|
||||
"scripts/prepare_trace_windows.py",
|
||||
"--legacy-source",
|
||||
str(legacy_source),
|
||||
"--thinking-source",
|
||||
str(thinking_source),
|
||||
"--output-root",
|
||||
str(output_root),
|
||||
"--workloads",
|
||||
"chat",
|
||||
"--overwrite",
|
||||
],
|
||||
cwd="/home/gahow/phd/aituner",
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
self.assertNotEqual(proc.returncode, 0)
|
||||
self.assertEqual(sentinel.read_text(encoding="utf-8"), "sentinel\n")
|
||||
self.assertEqual(sorted(path.name for path in traces_dir.glob("*.tmp.*")), [])
|
||||
|
||||
def test_binary_search_max_feasible(self) -> None:
|
||||
result = binary_search_max_feasible(
|
||||
low=0.0,
|
||||
|
||||
Reference in New Issue
Block a user