111 lines
3.1 KiB
Python
111 lines
3.1 KiB
Python
import io
|
|
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
import unittest
|
|
|
|
from rich.console import Console
|
|
|
|
from mm.cli import (
|
|
GPUStat,
|
|
MachineConfig,
|
|
MachineResult,
|
|
build_parser,
|
|
gpu_summary,
|
|
load_machines,
|
|
parse_nvidia_smi_csv,
|
|
render_dashboard,
|
|
)
|
|
|
|
|
|
class ConfigTests(unittest.TestCase):
|
|
def test_loads_string_and_object_machine_entries(self) -> None:
|
|
with TemporaryDirectory() as tmpdir:
|
|
config_path = Path(tmpdir) / "list.yaml"
|
|
config_path.write_text(
|
|
"""
|
|
machines:
|
|
- dash0
|
|
- alias: dash1
|
|
label: training-1
|
|
""",
|
|
encoding="utf-8",
|
|
)
|
|
|
|
machines = load_machines(config_path)
|
|
|
|
self.assertEqual([machine.alias for machine in machines], ["dash0", "dash1"])
|
|
self.assertEqual([machine.label for machine in machines], ["dash0", "training-1"])
|
|
|
|
def test_loads_mapping_machine_entries(self) -> None:
|
|
with TemporaryDirectory() as tmpdir:
|
|
config_path = Path(tmpdir) / "list.yaml"
|
|
config_path.write_text(
|
|
"""
|
|
machines:
|
|
dash0:
|
|
dash1:
|
|
label: training-1
|
|
""",
|
|
encoding="utf-8",
|
|
)
|
|
|
|
machines = load_machines(config_path)
|
|
|
|
self.assertEqual([machine.alias for machine in machines], ["dash0", "dash1"])
|
|
self.assertEqual([machine.label for machine in machines], ["dash0", "training-1"])
|
|
|
|
|
|
class NvidiaSmiParsingTests(unittest.TestCase):
|
|
def test_parse_csv_output(self) -> None:
|
|
output = (
|
|
"0, NVIDIA A100-SXM4-80GB, 12000, 81920, 84, 61, 294.3, 400.0\n"
|
|
"1, NVIDIA A100-SXM4-80GB, 0, 81920, 0, 32, 49.8, 400.0\n"
|
|
)
|
|
|
|
gpus = parse_nvidia_smi_csv(output)
|
|
|
|
self.assertEqual(len(gpus), 2)
|
|
self.assertEqual(gpus[0].index, "0")
|
|
self.assertEqual(gpus[0].memory_used_mib, 12000)
|
|
self.assertEqual(gpus[0].utilization_pct, 84)
|
|
self.assertEqual(gpus[1].power_draw_w, 49.8)
|
|
|
|
|
|
class RenderingTests(unittest.TestCase):
|
|
def test_render_dashboard_includes_last_refresh_time(self) -> None:
|
|
console = Console(
|
|
file=io.StringIO(),
|
|
record=True,
|
|
color_system=None,
|
|
width=120,
|
|
)
|
|
|
|
render_dashboard([], Path("list.yaml"), console, refreshed_at="12:34:56")
|
|
|
|
self.assertIn("last refresh: 12:34:56", console.export_text())
|
|
|
|
def test_gpu_summary_uses_fixed_width_decimal_metrics(self) -> None:
|
|
result = MachineResult(
|
|
machine=MachineConfig(alias="dash0", label="dash0"),
|
|
gpus=(
|
|
GPUStat("0", "GPU", 88 * 1024, 96 * 1024, 0, None, None, None),
|
|
GPUStat("1", "GPU", 0, 96 * 1024, 0, None, None, None),
|
|
),
|
|
)
|
|
|
|
self.assertEqual(
|
|
gpu_summary(result).plain,
|
|
"0 88.0/ 96.0G 0.0% 1 0.0/ 96.0G 0.0%",
|
|
)
|
|
|
|
|
|
class CLITests(unittest.TestCase):
|
|
def test_interval_flag_sets_refresh_interval(self) -> None:
|
|
args = build_parser().parse_args(["--interval", "15"])
|
|
|
|
self.assertEqual(args.interval, 15)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|