Initial project scaffold

This commit is contained in:
wjh
2026-04-10 13:15:06 +00:00
commit a4a6b1f1c8
94 changed files with 3964 additions and 0 deletions

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
import statistics
import sys
import time
from pathlib import Path
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import torch
from kernels.triton.tiled_matmul import triton_tiled_matmul
from reference.torch_matmul import torch_matmul
def benchmark(fn, *args, warmup: int = 5, reps: int = 20) -> float:
for _ in range(warmup):
fn(*args)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms = []
for _ in range(reps):
if args[0].is_cuda:
torch.cuda.synchronize()
start = time.perf_counter()
fn(*args)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms.append((time.perf_counter() - start) * 1e3)
return statistics.median(times_ms)
def main() -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"
for m, k, n in [(128, 128, 128), (512, 512, 512)]:
a = torch.randn(m, k, device=device)
b = torch.randn(k, n, device=device)
ref_ms = benchmark(torch_matmul, a, b)
print(f"torch_matmul {m}x{k}x{n}: {ref_ms:.3f} ms")
if device == "cuda":
try:
triton_ms = benchmark(triton_tiled_matmul, a, b)
print(f"triton_tiled_matmul {m}x{k}x{n}: {triton_ms:.3f} ms")
except (NotImplementedError, RuntimeError) as exc:
print(f"triton_tiled_matmul {m}x{k}x{n}: skipped ({exc})")
if __name__ == "__main__":
main()