autograd: batch dim for ops (flatten linears, batched attention)

Add the batched-forward primitives. Linears/norms/elementwise/embedding/CE
already act on flat [rows,dim], so they work unchanged on [B*S,dim]; only
attention + RoPE need sequence awareness:

- RoPE: kernel takes a `period` (= seq len) so position = row % period, i.e.
  per-sequence position on a flattened batch (period == tokens = single seq).
- Fused batched causal attention: new `Tensor::attention`/`attention_backward`
  + ops node, running QKᵀ and PV as cublasSgemmStridedBatched over the B*nh
  (sequence,head) blocks (new sgemm_strided_batched binding) and a causal
  softmax kernel (scale + per-row causal mask inline) — the whole attention is
  3 launches regardless of B*nh, no per-head/per-seq loop, no host round-trip.
- transpose_4d12 ([B,S,nh,hd] <-> [B,nh,S,hd]) to lay out the batched heads.

grad-checks: new batched-rope, transpose_4d12, batched-attention dQ/dK/dV all
pass finite-diff (attn dK 1.5e-2, dQ 7.5e-3, dV 2.9e-4; rest tighter) alongside
the existing 12.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-16 00:44:15 +08:00
parent d2a585c5cb
commit 7821bd9c34
9 changed files with 629 additions and 21 deletions

93
csrc/ops/attention.cu Normal file
View File

@@ -0,0 +1,93 @@
// Batched scaled-dot-product attention helpers (Phase T10).
//
// The QKᵀ and PV matmuls run as cublasSgemmStridedBatched in Rust; the only
// kernel attention needs of its own is a CAUSAL row-wise softmax over the score
// rows. Scores are [B*nh, S, S] flattened to rows of length S; for a flat row r
// the query position within its sequence is `r % S`, so columns j > r%S are
// future positions and get probability 0 (no additive -1e9 mask tensor needed).
//
// The forward also folds in the 1/sqrt(head_dim) scale (applied to logits before
// the max/exp) so we don't need a separate scale pass. Backward is the ordinary
// softmax Jacobian (csrc/ops/nn.cu launch_softmax_dx_f32): masked entries have
// y=0, so their contribution vanishes — no causal-specific backward needed.
//
// All F32, row-major, contiguous. Reduction helpers mirror nn.cu (inlined so the
// file is self-contained, matching the csrc/ layout).
#include <math.h>
extern "C" {
__device__ __forceinline__ float att_warp_sum(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v += __shfl_down_sync(0xffffffff, v, off);
return v;
}
__device__ __forceinline__ float att_warp_max(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
return v;
}
__device__ __forceinline__ float att_block_sum(float v) {
__shared__ float sh[32];
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = att_warp_sum(v);
if (lane == 0) sh[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : 0.0f;
if (warp == 0) v = att_warp_sum(v);
__shared__ float bc;
if (threadIdx.x == 0) bc = v;
__syncthreads();
return bc;
}
__device__ __forceinline__ float att_block_max(float v) {
__shared__ float sh[32];
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = att_warp_max(v);
if (lane == 0) sh[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : -INFINITY;
if (warp == 0) v = att_warp_max(v);
__shared__ float bc;
if (threadIdx.x == 0) bc = v;
__syncthreads();
return bc;
}
// One block per score row. rows = B*nh*S total; the query position within its
// sequence is (blockIdx.x % seq). Logits are scaled by `scale` (= 1/sqrt(hd))
// before softmax; columns j > qpos are masked to probability 0.
__global__ void softmax_causal_k(const float* x, float* y, int seq, float scale) {
int r = blockIdx.x;
int qpos = r % seq;
const float* xr = x + (size_t)r * seq;
float* yr = y + (size_t)r * seq;
int valid = qpos + 1; // attend to columns [0, qpos]
float m = -INFINITY;
for (int c = threadIdx.x; c < valid; c += blockDim.x)
m = fmaxf(m, xr[c] * scale);
m = att_block_max(m);
float sum = 0.0f;
for (int c = threadIdx.x; c < valid; c += blockDim.x) {
float e = expf(xr[c] * scale - m);
yr[c] = e;
sum += e;
}
sum = att_block_sum(sum);
float inv = 1.0f / sum;
for (int c = threadIdx.x; c < seq; c += blockDim.x)
yr[c] = (c < valid) ? yr[c] * inv : 0.0f;
}
void launch_softmax_causal_f32(const float* x, float* y, int rows, int seq,
float scale, void* s) {
int blk = seq < 1024 ? seq : 1024;
if (blk < 32) blk = 32;
softmax_causal_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, y, seq, scale);
}
} // extern "C"

View File

@@ -63,4 +63,26 @@ void launch_transpose_3d01_f32(const float* in, float* out, int a, int b, int c,
transpose_3d01_k<<<grid, blk, 0, (cudaStream_t)s>>>(in, out, a, b, c);
}
// =====================================================================
// 4D axis-(1,2) transpose: in:[a,b,c,d] -> out:[a,c,b,d]. out[i,k,j,l]=in[i,j,k,l].
// Lays out batched multi-head attention: [B,S,nh,hd] <-> [B,nh,S,hd], so a
// flattened [B*nh, S, hd] view feeds the strided-batched-GEMM attention. Its own
// backward is the same op (swap b,c), so one kernel suffices.
// =====================================================================
__global__ void transpose_4d12_k(const float* in, float* out, int a, int b, int c, int d) {
int idx = blockIdx.x * blockDim.x + threadIdx.x; // over a*b*c*d
if (idx >= a * b * c * d) return;
int l = idx % d;
int k = (idx / d) % c;
int j = (idx / (d * c)) % b;
int i = idx / (d * c * b);
// out[i,k,j,l] at ((i*c + k)*b + j)*d + l
out[(((i * c + k) * b) + j) * d + l] = in[idx];
}
void launch_transpose_4d12_f32(const float* in, float* out, int a, int b, int c, int d, void* s) {
int n = a * b * c * d, blk = 256, grid = (n + blk - 1) / blk;
transpose_4d12_k<<<grid, blk, 0, (cudaStream_t)s>>>(in, out, a, b, c, d);
}
} // extern "C"

View File

@@ -215,14 +215,20 @@ void launch_silu_dx_f32(const float* x, const float* dy, float* dx, int n, void*
// dx[i+h] = dy[i+h]*cos - dy[i]*sin
// =====================================================================
__global__ void rope_k(const float* x, float* y, int heads, int head_dim, float theta) {
// `period` is the sequence length: a flattened batch lays B sequences end to end
// along the `tokens` axis, so each token's RoPE position is its index WITHIN its
// own sequence, `tok % period`. With period == tokens (single sequence) this is
// the original position = row.
__global__ void rope_k(const float* x, float* y, int heads, int head_dim,
float theta, int period) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = tok % period;
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)tok * freq;
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float x0 = x[base + i], x1 = x[base + i + half];
@@ -230,20 +236,22 @@ __global__ void rope_k(const float* x, float* y, int heads, int head_dim, float
y[base + i + half] = x1 * c + x0 * sn;
}
void launch_rope_f32(const float* x, float* y, int tokens, int heads,
int head_dim, float theta, void* s) {
int head_dim, float theta, int period, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta);
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, period);
}
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim, float theta) {
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
float theta, int period) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = tok % period;
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)tok * freq;
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float d0 = dy[base + i], d1 = dy[base + i + half];
@@ -251,10 +259,10 @@ __global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim, f
dx[base + i + half] = d1 * c - d0 * sn;
}
void launch_rope_dx_f32(const float* dy, float* dx, int tokens, int heads,
int head_dim, float theta, void* s) {
int head_dim, float theta, int period, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(dy, dx, heads, head_dim, theta);
rope_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(dy, dx, heads, head_dim, theta, period);
}
// =====================================================================