Two forward-only Tensor primitives the KV-cache decode engine is built on, each gated by an isolated correctness test: - rope_at(theta, pos0): RoPE at an absolute position (pos = pos0 + row, no modulo) for a single decode token, vs the training rope_k (pos = row % period) left untouched. New forward-only CUDA kernel, no training-path risk. Gate: bit-identical to the full-sequence rope's corresponding row. - decode_attention(k, v, scale): single-query × cached-K/V SDPA, composed from the existing strided batched GEMM + plain (non-causal) softmax — no new kernel. Gate: equals the full causal attention's last query row (max |Δ| 6e-8). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
162 lines
5.9 KiB
Rust
162 lines
5.9 KiB
Rust
// GPU integration tests for the tensor abstraction. Both require nvcc + a GPU,
|
|
// so they are gated behind `not(no_cuda)`. On a GPU-less machine build.rs sets
|
|
// the `no_cuda` cfg and these compile out, keeping host `cargo check` green.
|
|
#![cfg(not(no_cuda))]
|
|
|
|
use xtrain_cuda::device;
|
|
use xtrain_tensor::{Device, Tensor};
|
|
|
|
/// (a) Host → device → host roundtrip preserves the data exactly.
|
|
#[test]
|
|
fn host_device_roundtrip() {
|
|
assert!(
|
|
device::device_count().expect("device count") > 0,
|
|
"no CUDA device"
|
|
);
|
|
device::set_device(0).unwrap();
|
|
|
|
let host: Vec<f32> = (0..1024).map(|i| i as f32 * 0.5).collect();
|
|
let cpu = Tensor::from_slice(&host, &[1024]);
|
|
|
|
let gpu = cpu.to_device(Device::Cuda(0));
|
|
assert_eq!(gpu.device(), Device::Cuda(0));
|
|
assert_eq!(gpu.shape(), &[1024]);
|
|
|
|
let back = gpu.to_device(Device::Cpu);
|
|
assert_eq!(back.device(), Device::Cpu);
|
|
assert_eq!(back.as_slice::<f32>(), host.as_slice());
|
|
println!("roundtrip OK: {} elems preserved", host.len());
|
|
}
|
|
|
|
/// (b) The elementwise `scale` kernel produces correct results.
|
|
#[test]
|
|
fn elementwise_scale_kernel() {
|
|
assert!(
|
|
device::device_count().expect("device count") > 0,
|
|
"no CUDA device"
|
|
);
|
|
device::set_device(0).unwrap();
|
|
|
|
let host: Vec<f32> = (0..2048).map(|i| i as f32).collect();
|
|
let alpha = 3.0f32;
|
|
let expected: Vec<f32> = host.iter().map(|x| x * alpha).collect();
|
|
|
|
let gpu = Tensor::from_slice(&host, &[2048]).to_device(Device::Cuda(0));
|
|
let scaled = gpu.scale(alpha);
|
|
let result = scaled.to_device(Device::Cpu);
|
|
|
|
assert_eq!(result.shape(), &[2048]);
|
|
assert_eq!(result.as_slice::<f32>(), expected.as_slice());
|
|
let r = result.as_slice::<f32>();
|
|
println!(
|
|
"scale OK (alpha={alpha}): first={} mid={} last={} ({} elems)",
|
|
r[0],
|
|
r[r.len() / 2],
|
|
r[r.len() - 1],
|
|
r.len()
|
|
);
|
|
}
|
|
|
|
/// (c) `rope_at` (KV-cache decode RoPE at an absolute position) is bit-identical
|
|
/// to the full-sequence `rope`'s corresponding row. This is the invariant the
|
|
/// decode KV-cache relies on: a single new token RoPE'd at position `t` must equal
|
|
/// what the full-sequence forward would have produced at row `t` (so cached
|
|
/// post-RoPE K matches the full-recompute path → token-identical decode).
|
|
#[test]
|
|
fn rope_at_matches_full_rope_row() {
|
|
assert!(
|
|
device::device_count().expect("device count") > 0,
|
|
"no CUDA device"
|
|
);
|
|
device::set_device(0).unwrap();
|
|
|
|
let (n, heads, hd) = (7usize, 3usize, 8usize);
|
|
let theta = 10000.0f32;
|
|
// Deterministic pseudo-random fill in [-1, 1).
|
|
let host: Vec<f32> = (0..n * heads * hd)
|
|
.map(|i| ((i * 37 % 101) as f32 / 50.0) - 1.0)
|
|
.collect();
|
|
|
|
// Full-sequence rope (period = n → row r gets position r).
|
|
let full = Tensor::from_slice(&host, &[n, heads, hd]).to_device(Device::Cuda(0));
|
|
let roped_full = full
|
|
.rope(theta, n)
|
|
.to_device(Device::Cpu)
|
|
.as_slice::<f32>()
|
|
.to_vec();
|
|
|
|
let row_len = heads * hd;
|
|
for t in 0..n {
|
|
let row = &host[t * row_len..(t + 1) * row_len];
|
|
let roped_row = Tensor::from_slice(row, &[1, heads, hd])
|
|
.to_device(Device::Cuda(0))
|
|
.rope_at(theta, t)
|
|
.to_device(Device::Cpu)
|
|
.as_slice::<f32>()
|
|
.to_vec();
|
|
let expect = &roped_full[t * row_len..(t + 1) * row_len];
|
|
assert_eq!(
|
|
roped_row.as_slice(),
|
|
expect,
|
|
"rope_at(pos0={t}) != full rope row {t}"
|
|
);
|
|
}
|
|
println!("rope_at OK: bit-identical to full rope across {n} positions");
|
|
}
|
|
|
|
/// (d) `decode_attention` (single query vs cached K/V, no mask) equals the LAST
|
|
/// query row of the full causal `attention`. This is the core decode-engine
|
|
/// invariant: the incremental path must reproduce what the full-recompute forward
|
|
/// computes for the final position, so KV-cache greedy decode is token-identical.
|
|
/// Tolerance is fp rounding (different softmax kernel + reduction order), not bits.
|
|
#[test]
|
|
fn decode_attention_matches_full_attention_last_row() {
|
|
assert!(
|
|
device::device_count().expect("device count") > 0,
|
|
"no CUDA device"
|
|
);
|
|
device::set_device(0).unwrap();
|
|
|
|
let (bh, t, hd) = (6usize, 5usize, 8usize);
|
|
let scale = 1.0 / (hd as f32).sqrt();
|
|
let n = bh * t * hd;
|
|
let qh: Vec<f32> = (0..n).map(|i| ((i * 31 % 97) as f32 / 48.0) - 1.0).collect();
|
|
let kh: Vec<f32> = (0..n).map(|i| ((i * 53 % 89) as f32 / 44.0) - 1.0).collect();
|
|
let vh: Vec<f32> = (0..n).map(|i| ((i * 17 % 83) as f32 / 41.0) - 1.0).collect();
|
|
let q = Tensor::from_slice(&qh, &[bh, t, hd]).to_device(Device::Cuda(0));
|
|
let k = Tensor::from_slice(&kh, &[bh, t, hd]).to_device(Device::Cuda(0));
|
|
let v = Tensor::from_slice(&vh, &[bh, t, hd]).to_device(Device::Cuda(0));
|
|
|
|
// Reference: full causal attention, take each head's last query row.
|
|
let (full, _) = q.attention(&k, &v, scale);
|
|
let full_h = full.to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
|
|
|
// Decode: build Q_last [bh,1,hd] from each head's last row, attend to all K/V.
|
|
let mut ql = vec![0f32; bh * hd];
|
|
for b in 0..bh {
|
|
let src = (b * t + (t - 1)) * hd;
|
|
ql[b * hd..(b + 1) * hd].copy_from_slice(&qh[src..src + hd]);
|
|
}
|
|
let q_last = Tensor::from_slice(&ql, &[bh, 1, hd]).to_device(Device::Cuda(0));
|
|
let dec = q_last
|
|
.decode_attention(&k, &v, scale)
|
|
.to_device(Device::Cpu)
|
|
.as_slice::<f32>()
|
|
.to_vec();
|
|
assert_eq!(dec.len(), bh * hd, "decode out shape");
|
|
|
|
let mut max_abs = 0f32;
|
|
for b in 0..bh {
|
|
for d in 0..hd {
|
|
let got = dec[b * hd + d];
|
|
let exp = full_h[(b * t + (t - 1)) * hd + d];
|
|
max_abs = max_abs.max((got - exp).abs());
|
|
}
|
|
}
|
|
assert!(
|
|
max_abs < 1e-4,
|
|
"decode_attention vs full last-row max abs diff {max_abs} exceeds 1e-4"
|
|
);
|
|
println!("decode_attention OK: matches full causal last row (bh={bh}, t={t}, max|Δ|={max_abs:.2e})");
|
|
}
|