Initial project scaffold
This commit is contained in:
57
tools/lab_extension.py
Normal file
57
tools/lab_extension.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch.utils.cpp_extension import load
|
||||
except ImportError: # pragma: no cover - depends on torch install
|
||||
load = None
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
CUDA_DIR = ROOT / "kernels" / "cuda"
|
||||
|
||||
|
||||
def _format_torch_cuda_arch(raw_arch: str) -> str:
|
||||
if raw_arch.isdigit() and len(raw_arch) == 3:
|
||||
return f"{raw_arch[:2]}.{raw_arch[2]}"
|
||||
return raw_arch
|
||||
|
||||
|
||||
def build_extension(verbose: bool = True) -> Optional[object]:
|
||||
"""Build or load the lab extension if the local environment allows it."""
|
||||
if load is None:
|
||||
print("torch.utils.cpp_extension.load is unavailable in this PyTorch build.")
|
||||
return None
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA is not available; skipping extension build.")
|
||||
return None
|
||||
|
||||
arch = _format_torch_cuda_arch(os.environ.get("KERNEL_LAB_CUDA_ARCH", "120"))
|
||||
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", arch)
|
||||
|
||||
sources = [
|
||||
str(CUDA_DIR / "binding" / "binding.cpp"),
|
||||
str(CUDA_DIR / "src" / "vector_add.cu"),
|
||||
str(CUDA_DIR / "src" / "row_softmax.cu"),
|
||||
str(CUDA_DIR / "src" / "tiled_matmul.cu"),
|
||||
str(CUDA_DIR / "src" / "online_softmax.cu"),
|
||||
str(CUDA_DIR / "src" / "flash_attention_fwd.cu"),
|
||||
]
|
||||
|
||||
try:
|
||||
return load(
|
||||
name="kernel_lab_ext",
|
||||
sources=sources,
|
||||
extra_include_paths=[str(CUDA_DIR / "include")],
|
||||
extra_cflags=["-O0", "-std=c++17"],
|
||||
extra_cuda_cflags=["-O0", "-lineinfo"],
|
||||
verbose=verbose,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - environment-dependent
|
||||
print(f"Extension build/load failed: {exc}")
|
||||
return None
|
||||
Reference in New Issue
Block a user