Clean marked trial engine processes
This commit is contained in:
@@ -409,7 +409,44 @@ def _process_group_exists(pgid: int) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _terminate_process_tree(process: subprocess.Popen[str], *, timeout_s: float = 30.0) -> None:
|
||||
def _pids_matching_env(marker_env: dict[str, str] | None) -> list[int]:
|
||||
if not marker_env:
|
||||
return []
|
||||
expected = {
|
||||
f"{key}={value}".encode()
|
||||
for key, value in marker_env.items()
|
||||
}
|
||||
pids: list[int] = []
|
||||
proc_root = Path("/proc")
|
||||
for entry in proc_root.iterdir():
|
||||
if not entry.name.isdigit():
|
||||
continue
|
||||
pid = int(entry.name)
|
||||
if pid == os.getpid():
|
||||
continue
|
||||
try:
|
||||
environ = (entry / "environ").read_bytes()
|
||||
except (FileNotFoundError, PermissionError, ProcessLookupError):
|
||||
continue
|
||||
if expected.issubset(set(environ.split(b"\0"))):
|
||||
pids.append(pid)
|
||||
return sorted(pids)
|
||||
|
||||
|
||||
def _signal_pids(pids: list[int], sig: signal.Signals) -> None:
|
||||
for pid in pids:
|
||||
try:
|
||||
os.kill(pid, sig)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
continue
|
||||
|
||||
|
||||
def _terminate_process_tree(
|
||||
process: subprocess.Popen[str],
|
||||
*,
|
||||
timeout_s: float = 30.0,
|
||||
marker_env: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
pgid = os.getpgid(process.pid)
|
||||
except ProcessLookupError:
|
||||
@@ -420,16 +457,18 @@ def _terminate_process_tree(process: subprocess.Popen[str], *, timeout_s: float
|
||||
try:
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
return
|
||||
pass
|
||||
_signal_pids(_pids_matching_env(marker_env), signal.SIGTERM)
|
||||
deadline = time.monotonic() + timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
if not _process_group_exists(pgid):
|
||||
if not _process_group_exists(pgid) and not _pids_matching_env(marker_env):
|
||||
return
|
||||
time.sleep(0.1)
|
||||
try:
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
return
|
||||
pass
|
||||
_signal_pids(_pids_matching_env(marker_env), signal.SIGKILL)
|
||||
if process.poll() is None:
|
||||
process.wait(timeout=timeout_s)
|
||||
|
||||
@@ -448,12 +487,17 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
|
||||
probe_details_path = artifact_dir / "probe_details.jsonl"
|
||||
if probe_details_path.exists():
|
||||
probe_details_path.unlink()
|
||||
trial_marker_env = {
|
||||
"AITUNER_STUDY_ID": trial.study_id,
|
||||
"AITUNER_TRIAL_ID": trial.trial_id,
|
||||
}
|
||||
with engine_log_path.open("w", encoding="utf-8") as engine_log:
|
||||
def launch_process() -> subprocess.Popen[str]:
|
||||
launch_env = {**recipe.env, **trial_marker_env}
|
||||
return subprocess.Popen( # noqa: S603
|
||||
recipe.argv,
|
||||
cwd=recipe.cwd,
|
||||
env=recipe.env,
|
||||
env=launch_env,
|
||||
stdout=engine_log,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
@@ -546,7 +590,11 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
|
||||
probe_history.append(probe_record)
|
||||
StudyStore.write_json(Path(trial.probe_log_path), probe_history)
|
||||
if early_stopped and restart_after_early_stop:
|
||||
_terminate_process_tree(process, timeout_s=30.0)
|
||||
_terminate_process_tree(
|
||||
process,
|
||||
timeout_s=30.0,
|
||||
marker_env=trial_marker_env,
|
||||
)
|
||||
process = launch_process()
|
||||
_wait_for_server_or_exit(
|
||||
process,
|
||||
@@ -691,4 +739,4 @@ def run_trial(trial_spec_path: Path) -> dict[str, Any]:
|
||||
StudyStore.write_json(Path(trial.result_path), result)
|
||||
return result
|
||||
finally:
|
||||
_terminate_process_tree(process, timeout_s=30.0)
|
||||
_terminate_process_tree(process, timeout_s=30.0, marker_env=trial_marker_env)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -4573,6 +4574,26 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertEqual(mock_killpg.call_args_list[0].args[0], 1234)
|
||||
process.wait.assert_not_called()
|
||||
|
||||
def test_terminate_process_tree_signals_marker_processes_when_group_missing(self) -> None:
|
||||
process = mock.Mock()
|
||||
process.pid = 1234
|
||||
process.poll.return_value = 0
|
||||
marker_env = {"AITUNER_TRIAL_ID": "trial-0001"}
|
||||
with mock.patch("aituner.worker.os.getpgid", side_effect=ProcessLookupError):
|
||||
with mock.patch("aituner.worker.os.killpg", side_effect=ProcessLookupError):
|
||||
with mock.patch(
|
||||
"aituner.worker._pids_matching_env",
|
||||
side_effect=[[2222], []],
|
||||
) as mock_pids:
|
||||
with mock.patch("aituner.worker._signal_pids") as mock_signal:
|
||||
_terminate_process_tree(
|
||||
process,
|
||||
timeout_s=1.0,
|
||||
marker_env=marker_env,
|
||||
)
|
||||
self.assertEqual(mock_pids.call_args_list[0].args[0], marker_env)
|
||||
self.assertEqual(mock_signal.call_args_list[0].args, ([2222], signal.SIGTERM))
|
||||
|
||||
def test_openai_url_avoids_double_v1(self) -> None:
|
||||
self.assertEqual(
|
||||
_openai_url("http://example.com", "/v1/chat/completions"),
|
||||
|
||||
Reference in New Issue
Block a user