139 lines
5.0 KiB
Python
139 lines
5.0 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.insert(0, str(ROOT))
|
|
|
|
import torch
|
|
|
|
from kernels.triton.flash_attention_fwd import triton_flash_attention_fwd
|
|
from kernels.triton.online_softmax import triton_online_softmax
|
|
from kernels.triton.row_softmax import triton_row_softmax
|
|
from kernels.triton.tiled_matmul import triton_tiled_matmul
|
|
from kernels.triton.vector_add import triton_vector_add
|
|
from reference.torch_attention import torch_attention
|
|
from reference.torch_matmul import torch_matmul
|
|
from reference.torch_online_softmax import torch_online_softmax
|
|
from reference.torch_row_softmax import torch_row_softmax
|
|
from reference.torch_vector_add import torch_vector_add
|
|
from tools.lab_extension import build_extension
|
|
|
|
|
|
def compare_vector_add(device: str) -> None:
|
|
x = torch.randn(4097, device=device)
|
|
y = torch.randn(4097, device=device)
|
|
ref = torch_vector_add(x, y)
|
|
print("torch reference ready")
|
|
try:
|
|
torch.testing.assert_close(triton_vector_add(x, y), ref)
|
|
print("triton matches torch")
|
|
except Exception as exc:
|
|
print(f"triton unavailable: {exc}")
|
|
ext = build_extension(verbose=False) if device == "cuda" else None
|
|
if ext is not None and hasattr(torch.ops, "kernel_lab"):
|
|
try:
|
|
torch.testing.assert_close(torch.ops.kernel_lab.vector_add(x, y), ref)
|
|
print("cuda op matches torch")
|
|
except Exception as exc:
|
|
print(f"cuda op unavailable: {exc}")
|
|
|
|
|
|
def compare_softmax(device: str, variant: str) -> None:
|
|
x = torch.randn(128, 257, device=device)
|
|
ref = torch_row_softmax(x) if variant == "row" else torch_online_softmax(x)
|
|
print("torch reference ready")
|
|
triton_fn = triton_row_softmax if variant == "row" else triton_online_softmax
|
|
try:
|
|
torch.testing.assert_close(triton_fn(x), ref, atol=1e-4, rtol=1e-4)
|
|
print("triton matches torch")
|
|
except Exception as exc:
|
|
print(f"triton unavailable: {exc}")
|
|
ext = build_extension(verbose=False) if device == "cuda" else None
|
|
op_name = "row_softmax" if variant == "row" else "online_softmax"
|
|
if ext is not None and hasattr(torch.ops, "kernel_lab"):
|
|
try:
|
|
torch.testing.assert_close(
|
|
getattr(torch.ops.kernel_lab, op_name)(x), ref, atol=1e-4, rtol=1e-4
|
|
)
|
|
print("cuda op matches torch")
|
|
except Exception as exc:
|
|
print(f"cuda op unavailable: {exc}")
|
|
|
|
|
|
def compare_matmul(device: str) -> None:
|
|
a = torch.randn(64, 96, device=device)
|
|
b = torch.randn(96, 48, device=device)
|
|
ref = torch_matmul(a, b)
|
|
print("torch reference ready")
|
|
try:
|
|
torch.testing.assert_close(triton_tiled_matmul(a, b), ref, atol=1e-3, rtol=1e-3)
|
|
print("triton matches torch")
|
|
except Exception as exc:
|
|
print(f"triton unavailable: {exc}")
|
|
ext = build_extension(verbose=False) if device == "cuda" else None
|
|
if ext is not None and hasattr(torch.ops, "kernel_lab"):
|
|
try:
|
|
torch.testing.assert_close(
|
|
torch.ops.kernel_lab.tiled_matmul(a, b), ref, atol=1e-3, rtol=1e-3
|
|
)
|
|
print("cuda op matches torch")
|
|
except Exception as exc:
|
|
print(f"cuda op unavailable: {exc}")
|
|
|
|
|
|
def compare_attention(device: str) -> None:
|
|
q = torch.randn(1, 2, 16, 32, device=device)
|
|
k = torch.randn(1, 2, 16, 32, device=device)
|
|
v = torch.randn(1, 2, 16, 32, device=device)
|
|
ref = torch_attention(q, k, v, causal=False)
|
|
print("torch reference ready")
|
|
try:
|
|
torch.testing.assert_close(
|
|
triton_flash_attention_fwd(q, k, v, causal=False), ref, atol=2e-3, rtol=2e-3
|
|
)
|
|
print("triton matches torch")
|
|
except Exception as exc:
|
|
print(f"triton unavailable: {exc}")
|
|
ext = build_extension(verbose=False) if device == "cuda" else None
|
|
if ext is not None and hasattr(torch.ops, "kernel_lab"):
|
|
try:
|
|
torch.testing.assert_close(
|
|
torch.ops.kernel_lab.flash_attention_fwd(q, k, v, False),
|
|
ref,
|
|
atol=2e-3,
|
|
rtol=2e-3,
|
|
)
|
|
print("cuda op matches torch")
|
|
except Exception as exc:
|
|
print(f"cuda op unavailable: {exc}")
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--task",
|
|
choices=["vector_add", "row_softmax", "online_softmax", "matmul", "attention"],
|
|
required=True,
|
|
)
|
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
|
args = parser.parse_args()
|
|
|
|
if args.task == "vector_add":
|
|
compare_vector_add(args.device)
|
|
elif args.task == "row_softmax":
|
|
compare_softmax(args.device, "row")
|
|
elif args.task == "online_softmax":
|
|
compare_softmax(args.device, "online")
|
|
elif args.task == "matmul":
|
|
compare_matmul(args.device)
|
|
else:
|
|
compare_attention(args.device)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|