Files
kernel-lab/tools/compare_against_torch.py
2026-04-10 13:15:06 +00:00

139 lines
5.0 KiB
Python

from __future__ import annotations
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import torch
from kernels.triton.flash_attention_fwd import triton_flash_attention_fwd
from kernels.triton.online_softmax import triton_online_softmax
from kernels.triton.row_softmax import triton_row_softmax
from kernels.triton.tiled_matmul import triton_tiled_matmul
from kernels.triton.vector_add import triton_vector_add
from reference.torch_attention import torch_attention
from reference.torch_matmul import torch_matmul
from reference.torch_online_softmax import torch_online_softmax
from reference.torch_row_softmax import torch_row_softmax
from reference.torch_vector_add import torch_vector_add
from tools.lab_extension import build_extension
def compare_vector_add(device: str) -> None:
x = torch.randn(4097, device=device)
y = torch.randn(4097, device=device)
ref = torch_vector_add(x, y)
print("torch reference ready")
try:
torch.testing.assert_close(triton_vector_add(x, y), ref)
print("triton matches torch")
except Exception as exc:
print(f"triton unavailable: {exc}")
ext = build_extension(verbose=False) if device == "cuda" else None
if ext is not None and hasattr(torch.ops, "kernel_lab"):
try:
torch.testing.assert_close(torch.ops.kernel_lab.vector_add(x, y), ref)
print("cuda op matches torch")
except Exception as exc:
print(f"cuda op unavailable: {exc}")
def compare_softmax(device: str, variant: str) -> None:
x = torch.randn(128, 257, device=device)
ref = torch_row_softmax(x) if variant == "row" else torch_online_softmax(x)
print("torch reference ready")
triton_fn = triton_row_softmax if variant == "row" else triton_online_softmax
try:
torch.testing.assert_close(triton_fn(x), ref, atol=1e-4, rtol=1e-4)
print("triton matches torch")
except Exception as exc:
print(f"triton unavailable: {exc}")
ext = build_extension(verbose=False) if device == "cuda" else None
op_name = "row_softmax" if variant == "row" else "online_softmax"
if ext is not None and hasattr(torch.ops, "kernel_lab"):
try:
torch.testing.assert_close(
getattr(torch.ops.kernel_lab, op_name)(x), ref, atol=1e-4, rtol=1e-4
)
print("cuda op matches torch")
except Exception as exc:
print(f"cuda op unavailable: {exc}")
def compare_matmul(device: str) -> None:
a = torch.randn(64, 96, device=device)
b = torch.randn(96, 48, device=device)
ref = torch_matmul(a, b)
print("torch reference ready")
try:
torch.testing.assert_close(triton_tiled_matmul(a, b), ref, atol=1e-3, rtol=1e-3)
print("triton matches torch")
except Exception as exc:
print(f"triton unavailable: {exc}")
ext = build_extension(verbose=False) if device == "cuda" else None
if ext is not None and hasattr(torch.ops, "kernel_lab"):
try:
torch.testing.assert_close(
torch.ops.kernel_lab.tiled_matmul(a, b), ref, atol=1e-3, rtol=1e-3
)
print("cuda op matches torch")
except Exception as exc:
print(f"cuda op unavailable: {exc}")
def compare_attention(device: str) -> None:
q = torch.randn(1, 2, 16, 32, device=device)
k = torch.randn(1, 2, 16, 32, device=device)
v = torch.randn(1, 2, 16, 32, device=device)
ref = torch_attention(q, k, v, causal=False)
print("torch reference ready")
try:
torch.testing.assert_close(
triton_flash_attention_fwd(q, k, v, causal=False), ref, atol=2e-3, rtol=2e-3
)
print("triton matches torch")
except Exception as exc:
print(f"triton unavailable: {exc}")
ext = build_extension(verbose=False) if device == "cuda" else None
if ext is not None and hasattr(torch.ops, "kernel_lab"):
try:
torch.testing.assert_close(
torch.ops.kernel_lab.flash_attention_fwd(q, k, v, False),
ref,
atol=2e-3,
rtol=2e-3,
)
print("cuda op matches torch")
except Exception as exc:
print(f"cuda op unavailable: {exc}")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--task",
choices=["vector_add", "row_softmax", "online_softmax", "matmul", "attention"],
required=True,
)
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
if args.task == "vector_add":
compare_vector_add(args.device)
elif args.task == "row_softmax":
compare_softmax(args.device, "row")
elif args.task == "online_softmax":
compare_softmax(args.device, "online")
elif args.task == "matmul":
compare_matmul(args.device)
else:
compare_attention(args.device)
if __name__ == "__main__":
main()