Clean marked trial engine processes

This commit is contained in:
2026-05-16 15:51:04 +08:00
parent cf9b8b3f68
commit d0c89dac48
2 changed files with 76 additions and 7 deletions

View File

@@ -409,7 +409,44 @@ def _process_group_exists(pgid: int) -> bool:
return False
def _terminate_process_tree(process: subprocess.Popen[str], *, timeout_s: float = 30.0) -> None:
def _pids_matching_env(marker_env: dict[str, str] | None) -> list[int]:
if not marker_env:
return []
expected = {
f"{key}={value}".encode()
for key, value in marker_env.items()
}
pids: list[int] = []
proc_root = Path("/proc")
for entry in proc_root.iterdir():
if not entry.name.isdigit():
continue
pid = int(entry.name)
if pid == os.getpid():
continue
try:
environ = (entry / "environ").read_bytes()
except (FileNotFoundError, PermissionError, ProcessLookupError):
continue
if expected.issubset(set(environ.split(b"\0"))):
pids.append(pid)
return sorted(pids)
def _signal_pids(pids: list[int], sig: signal.Signals) -> None:
for pid in pids:
try:
os.kill(pid, sig)
except (ProcessLookupError, PermissionError):
continue
def _terminate_process_tree(
process: subprocess.Popen[str],
*,
timeout_s: float = 30.0,
marker_env: dict[str, str] | None = None,
) -> None:
try:
pgid = os.getpgid(process.pid)
except ProcessLookupError:
@@ -420,16 +457,18 @@ def _terminate_process_tree(process: subprocess.Popen[str], *, timeout_s: float
try:
os.killpg(pgid, signal.SIGTERM)
except ProcessLookupError:
return
pass
_signal_pids(_pids_matching_env(marker_env), signal.SIGTERM)
deadline = time.monotonic() + timeout_s
while time.monotonic() < deadline:
if not _process_group_exists(pgid):
if not _process_group_exists(pgid) and not _pids_matching_env(marker_env):
return
time.sleep(0.1)
try:
os.killpg(pgid, signal.SIGKILL)
except ProcessLookupError:
return
pass
_signal_pids(_pids_matching_env(marker_env), signal.SIGKILL)
if process.poll() is None:
process.wait(timeout=timeout_s)
@@ -448,12 +487,17 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
probe_details_path = artifact_dir / "probe_details.jsonl"
if probe_details_path.exists():
probe_details_path.unlink()
trial_marker_env = {
"AITUNER_STUDY_ID": trial.study_id,
"AITUNER_TRIAL_ID": trial.trial_id,
}
with engine_log_path.open("w", encoding="utf-8") as engine_log:
def launch_process() -> subprocess.Popen[str]:
launch_env = {**recipe.env, **trial_marker_env}
return subprocess.Popen( # noqa: S603
recipe.argv,
cwd=recipe.cwd,
env=recipe.env,
env=launch_env,
stdout=engine_log,
stderr=subprocess.STDOUT,
text=True,
@@ -546,7 +590,11 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
probe_history.append(probe_record)
StudyStore.write_json(Path(trial.probe_log_path), probe_history)
if early_stopped and restart_after_early_stop:
_terminate_process_tree(process, timeout_s=30.0)
_terminate_process_tree(
process,
timeout_s=30.0,
marker_env=trial_marker_env,
)
process = launch_process()
_wait_for_server_or_exit(
process,
@@ -691,4 +739,4 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
StudyStore.write_json(Path(trial.result_path), result)
return result
finally:
_terminate_process_tree(process, timeout_s=30.0)
_terminate_process_tree(process, timeout_s=30.0, marker_env=trial_marker_env)

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import json
import os
import signal
import subprocess
import tempfile
import unittest
@@ -4573,6 +4574,26 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(mock_killpg.call_args_list[0].args[0], 1234)
process.wait.assert_not_called()
def test_terminate_process_tree_signals_marker_processes_when_group_missing(self) -> None:
process = mock.Mock()
process.pid = 1234
process.poll.return_value = 0
marker_env = {"AITUNER_TRIAL_ID": "trial-0001"}
with mock.patch("aituner.worker.os.getpgid", side_effect=ProcessLookupError):
with mock.patch("aituner.worker.os.killpg", side_effect=ProcessLookupError):
with mock.patch(
"aituner.worker._pids_matching_env",
side_effect=[[2222], []],
) as mock_pids:
with mock.patch("aituner.worker._signal_pids") as mock_signal:
_terminate_process_tree(
process,
timeout_s=1.0,
marker_env=marker_env,
)
self.assertEqual(mock_pids.call_args_list[0].args[0], marker_env)
self.assertEqual(mock_signal.call_args_list[0].args, ([2222], signal.SIGTERM))
def test_openai_url_avoids_double_v1(self) -> None:
self.assertEqual(
_openai_url("http://example.com", "/v1/chat/completions"),