Files
kernel-lab/bench/bench_vector_add.py
2026-04-10 13:15:06 +00:00

75 lines
2.4 KiB
Python

from __future__ import annotations
import argparse
import statistics
import sys
import time
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import torch
from kernels.triton.vector_add import triton_vector_add
from reference.torch_vector_add import torch_vector_add
from tools.lab_extension import build_extension
def benchmark(fn, *args, warmup: int = 5, reps: int = 30) -> float:
for _ in range(warmup):
fn(*args)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms = []
for _ in range(reps):
if args[0].is_cuda:
torch.cuda.synchronize()
start = time.perf_counter()
fn(*args)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms.append((time.perf_counter() - start) * 1e3)
return statistics.median(times_ms)
def report(name: str, elapsed_ms: float, x: torch.Tensor) -> None:
bytes_moved = 3 * x.numel() * x.element_size()
gbps = bytes_moved / (elapsed_ms * 1e-3) / 1e9
print(f"{name}: {elapsed_ms:.3f} ms | effective bandwidth {gbps:.2f} GB/s")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--mode", choices=["all", "torch", "triton", "cuda"], default="all")
parser.add_argument("--numel", type=int, default=1 << 24)
args = parser.parse_args()
x = torch.randn(args.numel, device=args.device)
y = torch.randn(args.numel, device=args.device)
if args.mode in {"all", "torch"}:
report("torch", benchmark(torch_vector_add, x, y), x)
if args.device == "cuda" and args.mode in {"all", "triton"}:
try:
report("triton", benchmark(triton_vector_add, x, y), x)
except (NotImplementedError, RuntimeError) as exc:
print(f"triton: skipped ({exc})")
if args.device == "cuda" and args.mode in {"all", "cuda"}:
ext = build_extension(verbose=False)
if ext is None or not hasattr(torch.ops, "kernel_lab"):
print("cuda: skipped (extension unavailable)")
else:
try:
report("cuda", benchmark(torch.ops.kernel_lab.vector_add, x, y), x)
except Exception as exc:
print(f"cuda: skipped ({exc})")
if __name__ == "__main__":
main()