from __future__ import annotations import argparse import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) import torch from kernels.triton.flash_attention_fwd import triton_flash_attention_fwd from kernels.triton.online_softmax import triton_online_softmax from kernels.triton.row_softmax import triton_row_softmax from kernels.triton.tiled_matmul import triton_tiled_matmul from kernels.triton.vector_add import triton_vector_add from reference.torch_attention import torch_attention from reference.torch_matmul import torch_matmul from reference.torch_online_softmax import torch_online_softmax from reference.torch_row_softmax import torch_row_softmax from reference.torch_vector_add import torch_vector_add from tools.lab_extension import build_extension def compare_vector_add(device: str) -> None: x = torch.randn(4097, device=device) y = torch.randn(4097, device=device) ref = torch_vector_add(x, y) print("torch reference ready") try: torch.testing.assert_close(triton_vector_add(x, y), ref) print("triton matches torch") except Exception as exc: print(f"triton unavailable: {exc}") ext = build_extension(verbose=False) if device == "cuda" else None if ext is not None and hasattr(torch.ops, "kernel_lab"): try: torch.testing.assert_close(torch.ops.kernel_lab.vector_add(x, y), ref) print("cuda op matches torch") except Exception as exc: print(f"cuda op unavailable: {exc}") def compare_softmax(device: str, variant: str) -> None: x = torch.randn(128, 257, device=device) ref = torch_row_softmax(x) if variant == "row" else torch_online_softmax(x) print("torch reference ready") triton_fn = triton_row_softmax if variant == "row" else triton_online_softmax try: torch.testing.assert_close(triton_fn(x), ref, atol=1e-4, rtol=1e-4) print("triton matches torch") except Exception as exc: print(f"triton unavailable: {exc}") ext = build_extension(verbose=False) if device == "cuda" else None op_name = "row_softmax" if variant == "row" else "online_softmax" if ext is not None and hasattr(torch.ops, "kernel_lab"): try: torch.testing.assert_close( getattr(torch.ops.kernel_lab, op_name)(x), ref, atol=1e-4, rtol=1e-4 ) print("cuda op matches torch") except Exception as exc: print(f"cuda op unavailable: {exc}") def compare_matmul(device: str) -> None: a = torch.randn(64, 96, device=device) b = torch.randn(96, 48, device=device) ref = torch_matmul(a, b) print("torch reference ready") try: torch.testing.assert_close(triton_tiled_matmul(a, b), ref, atol=1e-3, rtol=1e-3) print("triton matches torch") except Exception as exc: print(f"triton unavailable: {exc}") ext = build_extension(verbose=False) if device == "cuda" else None if ext is not None and hasattr(torch.ops, "kernel_lab"): try: torch.testing.assert_close( torch.ops.kernel_lab.tiled_matmul(a, b), ref, atol=1e-3, rtol=1e-3 ) print("cuda op matches torch") except Exception as exc: print(f"cuda op unavailable: {exc}") def compare_attention(device: str) -> None: q = torch.randn(1, 2, 16, 32, device=device) k = torch.randn(1, 2, 16, 32, device=device) v = torch.randn(1, 2, 16, 32, device=device) ref = torch_attention(q, k, v, causal=False) print("torch reference ready") try: torch.testing.assert_close( triton_flash_attention_fwd(q, k, v, causal=False), ref, atol=2e-3, rtol=2e-3 ) print("triton matches torch") except Exception as exc: print(f"triton unavailable: {exc}") ext = build_extension(verbose=False) if device == "cuda" else None if ext is not None and hasattr(torch.ops, "kernel_lab"): try: torch.testing.assert_close( torch.ops.kernel_lab.flash_attention_fwd(q, k, v, False), ref, atol=2e-3, rtol=2e-3, ) print("cuda op matches torch") except Exception as exc: print(f"cuda op unavailable: {exc}") def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( "--task", choices=["vector_add", "row_softmax", "online_softmax", "matmul", "attention"], required=True, ) parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") args = parser.parse_args() if args.task == "vector_add": compare_vector_add(args.device) elif args.task == "row_softmax": compare_softmax(args.device, "row") elif args.task == "online_softmax": compare_softmax(args.device, "online") elif args.task == "matmul": compare_matmul(args.device) else: compare_attention(args.device) if __name__ == "__main__": main()