diff --git a/scripts/prepare_trace_windows.py b/scripts/prepare_trace_windows.py index edab4f5..89b18ad 100644 --- a/scripts/prepare_trace_windows.py +++ b/scripts/prepare_trace_windows.py @@ -3,6 +3,7 @@ from __future__ import annotations import argparse import hashlib import json +import os from pathlib import Path from typing import Any from dataclasses import dataclass @@ -222,10 +223,19 @@ def materialize_windows( stats_by_window = {str(window["window_id"]): WindowStats() for window in windows} handles: dict[str, Any] = {} + final_paths: dict[str, Path] = {} + temp_paths: dict[str, Path] = {} + completed = False try: for window in windows: window_id = str(window["window_id"]) - handles[window_id] = (traces_dir / f"{window_id}.jsonl").open("w", encoding="utf-8") + final_path = traces_dir / f"{window_id}.jsonl" + temp_path = traces_dir / f".{window_id}.jsonl.tmp.{os.getpid()}" + if temp_path.exists(): + temp_path.unlink() + final_paths[window_id] = final_path + temp_paths[window_id] = temp_path + handles[window_id] = temp_path.open("w", encoding="utf-8") for trace_path, prompt_path in sorted(grouped.keys()): bucket = grouped[(trace_path, prompt_path)] @@ -270,9 +280,17 @@ def materialize_windows( f"materialized {trace_path.name} -> matched_rows={matched_rows}", flush=True, ) + completed = True finally: for handle in handles.values(): handle.close() + if completed: + for window_id, temp_path in temp_paths.items(): + os.replace(temp_path, final_paths[window_id]) + else: + for temp_path in temp_paths.values(): + if temp_path.exists(): + temp_path.unlink() return stats_by_window @@ -342,10 +360,17 @@ def main() -> int: "window_duration_seconds": 600.0, "windows": rendered_windows, } - (output_root / "windows.json").write_text( - json.dumps(windows_payload, ensure_ascii=False, indent=2) + "\n", - encoding="utf-8", - ) + windows_path = output_root / "windows.json" + windows_tmp_path = output_root / f".windows.json.tmp.{os.getpid()}" + try: + windows_tmp_path.write_text( + json.dumps(windows_payload, ensure_ascii=False, indent=2) + "\n", + encoding="utf-8", + ) + os.replace(windows_tmp_path, windows_path) + finally: + if windows_tmp_path.exists(): + windows_tmp_path.unlink() print(output_root) print(f"windows={len(rendered_windows)}") return 0 diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index bfbc6ea..7d2dff9 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -853,6 +853,64 @@ class CoreFlowTests(unittest.TestCase): self.assertEqual(rows[0]["output_length"], 7) self.assertIsInstance(rows[0]["sampling_u"], float) + def test_prepare_trace_windows_preserves_existing_files_on_failure(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + legacy_source = tmp_path / "legacy" + thinking_source = tmp_path / "thinking" + output_root = tmp_path / "trace_windows" + traces_dir = output_root / "traces" + legacy_source.mkdir() + thinking_source.mkdir() + traces_dir.mkdir(parents=True) + + for filename in [ + "qwen_chat_blksz_64_031109-031111", + "qwen_chat_blksz_64_031121-031123", + ]: + for suffix in [".jsonl", "_prompt.jsonl"]: + path = legacy_source / f"{filename}{suffix}" + path.write_text( + json.dumps( + { + "chat_id": "c1", + "turn": 1, + "timestamp": 3605.0, + "input_length": 20, + "output_length": 7, + "prompt": "prompt", + } + ) + + "\n", + encoding="utf-8", + ) + + sentinel = traces_dir / "chat_w20260311_1000.jsonl" + sentinel.write_text("sentinel\n", encoding="utf-8") + + proc = subprocess.run( + [ + "python3", + "scripts/prepare_trace_windows.py", + "--legacy-source", + str(legacy_source), + "--thinking-source", + str(thinking_source), + "--output-root", + str(output_root), + "--workloads", + "chat", + "--overwrite", + ], + cwd="/home/gahow/phd/aituner", + capture_output=True, + text=True, + ) + + self.assertNotEqual(proc.returncode, 0) + self.assertEqual(sentinel.read_text(encoding="utf-8"), "sentinel\n") + self.assertEqual(sorted(path.name for path in traces_dir.glob("*.tmp.*")), []) + def test_binary_search_max_feasible(self) -> None: result = binary_search_max_feasible( low=0.0,