83 lines
3.0 KiB
Python
83 lines
3.0 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import statistics
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.insert(0, str(ROOT))
|
|
|
|
import torch
|
|
|
|
from kernels.triton.online_softmax import triton_online_softmax
|
|
from kernels.triton.row_softmax import triton_row_softmax
|
|
from reference.torch_online_softmax import torch_online_softmax
|
|
from reference.torch_row_softmax import torch_row_softmax
|
|
from tools.lab_extension import build_extension
|
|
|
|
|
|
def benchmark(fn, *args, warmup: int = 5, reps: int = 25) -> float:
|
|
for _ in range(warmup):
|
|
fn(*args)
|
|
if args[0].is_cuda:
|
|
torch.cuda.synchronize()
|
|
times_ms = []
|
|
for _ in range(reps):
|
|
if args[0].is_cuda:
|
|
torch.cuda.synchronize()
|
|
start = time.perf_counter()
|
|
fn(*args)
|
|
if args[0].is_cuda:
|
|
torch.cuda.synchronize()
|
|
times_ms.append((time.perf_counter() - start) * 1e3)
|
|
return statistics.median(times_ms)
|
|
|
|
|
|
def report(name: str, elapsed_ms: float, x: torch.Tensor) -> None:
|
|
logical_bytes = 3 * x.numel() * x.element_size()
|
|
gbps = logical_bytes / (elapsed_ms * 1e-3) / 1e9
|
|
print(f"{name}: {elapsed_ms:.3f} ms | logical bandwidth {gbps:.2f} GB/s")
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
|
parser.add_argument("--mode", choices=["all", "torch", "triton", "cuda"], default="all")
|
|
parser.add_argument("--variant", choices=["row", "online"], default="row")
|
|
parser.add_argument("--rows", type=int, default=4096)
|
|
parser.add_argument("--cols", type=int, default=1024)
|
|
args = parser.parse_args()
|
|
|
|
x = torch.randn(args.rows, args.cols, device=args.device)
|
|
|
|
ref_fn = torch_row_softmax if args.variant == "row" else torch_online_softmax
|
|
triton_fn = triton_row_softmax if args.variant == "row" else triton_online_softmax
|
|
cuda_name = "row_softmax" if args.variant == "row" else "online_softmax"
|
|
|
|
if args.mode in {"all", "torch"}:
|
|
report(f"torch_{args.variant}_softmax", benchmark(ref_fn, x), x)
|
|
|
|
if args.device == "cuda" and args.mode in {"all", "triton"}:
|
|
try:
|
|
report(f"triton_{args.variant}_softmax", benchmark(triton_fn, x), x)
|
|
except (NotImplementedError, RuntimeError) as exc:
|
|
print(f"triton_{args.variant}_softmax: skipped ({exc})")
|
|
|
|
if args.device == "cuda" and args.mode in {"all", "cuda"}:
|
|
ext = build_extension(verbose=False)
|
|
if ext is None or not hasattr(torch.ops, "kernel_lab"):
|
|
print(f"cuda_{args.variant}_softmax: skipped (extension unavailable)")
|
|
else:
|
|
try:
|
|
cuda_fn = getattr(torch.ops.kernel_lab, cuda_name)
|
|
report(f"cuda_{args.variant}_softmax", benchmark(cuda_fn, x), x)
|
|
except Exception as exc:
|
|
print(f"cuda_{args.variant}_softmax: skipped ({exc})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|