Support q exit in interval mode
This commit is contained in:
82
mm/cli.py
82
mm/cli.py
@@ -3,13 +3,17 @@ from __future__ import annotations
|
|||||||
import argparse
|
import argparse
|
||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
|
import select
|
||||||
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Sequence
|
from typing import Any, Iterator, Sequence
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from rich import box
|
from rich import box
|
||||||
@@ -369,23 +373,81 @@ def format_memory_pair(gpu: GPUStat) -> str:
|
|||||||
return f"{used:5.1f}/{total:5.1f}M"
|
return f"{used:5.1f}/{total:5.1f}M"
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def terminal_key_mode(input_stream: Any) -> Iterator[None]:
|
||||||
|
if not input_stream.isatty():
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
import termios
|
||||||
|
import tty
|
||||||
|
|
||||||
|
fd = input_stream.fileno()
|
||||||
|
original_attrs = termios.tcgetattr(fd)
|
||||||
|
try:
|
||||||
|
tty.setcbreak(fd)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
termios.tcsetattr(fd, termios.TCSADRAIN, original_attrs)
|
||||||
|
|
||||||
|
|
||||||
|
def read_key(input_stream: Any, timeout: float) -> str | None:
|
||||||
|
readable, _, _ = select.select([input_stream], [], [], timeout)
|
||||||
|
if not readable:
|
||||||
|
return None
|
||||||
|
return input_stream.read(1)
|
||||||
|
|
||||||
|
|
||||||
|
def watch_quit_key(input_stream: Any, stop_event: threading.Event) -> None:
|
||||||
|
while not stop_event.is_set():
|
||||||
|
try:
|
||||||
|
key = read_key(input_stream, 0.1)
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
if key and key.lower() == "q":
|
||||||
|
os.kill(os.getpid(), signal.SIGINT)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def quit_key_interrupt(input_stream: Any) -> Iterator[None]:
|
||||||
|
if not input_stream.isatty():
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
stop_event = threading.Event()
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=watch_quit_key,
|
||||||
|
args=(input_stream, stop_event),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
stop_event.set()
|
||||||
|
thread.join(timeout=0.2)
|
||||||
|
|
||||||
|
|
||||||
def refresh_dashboard(
|
def refresh_dashboard(
|
||||||
machines: Sequence[MachineConfig],
|
machines: Sequence[MachineConfig],
|
||||||
config_path: Path,
|
config_path: Path,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
interval: float,
|
interval: float,
|
||||||
console: Console,
|
console: Console,
|
||||||
|
input_stream: Any = sys.stdin,
|
||||||
) -> None:
|
) -> None:
|
||||||
while True:
|
with terminal_key_mode(input_stream), quit_key_interrupt(input_stream):
|
||||||
started_at = time.monotonic()
|
while True:
|
||||||
results = collect_status(machines, timeout)
|
started_at = time.monotonic()
|
||||||
refreshed_at = time.strftime("%H:%M:%S")
|
results = collect_status(machines, timeout)
|
||||||
console.clear()
|
refreshed_at = time.strftime("%H:%M:%S")
|
||||||
render_dashboard(results, config_path, console, refreshed_at=refreshed_at)
|
console.clear()
|
||||||
|
render_dashboard(results, config_path, console, refreshed_at=refreshed_at)
|
||||||
|
|
||||||
sleep_seconds = interval - (time.monotonic() - started_at)
|
sleep_seconds = interval - (time.monotonic() - started_at)
|
||||||
if sleep_seconds > 0:
|
if sleep_seconds > 0:
|
||||||
time.sleep(sleep_seconds)
|
time.sleep(sleep_seconds)
|
||||||
|
|
||||||
|
|
||||||
def build_parser() -> argparse.ArgumentParser:
|
def build_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
import io
|
import io
|
||||||
|
import signal
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
import unittest
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
@@ -14,6 +17,8 @@ from mm.cli import (
|
|||||||
load_machines,
|
load_machines,
|
||||||
parse_nvidia_smi_csv,
|
parse_nvidia_smi_csv,
|
||||||
render_dashboard,
|
render_dashboard,
|
||||||
|
refresh_dashboard,
|
||||||
|
watch_quit_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -105,6 +110,37 @@ class CLITests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(args.interval, 15)
|
self.assertEqual(args.interval, 15)
|
||||||
|
|
||||||
|
def test_watch_quit_key_sends_sigint_for_q_key(self) -> None:
|
||||||
|
with (
|
||||||
|
patch("mm.cli.read_key", return_value="q"),
|
||||||
|
patch("mm.cli.os.getpid", return_value=123),
|
||||||
|
patch("mm.cli.os.kill") as kill,
|
||||||
|
):
|
||||||
|
watch_quit_key(io.StringIO(), threading.Event())
|
||||||
|
|
||||||
|
kill.assert_called_once_with(123, signal.SIGINT)
|
||||||
|
|
||||||
|
def test_refresh_dashboard_sleeps_between_refreshes(self) -> None:
|
||||||
|
console = Console(file=io.StringIO(), record=True, color_system=None, width=120)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("mm.cli.collect_status", return_value=[]) as collect_status,
|
||||||
|
patch("mm.cli.render_dashboard"),
|
||||||
|
patch("mm.cli.time.sleep", side_effect=KeyboardInterrupt) as sleep,
|
||||||
|
):
|
||||||
|
with self.assertRaises(KeyboardInterrupt):
|
||||||
|
refresh_dashboard(
|
||||||
|
[],
|
||||||
|
Path("list.yaml"),
|
||||||
|
timeout=1,
|
||||||
|
interval=10,
|
||||||
|
console=console,
|
||||||
|
input_stream=io.StringIO(),
|
||||||
|
)
|
||||||
|
|
||||||
|
collect_status.assert_called_once_with([], 1)
|
||||||
|
sleep.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user