ops: transformer op fwd/bwd CUDA kernels + Tensor wrappers

add/mul/add_bias(+sum_rows)/rms_norm/silu/rope/softmax/cross_entropy,
each with its analytic backward, in csrc/ops/nn.cu (inlined warp/block
reductions). FFI declarations + nn.cu in build.rs (no_cuda gated). Tensor
gains the matching thin wrappers; DType grows I32 for cross-entropy targets.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 15:44:09 +08:00
parent 88fbe0a85d
commit 5aef3742d6
5 changed files with 815 additions and 0 deletions

View File

@@ -32,6 +32,7 @@ fn main() {
.file("../../csrc/test/vecadd.cu")
.file("../../csrc/ops/elementwise.cu")
.file("../../csrc/ops/gemm.cu")
.file("../../csrc/ops/nn.cu")
.compile("xtrain_cuda_kernels");
}

View File

@@ -63,6 +63,120 @@ unsafe extern "C" {
);
}
// Transformer / autograd op kernels (csrc/ops/nn.cu). Forward + backward for the
// ops the Phase T4 tape engine needs. All F32, row-major, contiguous.
#[cfg(not(no_cuda))]
unsafe extern "C" {
// Elementwise: out = a + b ; out = a * b.
pub fn launch_add_f32(a: *const f32, b: *const f32, out: *mut f32, n: i32, s: CudaStream);
pub fn launch_mul_f32(a: *const f32, b: *const f32, out: *mut f32, n: i32, s: CudaStream);
// Broadcast bias add: out[r,c] = x[r,c] + bias[c]. x:[rows,cols], bias:[cols].
pub fn launch_add_bias_f32(
x: *const f32,
bias: *const f32,
out: *mut f32,
rows: i32,
cols: i32,
s: CudaStream,
);
// Column-sum (over rows): dbias[c] = sum_r dout[r,c]. Bias backward.
pub fn launch_sum_rows_f32(
dout: *const f32,
dbias: *mut f32,
rows: i32,
cols: i32,
s: CudaStream,
);
// RMSNorm forward: writes y[rows,cols] and inv_rms[rows] (cached for bwd).
pub fn launch_rms_norm_f32(
x: *const f32,
gamma: *const f32,
y: *mut f32,
inv_rms: *mut f32,
rows: i32,
cols: i32,
eps: f32,
s: CudaStream,
);
pub fn launch_rms_norm_dx_f32(
x: *const f32,
gamma: *const f32,
dy: *const f32,
inv_rms: *const f32,
dx: *mut f32,
rows: i32,
cols: i32,
s: CudaStream,
);
pub fn launch_rms_norm_dgamma_f32(
x: *const f32,
dy: *const f32,
inv_rms: *const f32,
dgamma: *mut f32,
rows: i32,
cols: i32,
s: CudaStream,
);
// SiLU: y = x*sigmoid(x); backward dx.
pub fn launch_silu_f32(x: *const f32, y: *mut f32, n: i32, s: CudaStream);
pub fn launch_silu_dx_f32(x: *const f32, dy: *const f32, dx: *mut f32, n: i32, s: CudaStream);
// RoPE (rotate_half), x:[tokens,heads,head_dim], position = token index.
pub fn launch_rope_f32(
x: *const f32,
y: *mut f32,
tokens: i32,
heads: i32,
head_dim: i32,
theta: f32,
s: CudaStream,
);
pub fn launch_rope_dx_f32(
dy: *const f32,
dx: *mut f32,
tokens: i32,
heads: i32,
head_dim: i32,
theta: f32,
s: CudaStream,
);
// Row-wise softmax + Jacobian backward.
pub fn launch_softmax_f32(x: *const f32, y: *mut f32, rows: i32, cols: i32, s: CudaStream);
pub fn launch_softmax_dx_f32(
y: *const f32,
dy: *const f32,
dx: *mut f32,
rows: i32,
cols: i32,
s: CudaStream,
);
// Cross-entropy: fwd writes probs[rows,cols] + per-row loss[rows];
// bwd dx = scale*(probs - onehot).
pub fn launch_cross_entropy_fwd_f32(
x: *const f32,
target: *const i32,
probs: *mut f32,
loss: *mut f32,
rows: i32,
cols: i32,
s: CudaStream,
);
pub fn launch_cross_entropy_dx_f32(
probs: *const f32,
target: *const i32,
dx: *mut f32,
rows: i32,
cols: i32,
scale: f32,
s: CudaStream,
);
}
// cuBLAS — used ONLY as a correctness reference for the hand-written GEMM in
// tests. Declared (and linked, see build.rs) only when CUDA is compiled in.
#[cfg(not(no_cuda))]

View File

@@ -7,18 +7,22 @@
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DType {
F32,
/// 32-bit signed integers. Used for cross-entropy targets (token ids).
I32,
}
impl DType {
pub fn size_bytes(self) -> usize {
match self {
DType::F32 => 4,
DType::I32 => 4,
}
}
pub fn name(self) -> &'static str {
match self {
DType::F32 => "f32",
DType::I32 => "i32",
}
}
}
@@ -45,3 +49,13 @@ impl TensorDType for f32 {
v as f32
}
}
impl TensorDType for i32 {
const DTYPE: DType = DType::I32;
fn to_f64(self) -> f64 {
self as f64
}
fn from_f64(v: f64) -> Self {
v as i32
}
}

View File

@@ -247,6 +247,334 @@ impl Tensor {
let db = a.transpose_2d().matmul(dc); // [K,M] @ [M,N] = [K,N]
(da, db)
}
// --- Transformer / autograd op primitives (the T4 kernels) ---
//
// Each is a thin, contiguous-F32-on-GPU wrapper over a kernel in
// csrc/ops/nn.cu. The autograd `Var` layer (xtrain-autodiff) builds nodes on
// top of these; the analytic backwards are derived in docs/03-autograd-engine.md.
/// Elementwise `out = self + other` (same shape).
#[cfg(not(no_cuda))]
pub fn add(&self, other: &Tensor) -> Self {
self.check_binary(other, "add");
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_add_f32(
self.data_ptr() as *const f32,
other.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
self.numel() as i32,
std::ptr::null_mut(),
);
}
xtrain_cuda::device::synchronize().expect("add sync failed");
out
}
/// Elementwise `out = self * other` (same shape, Hadamard product).
#[cfg(not(no_cuda))]
pub fn mul(&self, other: &Tensor) -> Self {
self.check_binary(other, "mul");
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_mul_f32(
self.data_ptr() as *const f32,
other.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
self.numel() as i32,
std::ptr::null_mut(),
);
}
xtrain_cuda::device::synchronize().expect("mul sync failed");
out
}
/// Broadcast bias add: `out[r,c] = self[r,c] + bias[c]`.
/// `self`:[rows,cols], `bias`:[cols].
#[cfg(not(no_cuda))]
pub fn add_bias(&self, bias: &Tensor) -> Self {
assert_eq!(self.ndim(), 2, "add_bias requires 2D input");
assert_eq!(bias.ndim(), 1, "bias must be 1D");
assert_eq!(self.shape[1], bias.shape[0], "bias len != cols");
let (rows, cols) = (self.shape[0], self.shape[1]);
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_add_bias_f32(
self.data_ptr() as *const f32,
bias.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
rows as i32,
cols as i32,
std::ptr::null_mut(),
);
}
xtrain_cuda::device::synchronize().expect("add_bias sync failed");
out
}
/// Column-sum over rows: `out[c] = sum_r self[r,c]`. This is the bias
/// backward (sum the upstream grad over the broadcast dim). `self`:[rows,cols]
/// → [cols].
#[cfg(not(no_cuda))]
pub fn sum_rows(&self) -> Self {
assert_eq!(self.ndim(), 2, "sum_rows requires 2D input");
let (rows, cols) = (self.shape[0], self.shape[1]);
let out = Tensor::zeros(&[cols], DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_sum_rows_f32(
self.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
rows as i32,
cols as i32,
std::ptr::null_mut(),
);
}
xtrain_cuda::device::synchronize().expect("sum_rows sync failed");
out
}
/// RMSNorm forward: `y[r,c] = x[r,c] * inv_rms[r] * gamma[c]` with
/// `inv_rms = rsqrt(mean(x²) + eps)`. `self`:[rows,cols], `gamma`:[cols].
/// Returns `(y, inv_rms)`; `inv_rms`:[rows] is cached for backward.
#[cfg(not(no_cuda))]
pub fn rms_norm(&self, gamma: &Tensor, eps: f32) -> (Tensor, Tensor) {
assert_eq!(self.ndim(), 2, "rms_norm requires 2D input");
assert_eq!(gamma.ndim(), 1, "gamma must be 1D");
assert_eq!(self.shape[1], gamma.shape[0], "gamma len != cols");
let (rows, cols) = (self.shape[0], self.shape[1]);
let y = Tensor::zeros(&self.shape, DType::F32, self.device());
let inv_rms = Tensor::zeros(&[rows], DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_rms_norm_f32(
self.data_ptr() as *const f32,
gamma.data_ptr() as *const f32,
y.data_ptr() as *mut f32,
inv_rms.data_ptr() as *mut f32,
rows as i32,
cols as i32,
eps,
std::ptr::null_mut(),
);
}
xtrain_cuda::device::synchronize().expect("rms_norm sync failed");
(y, inv_rms)
}
/// RMSNorm backward. Inputs are the forward `x`, `gamma`, upstream `dy`, and
/// the cached `inv_rms`. Returns `(dx, dgamma)`.
#[cfg(not(no_cuda))]
pub fn rms_norm_backward(
x: &Tensor,
gamma: &Tensor,
dy: &Tensor,
inv_rms: &Tensor,
) -> (Tensor, Tensor) {
let (rows, cols) = (x.shape[0], x.shape[1]);
let dx = Tensor::zeros(&[rows, cols], DType::F32, x.device());
let dgamma = Tensor::zeros(&[cols], DType::F32, x.device());
unsafe {
xtrain_cuda::ffi::launch_rms_norm_dx_f32(
x.data_ptr() as *const f32,
gamma.data_ptr() as *const f32,
dy.data_ptr() as *const f32,
inv_rms.data_ptr() as *const f32,
dx.data_ptr() as *mut f32,
rows as i32,
cols as i32,
std::ptr::null_mut(),
);
xtrain_cuda::ffi::launch_rms_norm_dgamma_f32(
x.data_ptr() as *const f32,
dy.data_ptr() as *const f32,
inv_rms.data_ptr() as *const f32,
dgamma.data_ptr() as *mut f32,
rows as i32,
cols as i32,
std::ptr::null_mut(),
);
}
xtrain_cuda::device::synchronize().expect("rms_norm_backward sync failed");
(dx, dgamma)
}
/// SiLU forward: `y = x * sigmoid(x)`, elementwise.
#[cfg(not(no_cuda))]
pub fn silu(&self) -> Self {
assert_eq!(self.dtype, DType::F32, "silu only supports F32");
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_silu_f32(
self.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
self.numel() as i32,
std::ptr::null_mut(),
);
}
xtrain_cuda::device::synchronize().expect("silu sync failed");
out
}
/// SiLU backward: `dx = dy * (sig + x*sig*(1-sig))`, `sig = sigmoid(x)`.
/// Inputs are the forward `x` and upstream `dy`.
#[cfg(not(no_cuda))]
pub fn silu_backward(x: &Tensor, dy: &Tensor) -> Self {
let dx = Tensor::zeros(&x.shape, DType::F32, x.device());
unsafe {
xtrain_cuda::ffi::launch_silu_dx_f32(
x.data_ptr() as *const f32,
dy.data_ptr() as *const f32,
dx.data_ptr() as *mut f32,
x.numel() as i32,
std::ptr::null_mut(),
);
}
xtrain_cuda::device::synchronize().expect("silu_backward sync failed");
dx
}
/// RoPE forward (rotate_half). `self`:[tokens,heads,head_dim]; the position
/// of each token is its row index. Returns the rotated tensor.
#[cfg(not(no_cuda))]
pub fn rope(&self, theta: f32) -> Self {
assert_eq!(self.ndim(), 3, "rope requires [tokens,heads,head_dim]");
let (tokens, heads, head_dim) = (self.shape[0], self.shape[1], self.shape[2]);
assert_eq!(head_dim % 2, 0, "head_dim must be even");
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_rope_f32(
self.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
tokens as i32,
heads as i32,
head_dim as i32,
theta,
std::ptr::null_mut(),
);
}
xtrain_cuda::device::synchronize().expect("rope sync failed");
out
}
/// RoPE backward: apply the inverse (transpose) rotation to `dy`. RoPE is an
/// orthogonal map, so it needs no cached forward values, only `theta`.
#[cfg(not(no_cuda))]
pub fn rope_backward(dy: &Tensor, theta: f32) -> Self {
let (tokens, heads, head_dim) = (dy.shape[0], dy.shape[1], dy.shape[2]);
let dx = Tensor::zeros(&dy.shape, DType::F32, dy.device());
unsafe {
xtrain_cuda::ffi::launch_rope_dx_f32(
dy.data_ptr() as *const f32,
dx.data_ptr() as *mut f32,
tokens as i32,
heads as i32,
head_dim as i32,
theta,
std::ptr::null_mut(),
);
}
xtrain_cuda::device::synchronize().expect("rope_backward sync failed");
dx
}
/// Row-wise safe softmax over the last dim. `self`:[rows,cols].
#[cfg(not(no_cuda))]
pub fn softmax(&self) -> Self {
assert_eq!(self.ndim(), 2, "softmax requires 2D input");
let (rows, cols) = (self.shape[0], self.shape[1]);
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_softmax_f32(
self.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
rows as i32,
cols as i32,
std::ptr::null_mut(),
);
}
xtrain_cuda::device::synchronize().expect("softmax sync failed");
out
}
/// Softmax backward (Jacobian): `dx[r,c] = y[r,c]*(dy[r,c] - sum_c'(dy*y))`.
/// Inputs are the forward output `y` and upstream `dy`.
#[cfg(not(no_cuda))]
pub fn softmax_backward(y: &Tensor, dy: &Tensor) -> Self {
let (rows, cols) = (y.shape[0], y.shape[1]);
let dx = Tensor::zeros(&y.shape, DType::F32, y.device());
unsafe {
xtrain_cuda::ffi::launch_softmax_dx_f32(
y.data_ptr() as *const f32,
dy.data_ptr() as *const f32,
dx.data_ptr() as *mut f32,
rows as i32,
cols as i32,
std::ptr::null_mut(),
);
}
xtrain_cuda::device::synchronize().expect("softmax_backward sync failed");
dx
}
/// Cross-entropy forward over logits `self`:[rows,cols] with one I32 target
/// per row. Returns `(probs, loss)` where `probs`:[rows,cols] is the softmax
/// (cached for backward) and `loss`:[rows] is the per-row negative log-likelihood.
#[cfg(not(no_cuda))]
pub fn cross_entropy(&self, target: &Tensor) -> (Tensor, Tensor) {
assert_eq!(self.ndim(), 2, "cross_entropy requires 2D logits");
assert_eq!(target.dtype, DType::I32, "target must be I32");
assert_eq!(target.numel(), self.shape[0], "one target per row");
let (rows, cols) = (self.shape[0], self.shape[1]);
let probs = Tensor::zeros(&self.shape, DType::F32, self.device());
let loss = Tensor::zeros(&[rows], DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_cross_entropy_fwd_f32(
self.data_ptr() as *const f32,
target.data_ptr() as *const i32,
probs.data_ptr() as *mut f32,
loss.data_ptr() as *mut f32,
rows as i32,
cols as i32,
std::ptr::null_mut(),
);
}
xtrain_cuda::device::synchronize().expect("cross_entropy sync failed");
(probs, loss)
}
/// Cross-entropy backward: `dx = scale * (probs - onehot(target))`. With
/// `scale = upstream / rows`, this is the gradient of the mean per-row loss.
#[cfg(not(no_cuda))]
pub fn cross_entropy_backward(probs: &Tensor, target: &Tensor, scale: f32) -> Self {
let (rows, cols) = (probs.shape[0], probs.shape[1]);
let dx = Tensor::zeros(&probs.shape, DType::F32, probs.device());
unsafe {
xtrain_cuda::ffi::launch_cross_entropy_dx_f32(
probs.data_ptr() as *const f32,
target.data_ptr() as *const i32,
dx.data_ptr() as *mut f32,
rows as i32,
cols as i32,
scale,
std::ptr::null_mut(),
);
}
xtrain_cuda::device::synchronize().expect("cross_entropy_backward sync failed");
dx
}
// Shared validation for same-shape binary elementwise ops.
#[cfg(not(no_cuda))]
fn check_binary(&self, other: &Tensor, op: &str) {
assert_eq!(self.dtype, DType::F32, "{op} only supports F32");
assert_eq!(other.dtype, DType::F32, "{op} only supports F32");
assert_eq!(self.shape(), other.shape(), "{op} shape mismatch");
assert_eq!(self.device(), other.device(), "{op} device mismatch");
assert!(
self.is_contiguous() && other.is_contiguous(),
"{op} requires contiguous tensors"
);
}
}
impl std::fmt::Debug for Tensor {