Initial project scaffold

This commit is contained in:
wjh
2026-04-10 13:15:06 +00:00
commit a4a6b1f1c8
94 changed files with 3964 additions and 0 deletions

2
tools/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""Helper modules and scripts for the lab."""

58
tools/check_env.py Normal file
View File

@@ -0,0 +1,58 @@
from __future__ import annotations
import platform
import shutil
import subprocess
import torch
def run_command(cmd: list[str]) -> str:
if shutil.which(cmd[0]) is None:
return "not found"
try:
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
text = (result.stdout or result.stderr).strip()
return text or f"command exited with code {result.returncode}"
except Exception as exc: # pragma: no cover - defensive
return f"error: {exc}"
def main() -> None:
print("=== System ===")
print("python:", platform.python_version())
print("platform:", platform.platform())
print("\n=== PyTorch ===")
print("torch:", torch.__version__)
print("torch.cuda.is_available():", torch.cuda.is_available())
print("torch.version.cuda:", torch.version.cuda)
if torch.cuda.is_available():
device_count = torch.cuda.device_count()
print("cuda device count:", device_count)
for idx in range(device_count):
name = torch.cuda.get_device_name(idx)
capability = torch.cuda.get_device_capability(idx)
print(f"device {idx}: {name} | capability={capability[0]}.{capability[1]}")
else:
print("no CUDA device visible to PyTorch")
print("\n=== Triton ===")
try:
import triton # type: ignore
print("triton:", triton.__version__)
except Exception as exc:
print("triton import failed:", exc)
print("\n=== Toolkit / Driver Hints ===")
print("nvcc --version:")
print(run_command(["nvcc", "--version"]))
print("\nnvidia-smi:")
print(run_command(["nvidia-smi"]))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,138 @@
from __future__ import annotations
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import torch
from kernels.triton.flash_attention_fwd import triton_flash_attention_fwd
from kernels.triton.online_softmax import triton_online_softmax
from kernels.triton.row_softmax import triton_row_softmax
from kernels.triton.tiled_matmul import triton_tiled_matmul
from kernels.triton.vector_add import triton_vector_add
from reference.torch_attention import torch_attention
from reference.torch_matmul import torch_matmul
from reference.torch_online_softmax import torch_online_softmax
from reference.torch_row_softmax import torch_row_softmax
from reference.torch_vector_add import torch_vector_add
from tools.lab_extension import build_extension
def compare_vector_add(device: str) -> None:
x = torch.randn(4097, device=device)
y = torch.randn(4097, device=device)
ref = torch_vector_add(x, y)
print("torch reference ready")
try:
torch.testing.assert_close(triton_vector_add(x, y), ref)
print("triton matches torch")
except Exception as exc:
print(f"triton unavailable: {exc}")
ext = build_extension(verbose=False) if device == "cuda" else None
if ext is not None and hasattr(torch.ops, "kernel_lab"):
try:
torch.testing.assert_close(torch.ops.kernel_lab.vector_add(x, y), ref)
print("cuda op matches torch")
except Exception as exc:
print(f"cuda op unavailable: {exc}")
def compare_softmax(device: str, variant: str) -> None:
x = torch.randn(128, 257, device=device)
ref = torch_row_softmax(x) if variant == "row" else torch_online_softmax(x)
print("torch reference ready")
triton_fn = triton_row_softmax if variant == "row" else triton_online_softmax
try:
torch.testing.assert_close(triton_fn(x), ref, atol=1e-4, rtol=1e-4)
print("triton matches torch")
except Exception as exc:
print(f"triton unavailable: {exc}")
ext = build_extension(verbose=False) if device == "cuda" else None
op_name = "row_softmax" if variant == "row" else "online_softmax"
if ext is not None and hasattr(torch.ops, "kernel_lab"):
try:
torch.testing.assert_close(
getattr(torch.ops.kernel_lab, op_name)(x), ref, atol=1e-4, rtol=1e-4
)
print("cuda op matches torch")
except Exception as exc:
print(f"cuda op unavailable: {exc}")
def compare_matmul(device: str) -> None:
a = torch.randn(64, 96, device=device)
b = torch.randn(96, 48, device=device)
ref = torch_matmul(a, b)
print("torch reference ready")
try:
torch.testing.assert_close(triton_tiled_matmul(a, b), ref, atol=1e-3, rtol=1e-3)
print("triton matches torch")
except Exception as exc:
print(f"triton unavailable: {exc}")
ext = build_extension(verbose=False) if device == "cuda" else None
if ext is not None and hasattr(torch.ops, "kernel_lab"):
try:
torch.testing.assert_close(
torch.ops.kernel_lab.tiled_matmul(a, b), ref, atol=1e-3, rtol=1e-3
)
print("cuda op matches torch")
except Exception as exc:
print(f"cuda op unavailable: {exc}")
def compare_attention(device: str) -> None:
q = torch.randn(1, 2, 16, 32, device=device)
k = torch.randn(1, 2, 16, 32, device=device)
v = torch.randn(1, 2, 16, 32, device=device)
ref = torch_attention(q, k, v, causal=False)
print("torch reference ready")
try:
torch.testing.assert_close(
triton_flash_attention_fwd(q, k, v, causal=False), ref, atol=2e-3, rtol=2e-3
)
print("triton matches torch")
except Exception as exc:
print(f"triton unavailable: {exc}")
ext = build_extension(verbose=False) if device == "cuda" else None
if ext is not None and hasattr(torch.ops, "kernel_lab"):
try:
torch.testing.assert_close(
torch.ops.kernel_lab.flash_attention_fwd(q, k, v, False),
ref,
atol=2e-3,
rtol=2e-3,
)
print("cuda op matches torch")
except Exception as exc:
print(f"cuda op unavailable: {exc}")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--task",
choices=["vector_add", "row_softmax", "online_softmax", "matmul", "attention"],
required=True,
)
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
if args.task == "vector_add":
compare_vector_add(args.device)
elif args.task == "row_softmax":
compare_softmax(args.device, "row")
elif args.task == "online_softmax":
compare_softmax(args.device, "online")
elif args.task == "matmul":
compare_matmul(args.device)
else:
compare_attention(args.device)
if __name__ == "__main__":
main()

57
tools/lab_extension.py Normal file
View File

@@ -0,0 +1,57 @@
from __future__ import annotations
import os
from pathlib import Path
from typing import Optional
import torch
try:
from torch.utils.cpp_extension import load
except ImportError: # pragma: no cover - depends on torch install
load = None
ROOT = Path(__file__).resolve().parents[1]
CUDA_DIR = ROOT / "kernels" / "cuda"
def _format_torch_cuda_arch(raw_arch: str) -> str:
if raw_arch.isdigit() and len(raw_arch) == 3:
return f"{raw_arch[:2]}.{raw_arch[2]}"
return raw_arch
def build_extension(verbose: bool = True) -> Optional[object]:
"""Build or load the lab extension if the local environment allows it."""
if load is None:
print("torch.utils.cpp_extension.load is unavailable in this PyTorch build.")
return None
if not torch.cuda.is_available():
print("CUDA is not available; skipping extension build.")
return None
arch = _format_torch_cuda_arch(os.environ.get("KERNEL_LAB_CUDA_ARCH", "120"))
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", arch)
sources = [
str(CUDA_DIR / "binding" / "binding.cpp"),
str(CUDA_DIR / "src" / "vector_add.cu"),
str(CUDA_DIR / "src" / "row_softmax.cu"),
str(CUDA_DIR / "src" / "tiled_matmul.cu"),
str(CUDA_DIR / "src" / "online_softmax.cu"),
str(CUDA_DIR / "src" / "flash_attention_fwd.cu"),
]
try:
return load(
name="kernel_lab_ext",
sources=sources,
extra_include_paths=[str(CUDA_DIR / "include")],
extra_cflags=["-O0", "-std=c++17"],
extra_cuda_cflags=["-O0", "-lineinfo"],
verbose=verbose,
)
except Exception as exc: # pragma: no cover - environment-dependent
print(f"Extension build/load failed: {exc}")
return None

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
import torch
def main() -> None:
if not torch.cuda.is_available():
print("CUDA is not available.")
return
for idx in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(idx)
print(f"device {idx}: {props.name}")
print(f" capability: {props.major}.{props.minor}")
print(f" total memory (GB): {props.total_memory / 1e9:.2f}")
print(f" multiprocessors: {props.multi_processor_count}")
print(f" max threads per block: {props.max_threads_per_block}")
print(f" warp size: {props.warp_size}")
if __name__ == "__main__":
main()

10
tools/profile_ncu.sh Executable file
View File

@@ -0,0 +1,10 @@
#!/usr/bin/env bash
set -euo pipefail
if [[ $# -eq 0 ]]; then
echo "usage: $0 <command ...>"
exit 1
fi
ncu --set full --target-processes all "$@"

11
tools/profile_nsys.sh Executable file
View File

@@ -0,0 +1,11 @@
#!/usr/bin/env bash
set -euo pipefail
if [[ $# -eq 0 ]]; then
echo "usage: $0 <command ...>"
exit 1
fi
mkdir -p profile-output
nsys profile --trace=cuda,nvtx,osrt --sample=none -o profile-output/profile "$@"

8
tools/run_all_benchmarks.sh Executable file
View File

@@ -0,0 +1,8 @@
#!/usr/bin/env bash
set -euo pipefail
python bench/bench_vector_add.py "$@"
python bench/bench_softmax.py "$@"
python bench/bench_matmul.py "$@"
python bench/bench_attention.py "$@"

5
tools/run_all_tests.sh Executable file
View File

@@ -0,0 +1,5 @@
#!/usr/bin/env bash
set -euo pipefail
pytest -q