Files
xtrain/csrc/ops/attention.cu
Gahow Wang 7821bd9c34 autograd: batch dim for ops (flatten linears, batched attention)
Add the batched-forward primitives. Linears/norms/elementwise/embedding/CE
already act on flat [rows,dim], so they work unchanged on [B*S,dim]; only
attention + RoPE need sequence awareness:

- RoPE: kernel takes a `period` (= seq len) so position = row % period, i.e.
  per-sequence position on a flattened batch (period == tokens = single seq).
- Fused batched causal attention: new `Tensor::attention`/`attention_backward`
  + ops node, running QKᵀ and PV as cublasSgemmStridedBatched over the B*nh
  (sequence,head) blocks (new sgemm_strided_batched binding) and a causal
  softmax kernel (scale + per-row causal mask inline) — the whole attention is
  3 launches regardless of B*nh, no per-head/per-seq loop, no host round-trip.
- transpose_4d12 ([B,S,nh,hd] <-> [B,nh,S,hd]) to lay out the batched heads.

grad-checks: new batched-rope, transpose_4d12, batched-attention dQ/dK/dV all
pass finite-diff (attn dK 1.5e-2, dQ 7.5e-3, dV 2.9e-4; rest tighter) alongside
the existing 12.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 00:44:15 +08:00

94 lines
3.4 KiB
Plaintext

// Batched scaled-dot-product attention helpers (Phase T10).
//
// The QKᵀ and PV matmuls run as cublasSgemmStridedBatched in Rust; the only
// kernel attention needs of its own is a CAUSAL row-wise softmax over the score
// rows. Scores are [B*nh, S, S] flattened to rows of length S; for a flat row r
// the query position within its sequence is `r % S`, so columns j > r%S are
// future positions and get probability 0 (no additive -1e9 mask tensor needed).
//
// The forward also folds in the 1/sqrt(head_dim) scale (applied to logits before
// the max/exp) so we don't need a separate scale pass. Backward is the ordinary
// softmax Jacobian (csrc/ops/nn.cu launch_softmax_dx_f32): masked entries have
// y=0, so their contribution vanishes — no causal-specific backward needed.
//
// All F32, row-major, contiguous. Reduction helpers mirror nn.cu (inlined so the
// file is self-contained, matching the csrc/ layout).
#include <math.h>
extern "C" {
__device__ __forceinline__ float att_warp_sum(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v += __shfl_down_sync(0xffffffff, v, off);
return v;
}
__device__ __forceinline__ float att_warp_max(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
return v;
}
__device__ __forceinline__ float att_block_sum(float v) {
__shared__ float sh[32];
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = att_warp_sum(v);
if (lane == 0) sh[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : 0.0f;
if (warp == 0) v = att_warp_sum(v);
__shared__ float bc;
if (threadIdx.x == 0) bc = v;
__syncthreads();
return bc;
}
__device__ __forceinline__ float att_block_max(float v) {
__shared__ float sh[32];
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = att_warp_max(v);
if (lane == 0) sh[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : -INFINITY;
if (warp == 0) v = att_warp_max(v);
__shared__ float bc;
if (threadIdx.x == 0) bc = v;
__syncthreads();
return bc;
}
// One block per score row. rows = B*nh*S total; the query position within its
// sequence is (blockIdx.x % seq). Logits are scaled by `scale` (= 1/sqrt(hd))
// before softmax; columns j > qpos are masked to probability 0.
__global__ void softmax_causal_k(const float* x, float* y, int seq, float scale) {
int r = blockIdx.x;
int qpos = r % seq;
const float* xr = x + (size_t)r * seq;
float* yr = y + (size_t)r * seq;
int valid = qpos + 1; // attend to columns [0, qpos]
float m = -INFINITY;
for (int c = threadIdx.x; c < valid; c += blockDim.x)
m = fmaxf(m, xr[c] * scale);
m = att_block_max(m);
float sum = 0.0f;
for (int c = threadIdx.x; c < valid; c += blockDim.x) {
float e = expf(xr[c] * scale - m);
yr[c] = e;
sum += e;
}
sum = att_block_sum(sum);
float inv = 1.0f / sum;
for (int c = threadIdx.x; c < seq; c += blockDim.x)
yr[c] = (c < valid) ? yr[c] * inv : 0.0f;
}
void launch_softmax_causal_f32(const float* x, float* y, int rows, int seq,
float scale, void* s) {
int blk = seq < 1024 ? seq : 1024;
if (blk < 32) blk = 32;
softmax_causal_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, y, seq, scale);
}
} // extern "C"