Initial project scaffold
This commit is contained in:
2
tools/__init__.py
Normal file
2
tools/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Helper modules and scripts for the lab."""
|
||||
|
||||
58
tools/check_env.py
Normal file
58
tools/check_env.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def run_command(cmd: list[str]) -> str:
|
||||
if shutil.which(cmd[0]) is None:
|
||||
return "not found"
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||
text = (result.stdout or result.stderr).strip()
|
||||
return text or f"command exited with code {result.returncode}"
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
return f"error: {exc}"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
print("=== System ===")
|
||||
print("python:", platform.python_version())
|
||||
print("platform:", platform.platform())
|
||||
|
||||
print("\n=== PyTorch ===")
|
||||
print("torch:", torch.__version__)
|
||||
print("torch.cuda.is_available():", torch.cuda.is_available())
|
||||
print("torch.version.cuda:", torch.version.cuda)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device_count = torch.cuda.device_count()
|
||||
print("cuda device count:", device_count)
|
||||
for idx in range(device_count):
|
||||
name = torch.cuda.get_device_name(idx)
|
||||
capability = torch.cuda.get_device_capability(idx)
|
||||
print(f"device {idx}: {name} | capability={capability[0]}.{capability[1]}")
|
||||
else:
|
||||
print("no CUDA device visible to PyTorch")
|
||||
|
||||
print("\n=== Triton ===")
|
||||
try:
|
||||
import triton # type: ignore
|
||||
|
||||
print("triton:", triton.__version__)
|
||||
except Exception as exc:
|
||||
print("triton import failed:", exc)
|
||||
|
||||
print("\n=== Toolkit / Driver Hints ===")
|
||||
print("nvcc --version:")
|
||||
print(run_command(["nvcc", "--version"]))
|
||||
print("\nnvidia-smi:")
|
||||
print(run_command(["nvidia-smi"]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
138
tools/compare_against_torch.py
Normal file
138
tools/compare_against_torch.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
|
||||
from kernels.triton.flash_attention_fwd import triton_flash_attention_fwd
|
||||
from kernels.triton.online_softmax import triton_online_softmax
|
||||
from kernels.triton.row_softmax import triton_row_softmax
|
||||
from kernels.triton.tiled_matmul import triton_tiled_matmul
|
||||
from kernels.triton.vector_add import triton_vector_add
|
||||
from reference.torch_attention import torch_attention
|
||||
from reference.torch_matmul import torch_matmul
|
||||
from reference.torch_online_softmax import torch_online_softmax
|
||||
from reference.torch_row_softmax import torch_row_softmax
|
||||
from reference.torch_vector_add import torch_vector_add
|
||||
from tools.lab_extension import build_extension
|
||||
|
||||
|
||||
def compare_vector_add(device: str) -> None:
|
||||
x = torch.randn(4097, device=device)
|
||||
y = torch.randn(4097, device=device)
|
||||
ref = torch_vector_add(x, y)
|
||||
print("torch reference ready")
|
||||
try:
|
||||
torch.testing.assert_close(triton_vector_add(x, y), ref)
|
||||
print("triton matches torch")
|
||||
except Exception as exc:
|
||||
print(f"triton unavailable: {exc}")
|
||||
ext = build_extension(verbose=False) if device == "cuda" else None
|
||||
if ext is not None and hasattr(torch.ops, "kernel_lab"):
|
||||
try:
|
||||
torch.testing.assert_close(torch.ops.kernel_lab.vector_add(x, y), ref)
|
||||
print("cuda op matches torch")
|
||||
except Exception as exc:
|
||||
print(f"cuda op unavailable: {exc}")
|
||||
|
||||
|
||||
def compare_softmax(device: str, variant: str) -> None:
|
||||
x = torch.randn(128, 257, device=device)
|
||||
ref = torch_row_softmax(x) if variant == "row" else torch_online_softmax(x)
|
||||
print("torch reference ready")
|
||||
triton_fn = triton_row_softmax if variant == "row" else triton_online_softmax
|
||||
try:
|
||||
torch.testing.assert_close(triton_fn(x), ref, atol=1e-4, rtol=1e-4)
|
||||
print("triton matches torch")
|
||||
except Exception as exc:
|
||||
print(f"triton unavailable: {exc}")
|
||||
ext = build_extension(verbose=False) if device == "cuda" else None
|
||||
op_name = "row_softmax" if variant == "row" else "online_softmax"
|
||||
if ext is not None and hasattr(torch.ops, "kernel_lab"):
|
||||
try:
|
||||
torch.testing.assert_close(
|
||||
getattr(torch.ops.kernel_lab, op_name)(x), ref, atol=1e-4, rtol=1e-4
|
||||
)
|
||||
print("cuda op matches torch")
|
||||
except Exception as exc:
|
||||
print(f"cuda op unavailable: {exc}")
|
||||
|
||||
|
||||
def compare_matmul(device: str) -> None:
|
||||
a = torch.randn(64, 96, device=device)
|
||||
b = torch.randn(96, 48, device=device)
|
||||
ref = torch_matmul(a, b)
|
||||
print("torch reference ready")
|
||||
try:
|
||||
torch.testing.assert_close(triton_tiled_matmul(a, b), ref, atol=1e-3, rtol=1e-3)
|
||||
print("triton matches torch")
|
||||
except Exception as exc:
|
||||
print(f"triton unavailable: {exc}")
|
||||
ext = build_extension(verbose=False) if device == "cuda" else None
|
||||
if ext is not None and hasattr(torch.ops, "kernel_lab"):
|
||||
try:
|
||||
torch.testing.assert_close(
|
||||
torch.ops.kernel_lab.tiled_matmul(a, b), ref, atol=1e-3, rtol=1e-3
|
||||
)
|
||||
print("cuda op matches torch")
|
||||
except Exception as exc:
|
||||
print(f"cuda op unavailable: {exc}")
|
||||
|
||||
|
||||
def compare_attention(device: str) -> None:
|
||||
q = torch.randn(1, 2, 16, 32, device=device)
|
||||
k = torch.randn(1, 2, 16, 32, device=device)
|
||||
v = torch.randn(1, 2, 16, 32, device=device)
|
||||
ref = torch_attention(q, k, v, causal=False)
|
||||
print("torch reference ready")
|
||||
try:
|
||||
torch.testing.assert_close(
|
||||
triton_flash_attention_fwd(q, k, v, causal=False), ref, atol=2e-3, rtol=2e-3
|
||||
)
|
||||
print("triton matches torch")
|
||||
except Exception as exc:
|
||||
print(f"triton unavailable: {exc}")
|
||||
ext = build_extension(verbose=False) if device == "cuda" else None
|
||||
if ext is not None and hasattr(torch.ops, "kernel_lab"):
|
||||
try:
|
||||
torch.testing.assert_close(
|
||||
torch.ops.kernel_lab.flash_attention_fwd(q, k, v, False),
|
||||
ref,
|
||||
atol=2e-3,
|
||||
rtol=2e-3,
|
||||
)
|
||||
print("cuda op matches torch")
|
||||
except Exception as exc:
|
||||
print(f"cuda op unavailable: {exc}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
choices=["vector_add", "row_softmax", "online_softmax", "matmul", "attention"],
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.task == "vector_add":
|
||||
compare_vector_add(args.device)
|
||||
elif args.task == "row_softmax":
|
||||
compare_softmax(args.device, "row")
|
||||
elif args.task == "online_softmax":
|
||||
compare_softmax(args.device, "online")
|
||||
elif args.task == "matmul":
|
||||
compare_matmul(args.device)
|
||||
else:
|
||||
compare_attention(args.device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
57
tools/lab_extension.py
Normal file
57
tools/lab_extension.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch.utils.cpp_extension import load
|
||||
except ImportError: # pragma: no cover - depends on torch install
|
||||
load = None
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
CUDA_DIR = ROOT / "kernels" / "cuda"
|
||||
|
||||
|
||||
def _format_torch_cuda_arch(raw_arch: str) -> str:
|
||||
if raw_arch.isdigit() and len(raw_arch) == 3:
|
||||
return f"{raw_arch[:2]}.{raw_arch[2]}"
|
||||
return raw_arch
|
||||
|
||||
|
||||
def build_extension(verbose: bool = True) -> Optional[object]:
|
||||
"""Build or load the lab extension if the local environment allows it."""
|
||||
if load is None:
|
||||
print("torch.utils.cpp_extension.load is unavailable in this PyTorch build.")
|
||||
return None
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA is not available; skipping extension build.")
|
||||
return None
|
||||
|
||||
arch = _format_torch_cuda_arch(os.environ.get("KERNEL_LAB_CUDA_ARCH", "120"))
|
||||
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", arch)
|
||||
|
||||
sources = [
|
||||
str(CUDA_DIR / "binding" / "binding.cpp"),
|
||||
str(CUDA_DIR / "src" / "vector_add.cu"),
|
||||
str(CUDA_DIR / "src" / "row_softmax.cu"),
|
||||
str(CUDA_DIR / "src" / "tiled_matmul.cu"),
|
||||
str(CUDA_DIR / "src" / "online_softmax.cu"),
|
||||
str(CUDA_DIR / "src" / "flash_attention_fwd.cu"),
|
||||
]
|
||||
|
||||
try:
|
||||
return load(
|
||||
name="kernel_lab_ext",
|
||||
sources=sources,
|
||||
extra_include_paths=[str(CUDA_DIR / "include")],
|
||||
extra_cflags=["-O0", "-std=c++17"],
|
||||
extra_cuda_cflags=["-O0", "-lineinfo"],
|
||||
verbose=verbose,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - environment-dependent
|
||||
print(f"Extension build/load failed: {exc}")
|
||||
return None
|
||||
23
tools/print_device_info.py
Normal file
23
tools/print_device_info.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA is not available.")
|
||||
return
|
||||
|
||||
for idx in range(torch.cuda.device_count()):
|
||||
props = torch.cuda.get_device_properties(idx)
|
||||
print(f"device {idx}: {props.name}")
|
||||
print(f" capability: {props.major}.{props.minor}")
|
||||
print(f" total memory (GB): {props.total_memory / 1e9:.2f}")
|
||||
print(f" multiprocessors: {props.multi_processor_count}")
|
||||
print(f" max threads per block: {props.max_threads_per_block}")
|
||||
print(f" warp size: {props.warp_size}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
10
tools/profile_ncu.sh
Executable file
10
tools/profile_ncu.sh
Executable file
@@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
if [[ $# -eq 0 ]]; then
|
||||
echo "usage: $0 <command ...>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ncu --set full --target-processes all "$@"
|
||||
|
||||
11
tools/profile_nsys.sh
Executable file
11
tools/profile_nsys.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
if [[ $# -eq 0 ]]; then
|
||||
echo "usage: $0 <command ...>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p profile-output
|
||||
nsys profile --trace=cuda,nvtx,osrt --sample=none -o profile-output/profile "$@"
|
||||
|
||||
8
tools/run_all_benchmarks.sh
Executable file
8
tools/run_all_benchmarks.sh
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
python bench/bench_vector_add.py "$@"
|
||||
python bench/bench_softmax.py "$@"
|
||||
python bench/bench_matmul.py "$@"
|
||||
python bench/bench_attention.py "$@"
|
||||
|
||||
5
tools/run_all_tests.sh
Executable file
5
tools/run_all_tests.sh
Executable file
@@ -0,0 +1,5 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
pytest -q
|
||||
|
||||
Reference in New Issue
Block a user