Initial project scaffold
This commit is contained in:
51
tasks/03_tiled_matmul/bench.py
Normal file
51
tasks/03_tiled_matmul/bench.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
|
||||
from kernels.triton.tiled_matmul import triton_tiled_matmul
|
||||
from reference.torch_matmul import torch_matmul
|
||||
|
||||
|
||||
def benchmark(fn, *args, warmup: int = 5, reps: int = 20) -> float:
|
||||
for _ in range(warmup):
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms = []
|
||||
for _ in range(reps):
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms.append((time.perf_counter() - start) * 1e3)
|
||||
return statistics.median(times_ms)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
for m, k, n in [(128, 128, 128), (512, 512, 512)]:
|
||||
a = torch.randn(m, k, device=device)
|
||||
b = torch.randn(k, n, device=device)
|
||||
ref_ms = benchmark(torch_matmul, a, b)
|
||||
print(f"torch_matmul {m}x{k}x{n}: {ref_ms:.3f} ms")
|
||||
if device == "cuda":
|
||||
try:
|
||||
triton_ms = benchmark(triton_tiled_matmul, a, b)
|
||||
print(f"triton_tiled_matmul {m}x{k}x{n}: {triton_ms:.3f} ms")
|
||||
except (NotImplementedError, RuntimeError) as exc:
|
||||
print(f"triton_tiled_matmul {m}x{k}x{n}: skipped ({exc})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user