Files
xserv/crates/xserv-kernels/tests/attention_test.rs

233 lines
7.3 KiB
Rust

use xserv_kernels::*;
use xserv_tensor::{Device, Tensor};
fn init() {
xserv_cuda::device::set_device(0).unwrap();
}
fn cpu_attention(
q: &[f32],
k: &[f32],
v: &[f32],
batch: usize,
heads: usize,
q_len: usize,
kv_len: usize,
head_dim: usize,
causal: bool,
) -> Vec<f32> {
let mut out = vec![0.0f32; batch * heads * q_len * head_dim];
let scale = 1.0 / (head_dim as f32).sqrt();
for b in 0..batch {
for h in 0..heads {
// scores = Q @ K^T, scaled
let mut scores = vec![0.0f32; q_len * kv_len];
for i in 0..q_len {
for j in 0..kv_len {
let mut s = 0.0f32;
for d in 0..head_dim {
let qi = q[((b * heads + h) * q_len + i) * head_dim + d];
let ki = k[((b * heads + h) * kv_len + j) * head_dim + d];
s += qi * ki;
}
scores[i * kv_len + j] = s * scale;
}
}
// causal mask
if causal {
let offset = kv_len - q_len;
for i in 0..q_len {
for j in 0..kv_len {
if j > i + offset {
scores[i * kv_len + j] = f32::NEG_INFINITY;
}
}
}
}
// softmax per row
for i in 0..q_len {
let row = &mut scores[i * kv_len..(i + 1) * kv_len];
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for v in row.iter_mut() {
*v = (*v - max).exp();
sum += *v;
}
for v in row.iter_mut() {
*v /= sum;
}
}
// output = weights @ V
for i in 0..q_len {
for d in 0..head_dim {
let mut s = 0.0f32;
for j in 0..kv_len {
let w = scores[i * kv_len + j];
let vi = v[((b * heads + h) * kv_len + j) * head_dim + d];
s += w * vi;
}
out[((b * heads + h) * q_len + i) * head_dim + d] = s;
}
}
}
}
out
}
fn check_close(a: &[f32], b: &[f32], atol: f32, name: &str) {
assert_eq!(a.len(), b.len(), "{name}: length mismatch");
let mut max_err = 0.0f32;
for (i, (x, y)) in a.iter().zip(b).enumerate() {
let err = (x - y).abs();
if err > max_err {
max_err = err;
}
assert!(
err <= atol,
"{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}"
);
}
println!("{name}: max_err = {max_err:.6e}");
}
fn make_data(n: usize) -> Vec<f32> {
(0..n).map(|i| ((i % 17) as f32 - 8.0) * 0.05).collect()
}
#[test]
fn test_batched_matmul() {
init();
let batch = 4;
let heads = 8;
let m = 32;
let k = 64;
let n = 32;
let a_data = make_data(batch * heads * m * k);
let b_data = make_data(batch * heads * k * n);
let a = Tensor::from_slice(&a_data, &[batch, heads, m, k]).to_device(Device::Cuda(0));
let b = Tensor::from_slice(&b_data, &[batch, heads, k, n]).to_device(Device::Cuda(0));
let c = batched_matmul(&a, &b).to_device(Device::Cpu);
assert_eq!(c.shape(), &[batch, heads, m, n]);
// Verify one batch element
let a_cpu = &a_data[0..m * k];
let b_cpu = &b_data[0..k * n];
let mut expected = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut s = 0.0f32;
for kk in 0..k {
s += a_cpu[i * k + kk] * b_cpu[kk * n + j];
}
expected[i * n + j] = s;
}
}
let result = c.as_slice::<f32>();
check_close(&result[0..m * n], &expected, 1e-3, "batched_matmul[0]");
}
#[test]
fn test_attention_no_causal() {
init();
let b = 1;
let h = 2;
let s = 8;
let d = 16;
let q_data = make_data(b * h * s * d);
let k_data = make_data(b * h * s * d);
let v_data = make_data(b * h * s * d);
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, false);
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let out = attention(&q, &k, &v, false).to_device(Device::Cpu);
check_close(
out.as_slice::<f32>(),
&expected,
1e-4,
"attention_no_causal",
);
}
#[test]
fn test_attention_causal() {
init();
let b = 1;
let h = 2;
let s = 16;
let d = 32;
let q_data = make_data(b * h * s * d);
let k_data = make_data(b * h * s * d);
let v_data = make_data(b * h * s * d);
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true);
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
check_close(out.as_slice::<f32>(), &expected, 1e-3, "attention_causal");
}
#[test]
fn test_attention_causal_larger() {
init();
let b = 2;
let h = 4;
let s = 64;
let d = 64;
let q_data = make_data(b * h * s * d);
let k_data = make_data(b * h * s * d);
let v_data = make_data(b * h * s * d);
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true);
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
check_close(
out.as_slice::<f32>(),
&expected,
1e-2,
"attention_causal_larger",
);
}
#[test]
fn test_attention_causal_first_row_sees_only_first_token() {
init();
let b = 1;
let h = 1;
let s = 4;
let d = 8;
let q_data = make_data(b * h * s * d);
let k_data = make_data(b * h * s * d);
let v_data: Vec<f32> = (0..s * d)
.map(|i| {
if i < d { 1.0 } else { 0.0 } // only first V row is nonzero
})
.collect();
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
// First row (position 0) with causal mask can only see position 0.
// So attention weight for position 0 is 1.0 for token 0 only.
// output[0] should be exactly V[0] = [1, 1, 1, ...1]
let result = out.as_slice::<f32>();
for i in 0..d {
assert!(
(result[i] - 1.0).abs() < 1e-5,
"first row should equal V[0], got {} at dim {}",
result[i],
i
);
}
}