from __future__ import annotations import argparse import statistics import sys import time from pathlib import Path ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) import torch from kernels.triton.online_softmax import triton_online_softmax from kernels.triton.row_softmax import triton_row_softmax from reference.torch_online_softmax import torch_online_softmax from reference.torch_row_softmax import torch_row_softmax from tools.lab_extension import build_extension def benchmark(fn, *args, warmup: int = 5, reps: int = 25) -> float: for _ in range(warmup): fn(*args) if args[0].is_cuda: torch.cuda.synchronize() times_ms = [] for _ in range(reps): if args[0].is_cuda: torch.cuda.synchronize() start = time.perf_counter() fn(*args) if args[0].is_cuda: torch.cuda.synchronize() times_ms.append((time.perf_counter() - start) * 1e3) return statistics.median(times_ms) def report(name: str, elapsed_ms: float, x: torch.Tensor) -> None: logical_bytes = 3 * x.numel() * x.element_size() gbps = logical_bytes / (elapsed_ms * 1e-3) / 1e9 print(f"{name}: {elapsed_ms:.3f} ms | logical bandwidth {gbps:.2f} GB/s") def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--mode", choices=["all", "torch", "triton", "cuda"], default="all") parser.add_argument("--variant", choices=["row", "online"], default="row") parser.add_argument("--rows", type=int, default=4096) parser.add_argument("--cols", type=int, default=1024) args = parser.parse_args() x = torch.randn(args.rows, args.cols, device=args.device) ref_fn = torch_row_softmax if args.variant == "row" else torch_online_softmax triton_fn = triton_row_softmax if args.variant == "row" else triton_online_softmax cuda_name = "row_softmax" if args.variant == "row" else "online_softmax" if args.mode in {"all", "torch"}: report(f"torch_{args.variant}_softmax", benchmark(ref_fn, x), x) if args.device == "cuda" and args.mode in {"all", "triton"}: try: report(f"triton_{args.variant}_softmax", benchmark(triton_fn, x), x) except (NotImplementedError, RuntimeError) as exc: print(f"triton_{args.variant}_softmax: skipped ({exc})") if args.device == "cuda" and args.mode in {"all", "cuda"}: ext = build_extension(verbose=False) if ext is None or not hasattr(torch.ops, "kernel_lab"): print(f"cuda_{args.variant}_softmax: skipped (extension unavailable)") else: try: cuda_fn = getattr(torch.ops.kernel_lab, cuda_name) report(f"cuda_{args.variant}_softmax", benchmark(cuda_fn, x), x) except Exception as exc: print(f"cuda_{args.variant}_softmax: skipped ({exc})") if __name__ == "__main__": main()