Add the batched-forward primitives. Linears/norms/elementwise/embedding/CE already act on flat [rows,dim], so they work unchanged on [B*S,dim]; only attention + RoPE need sequence awareness: - RoPE: kernel takes a `period` (= seq len) so position = row % period, i.e. per-sequence position on a flattened batch (period == tokens = single seq). - Fused batched causal attention: new `Tensor::attention`/`attention_backward` + ops node, running QKᵀ and PV as cublasSgemmStridedBatched over the B*nh (sequence,head) blocks (new sgemm_strided_batched binding) and a causal softmax kernel (scale + per-row causal mask inline) — the whole attention is 3 launches regardless of B*nh, no per-head/per-seq loop, no host round-trip. - transpose_4d12 ([B,S,nh,hd] <-> [B,nh,S,hd]) to lay out the batched heads. grad-checks: new batched-rope, transpose_4d12, batched-attention dQ/dK/dV all pass finite-diff (attn dK 1.5e-2, dQ 7.5e-3, dV 2.9e-4; rest tighter) alongside the existing 12. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
94 lines
3.4 KiB
Plaintext
94 lines
3.4 KiB
Plaintext
// Batched scaled-dot-product attention helpers (Phase T10).
|
|
//
|
|
// The QKᵀ and PV matmuls run as cublasSgemmStridedBatched in Rust; the only
|
|
// kernel attention needs of its own is a CAUSAL row-wise softmax over the score
|
|
// rows. Scores are [B*nh, S, S] flattened to rows of length S; for a flat row r
|
|
// the query position within its sequence is `r % S`, so columns j > r%S are
|
|
// future positions and get probability 0 (no additive -1e9 mask tensor needed).
|
|
//
|
|
// The forward also folds in the 1/sqrt(head_dim) scale (applied to logits before
|
|
// the max/exp) so we don't need a separate scale pass. Backward is the ordinary
|
|
// softmax Jacobian (csrc/ops/nn.cu launch_softmax_dx_f32): masked entries have
|
|
// y=0, so their contribution vanishes — no causal-specific backward needed.
|
|
//
|
|
// All F32, row-major, contiguous. Reduction helpers mirror nn.cu (inlined so the
|
|
// file is self-contained, matching the csrc/ layout).
|
|
|
|
#include <math.h>
|
|
|
|
extern "C" {
|
|
|
|
__device__ __forceinline__ float att_warp_sum(float v) {
|
|
#pragma unroll
|
|
for (int off = 16; off > 0; off >>= 1)
|
|
v += __shfl_down_sync(0xffffffff, v, off);
|
|
return v;
|
|
}
|
|
__device__ __forceinline__ float att_warp_max(float v) {
|
|
#pragma unroll
|
|
for (int off = 16; off > 0; off >>= 1)
|
|
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
|
|
return v;
|
|
}
|
|
__device__ __forceinline__ float att_block_sum(float v) {
|
|
__shared__ float sh[32];
|
|
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
|
|
int nwarps = (blockDim.x + 31) >> 5;
|
|
v = att_warp_sum(v);
|
|
if (lane == 0) sh[warp] = v;
|
|
__syncthreads();
|
|
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : 0.0f;
|
|
if (warp == 0) v = att_warp_sum(v);
|
|
__shared__ float bc;
|
|
if (threadIdx.x == 0) bc = v;
|
|
__syncthreads();
|
|
return bc;
|
|
}
|
|
__device__ __forceinline__ float att_block_max(float v) {
|
|
__shared__ float sh[32];
|
|
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
|
|
int nwarps = (blockDim.x + 31) >> 5;
|
|
v = att_warp_max(v);
|
|
if (lane == 0) sh[warp] = v;
|
|
__syncthreads();
|
|
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : -INFINITY;
|
|
if (warp == 0) v = att_warp_max(v);
|
|
__shared__ float bc;
|
|
if (threadIdx.x == 0) bc = v;
|
|
__syncthreads();
|
|
return bc;
|
|
}
|
|
|
|
// One block per score row. rows = B*nh*S total; the query position within its
|
|
// sequence is (blockIdx.x % seq). Logits are scaled by `scale` (= 1/sqrt(hd))
|
|
// before softmax; columns j > qpos are masked to probability 0.
|
|
__global__ void softmax_causal_k(const float* x, float* y, int seq, float scale) {
|
|
int r = blockIdx.x;
|
|
int qpos = r % seq;
|
|
const float* xr = x + (size_t)r * seq;
|
|
float* yr = y + (size_t)r * seq;
|
|
int valid = qpos + 1; // attend to columns [0, qpos]
|
|
float m = -INFINITY;
|
|
for (int c = threadIdx.x; c < valid; c += blockDim.x)
|
|
m = fmaxf(m, xr[c] * scale);
|
|
m = att_block_max(m);
|
|
float sum = 0.0f;
|
|
for (int c = threadIdx.x; c < valid; c += blockDim.x) {
|
|
float e = expf(xr[c] * scale - m);
|
|
yr[c] = e;
|
|
sum += e;
|
|
}
|
|
sum = att_block_sum(sum);
|
|
float inv = 1.0f / sum;
|
|
for (int c = threadIdx.x; c < seq; c += blockDim.x)
|
|
yr[c] = (c < valid) ? yr[c] * inv : 0.0f;
|
|
}
|
|
void launch_softmax_causal_f32(const float* x, float* y, int rows, int seq,
|
|
float scale, void* s) {
|
|
int blk = seq < 1024 ? seq : 1024;
|
|
if (blk < 32) blk = 32;
|
|
softmax_causal_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, y, seq, scale);
|
|
}
|
|
|
|
} // extern "C"
|