diff --git a/src/aituner/worker.py b/src/aituner/worker.py index fef9901..a4df42c 100644 --- a/src/aituner/worker.py +++ b/src/aituner/worker.py @@ -2,6 +2,8 @@ from __future__ import annotations import json import math +import os +import signal import subprocess import threading import time @@ -218,6 +220,29 @@ def _wait_for_server_or_exit( raise HttpClientError(f"Timed out waiting for {base_url}{healthcheck_path}: {last_error}") +def _terminate_process_tree(process: subprocess.Popen[str], *, timeout_s: float = 30.0) -> None: + if process.poll() is not None: + return + try: + pgid = os.getpgid(process.pid) + except ProcessLookupError: + return + try: + os.killpg(pgid, signal.SIGTERM) + except ProcessLookupError: + return + deadline = time.monotonic() + timeout_s + while time.monotonic() < deadline: + if process.poll() is not None: + return + time.sleep(0.1) + try: + os.killpg(pgid, signal.SIGKILL) + except ProcessLookupError: + return + process.wait(timeout=timeout_s) + + def run_trial(trial_spec_path: Path) -> dict[str, Any]: from .store import StudyStore @@ -237,6 +262,7 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]: stdout=engine_log, stderr=subprocess.STDOUT, text=True, + start_new_session=True, ) probe_history: list[dict[str, Any]] = [] try: @@ -352,10 +378,4 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]: StudyStore.write_json(Path(trial.result_path), result) return result finally: - if process.poll() is None: - process.terminate() - try: - process.wait(timeout=30) - except subprocess.TimeoutExpired: - process.kill() - process.wait(timeout=30) + _terminate_process_tree(process, timeout_s=30.0) diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index 9e45ede..f2ee9ea 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -16,7 +16,7 @@ from aituner.slo import RequestOutcome, summarize_evaluations from aituner.spec import Proposal, StudyState, TrialSummary, load_study_spec from aituner.store import StudyStore from aituner.trace import load_trace_requests, summarize_window -from aituner.worker import _replay_requests, _wait_for_server_or_exit +from aituner.worker import _replay_requests, _terminate_process_tree, _wait_for_server_or_exit from aituner.trace import TraceRequest @@ -910,6 +910,17 @@ class CoreFlowTests(unittest.TestCase): ready_timeout_s=10.0, ) + def test_terminate_process_tree_kills_process_group(self) -> None: + process = mock.Mock() + process.pid = 1234 + process.poll.side_effect = [None, None, 0] + process.wait.return_value = 0 + with mock.patch("aituner.worker.os.getpgid", return_value=1234): + with mock.patch("aituner.worker.os.killpg") as mock_killpg: + _terminate_process_tree(process, timeout_s=1.0) + mock_killpg.assert_called_once() + self.assertEqual(mock_killpg.call_args[0][0], 1234) + def test_openai_url_avoids_double_v1(self) -> None: self.assertEqual( _openai_url("http://example.com", "/v1/chat/completions"),