ops: embedding/reshape/transpose/split-merge-heads fwd+bwd
Phase T5 structural ops on top of the T4 set, needed to assemble the tiny transformer: - embedding: gather rows by I32 ids (CUDA kernel) / scatter-add backward (atomic, so repeated ids accumulate). csrc/ops/model.cu + ffi. - reshape: contiguous metadata-only view (Tensor::reshape), no kernel. - transpose_3d01: [a,b,c]->[b,a,c] for the multi-head layout (kernel). - autograd nodes: embedding/reshape/transpose_3d01/transpose_2d, plus split_heads (->Vec<Var>) / merge_heads for per-head attention. - tape: Var::zero_grad + set_value so a hand-written GD step can update params and clear grads between steps. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -146,6 +146,126 @@ pub fn softmax(x: &Var) -> Var {
|
||||
)
|
||||
}
|
||||
|
||||
/// Token embedding gather: `out[s,:] = table[ids[s], :]`. `table`:[vocab,dim]
|
||||
/// (a learnable [`Var`]), `ids`:[seq] I32 (a constant index, not a `Var`).
|
||||
/// Backward scatter-adds the upstream grad back into the table rows.
|
||||
pub fn embedding(table: &Var, ids: &Tensor) -> Var {
|
||||
let out = table.value().embedding(ids);
|
||||
let vocab = table.value().shape()[0];
|
||||
let ids = ids.clone();
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![table.clone()],
|
||||
Box::new(move |dout, parents| {
|
||||
let dtable = Tensor::embedding_backward(dout, &ids, vocab);
|
||||
Var::push_grad(&parents[0], dtable);
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Reshape (contiguous, metadata-only). Backward reshapes the grad back to the
|
||||
/// input shape. Used for the multi-head layout swap `[seq, h*hd] <-> [seq, h, hd]`.
|
||||
pub fn reshape(x: &Var, new_shape: &[usize]) -> Var {
|
||||
let in_shape: Vec<usize> = x.value().shape().to_vec();
|
||||
let out = x.value().reshape(new_shape);
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![x.clone()],
|
||||
Box::new(move |d, parents| {
|
||||
Var::push_grad(&parents[0], d.reshape(&in_shape));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// 3D axis-(0,1) transpose `[a,b,c] -> [b,a,c]`. Self-inverse structure: the
|
||||
/// backward is the same transpose applied to the grad.
|
||||
pub fn transpose_3d01(x: &Var) -> Var {
|
||||
let out = x.value().transpose_3d01();
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![x.clone()],
|
||||
Box::new(|d, parents| {
|
||||
Var::push_grad(&parents[0], d.transpose_3d01());
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// 2D transpose `[r,c] -> [c,r]` as an autograd node (backward transposes the
|
||||
/// grad back). Used for `Kᵀ` in attention scores.
|
||||
pub fn transpose_2d(x: &Var) -> Var {
|
||||
let out = x.value().transpose_2d();
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![x.clone()],
|
||||
Box::new(|d, parents| {
|
||||
Var::push_grad(&parents[0], d.transpose_2d());
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Split a `[heads, seq, head_dim]` tensor into one `[seq, head_dim]` [`Var`] per
|
||||
/// head. Each head block is contiguous in this layout, so the forward copies the
|
||||
/// head block into its own contiguous tensor; the backward scatters each head's
|
||||
/// grad back into a zero `[heads, seq, head_dim]` grad (the engine then SUMs the
|
||||
/// `heads` contributions on the shared parent — fan-out).
|
||||
pub fn split_heads(x: &Var) -> Vec<Var> {
|
||||
let v = x.value();
|
||||
assert_eq!(v.ndim(), 3, "split_heads requires [heads,seq,head_dim]");
|
||||
let (heads, seq, hd) = (v.shape()[0], v.shape()[1], v.shape()[2]);
|
||||
let dev = v.device();
|
||||
let flat_host = v.to_device(xtrain_tensor::Device::Cpu);
|
||||
let flat = flat_host.as_slice::<f32>();
|
||||
(0..heads)
|
||||
.map(|h| {
|
||||
let base = h * seq * hd;
|
||||
let block = Tensor::from_slice(&flat[base..base + seq * hd], &[seq, hd]).to_device(dev);
|
||||
Var::from_op(
|
||||
block,
|
||||
vec![x.clone()],
|
||||
Box::new(move |d, parents| {
|
||||
let mut host = vec![0.0f32; heads * seq * hd];
|
||||
let dvals = d.to_device(xtrain_tensor::Device::Cpu);
|
||||
let base = h * seq * hd;
|
||||
host[base..base + seq * hd].copy_from_slice(dvals.as_slice::<f32>());
|
||||
let g = Tensor::from_slice(&host, &[heads, seq, hd]).to_device(dev);
|
||||
Var::push_grad(&parents[0], g);
|
||||
}),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Inverse of [`split_heads`]: stack per-head `[seq, head_dim]` outputs into a
|
||||
/// `[heads, seq, head_dim]` tensor. Backward hands each head its own slice of the
|
||||
/// grad.
|
||||
pub fn merge_heads(heads_v: &[Var]) -> Var {
|
||||
let heads = heads_v.len();
|
||||
let v0 = heads_v[0].value();
|
||||
let (seq, hd) = (v0.shape()[0], v0.shape()[1]);
|
||||
let dev = v0.device();
|
||||
let mut host = vec![0.0f32; heads * seq * hd];
|
||||
for (h, hv) in heads_v.iter().enumerate() {
|
||||
let block = hv.value().to_device(xtrain_tensor::Device::Cpu);
|
||||
let base = h * seq * hd;
|
||||
host[base..base + seq * hd].copy_from_slice(block.as_slice::<f32>());
|
||||
}
|
||||
let out = Tensor::from_slice(&host, &[heads, seq, hd]).to_device(dev);
|
||||
Var::from_op(
|
||||
out,
|
||||
heads_v.to_vec(),
|
||||
Box::new(move |d, parents| {
|
||||
let dhost = d.to_device(xtrain_tensor::Device::Cpu);
|
||||
let dflat = dhost.as_slice::<f32>();
|
||||
for (h, parent) in parents.iter().enumerate() {
|
||||
let base = h * seq * hd;
|
||||
let g =
|
||||
Tensor::from_slice(&dflat[base..base + seq * hd], &[seq, hd]).to_device(dev);
|
||||
Var::push_grad(parent, g);
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
|
||||
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
|
||||
/// scaled by the upstream scalar grad.
|
||||
|
||||
@@ -68,6 +68,19 @@ impl Var {
|
||||
self.0.borrow().grad.clone()
|
||||
}
|
||||
|
||||
/// Clear the accumulated gradient. Call on every parameter between training
|
||||
/// steps so the next `backward` accumulates from zero (grads SUM otherwise).
|
||||
pub fn zero_grad(&self) {
|
||||
self.0.borrow_mut().grad = None;
|
||||
}
|
||||
|
||||
/// Overwrite this node's value tensor in place. Used by the optimizer to
|
||||
/// apply a parameter update (`p ← p − lr·grad`) while keeping the leaf's
|
||||
/// identity stable across steps.
|
||||
pub fn set_value(&self, value: Tensor) {
|
||||
self.0.borrow_mut().value = value;
|
||||
}
|
||||
|
||||
/// Pointer identity, used to dedup nodes during the topological sort.
|
||||
fn id(&self) -> *const RefCell<VarNode> {
|
||||
Rc::as_ptr(&self.0)
|
||||
|
||||
@@ -33,6 +33,7 @@ fn main() {
|
||||
.file("../../csrc/ops/elementwise.cu")
|
||||
.file("../../csrc/ops/gemm.cu")
|
||||
.file("../../csrc/ops/nn.cu")
|
||||
.file("../../csrc/ops/model.cu")
|
||||
.compile("xtrain_cuda_kernels");
|
||||
}
|
||||
|
||||
|
||||
@@ -177,6 +177,41 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
// Structural ops for the tiny transformer (csrc/ops/model.cu): token embedding
|
||||
// (gather fwd / scatter-add bwd) and a 3D axis-(0,1) transpose for the multi-head
|
||||
// attention layout. F32 values, I32 ids, row-major contiguous.
|
||||
#[cfg(not(no_cuda))]
|
||||
unsafe extern "C" {
|
||||
// Embedding: out[s,:] = table[ids[s], :]. table:[vocab,dim], ids:[seq] (I32).
|
||||
pub fn launch_embedding_fwd_f32(
|
||||
table: *const f32,
|
||||
ids: *const i32,
|
||||
out: *mut f32,
|
||||
seq: i32,
|
||||
dim: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// Scatter-add: dtable[ids[s],:] += dout[s,:] (dtable pre-zeroed; atomic).
|
||||
pub fn launch_embedding_bwd_f32(
|
||||
dout: *const f32,
|
||||
ids: *const i32,
|
||||
dtable: *mut f32,
|
||||
seq: i32,
|
||||
dim: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
|
||||
// 3D axis-(0,1) transpose: in:[a,b,c] -> out:[b,a,c]. out[j,i,k]=in[i,j,k].
|
||||
pub fn launch_transpose_3d01_f32(
|
||||
input: *const f32,
|
||||
out: *mut f32,
|
||||
a: i32,
|
||||
b: i32,
|
||||
c: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
}
|
||||
|
||||
// cuBLAS — used ONLY as a correctness reference for the hand-written GEMM in
|
||||
// tests. Declared (and linked, see build.rs) only when CUDA is compiled in.
|
||||
#[cfg(not(no_cuda))]
|
||||
|
||||
@@ -563,6 +563,98 @@ impl Tensor {
|
||||
dx
|
||||
}
|
||||
|
||||
// --- Structural / model ops (the T5 kernels) ---
|
||||
|
||||
/// Reshape to `new_shape` (must keep `numel`). Pure metadata change on a
|
||||
/// contiguous tensor — no data movement, shares the same storage. The
|
||||
/// multi-head layout `[seq, n_heads*head_dim] <-> [seq, n_heads, head_dim]`
|
||||
/// is exactly this.
|
||||
pub fn reshape(&self, new_shape: &[usize]) -> Self {
|
||||
assert!(self.is_contiguous(), "reshape requires a contiguous tensor");
|
||||
assert_eq!(
|
||||
shape::num_elements(new_shape),
|
||||
self.numel(),
|
||||
"reshape numel mismatch: {:?} -> {:?}",
|
||||
self.shape.as_slice(),
|
||||
new_shape
|
||||
);
|
||||
Self {
|
||||
storage: self.storage.clone(),
|
||||
shape: Dims::from_slice(new_shape),
|
||||
strides: shape::contiguous_strides(new_shape),
|
||||
offset: self.offset,
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
/// Embedding gather: `out[s,:] = self[ids[s], :]`. `self`:[vocab,dim] table,
|
||||
/// `ids`:[seq] I32 → out:[seq,dim].
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn embedding(&self, ids: &Tensor) -> Self {
|
||||
assert_eq!(self.dtype, DType::F32, "embedding table must be F32");
|
||||
assert_eq!(self.ndim(), 2, "embedding table must be [vocab,dim]");
|
||||
assert_eq!(ids.dtype, DType::I32, "embedding ids must be I32");
|
||||
assert_eq!(ids.ndim(), 1, "embedding ids must be 1D");
|
||||
let (seq, dim) = (ids.shape[0], self.shape[1]);
|
||||
let out = Tensor::zeros(&[seq, dim], DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_embedding_fwd_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
ids.data_ptr() as *const i32,
|
||||
out.data_ptr() as *mut f32,
|
||||
seq as i32,
|
||||
dim as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("embedding sync failed");
|
||||
out
|
||||
}
|
||||
|
||||
/// Embedding backward (scatter-add): `dtable[ids[s],:] += dout[s,:]`, where
|
||||
/// `dout`:[seq,dim], `ids`:[seq] I32. `vocab` sizes the output table.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn embedding_backward(dout: &Tensor, ids: &Tensor, vocab: usize) -> Self {
|
||||
let (seq, dim) = (dout.shape[0], dout.shape[1]);
|
||||
let dtable = Tensor::zeros(&[vocab, dim], DType::F32, dout.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_embedding_bwd_f32(
|
||||
dout.data_ptr() as *const f32,
|
||||
ids.data_ptr() as *const i32,
|
||||
dtable.data_ptr() as *mut f32,
|
||||
seq as i32,
|
||||
dim as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("embedding_backward sync failed");
|
||||
dtable
|
||||
}
|
||||
|
||||
/// 3D axis-(0,1) transpose: `self`:[a,b,c] → [b,a,c], `out[j,i,k]=self[i,j,k]`.
|
||||
/// Lays out multi-head attention (`[seq,heads,hd] <-> [heads,seq,hd]`). Its
|
||||
/// own backward is the same op (swap a,b).
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn transpose_3d01(&self) -> Self {
|
||||
assert_eq!(self.dtype, DType::F32, "transpose_3d01 only supports F32");
|
||||
assert_eq!(self.ndim(), 3, "transpose_3d01 requires a 3D tensor");
|
||||
assert!(self.is_contiguous(), "transpose_3d01 requires contiguous");
|
||||
let (a, b, c) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||
let out = Tensor::zeros(&[b, a, c], DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_transpose_3d01_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
a as i32,
|
||||
b as i32,
|
||||
c as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("transpose_3d01 sync failed");
|
||||
out
|
||||
}
|
||||
|
||||
// Shared validation for same-shape binary elementwise ops.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn check_binary(&self, other: &Tensor, op: &str) {
|
||||
|
||||
66
csrc/ops/model.cu
Normal file
66
csrc/ops/model.cu
Normal file
@@ -0,0 +1,66 @@
|
||||
// Structural ops the tiny transformer (Phase T5) needs on top of the T4 op set:
|
||||
// token embedding (gather forward / scatter-add backward) and a 3D axis-(0,1)
|
||||
// transpose used to lay out multi-head attention ([seq,heads,hd] <-> [heads,seq,hd]).
|
||||
//
|
||||
// reshape is a pure metadata change (no data movement) and so has no kernel — it
|
||||
// lives entirely in the Rust Tensor layer. All kernels here are F32 row-major
|
||||
// contiguous; ids are I32. Each launcher matches the existing csrc/ style.
|
||||
|
||||
extern "C" {
|
||||
|
||||
// =====================================================================
|
||||
// Embedding: gather rows of a table by integer ids.
|
||||
// table:[vocab, dim], ids:[seq] (I32) -> out[s,:] = table[ids[s], :]
|
||||
// Backward (scatter-add): dtable[ids[s], :] += dout[s, :]. Multiple positions
|
||||
// may map to the same id, so the accumulation must be atomic.
|
||||
// =====================================================================
|
||||
|
||||
__global__ void embedding_fwd_k(const float* table, const int* ids, float* out,
|
||||
int seq, int dim) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x; // over seq*dim
|
||||
if (i >= seq * dim) return;
|
||||
int s = i / dim, c = i % dim;
|
||||
out[i] = table[ids[s] * dim + c];
|
||||
}
|
||||
void launch_embedding_fwd_f32(const float* table, const int* ids, float* out,
|
||||
int seq, int dim, void* s) {
|
||||
int n = seq * dim, blk = 256, grid = (n + blk - 1) / blk;
|
||||
embedding_fwd_k<<<grid, blk, 0, (cudaStream_t)s>>>(table, ids, out, seq, dim);
|
||||
}
|
||||
|
||||
// dtable is assumed pre-zeroed (Tensor::zeros). Scatter-add with atomics so
|
||||
// repeated ids accumulate correctly.
|
||||
__global__ void embedding_bwd_k(const float* dout, const int* ids, float* dtable,
|
||||
int seq, int dim) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x; // over seq*dim
|
||||
if (i >= seq * dim) return;
|
||||
int s = i / dim, c = i % dim;
|
||||
atomicAdd(&dtable[ids[s] * dim + c], dout[i]);
|
||||
}
|
||||
void launch_embedding_bwd_f32(const float* dout, const int* ids, float* dtable,
|
||||
int seq, int dim, void* s) {
|
||||
int n = seq * dim, blk = 256, grid = (n + blk - 1) / blk;
|
||||
embedding_bwd_k<<<grid, blk, 0, (cudaStream_t)s>>>(dout, ids, dtable, seq, dim);
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
// 3D axis-(0,1) transpose: in:[a,b,c] -> out:[b,a,c] (last dim contiguous).
|
||||
// out[j, i, k] = in[i, j, k]
|
||||
// Its own backward is the same op with (a,b) swapped, so one kernel suffices.
|
||||
// =====================================================================
|
||||
|
||||
__global__ void transpose_3d01_k(const float* in, float* out, int a, int b, int c) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x; // over a*b*c
|
||||
if (idx >= a * b * c) return;
|
||||
int k = idx % c;
|
||||
int j = (idx / c) % b;
|
||||
int i = idx / (b * c);
|
||||
// out index: ((j*a) + i)*c + k
|
||||
out[(j * a + i) * c + k] = in[idx];
|
||||
}
|
||||
void launch_transpose_3d01_f32(const float* in, float* out, int a, int b, int c, void* s) {
|
||||
int n = a * b * c, blk = 256, grid = (n + blk - 1) / blk;
|
||||
transpose_3d01_k<<<grid, blk, 0, (cudaStream_t)s>>>(in, out, a, b, c);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
Reference in New Issue
Block a user