ops: transformer op fwd/bwd CUDA kernels + Tensor wrappers
add/mul/add_bias(+sum_rows)/rms_norm/silu/rope/softmax/cross_entropy, each with its analytic backward, in csrc/ops/nn.cu (inlined warp/block reductions). FFI declarations + nn.cu in build.rs (no_cuda gated). Tensor gains the matching thin wrappers; DType grows I32 for cross-entropy targets. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -32,6 +32,7 @@ fn main() {
|
||||
.file("../../csrc/test/vecadd.cu")
|
||||
.file("../../csrc/ops/elementwise.cu")
|
||||
.file("../../csrc/ops/gemm.cu")
|
||||
.file("../../csrc/ops/nn.cu")
|
||||
.compile("xtrain_cuda_kernels");
|
||||
}
|
||||
|
||||
|
||||
@@ -63,6 +63,120 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
// Transformer / autograd op kernels (csrc/ops/nn.cu). Forward + backward for the
|
||||
// ops the Phase T4 tape engine needs. All F32, row-major, contiguous.
|
||||
#[cfg(not(no_cuda))]
|
||||
unsafe extern "C" {
|
||||
// Elementwise: out = a + b ; out = a * b.
|
||||
pub fn launch_add_f32(a: *const f32, b: *const f32, out: *mut f32, n: i32, s: CudaStream);
|
||||
pub fn launch_mul_f32(a: *const f32, b: *const f32, out: *mut f32, n: i32, s: CudaStream);
|
||||
|
||||
// Broadcast bias add: out[r,c] = x[r,c] + bias[c]. x:[rows,cols], bias:[cols].
|
||||
pub fn launch_add_bias_f32(
|
||||
x: *const f32,
|
||||
bias: *const f32,
|
||||
out: *mut f32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// Column-sum (over rows): dbias[c] = sum_r dout[r,c]. Bias backward.
|
||||
pub fn launch_sum_rows_f32(
|
||||
dout: *const f32,
|
||||
dbias: *mut f32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
|
||||
// RMSNorm forward: writes y[rows,cols] and inv_rms[rows] (cached for bwd).
|
||||
pub fn launch_rms_norm_f32(
|
||||
x: *const f32,
|
||||
gamma: *const f32,
|
||||
y: *mut f32,
|
||||
inv_rms: *mut f32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
eps: f32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_rms_norm_dx_f32(
|
||||
x: *const f32,
|
||||
gamma: *const f32,
|
||||
dy: *const f32,
|
||||
inv_rms: *const f32,
|
||||
dx: *mut f32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_rms_norm_dgamma_f32(
|
||||
x: *const f32,
|
||||
dy: *const f32,
|
||||
inv_rms: *const f32,
|
||||
dgamma: *mut f32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
|
||||
// SiLU: y = x*sigmoid(x); backward dx.
|
||||
pub fn launch_silu_f32(x: *const f32, y: *mut f32, n: i32, s: CudaStream);
|
||||
pub fn launch_silu_dx_f32(x: *const f32, dy: *const f32, dx: *mut f32, n: i32, s: CudaStream);
|
||||
|
||||
// RoPE (rotate_half), x:[tokens,heads,head_dim], position = token index.
|
||||
pub fn launch_rope_f32(
|
||||
x: *const f32,
|
||||
y: *mut f32,
|
||||
tokens: i32,
|
||||
heads: i32,
|
||||
head_dim: i32,
|
||||
theta: f32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_rope_dx_f32(
|
||||
dy: *const f32,
|
||||
dx: *mut f32,
|
||||
tokens: i32,
|
||||
heads: i32,
|
||||
head_dim: i32,
|
||||
theta: f32,
|
||||
s: CudaStream,
|
||||
);
|
||||
|
||||
// Row-wise softmax + Jacobian backward.
|
||||
pub fn launch_softmax_f32(x: *const f32, y: *mut f32, rows: i32, cols: i32, s: CudaStream);
|
||||
pub fn launch_softmax_dx_f32(
|
||||
y: *const f32,
|
||||
dy: *const f32,
|
||||
dx: *mut f32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
|
||||
// Cross-entropy: fwd writes probs[rows,cols] + per-row loss[rows];
|
||||
// bwd dx = scale*(probs - onehot).
|
||||
pub fn launch_cross_entropy_fwd_f32(
|
||||
x: *const f32,
|
||||
target: *const i32,
|
||||
probs: *mut f32,
|
||||
loss: *mut f32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_cross_entropy_dx_f32(
|
||||
probs: *const f32,
|
||||
target: *const i32,
|
||||
dx: *mut f32,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
scale: f32,
|
||||
s: CudaStream,
|
||||
);
|
||||
}
|
||||
|
||||
// cuBLAS — used ONLY as a correctness reference for the hand-written GEMM in
|
||||
// tests. Declared (and linked, see build.rs) only when CUDA is compiled in.
|
||||
#[cfg(not(no_cuda))]
|
||||
|
||||
@@ -7,18 +7,22 @@
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
F32,
|
||||
/// 32-bit signed integers. Used for cross-entropy targets (token ids).
|
||||
I32,
|
||||
}
|
||||
|
||||
impl DType {
|
||||
pub fn size_bytes(self) -> usize {
|
||||
match self {
|
||||
DType::F32 => 4,
|
||||
DType::I32 => 4,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(self) -> &'static str {
|
||||
match self {
|
||||
DType::F32 => "f32",
|
||||
DType::I32 => "i32",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -45,3 +49,13 @@ impl TensorDType for f32 {
|
||||
v as f32
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorDType for i32 {
|
||||
const DTYPE: DType = DType::I32;
|
||||
fn to_f64(self) -> f64 {
|
||||
self as f64
|
||||
}
|
||||
fn from_f64(v: f64) -> Self {
|
||||
v as i32
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,6 +247,334 @@ impl Tensor {
|
||||
let db = a.transpose_2d().matmul(dc); // [K,M] @ [M,N] = [K,N]
|
||||
(da, db)
|
||||
}
|
||||
|
||||
// --- Transformer / autograd op primitives (the T4 kernels) ---
|
||||
//
|
||||
// Each is a thin, contiguous-F32-on-GPU wrapper over a kernel in
|
||||
// csrc/ops/nn.cu. The autograd `Var` layer (xtrain-autodiff) builds nodes on
|
||||
// top of these; the analytic backwards are derived in docs/03-autograd-engine.md.
|
||||
|
||||
/// Elementwise `out = self + other` (same shape).
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn add(&self, other: &Tensor) -> Self {
|
||||
self.check_binary(other, "add");
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_add_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
other.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
self.numel() as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("add sync failed");
|
||||
out
|
||||
}
|
||||
|
||||
/// Elementwise `out = self * other` (same shape, Hadamard product).
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn mul(&self, other: &Tensor) -> Self {
|
||||
self.check_binary(other, "mul");
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_mul_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
other.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
self.numel() as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("mul sync failed");
|
||||
out
|
||||
}
|
||||
|
||||
/// Broadcast bias add: `out[r,c] = self[r,c] + bias[c]`.
|
||||
/// `self`:[rows,cols], `bias`:[cols].
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn add_bias(&self, bias: &Tensor) -> Self {
|
||||
assert_eq!(self.ndim(), 2, "add_bias requires 2D input");
|
||||
assert_eq!(bias.ndim(), 1, "bias must be 1D");
|
||||
assert_eq!(self.shape[1], bias.shape[0], "bias len != cols");
|
||||
let (rows, cols) = (self.shape[0], self.shape[1]);
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_add_bias_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
bias.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("add_bias sync failed");
|
||||
out
|
||||
}
|
||||
|
||||
/// Column-sum over rows: `out[c] = sum_r self[r,c]`. This is the bias
|
||||
/// backward (sum the upstream grad over the broadcast dim). `self`:[rows,cols]
|
||||
/// → [cols].
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn sum_rows(&self) -> Self {
|
||||
assert_eq!(self.ndim(), 2, "sum_rows requires 2D input");
|
||||
let (rows, cols) = (self.shape[0], self.shape[1]);
|
||||
let out = Tensor::zeros(&[cols], DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_sum_rows_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("sum_rows sync failed");
|
||||
out
|
||||
}
|
||||
|
||||
/// RMSNorm forward: `y[r,c] = x[r,c] * inv_rms[r] * gamma[c]` with
|
||||
/// `inv_rms = rsqrt(mean(x²) + eps)`. `self`:[rows,cols], `gamma`:[cols].
|
||||
/// Returns `(y, inv_rms)`; `inv_rms`:[rows] is cached for backward.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn rms_norm(&self, gamma: &Tensor, eps: f32) -> (Tensor, Tensor) {
|
||||
assert_eq!(self.ndim(), 2, "rms_norm requires 2D input");
|
||||
assert_eq!(gamma.ndim(), 1, "gamma must be 1D");
|
||||
assert_eq!(self.shape[1], gamma.shape[0], "gamma len != cols");
|
||||
let (rows, cols) = (self.shape[0], self.shape[1]);
|
||||
let y = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
let inv_rms = Tensor::zeros(&[rows], DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_rms_norm_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
gamma.data_ptr() as *const f32,
|
||||
y.data_ptr() as *mut f32,
|
||||
inv_rms.data_ptr() as *mut f32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
eps,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("rms_norm sync failed");
|
||||
(y, inv_rms)
|
||||
}
|
||||
|
||||
/// RMSNorm backward. Inputs are the forward `x`, `gamma`, upstream `dy`, and
|
||||
/// the cached `inv_rms`. Returns `(dx, dgamma)`.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn rms_norm_backward(
|
||||
x: &Tensor,
|
||||
gamma: &Tensor,
|
||||
dy: &Tensor,
|
||||
inv_rms: &Tensor,
|
||||
) -> (Tensor, Tensor) {
|
||||
let (rows, cols) = (x.shape[0], x.shape[1]);
|
||||
let dx = Tensor::zeros(&[rows, cols], DType::F32, x.device());
|
||||
let dgamma = Tensor::zeros(&[cols], DType::F32, x.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_rms_norm_dx_f32(
|
||||
x.data_ptr() as *const f32,
|
||||
gamma.data_ptr() as *const f32,
|
||||
dy.data_ptr() as *const f32,
|
||||
inv_rms.data_ptr() as *const f32,
|
||||
dx.data_ptr() as *mut f32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
xtrain_cuda::ffi::launch_rms_norm_dgamma_f32(
|
||||
x.data_ptr() as *const f32,
|
||||
dy.data_ptr() as *const f32,
|
||||
inv_rms.data_ptr() as *const f32,
|
||||
dgamma.data_ptr() as *mut f32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("rms_norm_backward sync failed");
|
||||
(dx, dgamma)
|
||||
}
|
||||
|
||||
/// SiLU forward: `y = x * sigmoid(x)`, elementwise.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn silu(&self) -> Self {
|
||||
assert_eq!(self.dtype, DType::F32, "silu only supports F32");
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_silu_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
self.numel() as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("silu sync failed");
|
||||
out
|
||||
}
|
||||
|
||||
/// SiLU backward: `dx = dy * (sig + x*sig*(1-sig))`, `sig = sigmoid(x)`.
|
||||
/// Inputs are the forward `x` and upstream `dy`.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn silu_backward(x: &Tensor, dy: &Tensor) -> Self {
|
||||
let dx = Tensor::zeros(&x.shape, DType::F32, x.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_silu_dx_f32(
|
||||
x.data_ptr() as *const f32,
|
||||
dy.data_ptr() as *const f32,
|
||||
dx.data_ptr() as *mut f32,
|
||||
x.numel() as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("silu_backward sync failed");
|
||||
dx
|
||||
}
|
||||
|
||||
/// RoPE forward (rotate_half). `self`:[tokens,heads,head_dim]; the position
|
||||
/// of each token is its row index. Returns the rotated tensor.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn rope(&self, theta: f32) -> Self {
|
||||
assert_eq!(self.ndim(), 3, "rope requires [tokens,heads,head_dim]");
|
||||
let (tokens, heads, head_dim) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||
assert_eq!(head_dim % 2, 0, "head_dim must be even");
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_rope_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
tokens as i32,
|
||||
heads as i32,
|
||||
head_dim as i32,
|
||||
theta,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("rope sync failed");
|
||||
out
|
||||
}
|
||||
|
||||
/// RoPE backward: apply the inverse (transpose) rotation to `dy`. RoPE is an
|
||||
/// orthogonal map, so it needs no cached forward values, only `theta`.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn rope_backward(dy: &Tensor, theta: f32) -> Self {
|
||||
let (tokens, heads, head_dim) = (dy.shape[0], dy.shape[1], dy.shape[2]);
|
||||
let dx = Tensor::zeros(&dy.shape, DType::F32, dy.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_rope_dx_f32(
|
||||
dy.data_ptr() as *const f32,
|
||||
dx.data_ptr() as *mut f32,
|
||||
tokens as i32,
|
||||
heads as i32,
|
||||
head_dim as i32,
|
||||
theta,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("rope_backward sync failed");
|
||||
dx
|
||||
}
|
||||
|
||||
/// Row-wise safe softmax over the last dim. `self`:[rows,cols].
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn softmax(&self) -> Self {
|
||||
assert_eq!(self.ndim(), 2, "softmax requires 2D input");
|
||||
let (rows, cols) = (self.shape[0], self.shape[1]);
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_softmax_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("softmax sync failed");
|
||||
out
|
||||
}
|
||||
|
||||
/// Softmax backward (Jacobian): `dx[r,c] = y[r,c]*(dy[r,c] - sum_c'(dy*y))`.
|
||||
/// Inputs are the forward output `y` and upstream `dy`.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn softmax_backward(y: &Tensor, dy: &Tensor) -> Self {
|
||||
let (rows, cols) = (y.shape[0], y.shape[1]);
|
||||
let dx = Tensor::zeros(&y.shape, DType::F32, y.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_softmax_dx_f32(
|
||||
y.data_ptr() as *const f32,
|
||||
dy.data_ptr() as *const f32,
|
||||
dx.data_ptr() as *mut f32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("softmax_backward sync failed");
|
||||
dx
|
||||
}
|
||||
|
||||
/// Cross-entropy forward over logits `self`:[rows,cols] with one I32 target
|
||||
/// per row. Returns `(probs, loss)` where `probs`:[rows,cols] is the softmax
|
||||
/// (cached for backward) and `loss`:[rows] is the per-row negative log-likelihood.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn cross_entropy(&self, target: &Tensor) -> (Tensor, Tensor) {
|
||||
assert_eq!(self.ndim(), 2, "cross_entropy requires 2D logits");
|
||||
assert_eq!(target.dtype, DType::I32, "target must be I32");
|
||||
assert_eq!(target.numel(), self.shape[0], "one target per row");
|
||||
let (rows, cols) = (self.shape[0], self.shape[1]);
|
||||
let probs = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
let loss = Tensor::zeros(&[rows], DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_cross_entropy_fwd_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
target.data_ptr() as *const i32,
|
||||
probs.data_ptr() as *mut f32,
|
||||
loss.data_ptr() as *mut f32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("cross_entropy sync failed");
|
||||
(probs, loss)
|
||||
}
|
||||
|
||||
/// Cross-entropy backward: `dx = scale * (probs - onehot(target))`. With
|
||||
/// `scale = upstream / rows`, this is the gradient of the mean per-row loss.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn cross_entropy_backward(probs: &Tensor, target: &Tensor, scale: f32) -> Self {
|
||||
let (rows, cols) = (probs.shape[0], probs.shape[1]);
|
||||
let dx = Tensor::zeros(&probs.shape, DType::F32, probs.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_cross_entropy_dx_f32(
|
||||
probs.data_ptr() as *const f32,
|
||||
target.data_ptr() as *const i32,
|
||||
dx.data_ptr() as *mut f32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
scale,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("cross_entropy_backward sync failed");
|
||||
dx
|
||||
}
|
||||
|
||||
// Shared validation for same-shape binary elementwise ops.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn check_binary(&self, other: &Tensor, op: &str) {
|
||||
assert_eq!(self.dtype, DType::F32, "{op} only supports F32");
|
||||
assert_eq!(other.dtype, DType::F32, "{op} only supports F32");
|
||||
assert_eq!(self.shape(), other.shape(), "{op} shape mismatch");
|
||||
assert_eq!(self.device(), other.device(), "{op} device mismatch");
|
||||
assert!(
|
||||
self.is_contiguous() && other.is_contiguous(),
|
||||
"{op} requires contiguous tensors"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Tensor {
|
||||
|
||||
Reference in New Issue
Block a user