autograd: batch dim for ops (flatten linears, batched attention)
Add the batched-forward primitives. Linears/norms/elementwise/embedding/CE already act on flat [rows,dim], so they work unchanged on [B*S,dim]; only attention + RoPE need sequence awareness: - RoPE: kernel takes a `period` (= seq len) so position = row % period, i.e. per-sequence position on a flattened batch (period == tokens = single seq). - Fused batched causal attention: new `Tensor::attention`/`attention_backward` + ops node, running QKᵀ and PV as cublasSgemmStridedBatched over the B*nh (sequence,head) blocks (new sgemm_strided_batched binding) and a causal softmax kernel (scale + per-row causal mask inline) — the whole attention is 3 launches regardless of B*nh, no per-head/per-seq loop, no host round-trip. - transpose_4d12 ([B,S,nh,hd] <-> [B,nh,S,hd]) to lay out the batched heads. grad-checks: new batched-rope, transpose_4d12, batched-attention dQ/dK/dV all pass finite-diff (attn dK 1.5e-2, dQ 7.5e-3, dV 2.9e-4; rest tighter) alongside the existing 12. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -120,15 +120,17 @@ pub fn swiglu(gate: &Var, up: &Var) -> Var {
|
||||
mul(&silu(gate), up)
|
||||
}
|
||||
|
||||
/// RoPE (rotate_half) over `x:[tokens,heads,head_dim]`. Orthogonal map, so the
|
||||
/// backward is the inverse rotation of `dy` — no cached forward values needed.
|
||||
pub fn rope(x: &Var, theta: f32) -> Var {
|
||||
let out = x.value().rope(theta);
|
||||
/// RoPE (rotate_half) over `x:[tokens,heads,head_dim]` with per-sequence position
|
||||
/// `row % period` (`period` = sequence length; `period == tokens` for a single
|
||||
/// sequence). Orthogonal map, so the backward is the inverse rotation of `dy` — no
|
||||
/// cached forward values needed.
|
||||
pub fn rope(x: &Var, theta: f32, period: usize) -> Var {
|
||||
let out = x.value().rope(theta, period);
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![x.clone()],
|
||||
Box::new(move |dy, parents| {
|
||||
Var::push_grad(&parents[0], Tensor::rope_backward(dy, theta));
|
||||
Var::push_grad(&parents[0], Tensor::rope_backward(dy, theta, period));
|
||||
}),
|
||||
)
|
||||
}
|
||||
@@ -190,6 +192,20 @@ pub fn transpose_3d01(x: &Var) -> Var {
|
||||
)
|
||||
}
|
||||
|
||||
/// 4D axis-(1,2) transpose `[a,b,c,d] -> [a,c,b,d]`. Self-inverse structure: the
|
||||
/// backward is the same transpose applied to the grad. Lays out the batched
|
||||
/// multi-head attention `[B,S,nh,hd] <-> [B,nh,S,hd]`.
|
||||
pub fn transpose_4d12(x: &Var) -> Var {
|
||||
let out = x.value().transpose_4d12();
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![x.clone()],
|
||||
Box::new(|d, parents| {
|
||||
Var::push_grad(&parents[0], d.transpose_4d12());
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// 2D transpose `[r,c] -> [c,r]` as an autograd node (backward transposes the
|
||||
/// grad back). Used for `Kᵀ` in attention scores.
|
||||
pub fn transpose_2d(x: &Var) -> Var {
|
||||
@@ -266,6 +282,29 @@ pub fn merge_heads(heads_v: &[Var]) -> Var {
|
||||
)
|
||||
}
|
||||
|
||||
/// Batched causal scaled-dot-product attention. `q`,`k`,`v` are each
|
||||
/// `[bh, seq, head_dim]` (bh = batch·n_heads). Returns `[bh, seq, head_dim]`.
|
||||
/// One fused op (2 batched GEMMs + 1 causal-softmax kernel forward; 4 batched
|
||||
/// GEMMs + 1 softmax-backward kernel in backward) — replaces the per-(batch,head)
|
||||
/// matmul/softmax loop, so attention is a handful of launches regardless of bh.
|
||||
/// Caches the softmax `probs` for backward.
|
||||
pub fn attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
|
||||
let (out, probs) = q.value().attention(&k.value(), &v.value(), scale);
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![q.clone(), k.clone(), v.clone()],
|
||||
Box::new(move |dout, parents| {
|
||||
let q = parents[0].value();
|
||||
let k = parents[1].value();
|
||||
let v = parents[2].value();
|
||||
let (dq, dk, dv) = Tensor::attention_backward(&q, &k, &v, &probs, dout, scale);
|
||||
Var::push_grad(&parents[0], dq);
|
||||
Var::push_grad(&parents[1], dk);
|
||||
Var::push_grad(&parents[2], dv);
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
|
||||
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
|
||||
/// scaled by the upstream scalar grad.
|
||||
|
||||
@@ -327,12 +327,12 @@ fn rope_bwd() {
|
||||
let w = fill(n, 82);
|
||||
|
||||
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
|
||||
let out = ops::rope(&x, theta);
|
||||
let out = ops::rope(&x, theta, tokens);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||
let wf = w.clone();
|
||||
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).rope(theta), &wf);
|
||||
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).rope(theta, tokens), &wf);
|
||||
report(
|
||||
"rope dX",
|
||||
&grad_check(
|
||||
@@ -345,6 +345,38 @@ fn rope_bwd() {
|
||||
);
|
||||
}
|
||||
|
||||
// ---- rope batched (per-sequence position = row % period) ----
|
||||
// tokens = B*S laid end to end; period = S. Sequences 2 and 3 re-use positions
|
||||
// 0..S, so the kernel's `tok % period` must reset RoPE per sequence.
|
||||
#[test]
|
||||
fn rope_batched_bwd() {
|
||||
require_gpu();
|
||||
let (b, s, heads, head_dim) = (3, 4, 2, 8);
|
||||
let tokens = b * s;
|
||||
let n = tokens * heads * head_dim;
|
||||
let theta = 10000.0;
|
||||
let x_h = fill(n, 83);
|
||||
let w = fill(n, 84);
|
||||
|
||||
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
|
||||
let out = ops::rope(&x, theta, s);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||
let wf = w.clone();
|
||||
let lx = move |v: &[f32], sh: &[usize]| weighted_sum(&cuda(v, sh).rope(theta, s), &wf);
|
||||
report(
|
||||
"rope batched dX",
|
||||
&grad_check(
|
||||
&x_h,
|
||||
&[tokens, heads, head_dim],
|
||||
&lx,
|
||||
dx.as_slice::<f32>(),
|
||||
cfg_linear(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// ---- softmax ----
|
||||
#[test]
|
||||
fn softmax_bwd() {
|
||||
@@ -501,6 +533,98 @@ fn attention_composed_bwd() {
|
||||
);
|
||||
}
|
||||
|
||||
// ---- transpose_4d12 ([a,b,c,d] -> [a,c,b,d]) ----
|
||||
#[test]
|
||||
fn transpose_4d12_bwd() {
|
||||
require_gpu();
|
||||
let (a, b, c, d) = (2, 3, 4, 5);
|
||||
let n = a * b * c * d;
|
||||
let x_h = fill(n, 131);
|
||||
let w = fill(n, 132);
|
||||
|
||||
let x = Var::leaf(cuda(&x_h, &[a, b, c, d]));
|
||||
let out = ops::transpose_4d12(&x);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||
let wf = w.clone();
|
||||
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).transpose_4d12(), &wf);
|
||||
report(
|
||||
"transpose_4d12 dX",
|
||||
&grad_check(&x_h, &[a, b, c, d], &lx, dx.as_slice::<f32>(), cfg_linear()),
|
||||
);
|
||||
}
|
||||
|
||||
// ---- fused batched causal attention (the T10 op) ----
|
||||
// q,k,v: [bh, seq, hd]. Grad-check dq/dk/dv against finite-diff of L = sum(W∘out).
|
||||
// bh = 2 (e.g. batch 1 × 2 heads, or 2 sequences × 1 head) exercises the batched
|
||||
// GEMM stride; the causal mask is applied inside the op.
|
||||
#[test]
|
||||
fn attention_batched_bwd() {
|
||||
require_gpu();
|
||||
let (bh, seq, hd) = (2, 5, 6);
|
||||
let n = bh * seq * hd;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
let q_h = fill(n, 141);
|
||||
let k_h = fill(n, 142);
|
||||
let v_h = fill(n, 143);
|
||||
let w = fill(n, 144);
|
||||
|
||||
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
|
||||
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
|
||||
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
|
||||
let out = ops::attention(&q, &k, &v, scale);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let dq = q.grad().unwrap().to_device(Device::Cpu);
|
||||
let dk = k.grad().unwrap().to_device(Device::Cpu);
|
||||
let dv = v.grad().unwrap().to_device(Device::Cpu);
|
||||
|
||||
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
|
||||
let qv = cuda(qh, &[bh, seq, hd]);
|
||||
let kv = cuda(kh, &[bh, seq, hd]);
|
||||
let vv = cuda(vh, &[bh, seq, hd]);
|
||||
let (o, _) = qv.attention(&kv, &vv, scale);
|
||||
weighted_sum(&o, &w)
|
||||
};
|
||||
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
|
||||
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
|
||||
report(
|
||||
"attn(batched) dQ",
|
||||
&grad_check(
|
||||
&q_h,
|
||||
&[bh, seq, hd],
|
||||
&lq,
|
||||
dq.as_slice::<f32>(),
|
||||
cfg_nonlinear(),
|
||||
),
|
||||
);
|
||||
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
|
||||
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
|
||||
report(
|
||||
"attn(batched) dK",
|
||||
&grad_check(
|
||||
&k_h,
|
||||
&[bh, seq, hd],
|
||||
&lk,
|
||||
dk.as_slice::<f32>(),
|
||||
cfg_nonlinear(),
|
||||
),
|
||||
);
|
||||
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
|
||||
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
|
||||
report(
|
||||
"attn(batched) dV",
|
||||
&grad_check(
|
||||
&v_h,
|
||||
&[bh, seq, hd],
|
||||
&lv,
|
||||
dv.as_slice::<f32>(),
|
||||
cfg_linear(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// --- test helpers ---
|
||||
|
||||
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We
|
||||
|
||||
@@ -35,6 +35,7 @@ fn main() {
|
||||
.file("../../csrc/ops/nn.cu")
|
||||
.file("../../csrc/ops/model.cu")
|
||||
.file("../../csrc/ops/optim.cu")
|
||||
.file("../../csrc/ops/attention.cu")
|
||||
.compile("xtrain_cuda_kernels");
|
||||
}
|
||||
|
||||
|
||||
@@ -93,3 +93,69 @@ pub fn sgemm(
|
||||
assert_eq!(status, 0, "cublasSgemm failed: {status}");
|
||||
});
|
||||
}
|
||||
|
||||
/// Strided-batched row-major SGEMM: for each `i` in `0..batch`,
|
||||
/// `C_i[m,n] = alpha·opA(A_i)·opB(B_i) + beta·C_i`, where `A_i`/`B_i`/`C_i` are
|
||||
/// consecutive matrices laid `stride_*` elements apart in one contiguous buffer.
|
||||
/// Same row-major⟺col-major trick as [`sgemm`] (compute col-major `Cᵀ`), applied
|
||||
/// per batch element. Used for the batched attention `QKᵀ` / `PV` GEMMs (and their
|
||||
/// backwards), so the whole attention runs as 2 batched-GEMM launches, not a
|
||||
/// per-(batch,head) Python loop. `A`/`B`/`C` are device pointers to the first
|
||||
/// matrix; strides are in ELEMENTS.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn sgemm_strided_batched(
|
||||
trans_a: bool,
|
||||
trans_b: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: *const f32,
|
||||
stride_a: usize,
|
||||
b: *const f32,
|
||||
stride_b: usize,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
stride_c: usize,
|
||||
batch: usize,
|
||||
) {
|
||||
let lda = if trans_a { m } else { k };
|
||||
let ldb = if trans_b { k } else { n };
|
||||
let ldc = n;
|
||||
let op_a = if trans_a {
|
||||
ffi::CUBLAS_OP_T
|
||||
} else {
|
||||
ffi::CUBLAS_OP_N
|
||||
};
|
||||
let op_b = if trans_b {
|
||||
ffi::CUBLAS_OP_T
|
||||
} else {
|
||||
ffi::CUBLAS_OP_N
|
||||
};
|
||||
|
||||
with_handle(|handle| {
|
||||
let status = unsafe {
|
||||
ffi::cublasSgemmStridedBatched(
|
||||
handle,
|
||||
op_b,
|
||||
op_a,
|
||||
n as i32,
|
||||
m as i32,
|
||||
k as i32,
|
||||
&alpha,
|
||||
b,
|
||||
ldb as i32,
|
||||
stride_b as i64,
|
||||
a,
|
||||
lda as i32,
|
||||
stride_a as i64,
|
||||
&beta,
|
||||
c,
|
||||
ldc as i32,
|
||||
stride_c as i64,
|
||||
batch as i32,
|
||||
)
|
||||
};
|
||||
assert_eq!(status, 0, "cublasSgemmStridedBatched failed: {status}");
|
||||
});
|
||||
}
|
||||
|
||||
@@ -125,7 +125,9 @@ unsafe extern "C" {
|
||||
pub fn launch_silu_f32(x: *const f32, y: *mut f32, n: i32, s: CudaStream);
|
||||
pub fn launch_silu_dx_f32(x: *const f32, dy: *const f32, dx: *mut f32, n: i32, s: CudaStream);
|
||||
|
||||
// RoPE (rotate_half), x:[tokens,heads,head_dim], position = token index.
|
||||
// RoPE (rotate_half), x:[tokens,heads,head_dim], position = (token index %
|
||||
// period). `period` = sequence length, so a flattened batch of sequences gets
|
||||
// per-sequence positions; period == tokens reproduces the single-sequence case.
|
||||
pub fn launch_rope_f32(
|
||||
x: *const f32,
|
||||
y: *mut f32,
|
||||
@@ -133,6 +135,7 @@ unsafe extern "C" {
|
||||
heads: i32,
|
||||
head_dim: i32,
|
||||
theta: f32,
|
||||
period: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_rope_dx_f32(
|
||||
@@ -142,6 +145,7 @@ unsafe extern "C" {
|
||||
heads: i32,
|
||||
head_dim: i32,
|
||||
theta: f32,
|
||||
period: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
|
||||
@@ -211,6 +215,31 @@ unsafe extern "C" {
|
||||
c: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// 4D axis-(1,2) transpose: in:[a,b,c,d] -> out:[a,c,b,d]. out[i,k,j,l]=in[i,j,k,l].
|
||||
pub fn launch_transpose_4d12_f32(
|
||||
input: *const f32,
|
||||
out: *mut f32,
|
||||
a: i32,
|
||||
b: i32,
|
||||
c: i32,
|
||||
d: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
}
|
||||
|
||||
// Batched attention helper (csrc/ops/attention.cu): causal row-wise softmax over
|
||||
// score rows [rows, seq] with query position = (row % seq); scales logits by
|
||||
// `scale` (= 1/sqrt(head_dim)) and masks future columns to probability 0.
|
||||
#[cfg(not(no_cuda))]
|
||||
unsafe extern "C" {
|
||||
pub fn launch_softmax_causal_f32(
|
||||
x: *const f32,
|
||||
y: *mut f32,
|
||||
rows: i32,
|
||||
seq: i32,
|
||||
scale: f32,
|
||||
s: CudaStream,
|
||||
);
|
||||
}
|
||||
|
||||
// GPU-side optimizer kernels (csrc/ops/optim.cu): AdamW step (m/v on device) and
|
||||
@@ -267,6 +296,27 @@ unsafe extern "C" {
|
||||
c: *mut f32,
|
||||
ldc: i32,
|
||||
) -> i32;
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn cublasSgemmStridedBatched(
|
||||
handle: CublasHandle,
|
||||
transa: i32,
|
||||
transb: i32,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: *const f32,
|
||||
a: *const f32,
|
||||
lda: i32,
|
||||
stride_a: i64,
|
||||
b: *const f32,
|
||||
ldb: i32,
|
||||
stride_b: i64,
|
||||
beta: *const f32,
|
||||
c: *mut f32,
|
||||
ldc: i32,
|
||||
stride_c: i64,
|
||||
batch_count: i32,
|
||||
) -> i32;
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
|
||||
@@ -454,13 +454,20 @@ impl Tensor {
|
||||
dx
|
||||
}
|
||||
|
||||
/// RoPE forward (rotate_half). `self`:[tokens,heads,head_dim]; the position
|
||||
/// of each token is its row index. Returns the rotated tensor.
|
||||
/// RoPE forward (rotate_half). `self`:[tokens,heads,head_dim]; each token's
|
||||
/// position is `row % period`. `period` = sequence length, so a flattened
|
||||
/// batch `[B*S,heads,head_dim]` gets per-sequence positions (pass `period=S`);
|
||||
/// pass `period=tokens` for a single sequence (position = row). Returns the
|
||||
/// rotated tensor.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn rope(&self, theta: f32) -> Self {
|
||||
pub fn rope(&self, theta: f32, period: usize) -> Self {
|
||||
assert_eq!(self.ndim(), 3, "rope requires [tokens,heads,head_dim]");
|
||||
let (tokens, heads, head_dim) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||
assert_eq!(head_dim % 2, 0, "head_dim must be even");
|
||||
assert!(
|
||||
period > 0 && tokens % period == 0,
|
||||
"tokens must be a multiple of period"
|
||||
);
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_rope_f32(
|
||||
@@ -470,6 +477,7 @@ impl Tensor {
|
||||
heads as i32,
|
||||
head_dim as i32,
|
||||
theta,
|
||||
period as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
@@ -477,9 +485,9 @@ impl Tensor {
|
||||
}
|
||||
|
||||
/// RoPE backward: apply the inverse (transpose) rotation to `dy`. RoPE is an
|
||||
/// orthogonal map, so it needs no cached forward values, only `theta`.
|
||||
/// orthogonal map, so it needs no cached forward values, only `theta`/`period`.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn rope_backward(dy: &Tensor, theta: f32) -> Self {
|
||||
pub fn rope_backward(dy: &Tensor, theta: f32, period: usize) -> Self {
|
||||
let (tokens, heads, head_dim) = (dy.shape[0], dy.shape[1], dy.shape[2]);
|
||||
let dx = Tensor::zeros(&dy.shape, DType::F32, dy.device());
|
||||
unsafe {
|
||||
@@ -490,6 +498,7 @@ impl Tensor {
|
||||
heads as i32,
|
||||
head_dim as i32,
|
||||
theta,
|
||||
period as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
@@ -667,6 +676,202 @@ impl Tensor {
|
||||
out
|
||||
}
|
||||
|
||||
// --- Batched attention (the T10 fused op) ---
|
||||
|
||||
/// Batched causal scaled-dot-product attention. `self`=Q, `k`, `v` are each
|
||||
/// `[bh, seq, head_dim]` (bh = batch·n_heads), contiguous F32 on one GPU.
|
||||
/// Computes, per batch element, `out = softmax(causal(Q·Kᵀ / √hd)) · V`. The
|
||||
/// two GEMMs run as `cublasSgemmStridedBatched` and the softmax+scale+causal
|
||||
/// mask is one kernel, so the whole attention is 3 launches regardless of bh.
|
||||
/// Returns `(out, probs)` where `probs`:[bh,seq,seq] is cached for backward.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn attention(&self, k: &Tensor, v: &Tensor, scale: f32) -> (Tensor, Tensor) {
|
||||
assert_eq!(self.ndim(), 3, "attention Q must be [bh,seq,head_dim]");
|
||||
assert_eq!(self.shape(), k.shape(), "Q/K shape mismatch");
|
||||
assert_eq!(self.shape(), v.shape(), "Q/V shape mismatch");
|
||||
let (bh, seq, hd) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||
let dev = self.device();
|
||||
|
||||
// scores[bh,seq,seq] = Q[bh,seq,hd] · Kᵀ[bh,hd,seq]
|
||||
let scores = Tensor::zeros(&[bh, seq, seq], DType::F32, dev);
|
||||
xtrain_cuda::cublas::sgemm_strided_batched(
|
||||
false,
|
||||
true,
|
||||
seq,
|
||||
seq,
|
||||
hd,
|
||||
1.0,
|
||||
self.data_ptr() as *const f32,
|
||||
seq * hd,
|
||||
k.data_ptr() as *const f32,
|
||||
seq * hd,
|
||||
0.0,
|
||||
scores.data_ptr() as *mut f32,
|
||||
seq * seq,
|
||||
bh,
|
||||
);
|
||||
// probs = softmax(causal(scores · scale)), one block per [bh·seq] row.
|
||||
let probs = Tensor::zeros(&[bh, seq, seq], DType::F32, dev);
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_softmax_causal_f32(
|
||||
scores.data_ptr() as *const f32,
|
||||
probs.data_ptr() as *mut f32,
|
||||
(bh * seq) as i32,
|
||||
seq as i32,
|
||||
scale,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
// out[bh,seq,hd] = probs[bh,seq,seq] · V[bh,seq,hd]
|
||||
let out = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
xtrain_cuda::cublas::sgemm_strided_batched(
|
||||
false,
|
||||
false,
|
||||
seq,
|
||||
hd,
|
||||
seq,
|
||||
1.0,
|
||||
probs.data_ptr() as *const f32,
|
||||
seq * seq,
|
||||
v.data_ptr() as *const f32,
|
||||
seq * hd,
|
||||
0.0,
|
||||
out.data_ptr() as *mut f32,
|
||||
seq * hd,
|
||||
bh,
|
||||
);
|
||||
(out, probs)
|
||||
}
|
||||
|
||||
/// Backward of [`attention`](Self::attention). Inputs: forward `q`,`k`,`v`,
|
||||
/// the cached `probs`, the upstream `dout` (all batched `[bh,seq,*]`), and the
|
||||
/// same `scale`. Returns `(dq, dk, dv)`.
|
||||
///
|
||||
/// dP = dOut · Vᵀ ; dV = Pᵀ · dOut
|
||||
/// dScores = softmax_jacobian(P, dP) · scale (scale folded back in)
|
||||
/// dQ = dScores · K ; dK = dScoresᵀ · Q
|
||||
///
|
||||
/// Masked (future) entries of P are 0, so the softmax Jacobian zeros their
|
||||
/// gradient — the causal mask needs no special handling here.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn attention_backward(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
probs: &Tensor,
|
||||
dout: &Tensor,
|
||||
scale: f32,
|
||||
) -> (Tensor, Tensor, Tensor) {
|
||||
let (bh, seq, hd) = (q.shape[0], q.shape[1], q.shape[2]);
|
||||
let dev = q.device();
|
||||
|
||||
// dP[bh,seq,seq] = dOut[bh,seq,hd] · Vᵀ[bh,hd,seq]
|
||||
let dp = Tensor::zeros(&[bh, seq, seq], DType::F32, dev);
|
||||
xtrain_cuda::cublas::sgemm_strided_batched(
|
||||
false,
|
||||
true,
|
||||
seq,
|
||||
seq,
|
||||
hd,
|
||||
1.0,
|
||||
dout.data_ptr() as *const f32,
|
||||
seq * hd,
|
||||
v.data_ptr() as *const f32,
|
||||
seq * hd,
|
||||
0.0,
|
||||
dp.data_ptr() as *mut f32,
|
||||
seq * seq,
|
||||
bh,
|
||||
);
|
||||
// dV[bh,seq,hd] = Pᵀ[bh,seq,seq] · dOut[bh,seq,hd]
|
||||
let dv = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
xtrain_cuda::cublas::sgemm_strided_batched(
|
||||
true,
|
||||
false,
|
||||
seq,
|
||||
hd,
|
||||
seq,
|
||||
1.0,
|
||||
probs.data_ptr() as *const f32,
|
||||
seq * seq,
|
||||
dout.data_ptr() as *const f32,
|
||||
seq * hd,
|
||||
0.0,
|
||||
dv.data_ptr() as *mut f32,
|
||||
seq * hd,
|
||||
bh,
|
||||
);
|
||||
// dScores = softmax Jacobian (per row) applied to dP, then ×scale.
|
||||
// Reuse the row-wise softmax backward over the flattened [bh·seq, seq].
|
||||
let dscores = Tensor::softmax_backward(
|
||||
&probs.reshape(&[bh * seq, seq]),
|
||||
&dp.reshape(&[bh * seq, seq]),
|
||||
)
|
||||
.reshape(&[bh, seq, seq]);
|
||||
let dscores = dscores.scale(scale);
|
||||
// dQ[bh,seq,hd] = dScores[bh,seq,seq] · K[bh,seq,hd]
|
||||
let dq = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
xtrain_cuda::cublas::sgemm_strided_batched(
|
||||
false,
|
||||
false,
|
||||
seq,
|
||||
hd,
|
||||
seq,
|
||||
1.0,
|
||||
dscores.data_ptr() as *const f32,
|
||||
seq * seq,
|
||||
k.data_ptr() as *const f32,
|
||||
seq * hd,
|
||||
0.0,
|
||||
dq.data_ptr() as *mut f32,
|
||||
seq * hd,
|
||||
bh,
|
||||
);
|
||||
// dK[bh,seq,hd] = dScoresᵀ[bh,seq,seq] · Q[bh,seq,hd]
|
||||
let dk = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
xtrain_cuda::cublas::sgemm_strided_batched(
|
||||
true,
|
||||
false,
|
||||
seq,
|
||||
hd,
|
||||
seq,
|
||||
1.0,
|
||||
dscores.data_ptr() as *const f32,
|
||||
seq * seq,
|
||||
q.data_ptr() as *const f32,
|
||||
seq * hd,
|
||||
0.0,
|
||||
dk.data_ptr() as *mut f32,
|
||||
seq * hd,
|
||||
bh,
|
||||
);
|
||||
(dq, dk, dv)
|
||||
}
|
||||
|
||||
/// 4D axis-(1,2) transpose: `self`:[a,b,c,d] → [a,c,b,d],
|
||||
/// `out[i,k,j,l]=self[i,j,k,l]`. Lays out batched multi-head attention
|
||||
/// (`[B,S,nh,hd] <-> [B,nh,S,hd]`). Its own backward is the same op (swap b,c).
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn transpose_4d12(&self) -> Self {
|
||||
assert_eq!(self.dtype, DType::F32, "transpose_4d12 only supports F32");
|
||||
assert_eq!(self.ndim(), 4, "transpose_4d12 requires a 4D tensor");
|
||||
assert!(self.is_contiguous(), "transpose_4d12 requires contiguous");
|
||||
let (a, b, c, d) = (self.shape[0], self.shape[1], self.shape[2], self.shape[3]);
|
||||
let out = Tensor::zeros(&[a, c, b, d], DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_transpose_4d12_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
a as i32,
|
||||
b as i32,
|
||||
c as i32,
|
||||
d as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
// Shared validation for same-shape binary elementwise ops.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn check_binary(&self, other: &Tensor, op: &str) {
|
||||
|
||||
93
csrc/ops/attention.cu
Normal file
93
csrc/ops/attention.cu
Normal file
@@ -0,0 +1,93 @@
|
||||
// Batched scaled-dot-product attention helpers (Phase T10).
|
||||
//
|
||||
// The QKᵀ and PV matmuls run as cublasSgemmStridedBatched in Rust; the only
|
||||
// kernel attention needs of its own is a CAUSAL row-wise softmax over the score
|
||||
// rows. Scores are [B*nh, S, S] flattened to rows of length S; for a flat row r
|
||||
// the query position within its sequence is `r % S`, so columns j > r%S are
|
||||
// future positions and get probability 0 (no additive -1e9 mask tensor needed).
|
||||
//
|
||||
// The forward also folds in the 1/sqrt(head_dim) scale (applied to logits before
|
||||
// the max/exp) so we don't need a separate scale pass. Backward is the ordinary
|
||||
// softmax Jacobian (csrc/ops/nn.cu launch_softmax_dx_f32): masked entries have
|
||||
// y=0, so their contribution vanishes — no causal-specific backward needed.
|
||||
//
|
||||
// All F32, row-major, contiguous. Reduction helpers mirror nn.cu (inlined so the
|
||||
// file is self-contained, matching the csrc/ layout).
|
||||
|
||||
#include <math.h>
|
||||
|
||||
extern "C" {
|
||||
|
||||
__device__ __forceinline__ float att_warp_sum(float v) {
|
||||
#pragma unroll
|
||||
for (int off = 16; off > 0; off >>= 1)
|
||||
v += __shfl_down_sync(0xffffffff, v, off);
|
||||
return v;
|
||||
}
|
||||
__device__ __forceinline__ float att_warp_max(float v) {
|
||||
#pragma unroll
|
||||
for (int off = 16; off > 0; off >>= 1)
|
||||
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
|
||||
return v;
|
||||
}
|
||||
__device__ __forceinline__ float att_block_sum(float v) {
|
||||
__shared__ float sh[32];
|
||||
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
|
||||
int nwarps = (blockDim.x + 31) >> 5;
|
||||
v = att_warp_sum(v);
|
||||
if (lane == 0) sh[warp] = v;
|
||||
__syncthreads();
|
||||
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : 0.0f;
|
||||
if (warp == 0) v = att_warp_sum(v);
|
||||
__shared__ float bc;
|
||||
if (threadIdx.x == 0) bc = v;
|
||||
__syncthreads();
|
||||
return bc;
|
||||
}
|
||||
__device__ __forceinline__ float att_block_max(float v) {
|
||||
__shared__ float sh[32];
|
||||
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
|
||||
int nwarps = (blockDim.x + 31) >> 5;
|
||||
v = att_warp_max(v);
|
||||
if (lane == 0) sh[warp] = v;
|
||||
__syncthreads();
|
||||
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : -INFINITY;
|
||||
if (warp == 0) v = att_warp_max(v);
|
||||
__shared__ float bc;
|
||||
if (threadIdx.x == 0) bc = v;
|
||||
__syncthreads();
|
||||
return bc;
|
||||
}
|
||||
|
||||
// One block per score row. rows = B*nh*S total; the query position within its
|
||||
// sequence is (blockIdx.x % seq). Logits are scaled by `scale` (= 1/sqrt(hd))
|
||||
// before softmax; columns j > qpos are masked to probability 0.
|
||||
__global__ void softmax_causal_k(const float* x, float* y, int seq, float scale) {
|
||||
int r = blockIdx.x;
|
||||
int qpos = r % seq;
|
||||
const float* xr = x + (size_t)r * seq;
|
||||
float* yr = y + (size_t)r * seq;
|
||||
int valid = qpos + 1; // attend to columns [0, qpos]
|
||||
float m = -INFINITY;
|
||||
for (int c = threadIdx.x; c < valid; c += blockDim.x)
|
||||
m = fmaxf(m, xr[c] * scale);
|
||||
m = att_block_max(m);
|
||||
float sum = 0.0f;
|
||||
for (int c = threadIdx.x; c < valid; c += blockDim.x) {
|
||||
float e = expf(xr[c] * scale - m);
|
||||
yr[c] = e;
|
||||
sum += e;
|
||||
}
|
||||
sum = att_block_sum(sum);
|
||||
float inv = 1.0f / sum;
|
||||
for (int c = threadIdx.x; c < seq; c += blockDim.x)
|
||||
yr[c] = (c < valid) ? yr[c] * inv : 0.0f;
|
||||
}
|
||||
void launch_softmax_causal_f32(const float* x, float* y, int rows, int seq,
|
||||
float scale, void* s) {
|
||||
int blk = seq < 1024 ? seq : 1024;
|
||||
if (blk < 32) blk = 32;
|
||||
softmax_causal_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, y, seq, scale);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@@ -63,4 +63,26 @@ void launch_transpose_3d01_f32(const float* in, float* out, int a, int b, int c,
|
||||
transpose_3d01_k<<<grid, blk, 0, (cudaStream_t)s>>>(in, out, a, b, c);
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
// 4D axis-(1,2) transpose: in:[a,b,c,d] -> out:[a,c,b,d]. out[i,k,j,l]=in[i,j,k,l].
|
||||
// Lays out batched multi-head attention: [B,S,nh,hd] <-> [B,nh,S,hd], so a
|
||||
// flattened [B*nh, S, hd] view feeds the strided-batched-GEMM attention. Its own
|
||||
// backward is the same op (swap b,c), so one kernel suffices.
|
||||
// =====================================================================
|
||||
|
||||
__global__ void transpose_4d12_k(const float* in, float* out, int a, int b, int c, int d) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x; // over a*b*c*d
|
||||
if (idx >= a * b * c * d) return;
|
||||
int l = idx % d;
|
||||
int k = (idx / d) % c;
|
||||
int j = (idx / (d * c)) % b;
|
||||
int i = idx / (d * c * b);
|
||||
// out[i,k,j,l] at ((i*c + k)*b + j)*d + l
|
||||
out[(((i * c + k) * b) + j) * d + l] = in[idx];
|
||||
}
|
||||
void launch_transpose_4d12_f32(const float* in, float* out, int a, int b, int c, int d, void* s) {
|
||||
int n = a * b * c * d, blk = 256, grid = (n + blk - 1) / blk;
|
||||
transpose_4d12_k<<<grid, blk, 0, (cudaStream_t)s>>>(in, out, a, b, c, d);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
@@ -215,14 +215,20 @@ void launch_silu_dx_f32(const float* x, const float* dy, float* dx, int n, void*
|
||||
// dx[i+h] = dy[i+h]*cos - dy[i]*sin
|
||||
// =====================================================================
|
||||
|
||||
__global__ void rope_k(const float* x, float* y, int heads, int head_dim, float theta) {
|
||||
// `period` is the sequence length: a flattened batch lays B sequences end to end
|
||||
// along the `tokens` axis, so each token's RoPE position is its index WITHIN its
|
||||
// own sequence, `tok % period`. With period == tokens (single sequence) this is
|
||||
// the original position = row.
|
||||
__global__ void rope_k(const float* x, float* y, int heads, int head_dim,
|
||||
float theta, int period) {
|
||||
int tok = blockIdx.x;
|
||||
int head = blockIdx.y;
|
||||
int half = head_dim / 2;
|
||||
int i = threadIdx.x;
|
||||
if (i >= half) return;
|
||||
int pos = tok % period;
|
||||
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
|
||||
float angle = (float)tok * freq;
|
||||
float angle = (float)pos * freq;
|
||||
float c = cosf(angle), sn = sinf(angle);
|
||||
int base = (tok * heads + head) * head_dim;
|
||||
float x0 = x[base + i], x1 = x[base + i + half];
|
||||
@@ -230,20 +236,22 @@ __global__ void rope_k(const float* x, float* y, int heads, int head_dim, float
|
||||
y[base + i + half] = x1 * c + x0 * sn;
|
||||
}
|
||||
void launch_rope_f32(const float* x, float* y, int tokens, int heads,
|
||||
int head_dim, float theta, void* s) {
|
||||
int head_dim, float theta, int period, void* s) {
|
||||
dim3 grid(tokens, heads);
|
||||
int blk = head_dim / 2;
|
||||
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta);
|
||||
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, period);
|
||||
}
|
||||
|
||||
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim, float theta) {
|
||||
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
|
||||
float theta, int period) {
|
||||
int tok = blockIdx.x;
|
||||
int head = blockIdx.y;
|
||||
int half = head_dim / 2;
|
||||
int i = threadIdx.x;
|
||||
if (i >= half) return;
|
||||
int pos = tok % period;
|
||||
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
|
||||
float angle = (float)tok * freq;
|
||||
float angle = (float)pos * freq;
|
||||
float c = cosf(angle), sn = sinf(angle);
|
||||
int base = (tok * heads + head) * head_dim;
|
||||
float d0 = dy[base + i], d1 = dy[base + i + half];
|
||||
@@ -251,10 +259,10 @@ __global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim, f
|
||||
dx[base + i + half] = d1 * c - d0 * sn;
|
||||
}
|
||||
void launch_rope_dx_f32(const float* dy, float* dx, int tokens, int heads,
|
||||
int head_dim, float theta, void* s) {
|
||||
int head_dim, float theta, int period, void* s) {
|
||||
dim3 grid(tokens, heads);
|
||||
int blk = head_dim / 2;
|
||||
rope_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(dy, dx, heads, head_dim, theta);
|
||||
rope_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(dy, dx, heads, head_dim, theta, period);
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
|
||||
Reference in New Issue
Block a user