Support q exit in interval mode
This commit is contained in:
64
mm/cli.py
64
mm/cli.py
@@ -3,13 +3,17 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Sequence
|
||||
from typing import Any, Iterator, Sequence
|
||||
|
||||
import yaml
|
||||
from rich import box
|
||||
@@ -369,13 +373,71 @@ def format_memory_pair(gpu: GPUStat) -> str:
|
||||
return f"{used:5.1f}/{total:5.1f}M"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def terminal_key_mode(input_stream: Any) -> Iterator[None]:
|
||||
if not input_stream.isatty():
|
||||
yield
|
||||
return
|
||||
|
||||
import termios
|
||||
import tty
|
||||
|
||||
fd = input_stream.fileno()
|
||||
original_attrs = termios.tcgetattr(fd)
|
||||
try:
|
||||
tty.setcbreak(fd)
|
||||
yield
|
||||
finally:
|
||||
termios.tcsetattr(fd, termios.TCSADRAIN, original_attrs)
|
||||
|
||||
|
||||
def read_key(input_stream: Any, timeout: float) -> str | None:
|
||||
readable, _, _ = select.select([input_stream], [], [], timeout)
|
||||
if not readable:
|
||||
return None
|
||||
return input_stream.read(1)
|
||||
|
||||
|
||||
def watch_quit_key(input_stream: Any, stop_event: threading.Event) -> None:
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
key = read_key(input_stream, 0.1)
|
||||
except OSError:
|
||||
return
|
||||
if key and key.lower() == "q":
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
return
|
||||
|
||||
|
||||
@contextmanager
|
||||
def quit_key_interrupt(input_stream: Any) -> Iterator[None]:
|
||||
if not input_stream.isatty():
|
||||
yield
|
||||
return
|
||||
|
||||
stop_event = threading.Event()
|
||||
thread = threading.Thread(
|
||||
target=watch_quit_key,
|
||||
args=(input_stream, stop_event),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
stop_event.set()
|
||||
thread.join(timeout=0.2)
|
||||
|
||||
|
||||
def refresh_dashboard(
|
||||
machines: Sequence[MachineConfig],
|
||||
config_path: Path,
|
||||
timeout: float,
|
||||
interval: float,
|
||||
console: Console,
|
||||
input_stream: Any = sys.stdin,
|
||||
) -> None:
|
||||
with terminal_key_mode(input_stream), quit_key_interrupt(input_stream):
|
||||
while True:
|
||||
started_at = time.monotonic()
|
||||
results = collect_status(machines, timeout)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import io
|
||||
import signal
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
@@ -14,6 +17,8 @@ from mm.cli import (
|
||||
load_machines,
|
||||
parse_nvidia_smi_csv,
|
||||
render_dashboard,
|
||||
refresh_dashboard,
|
||||
watch_quit_key,
|
||||
)
|
||||
|
||||
|
||||
@@ -105,6 +110,37 @@ class CLITests(unittest.TestCase):
|
||||
|
||||
self.assertEqual(args.interval, 15)
|
||||
|
||||
def test_watch_quit_key_sends_sigint_for_q_key(self) -> None:
|
||||
with (
|
||||
patch("mm.cli.read_key", return_value="q"),
|
||||
patch("mm.cli.os.getpid", return_value=123),
|
||||
patch("mm.cli.os.kill") as kill,
|
||||
):
|
||||
watch_quit_key(io.StringIO(), threading.Event())
|
||||
|
||||
kill.assert_called_once_with(123, signal.SIGINT)
|
||||
|
||||
def test_refresh_dashboard_sleeps_between_refreshes(self) -> None:
|
||||
console = Console(file=io.StringIO(), record=True, color_system=None, width=120)
|
||||
|
||||
with (
|
||||
patch("mm.cli.collect_status", return_value=[]) as collect_status,
|
||||
patch("mm.cli.render_dashboard"),
|
||||
patch("mm.cli.time.sleep", side_effect=KeyboardInterrupt) as sleep,
|
||||
):
|
||||
with self.assertRaises(KeyboardInterrupt):
|
||||
refresh_dashboard(
|
||||
[],
|
||||
Path("list.yaml"),
|
||||
timeout=1,
|
||||
interval=10,
|
||||
console=console,
|
||||
input_stream=io.StringIO(),
|
||||
)
|
||||
|
||||
collect_status.assert_called_once_with([], 1)
|
||||
sleep.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user