Support q exit in interval mode

This commit is contained in:
2026-06-25 16:51:23 +08:00
parent 214dddced4
commit 766f5ade82
2 changed files with 108 additions and 10 deletions

View File

@@ -3,13 +3,17 @@ from __future__ import annotations
import argparse
import csv
import os
import select
import signal
import subprocess
import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Sequence
from typing import Any, Iterator, Sequence
import yaml
from rich import box
@@ -369,23 +373,81 @@ def format_memory_pair(gpu: GPUStat) -> str:
return f"{used:5.1f}/{total:5.1f}M"
@contextmanager
def terminal_key_mode(input_stream: Any) -> Iterator[None]:
if not input_stream.isatty():
yield
return
import termios
import tty
fd = input_stream.fileno()
original_attrs = termios.tcgetattr(fd)
try:
tty.setcbreak(fd)
yield
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, original_attrs)
def read_key(input_stream: Any, timeout: float) -> str | None:
readable, _, _ = select.select([input_stream], [], [], timeout)
if not readable:
return None
return input_stream.read(1)
def watch_quit_key(input_stream: Any, stop_event: threading.Event) -> None:
while not stop_event.is_set():
try:
key = read_key(input_stream, 0.1)
except OSError:
return
if key and key.lower() == "q":
os.kill(os.getpid(), signal.SIGINT)
return
@contextmanager
def quit_key_interrupt(input_stream: Any) -> Iterator[None]:
if not input_stream.isatty():
yield
return
stop_event = threading.Event()
thread = threading.Thread(
target=watch_quit_key,
args=(input_stream, stop_event),
daemon=True,
)
thread.start()
try:
yield
finally:
stop_event.set()
thread.join(timeout=0.2)
def refresh_dashboard(
machines: Sequence[MachineConfig],
config_path: Path,
timeout: float,
interval: float,
console: Console,
input_stream: Any = sys.stdin,
) -> None:
while True:
started_at = time.monotonic()
results = collect_status(machines, timeout)
refreshed_at = time.strftime("%H:%M:%S")
console.clear()
render_dashboard(results, config_path, console, refreshed_at=refreshed_at)
with terminal_key_mode(input_stream), quit_key_interrupt(input_stream):
while True:
started_at = time.monotonic()
results = collect_status(machines, timeout)
refreshed_at = time.strftime("%H:%M:%S")
console.clear()
render_dashboard(results, config_path, console, refreshed_at=refreshed_at)
sleep_seconds = interval - (time.monotonic() - started_at)
if sleep_seconds > 0:
time.sleep(sleep_seconds)
sleep_seconds = interval - (time.monotonic() - started_at)
if sleep_seconds > 0:
time.sleep(sleep_seconds)
def build_parser() -> argparse.ArgumentParser:

View File

@@ -1,7 +1,10 @@
import io
import signal
import threading
from pathlib import Path
from tempfile import TemporaryDirectory
import unittest
from unittest.mock import patch
from rich.console import Console
@@ -14,6 +17,8 @@ from mm.cli import (
load_machines,
parse_nvidia_smi_csv,
render_dashboard,
refresh_dashboard,
watch_quit_key,
)
@@ -105,6 +110,37 @@ class CLITests(unittest.TestCase):
self.assertEqual(args.interval, 15)
def test_watch_quit_key_sends_sigint_for_q_key(self) -> None:
with (
patch("mm.cli.read_key", return_value="q"),
patch("mm.cli.os.getpid", return_value=123),
patch("mm.cli.os.kill") as kill,
):
watch_quit_key(io.StringIO(), threading.Event())
kill.assert_called_once_with(123, signal.SIGINT)
def test_refresh_dashboard_sleeps_between_refreshes(self) -> None:
console = Console(file=io.StringIO(), record=True, color_system=None, width=120)
with (
patch("mm.cli.collect_status", return_value=[]) as collect_status,
patch("mm.cli.render_dashboard"),
patch("mm.cli.time.sleep", side_effect=KeyboardInterrupt) as sleep,
):
with self.assertRaises(KeyboardInterrupt):
refresh_dashboard(
[],
Path("list.yaml"),
timeout=1,
interval=10,
console=console,
input_stream=io.StringIO(),
)
collect_status.assert_called_once_with([], 1)
sleep.assert_called_once()
if __name__ == "__main__":
unittest.main()