from __future__ import annotations import os from pathlib import Path from typing import Optional import torch try: from torch.utils.cpp_extension import load except ImportError: # pragma: no cover - depends on torch install load = None ROOT = Path(__file__).resolve().parents[1] CUDA_DIR = ROOT / "kernels" / "cuda" def _format_torch_cuda_arch(raw_arch: str) -> str: if raw_arch.isdigit() and len(raw_arch) == 3: return f"{raw_arch[:2]}.{raw_arch[2]}" return raw_arch def build_extension(verbose: bool = True) -> Optional[object]: """Build or load the lab extension if the local environment allows it.""" if load is None: print("torch.utils.cpp_extension.load is unavailable in this PyTorch build.") return None if not torch.cuda.is_available(): print("CUDA is not available; skipping extension build.") return None arch = _format_torch_cuda_arch(os.environ.get("KERNEL_LAB_CUDA_ARCH", "120")) os.environ.setdefault("TORCH_CUDA_ARCH_LIST", arch) sources = [ str(CUDA_DIR / "binding" / "binding.cpp"), str(CUDA_DIR / "src" / "vector_add.cu"), str(CUDA_DIR / "src" / "row_softmax.cu"), str(CUDA_DIR / "src" / "tiled_matmul.cu"), str(CUDA_DIR / "src" / "online_softmax.cu"), str(CUDA_DIR / "src" / "flash_attention_fwd.cu"), ] try: return load( name="kernel_lab_ext", sources=sources, extra_include_paths=[str(CUDA_DIR / "include")], extra_cflags=["-O0", "-std=c++17"], extra_cuda_cflags=["-O0", "-lineinfo"], verbose=verbose, ) except Exception as exc: # pragma: no cover - environment-dependent print(f"Extension build/load failed: {exc}") return None