58 lines
1.8 KiB
Python
58 lines
1.8 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
try:
|
|
from torch.utils.cpp_extension import load
|
|
except ImportError: # pragma: no cover - depends on torch install
|
|
load = None
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
CUDA_DIR = ROOT / "kernels" / "cuda"
|
|
|
|
|
|
def _format_torch_cuda_arch(raw_arch: str) -> str:
|
|
if raw_arch.isdigit() and len(raw_arch) == 3:
|
|
return f"{raw_arch[:2]}.{raw_arch[2]}"
|
|
return raw_arch
|
|
|
|
|
|
def build_extension(verbose: bool = True) -> Optional[object]:
|
|
"""Build or load the lab extension if the local environment allows it."""
|
|
if load is None:
|
|
print("torch.utils.cpp_extension.load is unavailable in this PyTorch build.")
|
|
return None
|
|
if not torch.cuda.is_available():
|
|
print("CUDA is not available; skipping extension build.")
|
|
return None
|
|
|
|
arch = _format_torch_cuda_arch(os.environ.get("KERNEL_LAB_CUDA_ARCH", "120"))
|
|
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", arch)
|
|
|
|
sources = [
|
|
str(CUDA_DIR / "binding" / "binding.cpp"),
|
|
str(CUDA_DIR / "src" / "vector_add.cu"),
|
|
str(CUDA_DIR / "src" / "row_softmax.cu"),
|
|
str(CUDA_DIR / "src" / "tiled_matmul.cu"),
|
|
str(CUDA_DIR / "src" / "online_softmax.cu"),
|
|
str(CUDA_DIR / "src" / "flash_attention_fwd.cu"),
|
|
]
|
|
|
|
try:
|
|
return load(
|
|
name="kernel_lab_ext",
|
|
sources=sources,
|
|
extra_include_paths=[str(CUDA_DIR / "include")],
|
|
extra_cflags=["-O0", "-std=c++17"],
|
|
extra_cuda_cflags=["-O0", "-lineinfo"],
|
|
verbose=verbose,
|
|
)
|
|
except Exception as exc: # pragma: no cover - environment-dependent
|
|
print(f"Extension build/load failed: {exc}")
|
|
return None
|