Initial project scaffold
This commit is contained in:
2
tasks/06_pytorch_custom_op/__init__.py
Normal file
2
tasks/06_pytorch_custom_op/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""PyTorch custom op task."""
|
||||
|
||||
26
tasks/06_pytorch_custom_op/extension_skeleton.py
Normal file
26
tasks/06_pytorch_custom_op/extension_skeleton.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
|
||||
from tools.lab_extension import build_extension
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ext = build_extension(verbose=True)
|
||||
if ext is None:
|
||||
return
|
||||
print("Extension loaded.")
|
||||
print("Available torch.ops namespace:", hasattr(torch.ops, "kernel_lab"))
|
||||
if hasattr(torch.ops, "kernel_lab"):
|
||||
print("Registered ops:", dir(torch.ops.kernel_lab))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
27
tasks/06_pytorch_custom_op/opcheck_test.py
Normal file
27
tasks/06_pytorch_custom_op/opcheck_test.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tools.lab_extension import build_extension
|
||||
|
||||
|
||||
@pytest.mark.cuda_required
|
||||
@pytest.mark.skeleton
|
||||
def test_vector_add_opcheck_if_available():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
ext = build_extension(verbose=False)
|
||||
if ext is None or not hasattr(torch.ops, "kernel_lab"):
|
||||
pytest.skip("extension is unavailable")
|
||||
if not hasattr(torch.library, "opcheck"):
|
||||
pytest.skip("torch.library.opcheck is unavailable")
|
||||
|
||||
x = torch.randn(32, device="cuda")
|
||||
y = torch.randn(32, device="cuda")
|
||||
try:
|
||||
torch.ops.kernel_lab.vector_add(x, y)
|
||||
except Exception as exc:
|
||||
pytest.skip(f"operator is not implemented yet: {exc}")
|
||||
|
||||
torch.library.opcheck(torch.ops.kernel_lab.vector_add, (x, y))
|
||||
45
tasks/06_pytorch_custom_op/spec.md
Normal file
45
tasks/06_pytorch_custom_op/spec.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# Task 06: PyTorch Custom Op
|
||||
|
||||
## 1. Problem Statement
|
||||
|
||||
Expose a CUDA kernel as a PyTorch operator so Python code can call it and test it like any other operator.
|
||||
|
||||
## 2. Expected Input/Output Shapes
|
||||
|
||||
For the starter binding, use vector add:
|
||||
|
||||
- `x`: `[N]`
|
||||
- `y`: `[N]`
|
||||
- output: `[N]`
|
||||
|
||||
The same pattern can later be extended to the other operators.
|
||||
|
||||
## 3. Performance Intuition
|
||||
|
||||
The binding layer is not usually where the kernel time goes, but it determines whether you can test, benchmark, and profile the CUDA implementation from Python.
|
||||
|
||||
## 4. Memory Access Discussion
|
||||
|
||||
The binding itself does not optimize memory traffic; it passes tensors and dispatches the kernel. Still, the binding must preserve shape, dtype, device, and contiguity assumptions.
|
||||
|
||||
## 5. What Triton Is Abstracting
|
||||
|
||||
Triton often avoids a separate C++ binding layer because Python can launch the JIT kernel directly.
|
||||
|
||||
## 6. What CUDA Makes Explicit
|
||||
|
||||
CUDA plus PyTorch binding requires you to define function signatures, operator registration, and build integration explicitly.
|
||||
|
||||
## 7. Reflection Questions
|
||||
|
||||
- What assumptions should the binding validate before calling a CUDA kernel?
|
||||
- Why is operator registration useful for testing and benchmarking?
|
||||
- What changes once you want autograd support?
|
||||
|
||||
## 8. Implementation Checklist
|
||||
|
||||
- Read `kernels/cuda/binding/binding.cpp`
|
||||
- Build or load the extension from Python
|
||||
- Call the operator from `torch.ops.kernel_lab`
|
||||
- Add correctness checks once the CUDA kernel is implemented
|
||||
- Try `torch.library.opcheck` if your PyTorch build provides it
|
||||
Reference in New Issue
Block a user