From 766f5ade8274f73586feb2330cfca32272ec4648 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Thu, 25 Jun 2026 16:51:23 +0800 Subject: [PATCH] Support q exit in interval mode --- mm/cli.py | 82 +++++++++++++++++++++++++++++++++++++++++------ tests/test_cli.py | 36 +++++++++++++++++++++ 2 files changed, 108 insertions(+), 10 deletions(-) diff --git a/mm/cli.py b/mm/cli.py index 931cd26..33810b9 100644 --- a/mm/cli.py +++ b/mm/cli.py @@ -3,13 +3,17 @@ from __future__ import annotations import argparse import csv import os +import select +import signal import subprocess import sys +import threading import time from concurrent.futures import ThreadPoolExecutor, as_completed +from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Any, Sequence +from typing import Any, Iterator, Sequence import yaml from rich import box @@ -369,23 +373,81 @@ def format_memory_pair(gpu: GPUStat) -> str: return f"{used:5.1f}/{total:5.1f}M" +@contextmanager +def terminal_key_mode(input_stream: Any) -> Iterator[None]: + if not input_stream.isatty(): + yield + return + + import termios + import tty + + fd = input_stream.fileno() + original_attrs = termios.tcgetattr(fd) + try: + tty.setcbreak(fd) + yield + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, original_attrs) + + +def read_key(input_stream: Any, timeout: float) -> str | None: + readable, _, _ = select.select([input_stream], [], [], timeout) + if not readable: + return None + return input_stream.read(1) + + +def watch_quit_key(input_stream: Any, stop_event: threading.Event) -> None: + while not stop_event.is_set(): + try: + key = read_key(input_stream, 0.1) + except OSError: + return + if key and key.lower() == "q": + os.kill(os.getpid(), signal.SIGINT) + return + + +@contextmanager +def quit_key_interrupt(input_stream: Any) -> Iterator[None]: + if not input_stream.isatty(): + yield + return + + stop_event = threading.Event() + thread = threading.Thread( + target=watch_quit_key, + args=(input_stream, stop_event), + daemon=True, + ) + thread.start() + try: + yield + finally: + stop_event.set() + thread.join(timeout=0.2) + + def refresh_dashboard( machines: Sequence[MachineConfig], config_path: Path, timeout: float, interval: float, console: Console, + input_stream: Any = sys.stdin, ) -> None: - while True: - started_at = time.monotonic() - results = collect_status(machines, timeout) - refreshed_at = time.strftime("%H:%M:%S") - console.clear() - render_dashboard(results, config_path, console, refreshed_at=refreshed_at) + with terminal_key_mode(input_stream), quit_key_interrupt(input_stream): + while True: + started_at = time.monotonic() + results = collect_status(machines, timeout) + refreshed_at = time.strftime("%H:%M:%S") + console.clear() + render_dashboard(results, config_path, console, refreshed_at=refreshed_at) - sleep_seconds = interval - (time.monotonic() - started_at) - if sleep_seconds > 0: - time.sleep(sleep_seconds) + sleep_seconds = interval - (time.monotonic() - started_at) + if sleep_seconds > 0: + time.sleep(sleep_seconds) def build_parser() -> argparse.ArgumentParser: diff --git a/tests/test_cli.py b/tests/test_cli.py index 73ddd7a..fdeacbd 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,7 +1,10 @@ import io +import signal +import threading from pathlib import Path from tempfile import TemporaryDirectory import unittest +from unittest.mock import patch from rich.console import Console @@ -14,6 +17,8 @@ from mm.cli import ( load_machines, parse_nvidia_smi_csv, render_dashboard, + refresh_dashboard, + watch_quit_key, ) @@ -105,6 +110,37 @@ class CLITests(unittest.TestCase): self.assertEqual(args.interval, 15) + def test_watch_quit_key_sends_sigint_for_q_key(self) -> None: + with ( + patch("mm.cli.read_key", return_value="q"), + patch("mm.cli.os.getpid", return_value=123), + patch("mm.cli.os.kill") as kill, + ): + watch_quit_key(io.StringIO(), threading.Event()) + + kill.assert_called_once_with(123, signal.SIGINT) + + def test_refresh_dashboard_sleeps_between_refreshes(self) -> None: + console = Console(file=io.StringIO(), record=True, color_system=None, width=120) + + with ( + patch("mm.cli.collect_status", return_value=[]) as collect_status, + patch("mm.cli.render_dashboard"), + patch("mm.cli.time.sleep", side_effect=KeyboardInterrupt) as sleep, + ): + with self.assertRaises(KeyboardInterrupt): + refresh_dashboard( + [], + Path("list.yaml"), + timeout=1, + interval=10, + console=console, + input_stream=io.StringIO(), + ) + + collect_status.assert_called_once_with([], 1) + sleep.assert_called_once() + if __name__ == "__main__": unittest.main()