Add the batched-forward primitives. Linears/norms/elementwise/embedding/CE already act on flat [rows,dim], so they work unchanged on [B*S,dim]; only attention + RoPE need sequence awareness: - RoPE: kernel takes a `period` (= seq len) so position = row % period, i.e. per-sequence position on a flattened batch (period == tokens = single seq). - Fused batched causal attention: new `Tensor::attention`/`attention_backward` + ops node, running QKᵀ and PV as cublasSgemmStridedBatched over the B*nh (sequence,head) blocks (new sgemm_strided_batched binding) and a causal softmax kernel (scale + per-row causal mask inline) — the whole attention is 3 launches regardless of B*nh, no per-head/per-seq loop, no host round-trip. - transpose_4d12 ([B,S,nh,hd] <-> [B,nh,S,hd]) to lay out the batched heads. grad-checks: new batched-rope, transpose_4d12, batched-attention dQ/dK/dV all pass finite-diff (attn dK 1.5e-2, dQ 7.5e-3, dV 2.9e-4; rest tighter) alongside the existing 12. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
89 lines
4.2 KiB
Plaintext
89 lines
4.2 KiB
Plaintext
// Structural ops the tiny transformer (Phase T5) needs on top of the T4 op set:
|
|
// token embedding (gather forward / scatter-add backward) and a 3D axis-(0,1)
|
|
// transpose used to lay out multi-head attention ([seq,heads,hd] <-> [heads,seq,hd]).
|
|
//
|
|
// reshape is a pure metadata change (no data movement) and so has no kernel — it
|
|
// lives entirely in the Rust Tensor layer. All kernels here are F32 row-major
|
|
// contiguous; ids are I32. Each launcher matches the existing csrc/ style.
|
|
|
|
extern "C" {
|
|
|
|
// =====================================================================
|
|
// Embedding: gather rows of a table by integer ids.
|
|
// table:[vocab, dim], ids:[seq] (I32) -> out[s,:] = table[ids[s], :]
|
|
// Backward (scatter-add): dtable[ids[s], :] += dout[s, :]. Multiple positions
|
|
// may map to the same id, so the accumulation must be atomic.
|
|
// =====================================================================
|
|
|
|
__global__ void embedding_fwd_k(const float* table, const int* ids, float* out,
|
|
int seq, int dim) {
|
|
int i = blockIdx.x * blockDim.x + threadIdx.x; // over seq*dim
|
|
if (i >= seq * dim) return;
|
|
int s = i / dim, c = i % dim;
|
|
out[i] = table[ids[s] * dim + c];
|
|
}
|
|
void launch_embedding_fwd_f32(const float* table, const int* ids, float* out,
|
|
int seq, int dim, void* s) {
|
|
int n = seq * dim, blk = 256, grid = (n + blk - 1) / blk;
|
|
embedding_fwd_k<<<grid, blk, 0, (cudaStream_t)s>>>(table, ids, out, seq, dim);
|
|
}
|
|
|
|
// dtable is assumed pre-zeroed (Tensor::zeros). Scatter-add with atomics so
|
|
// repeated ids accumulate correctly.
|
|
__global__ void embedding_bwd_k(const float* dout, const int* ids, float* dtable,
|
|
int seq, int dim) {
|
|
int i = blockIdx.x * blockDim.x + threadIdx.x; // over seq*dim
|
|
if (i >= seq * dim) return;
|
|
int s = i / dim, c = i % dim;
|
|
atomicAdd(&dtable[ids[s] * dim + c], dout[i]);
|
|
}
|
|
void launch_embedding_bwd_f32(const float* dout, const int* ids, float* dtable,
|
|
int seq, int dim, void* s) {
|
|
int n = seq * dim, blk = 256, grid = (n + blk - 1) / blk;
|
|
embedding_bwd_k<<<grid, blk, 0, (cudaStream_t)s>>>(dout, ids, dtable, seq, dim);
|
|
}
|
|
|
|
// =====================================================================
|
|
// 3D axis-(0,1) transpose: in:[a,b,c] -> out:[b,a,c] (last dim contiguous).
|
|
// out[j, i, k] = in[i, j, k]
|
|
// Its own backward is the same op with (a,b) swapped, so one kernel suffices.
|
|
// =====================================================================
|
|
|
|
__global__ void transpose_3d01_k(const float* in, float* out, int a, int b, int c) {
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x; // over a*b*c
|
|
if (idx >= a * b * c) return;
|
|
int k = idx % c;
|
|
int j = (idx / c) % b;
|
|
int i = idx / (b * c);
|
|
// out index: ((j*a) + i)*c + k
|
|
out[(j * a + i) * c + k] = in[idx];
|
|
}
|
|
void launch_transpose_3d01_f32(const float* in, float* out, int a, int b, int c, void* s) {
|
|
int n = a * b * c, blk = 256, grid = (n + blk - 1) / blk;
|
|
transpose_3d01_k<<<grid, blk, 0, (cudaStream_t)s>>>(in, out, a, b, c);
|
|
}
|
|
|
|
// =====================================================================
|
|
// 4D axis-(1,2) transpose: in:[a,b,c,d] -> out:[a,c,b,d]. out[i,k,j,l]=in[i,j,k,l].
|
|
// Lays out batched multi-head attention: [B,S,nh,hd] <-> [B,nh,S,hd], so a
|
|
// flattened [B*nh, S, hd] view feeds the strided-batched-GEMM attention. Its own
|
|
// backward is the same op (swap b,c), so one kernel suffices.
|
|
// =====================================================================
|
|
|
|
__global__ void transpose_4d12_k(const float* in, float* out, int a, int b, int c, int d) {
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x; // over a*b*c*d
|
|
if (idx >= a * b * c * d) return;
|
|
int l = idx % d;
|
|
int k = (idx / d) % c;
|
|
int j = (idx / (d * c)) % b;
|
|
int i = idx / (d * c * b);
|
|
// out[i,k,j,l] at ((i*c + k)*b + j)*d + l
|
|
out[(((i * c + k) * b) + j) * d + l] = in[idx];
|
|
}
|
|
void launch_transpose_4d12_f32(const float* in, float* out, int a, int b, int c, int d, void* s) {
|
|
int n = a * b * c * d, blk = 256, grid = (n + blk - 1) / blk;
|
|
transpose_4d12_k<<<grid, blk, 0, (cudaStream_t)s>>>(in, out, a, b, c, d);
|
|
}
|
|
|
|
} // extern "C"
|