post-train: M2 — decode primitives (rope_at + decode_attention)
Two forward-only Tensor primitives the KV-cache decode engine is built on, each gated by an isolated correctness test: - rope_at(theta, pos0): RoPE at an absolute position (pos = pos0 + row, no modulo) for a single decode token, vs the training rope_k (pos = row % period) left untouched. New forward-only CUDA kernel, no training-path risk. Gate: bit-identical to the full-sequence rope's corresponding row. - decode_attention(k, v, scale): single-query × cached-K/V SDPA, composed from the existing strided batched GEMM + plain (non-causal) softmax — no new kernel. Gate: equals the full causal attention's last query row (max |Δ| 6e-8). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -139,6 +139,19 @@ unsafe extern "C" {
|
|||||||
period: i32,
|
period: i32,
|
||||||
s: CudaStream,
|
s: CudaStream,
|
||||||
);
|
);
|
||||||
|
// RoPE at an absolute position offset (KV-cache decode, forward only): row
|
||||||
|
// `tok`'s position is `pos0 + tok` (no modulo). For a single decode token
|
||||||
|
// (tokens == 1) the one row sits at absolute position `pos0`.
|
||||||
|
pub fn launch_rope_at_f32(
|
||||||
|
x: *const f32,
|
||||||
|
y: *mut f32,
|
||||||
|
tokens: i32,
|
||||||
|
heads: i32,
|
||||||
|
head_dim: i32,
|
||||||
|
theta: f32,
|
||||||
|
pos0: i32,
|
||||||
|
s: CudaStream,
|
||||||
|
);
|
||||||
pub fn launch_rope_dx_f32(
|
pub fn launch_rope_dx_f32(
|
||||||
dy: *const f32,
|
dy: *const f32,
|
||||||
dx: *mut f32,
|
dx: *mut f32,
|
||||||
|
|||||||
@@ -790,6 +790,38 @@ impl Tensor {
|
|||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// RoPE at an absolute position offset (KV-cache decode, forward only).
|
||||||
|
/// `self`:[tokens,heads,head_dim]; row `r`'s position is `pos0 + r` (no
|
||||||
|
/// modulo). For a single new decode token pass `tokens == 1` → the one row is
|
||||||
|
/// rotated at absolute position `pos0`. Mirrors [`rope`](Self::rope)'s dtype
|
||||||
|
/// handling (bf16 → f32 → bf16); no backward (inference path).
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
pub fn rope_at(&self, theta: f32, pos0: usize) -> Self {
|
||||||
|
assert_eq!(self.ndim(), 3, "rope_at requires [tokens,heads,head_dim]");
|
||||||
|
let (tokens, heads, head_dim) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||||
|
assert_eq!(head_dim % 2, 0, "head_dim must be even");
|
||||||
|
if self.dtype == DType::BF16 {
|
||||||
|
return self
|
||||||
|
.to_dtype(DType::F32)
|
||||||
|
.rope_at(theta, pos0)
|
||||||
|
.to_dtype(DType::BF16);
|
||||||
|
}
|
||||||
|
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||||
|
unsafe {
|
||||||
|
xtrain_cuda::ffi::launch_rope_at_f32(
|
||||||
|
self.data_ptr() as *const f32,
|
||||||
|
out.data_ptr() as *mut f32,
|
||||||
|
tokens as i32,
|
||||||
|
heads as i32,
|
||||||
|
head_dim as i32,
|
||||||
|
theta,
|
||||||
|
pos0 as i32,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
/// RoPE backward: apply the inverse (transpose) rotation to `dy`. RoPE is an
|
/// RoPE backward: apply the inverse (transpose) rotation to `dy`. RoPE is an
|
||||||
/// orthogonal map, so it needs no cached forward values, only `theta`/`period`.
|
/// orthogonal map, so it needs no cached forward values, only `theta`/`period`.
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
@@ -1076,6 +1108,76 @@ impl Tensor {
|
|||||||
(out, probs)
|
(out, probs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Decode-time (incremental) attention: a SINGLE query position against a
|
||||||
|
/// cached K/V of length `t` (KV-cache decode, forward only). `self` = Q
|
||||||
|
/// `[bh,1,head_dim]`; `k`,`v` = `[bh,t,head_dim]`, already repeat_kv-expanded
|
||||||
|
/// to `bh` heads. Returns out `[bh,head_dim]` (= `[bh,1,head_dim]` flattened).
|
||||||
|
///
|
||||||
|
/// No causal mask is needed — the one query sits at the end, so every cached
|
||||||
|
/// key (positions `0..t`) is visible. This is exactly the LAST query row of the
|
||||||
|
/// full causal [`attention`](Self::attention), so KV-cache greedy decode is
|
||||||
|
/// token-identical to full recompute. Softmax is computed in f32 (matching the
|
||||||
|
/// causal path) with `scale` folded in before the exponentials.
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
pub fn decode_attention(&self, k: &Tensor, v: &Tensor, scale: f32) -> Self {
|
||||||
|
assert_eq!(self.ndim(), 3, "decode_attention Q must be [bh,1,head_dim]");
|
||||||
|
assert_eq!(self.shape[1], 1, "decode_attention Q seq must be 1");
|
||||||
|
assert_eq!(k.ndim(), 3, "decode_attention K must be [bh,t,head_dim]");
|
||||||
|
assert_eq!(k.shape(), v.shape(), "K/V shape mismatch");
|
||||||
|
assert_eq!(self.dtype, k.dtype, "Q/K dtype mismatch");
|
||||||
|
assert_eq!(self.dtype, v.dtype, "Q/V dtype mismatch");
|
||||||
|
let (bh, hd) = (self.shape[0], self.shape[2]);
|
||||||
|
assert_eq!(k.shape[0], bh, "Q/K batch-head mismatch");
|
||||||
|
assert_eq!(k.shape[2], hd, "Q/K head_dim mismatch");
|
||||||
|
let t = k.shape[1]; // cached length
|
||||||
|
let dt = self.dtype;
|
||||||
|
let dev = self.device();
|
||||||
|
|
||||||
|
// scores[bh,1,t] = Q[bh,1,hd] · Kᵀ[bh,hd,t] (per-head batched GEMM).
|
||||||
|
// [bh,1,t] is stored identically to [bh,t]; allocate 2D so the rowwise
|
||||||
|
// softmax can run without a reshape.
|
||||||
|
let scores = Tensor::zeros(&[bh, t], dt, dev);
|
||||||
|
strided_batched_gemm(
|
||||||
|
dt,
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
1,
|
||||||
|
t,
|
||||||
|
hd,
|
||||||
|
self.data_ptr(),
|
||||||
|
hd,
|
||||||
|
k.data_ptr(),
|
||||||
|
t * hd,
|
||||||
|
scores.data_ptr(),
|
||||||
|
t,
|
||||||
|
bh,
|
||||||
|
);
|
||||||
|
// probs = softmax(scale · scores) over the t keys (f32, like the causal path).
|
||||||
|
let probs = scores
|
||||||
|
.to_dtype(DType::F32)
|
||||||
|
.scale(scale)
|
||||||
|
.softmax()
|
||||||
|
.to_dtype(dt);
|
||||||
|
// out[bh,1,hd] = probs[bh,1,t] · V[bh,t,hd].
|
||||||
|
let out = Tensor::zeros(&[bh, hd], dt, dev);
|
||||||
|
strided_batched_gemm(
|
||||||
|
dt,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
1,
|
||||||
|
hd,
|
||||||
|
t,
|
||||||
|
probs.data_ptr(),
|
||||||
|
t,
|
||||||
|
v.data_ptr(),
|
||||||
|
t * hd,
|
||||||
|
out.data_ptr(),
|
||||||
|
hd,
|
||||||
|
bh,
|
||||||
|
);
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
/// Backward of [`attention`](Self::attention). Inputs: forward `q`,`k`,`v`,
|
/// Backward of [`attention`](Self::attention). Inputs: forward `q`,`k`,`v`,
|
||||||
/// the cached `probs`, the upstream `dout` (all batched `[bh,seq,*]`), and the
|
/// the cached `probs`, the upstream `dout` (all batched `[bh,seq,*]`), and the
|
||||||
/// same `scale`. Returns `(dq, dk, dv)`.
|
/// same `scale`. Returns `(dq, dk, dv)`.
|
||||||
|
|||||||
@@ -56,3 +56,106 @@ fn elementwise_scale_kernel() {
|
|||||||
r.len()
|
r.len()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// (c) `rope_at` (KV-cache decode RoPE at an absolute position) is bit-identical
|
||||||
|
/// to the full-sequence `rope`'s corresponding row. This is the invariant the
|
||||||
|
/// decode KV-cache relies on: a single new token RoPE'd at position `t` must equal
|
||||||
|
/// what the full-sequence forward would have produced at row `t` (so cached
|
||||||
|
/// post-RoPE K matches the full-recompute path → token-identical decode).
|
||||||
|
#[test]
|
||||||
|
fn rope_at_matches_full_rope_row() {
|
||||||
|
assert!(
|
||||||
|
device::device_count().expect("device count") > 0,
|
||||||
|
"no CUDA device"
|
||||||
|
);
|
||||||
|
device::set_device(0).unwrap();
|
||||||
|
|
||||||
|
let (n, heads, hd) = (7usize, 3usize, 8usize);
|
||||||
|
let theta = 10000.0f32;
|
||||||
|
// Deterministic pseudo-random fill in [-1, 1).
|
||||||
|
let host: Vec<f32> = (0..n * heads * hd)
|
||||||
|
.map(|i| ((i * 37 % 101) as f32 / 50.0) - 1.0)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Full-sequence rope (period = n → row r gets position r).
|
||||||
|
let full = Tensor::from_slice(&host, &[n, heads, hd]).to_device(Device::Cuda(0));
|
||||||
|
let roped_full = full
|
||||||
|
.rope(theta, n)
|
||||||
|
.to_device(Device::Cpu)
|
||||||
|
.as_slice::<f32>()
|
||||||
|
.to_vec();
|
||||||
|
|
||||||
|
let row_len = heads * hd;
|
||||||
|
for t in 0..n {
|
||||||
|
let row = &host[t * row_len..(t + 1) * row_len];
|
||||||
|
let roped_row = Tensor::from_slice(row, &[1, heads, hd])
|
||||||
|
.to_device(Device::Cuda(0))
|
||||||
|
.rope_at(theta, t)
|
||||||
|
.to_device(Device::Cpu)
|
||||||
|
.as_slice::<f32>()
|
||||||
|
.to_vec();
|
||||||
|
let expect = &roped_full[t * row_len..(t + 1) * row_len];
|
||||||
|
assert_eq!(
|
||||||
|
roped_row.as_slice(),
|
||||||
|
expect,
|
||||||
|
"rope_at(pos0={t}) != full rope row {t}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
println!("rope_at OK: bit-identical to full rope across {n} positions");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// (d) `decode_attention` (single query vs cached K/V, no mask) equals the LAST
|
||||||
|
/// query row of the full causal `attention`. This is the core decode-engine
|
||||||
|
/// invariant: the incremental path must reproduce what the full-recompute forward
|
||||||
|
/// computes for the final position, so KV-cache greedy decode is token-identical.
|
||||||
|
/// Tolerance is fp rounding (different softmax kernel + reduction order), not bits.
|
||||||
|
#[test]
|
||||||
|
fn decode_attention_matches_full_attention_last_row() {
|
||||||
|
assert!(
|
||||||
|
device::device_count().expect("device count") > 0,
|
||||||
|
"no CUDA device"
|
||||||
|
);
|
||||||
|
device::set_device(0).unwrap();
|
||||||
|
|
||||||
|
let (bh, t, hd) = (6usize, 5usize, 8usize);
|
||||||
|
let scale = 1.0 / (hd as f32).sqrt();
|
||||||
|
let n = bh * t * hd;
|
||||||
|
let qh: Vec<f32> = (0..n).map(|i| ((i * 31 % 97) as f32 / 48.0) - 1.0).collect();
|
||||||
|
let kh: Vec<f32> = (0..n).map(|i| ((i * 53 % 89) as f32 / 44.0) - 1.0).collect();
|
||||||
|
let vh: Vec<f32> = (0..n).map(|i| ((i * 17 % 83) as f32 / 41.0) - 1.0).collect();
|
||||||
|
let q = Tensor::from_slice(&qh, &[bh, t, hd]).to_device(Device::Cuda(0));
|
||||||
|
let k = Tensor::from_slice(&kh, &[bh, t, hd]).to_device(Device::Cuda(0));
|
||||||
|
let v = Tensor::from_slice(&vh, &[bh, t, hd]).to_device(Device::Cuda(0));
|
||||||
|
|
||||||
|
// Reference: full causal attention, take each head's last query row.
|
||||||
|
let (full, _) = q.attention(&k, &v, scale);
|
||||||
|
let full_h = full.to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||||
|
|
||||||
|
// Decode: build Q_last [bh,1,hd] from each head's last row, attend to all K/V.
|
||||||
|
let mut ql = vec![0f32; bh * hd];
|
||||||
|
for b in 0..bh {
|
||||||
|
let src = (b * t + (t - 1)) * hd;
|
||||||
|
ql[b * hd..(b + 1) * hd].copy_from_slice(&qh[src..src + hd]);
|
||||||
|
}
|
||||||
|
let q_last = Tensor::from_slice(&ql, &[bh, 1, hd]).to_device(Device::Cuda(0));
|
||||||
|
let dec = q_last
|
||||||
|
.decode_attention(&k, &v, scale)
|
||||||
|
.to_device(Device::Cpu)
|
||||||
|
.as_slice::<f32>()
|
||||||
|
.to_vec();
|
||||||
|
assert_eq!(dec.len(), bh * hd, "decode out shape");
|
||||||
|
|
||||||
|
let mut max_abs = 0f32;
|
||||||
|
for b in 0..bh {
|
||||||
|
for d in 0..hd {
|
||||||
|
let got = dec[b * hd + d];
|
||||||
|
let exp = full_h[(b * t + (t - 1)) * hd + d];
|
||||||
|
max_abs = max_abs.max((got - exp).abs());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert!(
|
||||||
|
max_abs < 1e-4,
|
||||||
|
"decode_attention vs full last-row max abs diff {max_abs} exceeds 1e-4"
|
||||||
|
);
|
||||||
|
println!("decode_attention OK: matches full causal last row (bh={bh}, t={t}, max|Δ|={max_abs:.2e})");
|
||||||
|
}
|
||||||
|
|||||||
@@ -242,6 +242,33 @@ void launch_rope_f32(const float* x, float* y, int tokens, int heads,
|
|||||||
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, period);
|
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, period);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RoPE at an absolute position offset (KV-cache decode-time, forward only). Same
|
||||||
|
// rotate_half as rope_k, but row `tok`'s position is `pos0 + tok` (no modulo) —
|
||||||
|
// a single new decode token sits at absolute position pos0. The training rope_k
|
||||||
|
// (position = tok % period) is left untouched, so this adds no training-path risk.
|
||||||
|
__global__ void rope_at_k(const float* x, float* y, int heads, int head_dim,
|
||||||
|
float theta, int pos0) {
|
||||||
|
int tok = blockIdx.x;
|
||||||
|
int head = blockIdx.y;
|
||||||
|
int half = head_dim / 2;
|
||||||
|
int i = threadIdx.x;
|
||||||
|
if (i >= half) return;
|
||||||
|
int pos = pos0 + tok;
|
||||||
|
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
|
||||||
|
float angle = (float)pos * freq;
|
||||||
|
float c = cosf(angle), sn = sinf(angle);
|
||||||
|
int base = (tok * heads + head) * head_dim;
|
||||||
|
float x0 = x[base + i], x1 = x[base + i + half];
|
||||||
|
y[base + i] = x0 * c - x1 * sn;
|
||||||
|
y[base + i + half] = x1 * c + x0 * sn;
|
||||||
|
}
|
||||||
|
void launch_rope_at_f32(const float* x, float* y, int tokens, int heads,
|
||||||
|
int head_dim, float theta, int pos0, void* s) {
|
||||||
|
dim3 grid(tokens, heads);
|
||||||
|
int blk = head_dim / 2;
|
||||||
|
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
|
||||||
|
}
|
||||||
|
|
||||||
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
|
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
|
||||||
float theta, int period) {
|
float theta, int period) {
|
||||||
int tok = blockIdx.x;
|
int tok = blockIdx.x;
|
||||||
|
|||||||
Reference in New Issue
Block a user