from __future__ import annotations import statistics import sys import time from pathlib import Path ROOT = Path(__file__).resolve().parents[2] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) import torch from kernels.triton.tiled_matmul import triton_tiled_matmul from reference.torch_matmul import torch_matmul def benchmark(fn, *args, warmup: int = 5, reps: int = 20) -> float: for _ in range(warmup): fn(*args) if args[0].is_cuda: torch.cuda.synchronize() times_ms = [] for _ in range(reps): if args[0].is_cuda: torch.cuda.synchronize() start = time.perf_counter() fn(*args) if args[0].is_cuda: torch.cuda.synchronize() times_ms.append((time.perf_counter() - start) * 1e3) return statistics.median(times_ms) def main() -> None: device = "cuda" if torch.cuda.is_available() else "cpu" for m, k, n in [(128, 128, 128), (512, 512, 512)]: a = torch.randn(m, k, device=device) b = torch.randn(k, n, device=device) ref_ms = benchmark(torch_matmul, a, b) print(f"torch_matmul {m}x{k}x{n}: {ref_ms:.3f} ms") if device == "cuda": try: triton_ms = benchmark(triton_tiled_matmul, a, b) print(f"triton_tiled_matmul {m}x{k}x{n}: {triton_ms:.3f} ms") except (NotImplementedError, RuntimeError) as exc: print(f"triton_tiled_matmul {m}x{k}x{n}: skipped ({exc})") if __name__ == "__main__": main()