Add interval refresh mode

This commit is contained in:
2026-06-25 15:26:18 +08:00
parent b8f007fd5c
commit 214dddced4
3 changed files with 68 additions and 1 deletions

View File

@@ -31,4 +31,7 @@ available on the remote machine.
mm mm
mm --config ~/.config/mm/list.yaml mm --config ~/.config/mm/list.yaml
mm --timeout 8 mm --timeout 8
mm --interval 15
``` ```
Interval mode shows the last refresh time in the header.

View File

@@ -5,6 +5,7 @@ import csv
import os import os
import subprocess import subprocess
import sys import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
@@ -266,8 +267,12 @@ def render_dashboard(
results: Sequence[MachineResult], results: Sequence[MachineResult],
config_path: Path, config_path: Path,
console: Console, console: Console,
refreshed_at: str | None = None,
) -> None: ) -> None:
console.print(Text(f"GPU status from {config_path}", style="dim")) title = f"GPU status from {config_path}"
if refreshed_at:
title = f"{title} (last refresh: {refreshed_at})"
console.print(Text(title, style="dim"))
console.print(render_compact_table(results)) console.print(render_compact_table(results))
@@ -364,6 +369,25 @@ def format_memory_pair(gpu: GPUStat) -> str:
return f"{used:5.1f}/{total:5.1f}M" return f"{used:5.1f}/{total:5.1f}M"
def refresh_dashboard(
machines: Sequence[MachineConfig],
config_path: Path,
timeout: float,
interval: float,
console: Console,
) -> None:
while True:
started_at = time.monotonic()
results = collect_status(machines, timeout)
refreshed_at = time.strftime("%H:%M:%S")
console.clear()
render_dashboard(results, config_path, console, refreshed_at=refreshed_at)
sleep_seconds = interval - (time.monotonic() - started_at)
if sleep_seconds > 0:
time.sleep(sleep_seconds)
def build_parser() -> argparse.ArgumentParser: def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog="mm", prog="mm",
@@ -381,6 +405,12 @@ def build_parser() -> argparse.ArgumentParser:
default=10.0, default=10.0,
help="SSH timeout per machine in seconds. Default: 10.", help="SSH timeout per machine in seconds. Default: 10.",
) )
parser.add_argument(
"-i",
"--interval",
type=float,
help="Refresh every N seconds until interrupted.",
)
parser.add_argument("--no-color", action="store_true", help="Disable terminal colors.") parser.add_argument("--no-color", action="store_true", help="Disable terminal colors.")
parser.add_argument("--version", action="version", version=f"mm {__version__}") parser.add_argument("--version", action="version", version=f"mm {__version__}")
return parser return parser
@@ -394,6 +424,9 @@ def run(argv: Sequence[str] | None = None) -> int:
if args.timeout <= 0: if args.timeout <= 0:
console.print("[red]error:[/] --timeout must be greater than 0") console.print("[red]error:[/] --timeout must be greater than 0")
return 2 return 2
if args.interval is not None and args.interval <= 0:
console.print("[red]error:[/] --interval must be greater than 0")
return 2
try: try:
config_path = resolve_config_path(args.config) config_path = resolve_config_path(args.config)
@@ -402,6 +435,13 @@ def run(argv: Sequence[str] | None = None) -> int:
console.print(f"[red]error:[/] {exc}") console.print(f"[red]error:[/] {exc}")
return 2 return 2
if args.interval is not None:
try:
refresh_dashboard(machines, config_path, args.timeout, args.interval, console)
except KeyboardInterrupt:
console.print()
return 0
results = collect_status(machines, args.timeout) results = collect_status(machines, args.timeout)
render_dashboard(results, config_path, console) render_dashboard(results, config_path, console)
return 1 if any(result.error for result in results) else 0 return 1 if any(result.error for result in results) else 0

View File

@@ -1,14 +1,19 @@
import io
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import unittest import unittest
from rich.console import Console
from mm.cli import ( from mm.cli import (
GPUStat, GPUStat,
MachineConfig, MachineConfig,
MachineResult, MachineResult,
build_parser,
gpu_summary, gpu_summary,
load_machines, load_machines,
parse_nvidia_smi_csv, parse_nvidia_smi_csv,
render_dashboard,
) )
@@ -67,6 +72,18 @@ class NvidiaSmiParsingTests(unittest.TestCase):
class RenderingTests(unittest.TestCase): class RenderingTests(unittest.TestCase):
def test_render_dashboard_includes_last_refresh_time(self) -> None:
console = Console(
file=io.StringIO(),
record=True,
color_system=None,
width=120,
)
render_dashboard([], Path("list.yaml"), console, refreshed_at="12:34:56")
self.assertIn("last refresh: 12:34:56", console.export_text())
def test_gpu_summary_uses_fixed_width_decimal_metrics(self) -> None: def test_gpu_summary_uses_fixed_width_decimal_metrics(self) -> None:
result = MachineResult( result = MachineResult(
machine=MachineConfig(alias="dash0", label="dash0"), machine=MachineConfig(alias="dash0", label="dash0"),
@@ -82,5 +99,12 @@ class RenderingTests(unittest.TestCase):
) )
class CLITests(unittest.TestCase):
def test_interval_flag_sets_refresh_interval(self) -> None:
args = build_parser().parse_args(["--interval", "15"])
self.assertEqual(args.interval, 15)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()