From d0c89dac48a6e9201db5577566dd1ad98250270d Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Sat, 16 May 2026 15:51:04 +0800 Subject: [PATCH] Clean marked trial engine processes --- src/aituner/worker.py | 62 ++++++++++++++++++++++++++++++++++++----- tests/test_core_flow.py | 21 ++++++++++++++ 2 files changed, 76 insertions(+), 7 deletions(-) diff --git a/src/aituner/worker.py b/src/aituner/worker.py index 7931f12..4a7930e 100644 --- a/src/aituner/worker.py +++ b/src/aituner/worker.py @@ -409,7 +409,44 @@ def _process_group_exists(pgid: int) -> bool: return False -def _terminate_process_tree(process: subprocess.Popen[str], *, timeout_s: float = 30.0) -> None: +def _pids_matching_env(marker_env: dict[str, str] | None) -> list[int]: + if not marker_env: + return [] + expected = { + f"{key}={value}".encode() + for key, value in marker_env.items() + } + pids: list[int] = [] + proc_root = Path("/proc") + for entry in proc_root.iterdir(): + if not entry.name.isdigit(): + continue + pid = int(entry.name) + if pid == os.getpid(): + continue + try: + environ = (entry / "environ").read_bytes() + except (FileNotFoundError, PermissionError, ProcessLookupError): + continue + if expected.issubset(set(environ.split(b"\0"))): + pids.append(pid) + return sorted(pids) + + +def _signal_pids(pids: list[int], sig: signal.Signals) -> None: + for pid in pids: + try: + os.kill(pid, sig) + except (ProcessLookupError, PermissionError): + continue + + +def _terminate_process_tree( + process: subprocess.Popen[str], + *, + timeout_s: float = 30.0, + marker_env: dict[str, str] | None = None, +) -> None: try: pgid = os.getpgid(process.pid) except ProcessLookupError: @@ -420,16 +457,18 @@ def _terminate_process_tree(process: subprocess.Popen[str], *, timeout_s: float try: os.killpg(pgid, signal.SIGTERM) except ProcessLookupError: - return + pass + _signal_pids(_pids_matching_env(marker_env), signal.SIGTERM) deadline = time.monotonic() + timeout_s while time.monotonic() < deadline: - if not _process_group_exists(pgid): + if not _process_group_exists(pgid) and not _pids_matching_env(marker_env): return time.sleep(0.1) try: os.killpg(pgid, signal.SIGKILL) except ProcessLookupError: - return + pass + _signal_pids(_pids_matching_env(marker_env), signal.SIGKILL) if process.poll() is None: process.wait(timeout=timeout_s) @@ -448,12 +487,17 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]: probe_details_path = artifact_dir / "probe_details.jsonl" if probe_details_path.exists(): probe_details_path.unlink() + trial_marker_env = { + "AITUNER_STUDY_ID": trial.study_id, + "AITUNER_TRIAL_ID": trial.trial_id, + } with engine_log_path.open("w", encoding="utf-8") as engine_log: def launch_process() -> subprocess.Popen[str]: + launch_env = {**recipe.env, **trial_marker_env} return subprocess.Popen( # noqa: S603 recipe.argv, cwd=recipe.cwd, - env=recipe.env, + env=launch_env, stdout=engine_log, stderr=subprocess.STDOUT, text=True, @@ -546,7 +590,11 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]: probe_history.append(probe_record) StudyStore.write_json(Path(trial.probe_log_path), probe_history) if early_stopped and restart_after_early_stop: - _terminate_process_tree(process, timeout_s=30.0) + _terminate_process_tree( + process, + timeout_s=30.0, + marker_env=trial_marker_env, + ) process = launch_process() _wait_for_server_or_exit( process, @@ -691,4 +739,4 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]: StudyStore.write_json(Path(trial.result_path), result) return result finally: - _terminate_process_tree(process, timeout_s=30.0) + _terminate_process_tree(process, timeout_s=30.0, marker_env=trial_marker_env) diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index c4da979..69b0547 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -2,6 +2,7 @@ from __future__ import annotations import json import os +import signal import subprocess import tempfile import unittest @@ -4573,6 +4574,26 @@ class CoreFlowTests(unittest.TestCase): self.assertEqual(mock_killpg.call_args_list[0].args[0], 1234) process.wait.assert_not_called() + def test_terminate_process_tree_signals_marker_processes_when_group_missing(self) -> None: + process = mock.Mock() + process.pid = 1234 + process.poll.return_value = 0 + marker_env = {"AITUNER_TRIAL_ID": "trial-0001"} + with mock.patch("aituner.worker.os.getpgid", side_effect=ProcessLookupError): + with mock.patch("aituner.worker.os.killpg", side_effect=ProcessLookupError): + with mock.patch( + "aituner.worker._pids_matching_env", + side_effect=[[2222], []], + ) as mock_pids: + with mock.patch("aituner.worker._signal_pids") as mock_signal: + _terminate_process_tree( + process, + timeout_s=1.0, + marker_env=marker_env, + ) + self.assertEqual(mock_pids.call_args_list[0].args[0], marker_env) + self.assertEqual(mock_signal.call_args_list[0].args, ([2222], signal.SIGTERM)) + def test_openai_url_avoids_double_v1(self) -> None: self.assertEqual( _openai_url("http://example.com", "/v1/chat/completions"),