ops: embedding/reshape/transpose/split-merge-heads fwd+bwd

Phase T5 structural ops on top of the T4 set, needed to assemble the
tiny transformer:
- embedding: gather rows by I32 ids (CUDA kernel) / scatter-add backward
  (atomic, so repeated ids accumulate). csrc/ops/model.cu + ffi.
- reshape: contiguous metadata-only view (Tensor::reshape), no kernel.
- transpose_3d01: [a,b,c]->[b,a,c] for the multi-head layout (kernel).
- autograd nodes: embedding/reshape/transpose_3d01/transpose_2d, plus
  split_heads (->Vec<Var>) / merge_heads for per-head attention.
- tape: Var::zero_grad + set_value so a hand-written GD step can update
  params and clear grads between steps.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 16:05:09 +08:00
parent 777f3c7949
commit 7fb1a29057
6 changed files with 327 additions and 0 deletions

66
csrc/ops/model.cu Normal file
View File

@@ -0,0 +1,66 @@
// Structural ops the tiny transformer (Phase T5) needs on top of the T4 op set:
// token embedding (gather forward / scatter-add backward) and a 3D axis-(0,1)
// transpose used to lay out multi-head attention ([seq,heads,hd] <-> [heads,seq,hd]).
//
// reshape is a pure metadata change (no data movement) and so has no kernel — it
// lives entirely in the Rust Tensor layer. All kernels here are F32 row-major
// contiguous; ids are I32. Each launcher matches the existing csrc/ style.
extern "C" {
// =====================================================================
// Embedding: gather rows of a table by integer ids.
// table:[vocab, dim], ids:[seq] (I32) -> out[s,:] = table[ids[s], :]
// Backward (scatter-add): dtable[ids[s], :] += dout[s, :]. Multiple positions
// may map to the same id, so the accumulation must be atomic.
// =====================================================================
__global__ void embedding_fwd_k(const float* table, const int* ids, float* out,
int seq, int dim) {
int i = blockIdx.x * blockDim.x + threadIdx.x; // over seq*dim
if (i >= seq * dim) return;
int s = i / dim, c = i % dim;
out[i] = table[ids[s] * dim + c];
}
void launch_embedding_fwd_f32(const float* table, const int* ids, float* out,
int seq, int dim, void* s) {
int n = seq * dim, blk = 256, grid = (n + blk - 1) / blk;
embedding_fwd_k<<<grid, blk, 0, (cudaStream_t)s>>>(table, ids, out, seq, dim);
}
// dtable is assumed pre-zeroed (Tensor::zeros). Scatter-add with atomics so
// repeated ids accumulate correctly.
__global__ void embedding_bwd_k(const float* dout, const int* ids, float* dtable,
int seq, int dim) {
int i = blockIdx.x * blockDim.x + threadIdx.x; // over seq*dim
if (i >= seq * dim) return;
int s = i / dim, c = i % dim;
atomicAdd(&dtable[ids[s] * dim + c], dout[i]);
}
void launch_embedding_bwd_f32(const float* dout, const int* ids, float* dtable,
int seq, int dim, void* s) {
int n = seq * dim, blk = 256, grid = (n + blk - 1) / blk;
embedding_bwd_k<<<grid, blk, 0, (cudaStream_t)s>>>(dout, ids, dtable, seq, dim);
}
// =====================================================================
// 3D axis-(0,1) transpose: in:[a,b,c] -> out:[b,a,c] (last dim contiguous).
// out[j, i, k] = in[i, j, k]
// Its own backward is the same op with (a,b) swapped, so one kernel suffices.
// =====================================================================
__global__ void transpose_3d01_k(const float* in, float* out, int a, int b, int c) {
int idx = blockIdx.x * blockDim.x + threadIdx.x; // over a*b*c
if (idx >= a * b * c) return;
int k = idx % c;
int j = (idx / c) % b;
int i = idx / (b * c);
// out index: ((j*a) + i)*c + k
out[(j * a + i) * c + k] = in[idx];
}
void launch_transpose_3d01_f32(const float* in, float* out, int a, int b, int c, void* s) {
int n = a * b * c, blk = 256, grid = (n + blk - 1) / blk;
transpose_3d01_k<<<grid, blk, 0, (cudaStream_t)s>>>(in, out, a, b, c);
}
} // extern "C"