from __future__ import annotations import argparse import statistics import sys import time from pathlib import Path ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) import torch from kernels.triton.flash_attention_fwd import triton_flash_attention_fwd from reference.torch_attention import torch_attention from tools.lab_extension import build_extension def benchmark(fn, *args, warmup: int = 5, reps: int = 20, **kwargs) -> float: for _ in range(warmup): fn(*args, **kwargs) if args[0].is_cuda: torch.cuda.synchronize() times_ms = [] for _ in range(reps): if args[0].is_cuda: torch.cuda.synchronize() start = time.perf_counter() fn(*args, **kwargs) if args[0].is_cuda: torch.cuda.synchronize() times_ms.append((time.perf_counter() - start) * 1e3) return statistics.median(times_ms) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--mode", choices=["all", "torch", "triton", "cuda"], default="all") parser.add_argument("--batch", type=int, default=2) parser.add_argument("--heads", type=int, default=8) parser.add_argument("--seq", type=int, default=128) parser.add_argument("--dim", type=int, default=64) parser.add_argument("--causal", action="store_true") args = parser.parse_args() q = torch.randn(args.batch, args.heads, args.seq, args.dim, device=args.device) k = torch.randn(args.batch, args.heads, args.seq, args.dim, device=args.device) v = torch.randn(args.batch, args.heads, args.seq, args.dim, device=args.device) if args.mode in {"all", "torch"}: elapsed_ms = benchmark(torch_attention, q, k, v, causal=args.causal) print(f"torch: {elapsed_ms:.3f} ms") if args.device == "cuda" and args.mode in {"all", "triton"}: try: elapsed_ms = benchmark(triton_flash_attention_fwd, q, k, v, causal=args.causal) print(f"triton: {elapsed_ms:.3f} ms") except (NotImplementedError, RuntimeError) as exc: print(f"triton: skipped ({exc})") if args.device == "cuda" and args.mode in {"all", "cuda"}: ext = build_extension(verbose=False) if ext is None or not hasattr(torch.ops, "kernel_lab"): print("cuda: skipped (extension unavailable)") else: try: elapsed_ms = benchmark( torch.ops.kernel_lab.flash_attention_fwd, q, k, v, args.causal ) print(f"cuda: {elapsed_ms:.3f} ms") except Exception as exc: print(f"cuda: skipped ({exc})") if __name__ == "__main__": main()