Files
mm/tests/test_cli.py

147 lines
4.4 KiB
Python

import io
import signal
import threading
from pathlib import Path
from tempfile import TemporaryDirectory
import unittest
from unittest.mock import patch
from rich.console import Console
from mm.cli import (
GPUStat,
MachineConfig,
MachineResult,
build_parser,
gpu_summary,
load_machines,
parse_nvidia_smi_csv,
render_dashboard,
refresh_dashboard,
watch_quit_key,
)
class ConfigTests(unittest.TestCase):
def test_loads_string_and_object_machine_entries(self) -> None:
with TemporaryDirectory() as tmpdir:
config_path = Path(tmpdir) / "list.yaml"
config_path.write_text(
"""
machines:
- dash0
- alias: dash1
label: training-1
""",
encoding="utf-8",
)
machines = load_machines(config_path)
self.assertEqual([machine.alias for machine in machines], ["dash0", "dash1"])
self.assertEqual([machine.label for machine in machines], ["dash0", "training-1"])
def test_loads_mapping_machine_entries(self) -> None:
with TemporaryDirectory() as tmpdir:
config_path = Path(tmpdir) / "list.yaml"
config_path.write_text(
"""
machines:
dash0:
dash1:
label: training-1
""",
encoding="utf-8",
)
machines = load_machines(config_path)
self.assertEqual([machine.alias for machine in machines], ["dash0", "dash1"])
self.assertEqual([machine.label for machine in machines], ["dash0", "training-1"])
class NvidiaSmiParsingTests(unittest.TestCase):
def test_parse_csv_output(self) -> None:
output = (
"0, NVIDIA A100-SXM4-80GB, 12000, 81920, 84, 61, 294.3, 400.0\n"
"1, NVIDIA A100-SXM4-80GB, 0, 81920, 0, 32, 49.8, 400.0\n"
)
gpus = parse_nvidia_smi_csv(output)
self.assertEqual(len(gpus), 2)
self.assertEqual(gpus[0].index, "0")
self.assertEqual(gpus[0].memory_used_mib, 12000)
self.assertEqual(gpus[0].utilization_pct, 84)
self.assertEqual(gpus[1].power_draw_w, 49.8)
class RenderingTests(unittest.TestCase):
def test_render_dashboard_includes_last_refresh_time(self) -> None:
console = Console(
file=io.StringIO(),
record=True,
color_system=None,
width=120,
)
render_dashboard([], Path("list.yaml"), console, refreshed_at="12:34:56")
self.assertIn("last refresh: 12:34:56", console.export_text())
def test_gpu_summary_uses_fixed_width_decimal_metrics(self) -> None:
result = MachineResult(
machine=MachineConfig(alias="dash0", label="dash0"),
gpus=(
GPUStat("0", "GPU", 88 * 1024, 96 * 1024, 0, None, None, None),
GPUStat("1", "GPU", 0, 96 * 1024, 0, None, None, None),
),
)
self.assertEqual(
gpu_summary(result).plain,
"0 88.0/ 96.0G 0.0% 1 0.0/ 96.0G 0.0%",
)
class CLITests(unittest.TestCase):
def test_interval_flag_sets_refresh_interval(self) -> None:
args = build_parser().parse_args(["--interval", "15"])
self.assertEqual(args.interval, 15)
def test_watch_quit_key_sends_sigint_for_q_key(self) -> None:
with (
patch("mm.cli.read_key", return_value="q"),
patch("mm.cli.os.getpid", return_value=123),
patch("mm.cli.os.kill") as kill,
):
watch_quit_key(io.StringIO(), threading.Event())
kill.assert_called_once_with(123, signal.SIGINT)
def test_refresh_dashboard_sleeps_between_refreshes(self) -> None:
console = Console(file=io.StringIO(), record=True, color_system=None, width=120)
with (
patch("mm.cli.collect_status", return_value=[]) as collect_status,
patch("mm.cli.render_dashboard"),
patch("mm.cli.time.sleep", side_effect=KeyboardInterrupt) as sleep,
):
with self.assertRaises(KeyboardInterrupt):
refresh_dashboard(
[],
Path("list.yaml"),
timeout=1,
interval=10,
console=console,
input_stream=io.StringIO(),
)
collect_status.assert_called_once_with([], 1)
sleep.assert_called_once()
if __name__ == "__main__":
unittest.main()