Support q exit in interval mode

This commit is contained in:
2026-06-25 16:51:23 +08:00
parent 214dddced4
commit 766f5ade82
2 changed files with 108 additions and 10 deletions

View File

@@ -3,13 +3,17 @@ from __future__ import annotations
import argparse import argparse
import csv import csv
import os import os
import select
import signal
import subprocess import subprocess
import sys import sys
import threading
import time import time
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Sequence from typing import Any, Iterator, Sequence
import yaml import yaml
from rich import box from rich import box
@@ -369,23 +373,81 @@ def format_memory_pair(gpu: GPUStat) -> str:
return f"{used:5.1f}/{total:5.1f}M" return f"{used:5.1f}/{total:5.1f}M"
@contextmanager
def terminal_key_mode(input_stream: Any) -> Iterator[None]:
if not input_stream.isatty():
yield
return
import termios
import tty
fd = input_stream.fileno()
original_attrs = termios.tcgetattr(fd)
try:
tty.setcbreak(fd)
yield
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, original_attrs)
def read_key(input_stream: Any, timeout: float) -> str | None:
readable, _, _ = select.select([input_stream], [], [], timeout)
if not readable:
return None
return input_stream.read(1)
def watch_quit_key(input_stream: Any, stop_event: threading.Event) -> None:
while not stop_event.is_set():
try:
key = read_key(input_stream, 0.1)
except OSError:
return
if key and key.lower() == "q":
os.kill(os.getpid(), signal.SIGINT)
return
@contextmanager
def quit_key_interrupt(input_stream: Any) -> Iterator[None]:
if not input_stream.isatty():
yield
return
stop_event = threading.Event()
thread = threading.Thread(
target=watch_quit_key,
args=(input_stream, stop_event),
daemon=True,
)
thread.start()
try:
yield
finally:
stop_event.set()
thread.join(timeout=0.2)
def refresh_dashboard( def refresh_dashboard(
machines: Sequence[MachineConfig], machines: Sequence[MachineConfig],
config_path: Path, config_path: Path,
timeout: float, timeout: float,
interval: float, interval: float,
console: Console, console: Console,
input_stream: Any = sys.stdin,
) -> None: ) -> None:
while True: with terminal_key_mode(input_stream), quit_key_interrupt(input_stream):
started_at = time.monotonic() while True:
results = collect_status(machines, timeout) started_at = time.monotonic()
refreshed_at = time.strftime("%H:%M:%S") results = collect_status(machines, timeout)
console.clear() refreshed_at = time.strftime("%H:%M:%S")
render_dashboard(results, config_path, console, refreshed_at=refreshed_at) console.clear()
render_dashboard(results, config_path, console, refreshed_at=refreshed_at)
sleep_seconds = interval - (time.monotonic() - started_at) sleep_seconds = interval - (time.monotonic() - started_at)
if sleep_seconds > 0: if sleep_seconds > 0:
time.sleep(sleep_seconds) time.sleep(sleep_seconds)
def build_parser() -> argparse.ArgumentParser: def build_parser() -> argparse.ArgumentParser:

View File

@@ -1,7 +1,10 @@
import io import io
import signal
import threading
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import unittest import unittest
from unittest.mock import patch
from rich.console import Console from rich.console import Console
@@ -14,6 +17,8 @@ from mm.cli import (
load_machines, load_machines,
parse_nvidia_smi_csv, parse_nvidia_smi_csv,
render_dashboard, render_dashboard,
refresh_dashboard,
watch_quit_key,
) )
@@ -105,6 +110,37 @@ class CLITests(unittest.TestCase):
self.assertEqual(args.interval, 15) self.assertEqual(args.interval, 15)
def test_watch_quit_key_sends_sigint_for_q_key(self) -> None:
with (
patch("mm.cli.read_key", return_value="q"),
patch("mm.cli.os.getpid", return_value=123),
patch("mm.cli.os.kill") as kill,
):
watch_quit_key(io.StringIO(), threading.Event())
kill.assert_called_once_with(123, signal.SIGINT)
def test_refresh_dashboard_sleeps_between_refreshes(self) -> None:
console = Console(file=io.StringIO(), record=True, color_system=None, width=120)
with (
patch("mm.cli.collect_status", return_value=[]) as collect_status,
patch("mm.cli.render_dashboard"),
patch("mm.cli.time.sleep", side_effect=KeyboardInterrupt) as sleep,
):
with self.assertRaises(KeyboardInterrupt):
refresh_dashboard(
[],
Path("list.yaml"),
timeout=1,
interval=10,
console=console,
input_stream=io.StringIO(),
)
collect_status.assert_called_once_with([], 1)
sleep.assert_called_once()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()