Initial project scaffold
This commit is contained in:
78
bench/bench_attention.py
Normal file
78
bench/bench_attention.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
|
||||
from kernels.triton.flash_attention_fwd import triton_flash_attention_fwd
|
||||
from reference.torch_attention import torch_attention
|
||||
from tools.lab_extension import build_extension
|
||||
|
||||
|
||||
def benchmark(fn, *args, warmup: int = 5, reps: int = 20, **kwargs) -> float:
|
||||
for _ in range(warmup):
|
||||
fn(*args, **kwargs)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms = []
|
||||
for _ in range(reps):
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
fn(*args, **kwargs)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms.append((time.perf_counter() - start) * 1e3)
|
||||
return statistics.median(times_ms)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--mode", choices=["all", "torch", "triton", "cuda"], default="all")
|
||||
parser.add_argument("--batch", type=int, default=2)
|
||||
parser.add_argument("--heads", type=int, default=8)
|
||||
parser.add_argument("--seq", type=int, default=128)
|
||||
parser.add_argument("--dim", type=int, default=64)
|
||||
parser.add_argument("--causal", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
q = torch.randn(args.batch, args.heads, args.seq, args.dim, device=args.device)
|
||||
k = torch.randn(args.batch, args.heads, args.seq, args.dim, device=args.device)
|
||||
v = torch.randn(args.batch, args.heads, args.seq, args.dim, device=args.device)
|
||||
|
||||
if args.mode in {"all", "torch"}:
|
||||
elapsed_ms = benchmark(torch_attention, q, k, v, causal=args.causal)
|
||||
print(f"torch: {elapsed_ms:.3f} ms")
|
||||
|
||||
if args.device == "cuda" and args.mode in {"all", "triton"}:
|
||||
try:
|
||||
elapsed_ms = benchmark(triton_flash_attention_fwd, q, k, v, causal=args.causal)
|
||||
print(f"triton: {elapsed_ms:.3f} ms")
|
||||
except (NotImplementedError, RuntimeError) as exc:
|
||||
print(f"triton: skipped ({exc})")
|
||||
|
||||
if args.device == "cuda" and args.mode in {"all", "cuda"}:
|
||||
ext = build_extension(verbose=False)
|
||||
if ext is None or not hasattr(torch.ops, "kernel_lab"):
|
||||
print("cuda: skipped (extension unavailable)")
|
||||
else:
|
||||
try:
|
||||
elapsed_ms = benchmark(
|
||||
torch.ops.kernel_lab.flash_attention_fwd, q, k, v, args.causal
|
||||
)
|
||||
print(f"cuda: {elapsed_ms:.3f} ms")
|
||||
except Exception as exc:
|
||||
print(f"cuda: skipped ({exc})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
81
bench/bench_matmul.py
Normal file
81
bench/bench_matmul.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
|
||||
from kernels.triton.tiled_matmul import triton_tiled_matmul
|
||||
from reference.torch_matmul import torch_matmul
|
||||
from tools.lab_extension import build_extension
|
||||
|
||||
|
||||
def benchmark(fn, *args, warmup: int = 5, reps: int = 20) -> float:
|
||||
for _ in range(warmup):
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms = []
|
||||
for _ in range(reps):
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms.append((time.perf_counter() - start) * 1e3)
|
||||
return statistics.median(times_ms)
|
||||
|
||||
|
||||
def report(name: str, elapsed_ms: float, m: int, n: int, k: int) -> None:
|
||||
tflops = (2.0 * m * n * k) / (elapsed_ms * 1e-3) / 1e12
|
||||
print(f"{name}: {elapsed_ms:.3f} ms | throughput {tflops:.3f} TFLOP/s")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--mode", choices=["all", "torch", "triton", "cuda"], default="all")
|
||||
parser.add_argument("--m", type=int, default=1024)
|
||||
parser.add_argument("--n", type=int, default=1024)
|
||||
parser.add_argument("--k", type=int, default=1024)
|
||||
args = parser.parse_args()
|
||||
|
||||
a = torch.randn(args.m, args.k, device=args.device)
|
||||
b = torch.randn(args.k, args.n, device=args.device)
|
||||
|
||||
if args.mode in {"all", "torch"}:
|
||||
report("torch", benchmark(torch_matmul, a, b), args.m, args.n, args.k)
|
||||
|
||||
if args.device == "cuda" and args.mode in {"all", "triton"}:
|
||||
try:
|
||||
report("triton", benchmark(triton_tiled_matmul, a, b), args.m, args.n, args.k)
|
||||
except (NotImplementedError, RuntimeError) as exc:
|
||||
print(f"triton: skipped ({exc})")
|
||||
|
||||
if args.device == "cuda" and args.mode in {"all", "cuda"}:
|
||||
ext = build_extension(verbose=False)
|
||||
if ext is None or not hasattr(torch.ops, "kernel_lab"):
|
||||
print("cuda: skipped (extension unavailable)")
|
||||
else:
|
||||
try:
|
||||
report(
|
||||
"cuda",
|
||||
benchmark(torch.ops.kernel_lab.tiled_matmul, a, b),
|
||||
args.m,
|
||||
args.n,
|
||||
args.k,
|
||||
)
|
||||
except Exception as exc:
|
||||
print(f"cuda: skipped ({exc})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
82
bench/bench_softmax.py
Normal file
82
bench/bench_softmax.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
|
||||
from kernels.triton.online_softmax import triton_online_softmax
|
||||
from kernels.triton.row_softmax import triton_row_softmax
|
||||
from reference.torch_online_softmax import torch_online_softmax
|
||||
from reference.torch_row_softmax import torch_row_softmax
|
||||
from tools.lab_extension import build_extension
|
||||
|
||||
|
||||
def benchmark(fn, *args, warmup: int = 5, reps: int = 25) -> float:
|
||||
for _ in range(warmup):
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms = []
|
||||
for _ in range(reps):
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms.append((time.perf_counter() - start) * 1e3)
|
||||
return statistics.median(times_ms)
|
||||
|
||||
|
||||
def report(name: str, elapsed_ms: float, x: torch.Tensor) -> None:
|
||||
logical_bytes = 3 * x.numel() * x.element_size()
|
||||
gbps = logical_bytes / (elapsed_ms * 1e-3) / 1e9
|
||||
print(f"{name}: {elapsed_ms:.3f} ms | logical bandwidth {gbps:.2f} GB/s")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--mode", choices=["all", "torch", "triton", "cuda"], default="all")
|
||||
parser.add_argument("--variant", choices=["row", "online"], default="row")
|
||||
parser.add_argument("--rows", type=int, default=4096)
|
||||
parser.add_argument("--cols", type=int, default=1024)
|
||||
args = parser.parse_args()
|
||||
|
||||
x = torch.randn(args.rows, args.cols, device=args.device)
|
||||
|
||||
ref_fn = torch_row_softmax if args.variant == "row" else torch_online_softmax
|
||||
triton_fn = triton_row_softmax if args.variant == "row" else triton_online_softmax
|
||||
cuda_name = "row_softmax" if args.variant == "row" else "online_softmax"
|
||||
|
||||
if args.mode in {"all", "torch"}:
|
||||
report(f"torch_{args.variant}_softmax", benchmark(ref_fn, x), x)
|
||||
|
||||
if args.device == "cuda" and args.mode in {"all", "triton"}:
|
||||
try:
|
||||
report(f"triton_{args.variant}_softmax", benchmark(triton_fn, x), x)
|
||||
except (NotImplementedError, RuntimeError) as exc:
|
||||
print(f"triton_{args.variant}_softmax: skipped ({exc})")
|
||||
|
||||
if args.device == "cuda" and args.mode in {"all", "cuda"}:
|
||||
ext = build_extension(verbose=False)
|
||||
if ext is None or not hasattr(torch.ops, "kernel_lab"):
|
||||
print(f"cuda_{args.variant}_softmax: skipped (extension unavailable)")
|
||||
else:
|
||||
try:
|
||||
cuda_fn = getattr(torch.ops.kernel_lab, cuda_name)
|
||||
report(f"cuda_{args.variant}_softmax", benchmark(cuda_fn, x), x)
|
||||
except Exception as exc:
|
||||
print(f"cuda_{args.variant}_softmax: skipped ({exc})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
74
bench/bench_vector_add.py
Normal file
74
bench/bench_vector_add.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
|
||||
from kernels.triton.vector_add import triton_vector_add
|
||||
from reference.torch_vector_add import torch_vector_add
|
||||
from tools.lab_extension import build_extension
|
||||
|
||||
|
||||
def benchmark(fn, *args, warmup: int = 5, reps: int = 30) -> float:
|
||||
for _ in range(warmup):
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms = []
|
||||
for _ in range(reps):
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms.append((time.perf_counter() - start) * 1e3)
|
||||
return statistics.median(times_ms)
|
||||
|
||||
|
||||
def report(name: str, elapsed_ms: float, x: torch.Tensor) -> None:
|
||||
bytes_moved = 3 * x.numel() * x.element_size()
|
||||
gbps = bytes_moved / (elapsed_ms * 1e-3) / 1e9
|
||||
print(f"{name}: {elapsed_ms:.3f} ms | effective bandwidth {gbps:.2f} GB/s")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--mode", choices=["all", "torch", "triton", "cuda"], default="all")
|
||||
parser.add_argument("--numel", type=int, default=1 << 24)
|
||||
args = parser.parse_args()
|
||||
|
||||
x = torch.randn(args.numel, device=args.device)
|
||||
y = torch.randn(args.numel, device=args.device)
|
||||
|
||||
if args.mode in {"all", "torch"}:
|
||||
report("torch", benchmark(torch_vector_add, x, y), x)
|
||||
|
||||
if args.device == "cuda" and args.mode in {"all", "triton"}:
|
||||
try:
|
||||
report("triton", benchmark(triton_vector_add, x, y), x)
|
||||
except (NotImplementedError, RuntimeError) as exc:
|
||||
print(f"triton: skipped ({exc})")
|
||||
|
||||
if args.device == "cuda" and args.mode in {"all", "cuda"}:
|
||||
ext = build_extension(verbose=False)
|
||||
if ext is None or not hasattr(torch.ops, "kernel_lab"):
|
||||
print("cuda: skipped (extension unavailable)")
|
||||
else:
|
||||
try:
|
||||
report("cuda", benchmark(torch.ops.kernel_lab.vector_add, x, y), x)
|
||||
except Exception as exc:
|
||||
print(f"cuda: skipped ({exc})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
32
bench/compare_impls.py
Normal file
32
bench/compare_impls.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
TASK_TO_SCRIPT = {
|
||||
"vector_add": ROOT / "bench" / "bench_vector_add.py",
|
||||
"softmax": ROOT / "bench" / "bench_softmax.py",
|
||||
"matmul": ROOT / "bench" / "bench_matmul.py",
|
||||
"attention": ROOT / "bench" / "bench_attention.py",
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", choices=sorted(TASK_TO_SCRIPT), required=True)
|
||||
parser.add_argument("extra_args", nargs="*")
|
||||
args = parser.parse_args()
|
||||
|
||||
cmd = [sys.executable, str(TASK_TO_SCRIPT[args.task]), *args.extra_args]
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user