autograd: batch dim for ops (flatten linears, batched attention)

Add the batched-forward primitives. Linears/norms/elementwise/embedding/CE
already act on flat [rows,dim], so they work unchanged on [B*S,dim]; only
attention + RoPE need sequence awareness:

- RoPE: kernel takes a `period` (= seq len) so position = row % period, i.e.
  per-sequence position on a flattened batch (period == tokens = single seq).
- Fused batched causal attention: new `Tensor::attention`/`attention_backward`
  + ops node, running QKᵀ and PV as cublasSgemmStridedBatched over the B*nh
  (sequence,head) blocks (new sgemm_strided_batched binding) and a causal
  softmax kernel (scale + per-row causal mask inline) — the whole attention is
  3 launches regardless of B*nh, no per-head/per-seq loop, no host round-trip.
- transpose_4d12 ([B,S,nh,hd] <-> [B,nh,S,hd]) to lay out the batched heads.

grad-checks: new batched-rope, transpose_4d12, batched-attention dQ/dK/dV all
pass finite-diff (attn dK 1.5e-2, dQ 7.5e-3, dV 2.9e-4; rest tighter) alongside
the existing 12.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-16 00:44:15 +08:00
parent d2a585c5cb
commit 7821bd9c34
9 changed files with 629 additions and 21 deletions

View File

@@ -120,15 +120,17 @@ pub fn swiglu(gate: &Var, up: &Var) -> Var {
mul(&silu(gate), up)
}
/// RoPE (rotate_half) over `x:[tokens,heads,head_dim]`. Orthogonal map, so the
/// backward is the inverse rotation of `dy` — no cached forward values needed.
pub fn rope(x: &Var, theta: f32) -> Var {
let out = x.value().rope(theta);
/// RoPE (rotate_half) over `x:[tokens,heads,head_dim]` with per-sequence position
/// `row % period` (`period` = sequence length; `period == tokens` for a single
/// sequence). Orthogonal map, so the backward is the inverse rotation of `dy` — no
/// cached forward values needed.
pub fn rope(x: &Var, theta: f32, period: usize) -> Var {
let out = x.value().rope(theta, period);
Var::from_op(
out,
vec![x.clone()],
Box::new(move |dy, parents| {
Var::push_grad(&parents[0], Tensor::rope_backward(dy, theta));
Var::push_grad(&parents[0], Tensor::rope_backward(dy, theta, period));
}),
)
}
@@ -190,6 +192,20 @@ pub fn transpose_3d01(x: &Var) -> Var {
)
}
/// 4D axis-(1,2) transpose `[a,b,c,d] -> [a,c,b,d]`. Self-inverse structure: the
/// backward is the same transpose applied to the grad. Lays out the batched
/// multi-head attention `[B,S,nh,hd] <-> [B,nh,S,hd]`.
pub fn transpose_4d12(x: &Var) -> Var {
let out = x.value().transpose_4d12();
Var::from_op(
out,
vec![x.clone()],
Box::new(|d, parents| {
Var::push_grad(&parents[0], d.transpose_4d12());
}),
)
}
/// 2D transpose `[r,c] -> [c,r]` as an autograd node (backward transposes the
/// grad back). Used for `Kᵀ` in attention scores.
pub fn transpose_2d(x: &Var) -> Var {
@@ -266,6 +282,29 @@ pub fn merge_heads(heads_v: &[Var]) -> Var {
)
}
/// Batched causal scaled-dot-product attention. `q`,`k`,`v` are each
/// `[bh, seq, head_dim]` (bh = batch·n_heads). Returns `[bh, seq, head_dim]`.
/// One fused op (2 batched GEMMs + 1 causal-softmax kernel forward; 4 batched
/// GEMMs + 1 softmax-backward kernel in backward) — replaces the per-(batch,head)
/// matmul/softmax loop, so attention is a handful of launches regardless of bh.
/// Caches the softmax `probs` for backward.
pub fn attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
let (out, probs) = q.value().attention(&k.value(), &v.value(), scale);
Var::from_op(
out,
vec![q.clone(), k.clone(), v.clone()],
Box::new(move |dout, parents| {
let q = parents[0].value();
let k = parents[1].value();
let v = parents[2].value();
let (dq, dk, dv) = Tensor::attention_backward(&q, &k, &v, &probs, dout, scale);
Var::push_grad(&parents[0], dq);
Var::push_grad(&parents[1], dk);
Var::push_grad(&parents[2], dv);
}),
)
}
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
/// scaled by the upstream scalar grad.

View File

@@ -327,12 +327,12 @@ fn rope_bwd() {
let w = fill(n, 82);
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
let out = ops::rope(&x, theta);
let out = ops::rope(&x, theta, tokens);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).rope(theta), &wf);
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).rope(theta, tokens), &wf);
report(
"rope dX",
&grad_check(
@@ -345,6 +345,38 @@ fn rope_bwd() {
);
}
// ---- rope batched (per-sequence position = row % period) ----
// tokens = B*S laid end to end; period = S. Sequences 2 and 3 re-use positions
// 0..S, so the kernel's `tok % period` must reset RoPE per sequence.
#[test]
fn rope_batched_bwd() {
require_gpu();
let (b, s, heads, head_dim) = (3, 4, 2, 8);
let tokens = b * s;
let n = tokens * heads * head_dim;
let theta = 10000.0;
let x_h = fill(n, 83);
let w = fill(n, 84);
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
let out = ops::rope(&x, theta, s);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], sh: &[usize]| weighted_sum(&cuda(v, sh).rope(theta, s), &wf);
report(
"rope batched dX",
&grad_check(
&x_h,
&[tokens, heads, head_dim],
&lx,
dx.as_slice::<f32>(),
cfg_linear(),
),
);
}
// ---- softmax ----
#[test]
fn softmax_bwd() {
@@ -501,6 +533,98 @@ fn attention_composed_bwd() {
);
}
// ---- transpose_4d12 ([a,b,c,d] -> [a,c,b,d]) ----
#[test]
fn transpose_4d12_bwd() {
require_gpu();
let (a, b, c, d) = (2, 3, 4, 5);
let n = a * b * c * d;
let x_h = fill(n, 131);
let w = fill(n, 132);
let x = Var::leaf(cuda(&x_h, &[a, b, c, d]));
let out = ops::transpose_4d12(&x);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).transpose_4d12(), &wf);
report(
"transpose_4d12 dX",
&grad_check(&x_h, &[a, b, c, d], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
}
// ---- fused batched causal attention (the T10 op) ----
// q,k,v: [bh, seq, hd]. Grad-check dq/dk/dv against finite-diff of L = sum(W∘out).
// bh = 2 (e.g. batch 1 × 2 heads, or 2 sequences × 1 head) exercises the batched
// GEMM stride; the causal mask is applied inside the op.
#[test]
fn attention_batched_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 5, 6);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q_h = fill(n, 141);
let k_h = fill(n, 142);
let v_h = fill(n, 143);
let w = fill(n, 144);
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
let out = ops::attention(&q, &k, &v, scale);
scalar_loss(&out, &w).backward();
let dq = q.grad().unwrap().to_device(Device::Cpu);
let dk = k.grad().unwrap().to_device(Device::Cpu);
let dv = v.grad().unwrap().to_device(Device::Cpu);
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
let qv = cuda(qh, &[bh, seq, hd]);
let kv = cuda(kh, &[bh, seq, hd]);
let vv = cuda(vh, &[bh, seq, hd]);
let (o, _) = qv.attention(&kv, &vv, scale);
weighted_sum(&o, &w)
};
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
report(
"attn(batched) dQ",
&grad_check(
&q_h,
&[bh, seq, hd],
&lq,
dq.as_slice::<f32>(),
cfg_nonlinear(),
),
);
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
report(
"attn(batched) dK",
&grad_check(
&k_h,
&[bh, seq, hd],
&lk,
dk.as_slice::<f32>(),
cfg_nonlinear(),
),
);
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
report(
"attn(batched) dV",
&grad_check(
&v_h,
&[bh, seq, hd],
&lv,
dv.as_slice::<f32>(),
cfg_linear(),
),
);
}
// --- test helpers ---
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We

View File

@@ -35,6 +35,7 @@ fn main() {
.file("../../csrc/ops/nn.cu")
.file("../../csrc/ops/model.cu")
.file("../../csrc/ops/optim.cu")
.file("../../csrc/ops/attention.cu")
.compile("xtrain_cuda_kernels");
}

View File

@@ -93,3 +93,69 @@ pub fn sgemm(
assert_eq!(status, 0, "cublasSgemm failed: {status}");
});
}
/// Strided-batched row-major SGEMM: for each `i` in `0..batch`,
/// `C_i[m,n] = alpha·opA(A_i)·opB(B_i) + beta·C_i`, where `A_i`/`B_i`/`C_i` are
/// consecutive matrices laid `stride_*` elements apart in one contiguous buffer.
/// Same row-major⟺col-major trick as [`sgemm`] (compute col-major `Cᵀ`), applied
/// per batch element. Used for the batched attention `QKᵀ` / `PV` GEMMs (and their
/// backwards), so the whole attention runs as 2 batched-GEMM launches, not a
/// per-(batch,head) Python loop. `A`/`B`/`C` are device pointers to the first
/// matrix; strides are in ELEMENTS.
#[allow(clippy::too_many_arguments)]
pub fn sgemm_strided_batched(
trans_a: bool,
trans_b: bool,
m: usize,
n: usize,
k: usize,
alpha: f32,
a: *const f32,
stride_a: usize,
b: *const f32,
stride_b: usize,
beta: f32,
c: *mut f32,
stride_c: usize,
batch: usize,
) {
let lda = if trans_a { m } else { k };
let ldb = if trans_b { k } else { n };
let ldc = n;
let op_a = if trans_a {
ffi::CUBLAS_OP_T
} else {
ffi::CUBLAS_OP_N
};
let op_b = if trans_b {
ffi::CUBLAS_OP_T
} else {
ffi::CUBLAS_OP_N
};
with_handle(|handle| {
let status = unsafe {
ffi::cublasSgemmStridedBatched(
handle,
op_b,
op_a,
n as i32,
m as i32,
k as i32,
&alpha,
b,
ldb as i32,
stride_b as i64,
a,
lda as i32,
stride_a as i64,
&beta,
c,
ldc as i32,
stride_c as i64,
batch as i32,
)
};
assert_eq!(status, 0, "cublasSgemmStridedBatched failed: {status}");
});
}

View File

@@ -125,7 +125,9 @@ unsafe extern "C" {
pub fn launch_silu_f32(x: *const f32, y: *mut f32, n: i32, s: CudaStream);
pub fn launch_silu_dx_f32(x: *const f32, dy: *const f32, dx: *mut f32, n: i32, s: CudaStream);
// RoPE (rotate_half), x:[tokens,heads,head_dim], position = token index.
// RoPE (rotate_half), x:[tokens,heads,head_dim], position = (token index %
// period). `period` = sequence length, so a flattened batch of sequences gets
// per-sequence positions; period == tokens reproduces the single-sequence case.
pub fn launch_rope_f32(
x: *const f32,
y: *mut f32,
@@ -133,6 +135,7 @@ unsafe extern "C" {
heads: i32,
head_dim: i32,
theta: f32,
period: i32,
s: CudaStream,
);
pub fn launch_rope_dx_f32(
@@ -142,6 +145,7 @@ unsafe extern "C" {
heads: i32,
head_dim: i32,
theta: f32,
period: i32,
s: CudaStream,
);
@@ -211,6 +215,31 @@ unsafe extern "C" {
c: i32,
s: CudaStream,
);
// 4D axis-(1,2) transpose: in:[a,b,c,d] -> out:[a,c,b,d]. out[i,k,j,l]=in[i,j,k,l].
pub fn launch_transpose_4d12_f32(
input: *const f32,
out: *mut f32,
a: i32,
b: i32,
c: i32,
d: i32,
s: CudaStream,
);
}
// Batched attention helper (csrc/ops/attention.cu): causal row-wise softmax over
// score rows [rows, seq] with query position = (row % seq); scales logits by
// `scale` (= 1/sqrt(head_dim)) and masks future columns to probability 0.
#[cfg(not(no_cuda))]
unsafe extern "C" {
pub fn launch_softmax_causal_f32(
x: *const f32,
y: *mut f32,
rows: i32,
seq: i32,
scale: f32,
s: CudaStream,
);
}
// GPU-side optimizer kernels (csrc/ops/optim.cu): AdamW step (m/v on device) and
@@ -267,6 +296,27 @@ unsafe extern "C" {
c: *mut f32,
ldc: i32,
) -> i32;
#[allow(clippy::too_many_arguments)]
pub fn cublasSgemmStridedBatched(
handle: CublasHandle,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const f32,
a: *const f32,
lda: i32,
stride_a: i64,
b: *const f32,
ldb: i32,
stride_b: i64,
beta: *const f32,
c: *mut f32,
ldc: i32,
stride_c: i64,
batch_count: i32,
) -> i32;
}
#[cfg(not(no_cuda))]

View File

@@ -454,13 +454,20 @@ impl Tensor {
dx
}
/// RoPE forward (rotate_half). `self`:[tokens,heads,head_dim]; the position
/// of each token is its row index. Returns the rotated tensor.
/// RoPE forward (rotate_half). `self`:[tokens,heads,head_dim]; each token's
/// position is `row % period`. `period` = sequence length, so a flattened
/// batch `[B*S,heads,head_dim]` gets per-sequence positions (pass `period=S`);
/// pass `period=tokens` for a single sequence (position = row). Returns the
/// rotated tensor.
#[cfg(not(no_cuda))]
pub fn rope(&self, theta: f32) -> Self {
pub fn rope(&self, theta: f32, period: usize) -> Self {
assert_eq!(self.ndim(), 3, "rope requires [tokens,heads,head_dim]");
let (tokens, heads, head_dim) = (self.shape[0], self.shape[1], self.shape[2]);
assert_eq!(head_dim % 2, 0, "head_dim must be even");
assert!(
period > 0 && tokens % period == 0,
"tokens must be a multiple of period"
);
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_rope_f32(
@@ -470,6 +477,7 @@ impl Tensor {
heads as i32,
head_dim as i32,
theta,
period as i32,
std::ptr::null_mut(),
);
}
@@ -477,9 +485,9 @@ impl Tensor {
}
/// RoPE backward: apply the inverse (transpose) rotation to `dy`. RoPE is an
/// orthogonal map, so it needs no cached forward values, only `theta`.
/// orthogonal map, so it needs no cached forward values, only `theta`/`period`.
#[cfg(not(no_cuda))]
pub fn rope_backward(dy: &Tensor, theta: f32) -> Self {
pub fn rope_backward(dy: &Tensor, theta: f32, period: usize) -> Self {
let (tokens, heads, head_dim) = (dy.shape[0], dy.shape[1], dy.shape[2]);
let dx = Tensor::zeros(&dy.shape, DType::F32, dy.device());
unsafe {
@@ -490,6 +498,7 @@ impl Tensor {
heads as i32,
head_dim as i32,
theta,
period as i32,
std::ptr::null_mut(),
);
}
@@ -667,6 +676,202 @@ impl Tensor {
out
}
// --- Batched attention (the T10 fused op) ---
/// Batched causal scaled-dot-product attention. `self`=Q, `k`, `v` are each
/// `[bh, seq, head_dim]` (bh = batch·n_heads), contiguous F32 on one GPU.
/// Computes, per batch element, `out = softmax(causal(Q·Kᵀ / √hd)) · V`. The
/// two GEMMs run as `cublasSgemmStridedBatched` and the softmax+scale+causal
/// mask is one kernel, so the whole attention is 3 launches regardless of bh.
/// Returns `(out, probs)` where `probs`:[bh,seq,seq] is cached for backward.
#[cfg(not(no_cuda))]
pub fn attention(&self, k: &Tensor, v: &Tensor, scale: f32) -> (Tensor, Tensor) {
assert_eq!(self.ndim(), 3, "attention Q must be [bh,seq,head_dim]");
assert_eq!(self.shape(), k.shape(), "Q/K shape mismatch");
assert_eq!(self.shape(), v.shape(), "Q/V shape mismatch");
let (bh, seq, hd) = (self.shape[0], self.shape[1], self.shape[2]);
let dev = self.device();
// scores[bh,seq,seq] = Q[bh,seq,hd] · Kᵀ[bh,hd,seq]
let scores = Tensor::zeros(&[bh, seq, seq], DType::F32, dev);
xtrain_cuda::cublas::sgemm_strided_batched(
false,
true,
seq,
seq,
hd,
1.0,
self.data_ptr() as *const f32,
seq * hd,
k.data_ptr() as *const f32,
seq * hd,
0.0,
scores.data_ptr() as *mut f32,
seq * seq,
bh,
);
// probs = softmax(causal(scores · scale)), one block per [bh·seq] row.
let probs = Tensor::zeros(&[bh, seq, seq], DType::F32, dev);
unsafe {
xtrain_cuda::ffi::launch_softmax_causal_f32(
scores.data_ptr() as *const f32,
probs.data_ptr() as *mut f32,
(bh * seq) as i32,
seq as i32,
scale,
std::ptr::null_mut(),
);
}
// out[bh,seq,hd] = probs[bh,seq,seq] · V[bh,seq,hd]
let out = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
xtrain_cuda::cublas::sgemm_strided_batched(
false,
false,
seq,
hd,
seq,
1.0,
probs.data_ptr() as *const f32,
seq * seq,
v.data_ptr() as *const f32,
seq * hd,
0.0,
out.data_ptr() as *mut f32,
seq * hd,
bh,
);
(out, probs)
}
/// Backward of [`attention`](Self::attention). Inputs: forward `q`,`k`,`v`,
/// the cached `probs`, the upstream `dout` (all batched `[bh,seq,*]`), and the
/// same `scale`. Returns `(dq, dk, dv)`.
///
/// dP = dOut · Vᵀ ; dV = Pᵀ · dOut
/// dScores = softmax_jacobian(P, dP) · scale (scale folded back in)
/// dQ = dScores · K ; dK = dScoresᵀ · Q
///
/// Masked (future) entries of P are 0, so the softmax Jacobian zeros their
/// gradient — the causal mask needs no special handling here.
#[cfg(not(no_cuda))]
pub fn attention_backward(
q: &Tensor,
k: &Tensor,
v: &Tensor,
probs: &Tensor,
dout: &Tensor,
scale: f32,
) -> (Tensor, Tensor, Tensor) {
let (bh, seq, hd) = (q.shape[0], q.shape[1], q.shape[2]);
let dev = q.device();
// dP[bh,seq,seq] = dOut[bh,seq,hd] · Vᵀ[bh,hd,seq]
let dp = Tensor::zeros(&[bh, seq, seq], DType::F32, dev);
xtrain_cuda::cublas::sgemm_strided_batched(
false,
true,
seq,
seq,
hd,
1.0,
dout.data_ptr() as *const f32,
seq * hd,
v.data_ptr() as *const f32,
seq * hd,
0.0,
dp.data_ptr() as *mut f32,
seq * seq,
bh,
);
// dV[bh,seq,hd] = Pᵀ[bh,seq,seq] · dOut[bh,seq,hd]
let dv = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
xtrain_cuda::cublas::sgemm_strided_batched(
true,
false,
seq,
hd,
seq,
1.0,
probs.data_ptr() as *const f32,
seq * seq,
dout.data_ptr() as *const f32,
seq * hd,
0.0,
dv.data_ptr() as *mut f32,
seq * hd,
bh,
);
// dScores = softmax Jacobian (per row) applied to dP, then ×scale.
// Reuse the row-wise softmax backward over the flattened [bh·seq, seq].
let dscores = Tensor::softmax_backward(
&probs.reshape(&[bh * seq, seq]),
&dp.reshape(&[bh * seq, seq]),
)
.reshape(&[bh, seq, seq]);
let dscores = dscores.scale(scale);
// dQ[bh,seq,hd] = dScores[bh,seq,seq] · K[bh,seq,hd]
let dq = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
xtrain_cuda::cublas::sgemm_strided_batched(
false,
false,
seq,
hd,
seq,
1.0,
dscores.data_ptr() as *const f32,
seq * seq,
k.data_ptr() as *const f32,
seq * hd,
0.0,
dq.data_ptr() as *mut f32,
seq * hd,
bh,
);
// dK[bh,seq,hd] = dScoresᵀ[bh,seq,seq] · Q[bh,seq,hd]
let dk = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
xtrain_cuda::cublas::sgemm_strided_batched(
true,
false,
seq,
hd,
seq,
1.0,
dscores.data_ptr() as *const f32,
seq * seq,
q.data_ptr() as *const f32,
seq * hd,
0.0,
dk.data_ptr() as *mut f32,
seq * hd,
bh,
);
(dq, dk, dv)
}
/// 4D axis-(1,2) transpose: `self`:[a,b,c,d] → [a,c,b,d],
/// `out[i,k,j,l]=self[i,j,k,l]`. Lays out batched multi-head attention
/// (`[B,S,nh,hd] <-> [B,nh,S,hd]`). Its own backward is the same op (swap b,c).
#[cfg(not(no_cuda))]
pub fn transpose_4d12(&self) -> Self {
assert_eq!(self.dtype, DType::F32, "transpose_4d12 only supports F32");
assert_eq!(self.ndim(), 4, "transpose_4d12 requires a 4D tensor");
assert!(self.is_contiguous(), "transpose_4d12 requires contiguous");
let (a, b, c, d) = (self.shape[0], self.shape[1], self.shape[2], self.shape[3]);
let out = Tensor::zeros(&[a, c, b, d], DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_transpose_4d12_f32(
self.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
a as i32,
b as i32,
c as i32,
d as i32,
std::ptr::null_mut(),
);
}
out
}
// Shared validation for same-shape binary elementwise ops.
#[cfg(not(no_cuda))]
fn check_binary(&self, other: &Tensor, op: &str) {

93
csrc/ops/attention.cu Normal file
View File

@@ -0,0 +1,93 @@
// Batched scaled-dot-product attention helpers (Phase T10).
//
// The QKᵀ and PV matmuls run as cublasSgemmStridedBatched in Rust; the only
// kernel attention needs of its own is a CAUSAL row-wise softmax over the score
// rows. Scores are [B*nh, S, S] flattened to rows of length S; for a flat row r
// the query position within its sequence is `r % S`, so columns j > r%S are
// future positions and get probability 0 (no additive -1e9 mask tensor needed).
//
// The forward also folds in the 1/sqrt(head_dim) scale (applied to logits before
// the max/exp) so we don't need a separate scale pass. Backward is the ordinary
// softmax Jacobian (csrc/ops/nn.cu launch_softmax_dx_f32): masked entries have
// y=0, so their contribution vanishes — no causal-specific backward needed.
//
// All F32, row-major, contiguous. Reduction helpers mirror nn.cu (inlined so the
// file is self-contained, matching the csrc/ layout).
#include <math.h>
extern "C" {
__device__ __forceinline__ float att_warp_sum(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v += __shfl_down_sync(0xffffffff, v, off);
return v;
}
__device__ __forceinline__ float att_warp_max(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
return v;
}
__device__ __forceinline__ float att_block_sum(float v) {
__shared__ float sh[32];
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = att_warp_sum(v);
if (lane == 0) sh[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : 0.0f;
if (warp == 0) v = att_warp_sum(v);
__shared__ float bc;
if (threadIdx.x == 0) bc = v;
__syncthreads();
return bc;
}
__device__ __forceinline__ float att_block_max(float v) {
__shared__ float sh[32];
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = att_warp_max(v);
if (lane == 0) sh[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : -INFINITY;
if (warp == 0) v = att_warp_max(v);
__shared__ float bc;
if (threadIdx.x == 0) bc = v;
__syncthreads();
return bc;
}
// One block per score row. rows = B*nh*S total; the query position within its
// sequence is (blockIdx.x % seq). Logits are scaled by `scale` (= 1/sqrt(hd))
// before softmax; columns j > qpos are masked to probability 0.
__global__ void softmax_causal_k(const float* x, float* y, int seq, float scale) {
int r = blockIdx.x;
int qpos = r % seq;
const float* xr = x + (size_t)r * seq;
float* yr = y + (size_t)r * seq;
int valid = qpos + 1; // attend to columns [0, qpos]
float m = -INFINITY;
for (int c = threadIdx.x; c < valid; c += blockDim.x)
m = fmaxf(m, xr[c] * scale);
m = att_block_max(m);
float sum = 0.0f;
for (int c = threadIdx.x; c < valid; c += blockDim.x) {
float e = expf(xr[c] * scale - m);
yr[c] = e;
sum += e;
}
sum = att_block_sum(sum);
float inv = 1.0f / sum;
for (int c = threadIdx.x; c < seq; c += blockDim.x)
yr[c] = (c < valid) ? yr[c] * inv : 0.0f;
}
void launch_softmax_causal_f32(const float* x, float* y, int rows, int seq,
float scale, void* s) {
int blk = seq < 1024 ? seq : 1024;
if (blk < 32) blk = 32;
softmax_causal_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, y, seq, scale);
}
} // extern "C"

View File

@@ -63,4 +63,26 @@ void launch_transpose_3d01_f32(const float* in, float* out, int a, int b, int c,
transpose_3d01_k<<<grid, blk, 0, (cudaStream_t)s>>>(in, out, a, b, c);
}
// =====================================================================
// 4D axis-(1,2) transpose: in:[a,b,c,d] -> out:[a,c,b,d]. out[i,k,j,l]=in[i,j,k,l].
// Lays out batched multi-head attention: [B,S,nh,hd] <-> [B,nh,S,hd], so a
// flattened [B*nh, S, hd] view feeds the strided-batched-GEMM attention. Its own
// backward is the same op (swap b,c), so one kernel suffices.
// =====================================================================
__global__ void transpose_4d12_k(const float* in, float* out, int a, int b, int c, int d) {
int idx = blockIdx.x * blockDim.x + threadIdx.x; // over a*b*c*d
if (idx >= a * b * c * d) return;
int l = idx % d;
int k = (idx / d) % c;
int j = (idx / (d * c)) % b;
int i = idx / (d * c * b);
// out[i,k,j,l] at ((i*c + k)*b + j)*d + l
out[(((i * c + k) * b) + j) * d + l] = in[idx];
}
void launch_transpose_4d12_f32(const float* in, float* out, int a, int b, int c, int d, void* s) {
int n = a * b * c * d, blk = 256, grid = (n + blk - 1) / blk;
transpose_4d12_k<<<grid, blk, 0, (cudaStream_t)s>>>(in, out, a, b, c, d);
}
} // extern "C"

View File

@@ -215,14 +215,20 @@ void launch_silu_dx_f32(const float* x, const float* dy, float* dx, int n, void*
// dx[i+h] = dy[i+h]*cos - dy[i]*sin
// =====================================================================
__global__ void rope_k(const float* x, float* y, int heads, int head_dim, float theta) {
// `period` is the sequence length: a flattened batch lays B sequences end to end
// along the `tokens` axis, so each token's RoPE position is its index WITHIN its
// own sequence, `tok % period`. With period == tokens (single sequence) this is
// the original position = row.
__global__ void rope_k(const float* x, float* y, int heads, int head_dim,
float theta, int period) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = tok % period;
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)tok * freq;
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float x0 = x[base + i], x1 = x[base + i + half];
@@ -230,20 +236,22 @@ __global__ void rope_k(const float* x, float* y, int heads, int head_dim, float
y[base + i + half] = x1 * c + x0 * sn;
}
void launch_rope_f32(const float* x, float* y, int tokens, int heads,
int head_dim, float theta, void* s) {
int head_dim, float theta, int period, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta);
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, period);
}
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim, float theta) {
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
float theta, int period) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = tok % period;
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)tok * freq;
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float d0 = dy[base + i], d1 = dy[base + i + half];
@@ -251,10 +259,10 @@ __global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim, f
dx[base + i + half] = d1 * c - d0 * sn;
}
void launch_rope_dx_f32(const float* dy, float* dx, int tokens, int heads,
int head_dim, float theta, void* s) {
int head_dim, float theta, int period, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(dy, dx, heads, head_dim, theta);
rope_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(dy, dx, heads, head_dim, theta, period);
}
// =====================================================================