autograd: batch dim for ops (flatten linears, batched attention)
Add the batched-forward primitives. Linears/norms/elementwise/embedding/CE already act on flat [rows,dim], so they work unchanged on [B*S,dim]; only attention + RoPE need sequence awareness: - RoPE: kernel takes a `period` (= seq len) so position = row % period, i.e. per-sequence position on a flattened batch (period == tokens = single seq). - Fused batched causal attention: new `Tensor::attention`/`attention_backward` + ops node, running QKᵀ and PV as cublasSgemmStridedBatched over the B*nh (sequence,head) blocks (new sgemm_strided_batched binding) and a causal softmax kernel (scale + per-row causal mask inline) — the whole attention is 3 launches regardless of B*nh, no per-head/per-seq loop, no host round-trip. - transpose_4d12 ([B,S,nh,hd] <-> [B,nh,S,hd]) to lay out the batched heads. grad-checks: new batched-rope, transpose_4d12, batched-attention dQ/dK/dV all pass finite-diff (attn dK 1.5e-2, dQ 7.5e-3, dV 2.9e-4; rest tighter) alongside the existing 12. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
93
csrc/ops/attention.cu
Normal file
93
csrc/ops/attention.cu
Normal file
@@ -0,0 +1,93 @@
|
||||
// Batched scaled-dot-product attention helpers (Phase T10).
|
||||
//
|
||||
// The QKᵀ and PV matmuls run as cublasSgemmStridedBatched in Rust; the only
|
||||
// kernel attention needs of its own is a CAUSAL row-wise softmax over the score
|
||||
// rows. Scores are [B*nh, S, S] flattened to rows of length S; for a flat row r
|
||||
// the query position within its sequence is `r % S`, so columns j > r%S are
|
||||
// future positions and get probability 0 (no additive -1e9 mask tensor needed).
|
||||
//
|
||||
// The forward also folds in the 1/sqrt(head_dim) scale (applied to logits before
|
||||
// the max/exp) so we don't need a separate scale pass. Backward is the ordinary
|
||||
// softmax Jacobian (csrc/ops/nn.cu launch_softmax_dx_f32): masked entries have
|
||||
// y=0, so their contribution vanishes — no causal-specific backward needed.
|
||||
//
|
||||
// All F32, row-major, contiguous. Reduction helpers mirror nn.cu (inlined so the
|
||||
// file is self-contained, matching the csrc/ layout).
|
||||
|
||||
#include <math.h>
|
||||
|
||||
extern "C" {
|
||||
|
||||
__device__ __forceinline__ float att_warp_sum(float v) {
|
||||
#pragma unroll
|
||||
for (int off = 16; off > 0; off >>= 1)
|
||||
v += __shfl_down_sync(0xffffffff, v, off);
|
||||
return v;
|
||||
}
|
||||
__device__ __forceinline__ float att_warp_max(float v) {
|
||||
#pragma unroll
|
||||
for (int off = 16; off > 0; off >>= 1)
|
||||
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
|
||||
return v;
|
||||
}
|
||||
__device__ __forceinline__ float att_block_sum(float v) {
|
||||
__shared__ float sh[32];
|
||||
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
|
||||
int nwarps = (blockDim.x + 31) >> 5;
|
||||
v = att_warp_sum(v);
|
||||
if (lane == 0) sh[warp] = v;
|
||||
__syncthreads();
|
||||
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : 0.0f;
|
||||
if (warp == 0) v = att_warp_sum(v);
|
||||
__shared__ float bc;
|
||||
if (threadIdx.x == 0) bc = v;
|
||||
__syncthreads();
|
||||
return bc;
|
||||
}
|
||||
__device__ __forceinline__ float att_block_max(float v) {
|
||||
__shared__ float sh[32];
|
||||
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
|
||||
int nwarps = (blockDim.x + 31) >> 5;
|
||||
v = att_warp_max(v);
|
||||
if (lane == 0) sh[warp] = v;
|
||||
__syncthreads();
|
||||
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : -INFINITY;
|
||||
if (warp == 0) v = att_warp_max(v);
|
||||
__shared__ float bc;
|
||||
if (threadIdx.x == 0) bc = v;
|
||||
__syncthreads();
|
||||
return bc;
|
||||
}
|
||||
|
||||
// One block per score row. rows = B*nh*S total; the query position within its
|
||||
// sequence is (blockIdx.x % seq). Logits are scaled by `scale` (= 1/sqrt(hd))
|
||||
// before softmax; columns j > qpos are masked to probability 0.
|
||||
__global__ void softmax_causal_k(const float* x, float* y, int seq, float scale) {
|
||||
int r = blockIdx.x;
|
||||
int qpos = r % seq;
|
||||
const float* xr = x + (size_t)r * seq;
|
||||
float* yr = y + (size_t)r * seq;
|
||||
int valid = qpos + 1; // attend to columns [0, qpos]
|
||||
float m = -INFINITY;
|
||||
for (int c = threadIdx.x; c < valid; c += blockDim.x)
|
||||
m = fmaxf(m, xr[c] * scale);
|
||||
m = att_block_max(m);
|
||||
float sum = 0.0f;
|
||||
for (int c = threadIdx.x; c < valid; c += blockDim.x) {
|
||||
float e = expf(xr[c] * scale - m);
|
||||
yr[c] = e;
|
||||
sum += e;
|
||||
}
|
||||
sum = att_block_sum(sum);
|
||||
float inv = 1.0f / sum;
|
||||
for (int c = threadIdx.x; c < seq; c += blockDim.x)
|
||||
yr[c] = (c < valid) ? yr[c] * inv : 0.0f;
|
||||
}
|
||||
void launch_softmax_causal_f32(const float* x, float* y, int rows, int seq,
|
||||
float scale, void* s) {
|
||||
int blk = seq < 1024 ? seq : 1024;
|
||||
if (blk < 32) blk = 32;
|
||||
softmax_causal_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, y, seq, scale);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@@ -63,4 +63,26 @@ void launch_transpose_3d01_f32(const float* in, float* out, int a, int b, int c,
|
||||
transpose_3d01_k<<<grid, blk, 0, (cudaStream_t)s>>>(in, out, a, b, c);
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
// 4D axis-(1,2) transpose: in:[a,b,c,d] -> out:[a,c,b,d]. out[i,k,j,l]=in[i,j,k,l].
|
||||
// Lays out batched multi-head attention: [B,S,nh,hd] <-> [B,nh,S,hd], so a
|
||||
// flattened [B*nh, S, hd] view feeds the strided-batched-GEMM attention. Its own
|
||||
// backward is the same op (swap b,c), so one kernel suffices.
|
||||
// =====================================================================
|
||||
|
||||
__global__ void transpose_4d12_k(const float* in, float* out, int a, int b, int c, int d) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x; // over a*b*c*d
|
||||
if (idx >= a * b * c * d) return;
|
||||
int l = idx % d;
|
||||
int k = (idx / d) % c;
|
||||
int j = (idx / (d * c)) % b;
|
||||
int i = idx / (d * c * b);
|
||||
// out[i,k,j,l] at ((i*c + k)*b + j)*d + l
|
||||
out[(((i * c + k) * b) + j) * d + l] = in[idx];
|
||||
}
|
||||
void launch_transpose_4d12_f32(const float* in, float* out, int a, int b, int c, int d, void* s) {
|
||||
int n = a * b * c * d, blk = 256, grid = (n + blk - 1) / blk;
|
||||
transpose_4d12_k<<<grid, blk, 0, (cudaStream_t)s>>>(in, out, a, b, c, d);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
@@ -215,14 +215,20 @@ void launch_silu_dx_f32(const float* x, const float* dy, float* dx, int n, void*
|
||||
// dx[i+h] = dy[i+h]*cos - dy[i]*sin
|
||||
// =====================================================================
|
||||
|
||||
__global__ void rope_k(const float* x, float* y, int heads, int head_dim, float theta) {
|
||||
// `period` is the sequence length: a flattened batch lays B sequences end to end
|
||||
// along the `tokens` axis, so each token's RoPE position is its index WITHIN its
|
||||
// own sequence, `tok % period`. With period == tokens (single sequence) this is
|
||||
// the original position = row.
|
||||
__global__ void rope_k(const float* x, float* y, int heads, int head_dim,
|
||||
float theta, int period) {
|
||||
int tok = blockIdx.x;
|
||||
int head = blockIdx.y;
|
||||
int half = head_dim / 2;
|
||||
int i = threadIdx.x;
|
||||
if (i >= half) return;
|
||||
int pos = tok % period;
|
||||
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
|
||||
float angle = (float)tok * freq;
|
||||
float angle = (float)pos * freq;
|
||||
float c = cosf(angle), sn = sinf(angle);
|
||||
int base = (tok * heads + head) * head_dim;
|
||||
float x0 = x[base + i], x1 = x[base + i + half];
|
||||
@@ -230,20 +236,22 @@ __global__ void rope_k(const float* x, float* y, int heads, int head_dim, float
|
||||
y[base + i + half] = x1 * c + x0 * sn;
|
||||
}
|
||||
void launch_rope_f32(const float* x, float* y, int tokens, int heads,
|
||||
int head_dim, float theta, void* s) {
|
||||
int head_dim, float theta, int period, void* s) {
|
||||
dim3 grid(tokens, heads);
|
||||
int blk = head_dim / 2;
|
||||
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta);
|
||||
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, period);
|
||||
}
|
||||
|
||||
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim, float theta) {
|
||||
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
|
||||
float theta, int period) {
|
||||
int tok = blockIdx.x;
|
||||
int head = blockIdx.y;
|
||||
int half = head_dim / 2;
|
||||
int i = threadIdx.x;
|
||||
if (i >= half) return;
|
||||
int pos = tok % period;
|
||||
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
|
||||
float angle = (float)tok * freq;
|
||||
float angle = (float)pos * freq;
|
||||
float c = cosf(angle), sn = sinf(angle);
|
||||
int base = (tok * heads + head) * head_dim;
|
||||
float d0 = dy[base + i], d1 = dy[base + i + half];
|
||||
@@ -251,10 +259,10 @@ __global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim, f
|
||||
dx[base + i + half] = d1 * c - d0 * sn;
|
||||
}
|
||||
void launch_rope_dx_f32(const float* dy, float* dx, int tokens, int heads,
|
||||
int head_dim, float theta, void* s) {
|
||||
int head_dim, float theta, int period, void* s) {
|
||||
dim3 grid(tokens, heads);
|
||||
int blk = head_dim / 2;
|
||||
rope_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(dy, dx, heads, head_dim, theta);
|
||||
rope_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(dy, dx, heads, head_dim, theta, period);
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
|
||||
Reference in New Issue
Block a user