51 lines
1.4 KiB
Python
51 lines
1.4 KiB
Python
from __future__ import annotations
|
|
|
|
import statistics
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[2]
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.insert(0, str(ROOT))
|
|
|
|
import torch
|
|
|
|
from kernels.triton.vector_add import triton_vector_add
|
|
from reference.torch_vector_add import torch_vector_add
|
|
|
|
|
|
def benchmark(fn, *args, warmup: int = 5, reps: int = 25) -> float:
|
|
for _ in range(warmup):
|
|
fn(*args)
|
|
if args[0].is_cuda:
|
|
torch.cuda.synchronize()
|
|
times_ms = []
|
|
for _ in range(reps):
|
|
if args[0].is_cuda:
|
|
torch.cuda.synchronize()
|
|
start = time.perf_counter()
|
|
fn(*args)
|
|
if args[0].is_cuda:
|
|
torch.cuda.synchronize()
|
|
times_ms.append((time.perf_counter() - start) * 1e3)
|
|
return statistics.median(times_ms)
|
|
|
|
|
|
def main() -> None:
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
x = torch.randn(1 << 20, device=device)
|
|
y = torch.randn(1 << 20, device=device)
|
|
ref_ms = benchmark(torch_vector_add, x, y)
|
|
print(f"torch_vector_add: {ref_ms:.3f} ms")
|
|
if device == "cuda":
|
|
try:
|
|
triton_ms = benchmark(triton_vector_add, x, y)
|
|
print(f"triton_vector_add: {triton_ms:.3f} ms")
|
|
except (NotImplementedError, RuntimeError) as exc:
|
|
print(f"triton_vector_add: skipped ({exc})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|