From 214dddced40c394e829aa086a7068090bcee31b8 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Thu, 25 Jun 2026 15:26:18 +0800 Subject: [PATCH] Add interval refresh mode --- README.md | 3 +++ mm/cli.py | 42 +++++++++++++++++++++++++++++++++++++++++- tests/test_cli.py | 24 ++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a9f1eb9..d61ab38 100644 --- a/README.md +++ b/README.md @@ -31,4 +31,7 @@ available on the remote machine. mm mm --config ~/.config/mm/list.yaml mm --timeout 8 +mm --interval 15 ``` + +Interval mode shows the last refresh time in the header. diff --git a/mm/cli.py b/mm/cli.py index 86dc92d..931cd26 100644 --- a/mm/cli.py +++ b/mm/cli.py @@ -5,6 +5,7 @@ import csv import os import subprocess import sys +import time from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from pathlib import Path @@ -266,8 +267,12 @@ def render_dashboard( results: Sequence[MachineResult], config_path: Path, console: Console, + refreshed_at: str | None = None, ) -> None: - console.print(Text(f"GPU status from {config_path}", style="dim")) + title = f"GPU status from {config_path}" + if refreshed_at: + title = f"{title} (last refresh: {refreshed_at})" + console.print(Text(title, style="dim")) console.print(render_compact_table(results)) @@ -364,6 +369,25 @@ def format_memory_pair(gpu: GPUStat) -> str: return f"{used:5.1f}/{total:5.1f}M" +def refresh_dashboard( + machines: Sequence[MachineConfig], + config_path: Path, + timeout: float, + interval: float, + console: Console, +) -> None: + while True: + started_at = time.monotonic() + results = collect_status(machines, timeout) + refreshed_at = time.strftime("%H:%M:%S") + console.clear() + render_dashboard(results, config_path, console, refreshed_at=refreshed_at) + + sleep_seconds = interval - (time.monotonic() - started_at) + if sleep_seconds > 0: + time.sleep(sleep_seconds) + + def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( prog="mm", @@ -381,6 +405,12 @@ def build_parser() -> argparse.ArgumentParser: default=10.0, help="SSH timeout per machine in seconds. Default: 10.", ) + parser.add_argument( + "-i", + "--interval", + type=float, + help="Refresh every N seconds until interrupted.", + ) parser.add_argument("--no-color", action="store_true", help="Disable terminal colors.") parser.add_argument("--version", action="version", version=f"mm {__version__}") return parser @@ -394,6 +424,9 @@ def run(argv: Sequence[str] | None = None) -> int: if args.timeout <= 0: console.print("[red]error:[/] --timeout must be greater than 0") return 2 + if args.interval is not None and args.interval <= 0: + console.print("[red]error:[/] --interval must be greater than 0") + return 2 try: config_path = resolve_config_path(args.config) @@ -402,6 +435,13 @@ def run(argv: Sequence[str] | None = None) -> int: console.print(f"[red]error:[/] {exc}") return 2 + if args.interval is not None: + try: + refresh_dashboard(machines, config_path, args.timeout, args.interval, console) + except KeyboardInterrupt: + console.print() + return 0 + results = collect_status(machines, args.timeout) render_dashboard(results, config_path, console) return 1 if any(result.error for result in results) else 0 diff --git a/tests/test_cli.py b/tests/test_cli.py index 9390943..73ddd7a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,14 +1,19 @@ +import io from pathlib import Path from tempfile import TemporaryDirectory import unittest +from rich.console import Console + from mm.cli import ( GPUStat, MachineConfig, MachineResult, + build_parser, gpu_summary, load_machines, parse_nvidia_smi_csv, + render_dashboard, ) @@ -67,6 +72,18 @@ class NvidiaSmiParsingTests(unittest.TestCase): class RenderingTests(unittest.TestCase): + def test_render_dashboard_includes_last_refresh_time(self) -> None: + console = Console( + file=io.StringIO(), + record=True, + color_system=None, + width=120, + ) + + render_dashboard([], Path("list.yaml"), console, refreshed_at="12:34:56") + + self.assertIn("last refresh: 12:34:56", console.export_text()) + def test_gpu_summary_uses_fixed_width_decimal_metrics(self) -> None: result = MachineResult( machine=MachineConfig(alias="dash0", label="dash0"), @@ -82,5 +99,12 @@ class RenderingTests(unittest.TestCase): ) +class CLITests(unittest.TestCase): + def test_interval_flag_sets_refresh_interval(self) -> None: + args = build_parser().parse_args(["--interval", "15"]) + + self.assertEqual(args.interval, 15) + + if __name__ == "__main__": unittest.main()