- Batched GEMM via cublasGemmStridedBatchedEx - Causal mask CUDA kernel (F32 + BF16) - Element-wise scale CUDA kernel (F32 + BF16) - attention() composing: batched_matmul + scale + causal_mask + softmax - Fixed to_device/contiguous infinite recursion (GPU contiguous via CPU round-trip) - 5 attention tests passing (max_err < 3e-7 F32) - Total: 61 tests passing across all crates Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
188 lines
7.0 KiB
Rust
188 lines
7.0 KiB
Rust
use xserv_kernels::*;
|
|
use xserv_tensor::{Device, Tensor};
|
|
|
|
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
|
|
|
|
fn cpu_attention(q: &[f32], k: &[f32], v: &[f32],
|
|
batch: usize, heads: usize, q_len: usize, kv_len: usize, head_dim: usize,
|
|
causal: bool) -> Vec<f32> {
|
|
let mut out = vec![0.0f32; batch * heads * q_len * head_dim];
|
|
let scale = 1.0 / (head_dim as f32).sqrt();
|
|
|
|
for b in 0..batch {
|
|
for h in 0..heads {
|
|
// scores = Q @ K^T, scaled
|
|
let mut scores = vec![0.0f32; q_len * kv_len];
|
|
for i in 0..q_len {
|
|
for j in 0..kv_len {
|
|
let mut s = 0.0f32;
|
|
for d in 0..head_dim {
|
|
let qi = q[((b * heads + h) * q_len + i) * head_dim + d];
|
|
let ki = k[((b * heads + h) * kv_len + j) * head_dim + d];
|
|
s += qi * ki;
|
|
}
|
|
scores[i * kv_len + j] = s * scale;
|
|
}
|
|
}
|
|
// causal mask
|
|
if causal {
|
|
let offset = kv_len - q_len;
|
|
for i in 0..q_len {
|
|
for j in 0..kv_len {
|
|
if j > i + offset {
|
|
scores[i * kv_len + j] = f32::NEG_INFINITY;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// softmax per row
|
|
for i in 0..q_len {
|
|
let row = &mut scores[i * kv_len..(i + 1) * kv_len];
|
|
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
|
let mut sum = 0.0f32;
|
|
for v in row.iter_mut() {
|
|
*v = (*v - max).exp();
|
|
sum += *v;
|
|
}
|
|
for v in row.iter_mut() {
|
|
*v /= sum;
|
|
}
|
|
}
|
|
// output = weights @ V
|
|
for i in 0..q_len {
|
|
for d in 0..head_dim {
|
|
let mut s = 0.0f32;
|
|
for j in 0..kv_len {
|
|
let w = scores[i * kv_len + j];
|
|
let vi = v[((b * heads + h) * kv_len + j) * head_dim + d];
|
|
s += w * vi;
|
|
}
|
|
out[((b * heads + h) * q_len + i) * head_dim + d] = s;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
out
|
|
}
|
|
|
|
fn check_close(a: &[f32], b: &[f32], atol: f32, name: &str) {
|
|
assert_eq!(a.len(), b.len(), "{name}: length mismatch");
|
|
let mut max_err = 0.0f32;
|
|
for (i, (x, y)) in a.iter().zip(b).enumerate() {
|
|
let err = (x - y).abs();
|
|
if err > max_err { max_err = err; }
|
|
assert!(err <= atol, "{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}");
|
|
}
|
|
println!("{name}: max_err = {max_err:.6e}");
|
|
}
|
|
|
|
fn make_data(n: usize) -> Vec<f32> {
|
|
(0..n).map(|i| ((i % 17) as f32 - 8.0) * 0.05).collect()
|
|
}
|
|
|
|
#[test]
|
|
fn test_batched_matmul() {
|
|
init();
|
|
let batch = 4;
|
|
let heads = 8;
|
|
let m = 32;
|
|
let k = 64;
|
|
let n = 32;
|
|
|
|
let a_data = make_data(batch * heads * m * k);
|
|
let b_data = make_data(batch * heads * k * n);
|
|
|
|
let a = Tensor::from_slice(&a_data, &[batch, heads, m, k]).to_device(Device::Cuda(0));
|
|
let b = Tensor::from_slice(&b_data, &[batch, heads, k, n]).to_device(Device::Cuda(0));
|
|
let c = batched_matmul(&a, &b).to_device(Device::Cpu);
|
|
|
|
assert_eq!(c.shape(), &[batch, heads, m, n]);
|
|
|
|
// Verify one batch element
|
|
let a_cpu = &a_data[0..m * k];
|
|
let b_cpu = &b_data[0..k * n];
|
|
let mut expected = vec![0.0f32; m * n];
|
|
for i in 0..m {
|
|
for j in 0..n {
|
|
let mut s = 0.0f32;
|
|
for kk in 0..k { s += a_cpu[i * k + kk] * b_cpu[kk * n + j]; }
|
|
expected[i * n + j] = s;
|
|
}
|
|
}
|
|
let result = c.as_slice::<f32>();
|
|
check_close(&result[0..m * n], &expected, 1e-3, "batched_matmul[0]");
|
|
}
|
|
|
|
#[test]
|
|
fn test_attention_no_causal() {
|
|
init();
|
|
let b = 1; let h = 2; let s = 8; let d = 16;
|
|
let q_data = make_data(b * h * s * d);
|
|
let k_data = make_data(b * h * s * d);
|
|
let v_data = make_data(b * h * s * d);
|
|
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, false);
|
|
|
|
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let out = attention(&q, &k, &v, false).to_device(Device::Cpu);
|
|
check_close(out.as_slice::<f32>(), &expected, 1e-4, "attention_no_causal");
|
|
}
|
|
|
|
#[test]
|
|
fn test_attention_causal() {
|
|
init();
|
|
let b = 1; let h = 2; let s = 16; let d = 32;
|
|
let q_data = make_data(b * h * s * d);
|
|
let k_data = make_data(b * h * s * d);
|
|
let v_data = make_data(b * h * s * d);
|
|
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true);
|
|
|
|
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
|
check_close(out.as_slice::<f32>(), &expected, 1e-3, "attention_causal");
|
|
}
|
|
|
|
#[test]
|
|
fn test_attention_causal_larger() {
|
|
init();
|
|
let b = 2; let h = 4; let s = 64; let d = 64;
|
|
let q_data = make_data(b * h * s * d);
|
|
let k_data = make_data(b * h * s * d);
|
|
let v_data = make_data(b * h * s * d);
|
|
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true);
|
|
|
|
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
|
check_close(out.as_slice::<f32>(), &expected, 1e-2, "attention_causal_larger");
|
|
}
|
|
|
|
#[test]
|
|
fn test_attention_causal_first_row_sees_only_first_token() {
|
|
init();
|
|
let b = 1; let h = 1; let s = 4; let d = 8;
|
|
let q_data = make_data(b * h * s * d);
|
|
let k_data = make_data(b * h * s * d);
|
|
let v_data: Vec<f32> = (0..s * d).map(|i| {
|
|
if i < d { 1.0 } else { 0.0 } // only first V row is nonzero
|
|
}).collect();
|
|
|
|
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
|
|
|
// First row (position 0) with causal mask can only see position 0.
|
|
// So attention weight for position 0 is 1.0 for token 0 only.
|
|
// output[0] should be exactly V[0] = [1, 1, 1, ...1]
|
|
let result = out.as_slice::<f32>();
|
|
for i in 0..d {
|
|
assert!((result[i] - 1.0).abs() < 1e-5,
|
|
"first row should equal V[0], got {} at dim {}", result[i], i);
|
|
}
|
|
}
|