233 lines
7.3 KiB
Rust
233 lines
7.3 KiB
Rust
use xserv_kernels::*;
|
|
use xserv_tensor::{Device, Tensor};
|
|
|
|
fn init() {
|
|
xserv_cuda::device::set_device(0).unwrap();
|
|
}
|
|
|
|
fn cpu_attention(
|
|
q: &[f32],
|
|
k: &[f32],
|
|
v: &[f32],
|
|
batch: usize,
|
|
heads: usize,
|
|
q_len: usize,
|
|
kv_len: usize,
|
|
head_dim: usize,
|
|
causal: bool,
|
|
) -> Vec<f32> {
|
|
let mut out = vec![0.0f32; batch * heads * q_len * head_dim];
|
|
let scale = 1.0 / (head_dim as f32).sqrt();
|
|
|
|
for b in 0..batch {
|
|
for h in 0..heads {
|
|
// scores = Q @ K^T, scaled
|
|
let mut scores = vec![0.0f32; q_len * kv_len];
|
|
for i in 0..q_len {
|
|
for j in 0..kv_len {
|
|
let mut s = 0.0f32;
|
|
for d in 0..head_dim {
|
|
let qi = q[((b * heads + h) * q_len + i) * head_dim + d];
|
|
let ki = k[((b * heads + h) * kv_len + j) * head_dim + d];
|
|
s += qi * ki;
|
|
}
|
|
scores[i * kv_len + j] = s * scale;
|
|
}
|
|
}
|
|
// causal mask
|
|
if causal {
|
|
let offset = kv_len - q_len;
|
|
for i in 0..q_len {
|
|
for j in 0..kv_len {
|
|
if j > i + offset {
|
|
scores[i * kv_len + j] = f32::NEG_INFINITY;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// softmax per row
|
|
for i in 0..q_len {
|
|
let row = &mut scores[i * kv_len..(i + 1) * kv_len];
|
|
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
|
let mut sum = 0.0f32;
|
|
for v in row.iter_mut() {
|
|
*v = (*v - max).exp();
|
|
sum += *v;
|
|
}
|
|
for v in row.iter_mut() {
|
|
*v /= sum;
|
|
}
|
|
}
|
|
// output = weights @ V
|
|
for i in 0..q_len {
|
|
for d in 0..head_dim {
|
|
let mut s = 0.0f32;
|
|
for j in 0..kv_len {
|
|
let w = scores[i * kv_len + j];
|
|
let vi = v[((b * heads + h) * kv_len + j) * head_dim + d];
|
|
s += w * vi;
|
|
}
|
|
out[((b * heads + h) * q_len + i) * head_dim + d] = s;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
out
|
|
}
|
|
|
|
fn check_close(a: &[f32], b: &[f32], atol: f32, name: &str) {
|
|
assert_eq!(a.len(), b.len(), "{name}: length mismatch");
|
|
let mut max_err = 0.0f32;
|
|
for (i, (x, y)) in a.iter().zip(b).enumerate() {
|
|
let err = (x - y).abs();
|
|
if err > max_err {
|
|
max_err = err;
|
|
}
|
|
assert!(
|
|
err <= atol,
|
|
"{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}"
|
|
);
|
|
}
|
|
println!("{name}: max_err = {max_err:.6e}");
|
|
}
|
|
|
|
fn make_data(n: usize) -> Vec<f32> {
|
|
(0..n).map(|i| ((i % 17) as f32 - 8.0) * 0.05).collect()
|
|
}
|
|
|
|
#[test]
|
|
fn test_batched_matmul() {
|
|
init();
|
|
let batch = 4;
|
|
let heads = 8;
|
|
let m = 32;
|
|
let k = 64;
|
|
let n = 32;
|
|
|
|
let a_data = make_data(batch * heads * m * k);
|
|
let b_data = make_data(batch * heads * k * n);
|
|
|
|
let a = Tensor::from_slice(&a_data, &[batch, heads, m, k]).to_device(Device::Cuda(0));
|
|
let b = Tensor::from_slice(&b_data, &[batch, heads, k, n]).to_device(Device::Cuda(0));
|
|
let c = batched_matmul(&a, &b).to_device(Device::Cpu);
|
|
|
|
assert_eq!(c.shape(), &[batch, heads, m, n]);
|
|
|
|
// Verify one batch element
|
|
let a_cpu = &a_data[0..m * k];
|
|
let b_cpu = &b_data[0..k * n];
|
|
let mut expected = vec![0.0f32; m * n];
|
|
for i in 0..m {
|
|
for j in 0..n {
|
|
let mut s = 0.0f32;
|
|
for kk in 0..k {
|
|
s += a_cpu[i * k + kk] * b_cpu[kk * n + j];
|
|
}
|
|
expected[i * n + j] = s;
|
|
}
|
|
}
|
|
let result = c.as_slice::<f32>();
|
|
check_close(&result[0..m * n], &expected, 1e-3, "batched_matmul[0]");
|
|
}
|
|
|
|
#[test]
|
|
fn test_attention_no_causal() {
|
|
init();
|
|
let b = 1;
|
|
let h = 2;
|
|
let s = 8;
|
|
let d = 16;
|
|
let q_data = make_data(b * h * s * d);
|
|
let k_data = make_data(b * h * s * d);
|
|
let v_data = make_data(b * h * s * d);
|
|
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, false);
|
|
|
|
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let out = attention(&q, &k, &v, false).to_device(Device::Cpu);
|
|
check_close(
|
|
out.as_slice::<f32>(),
|
|
&expected,
|
|
1e-4,
|
|
"attention_no_causal",
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_attention_causal() {
|
|
init();
|
|
let b = 1;
|
|
let h = 2;
|
|
let s = 16;
|
|
let d = 32;
|
|
let q_data = make_data(b * h * s * d);
|
|
let k_data = make_data(b * h * s * d);
|
|
let v_data = make_data(b * h * s * d);
|
|
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true);
|
|
|
|
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
|
check_close(out.as_slice::<f32>(), &expected, 1e-3, "attention_causal");
|
|
}
|
|
|
|
#[test]
|
|
fn test_attention_causal_larger() {
|
|
init();
|
|
let b = 2;
|
|
let h = 4;
|
|
let s = 64;
|
|
let d = 64;
|
|
let q_data = make_data(b * h * s * d);
|
|
let k_data = make_data(b * h * s * d);
|
|
let v_data = make_data(b * h * s * d);
|
|
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true);
|
|
|
|
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
|
check_close(
|
|
out.as_slice::<f32>(),
|
|
&expected,
|
|
1e-2,
|
|
"attention_causal_larger",
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_attention_causal_first_row_sees_only_first_token() {
|
|
init();
|
|
let b = 1;
|
|
let h = 1;
|
|
let s = 4;
|
|
let d = 8;
|
|
let q_data = make_data(b * h * s * d);
|
|
let k_data = make_data(b * h * s * d);
|
|
let v_data: Vec<f32> = (0..s * d)
|
|
.map(|i| {
|
|
if i < d { 1.0 } else { 0.0 } // only first V row is nonzero
|
|
})
|
|
.collect();
|
|
|
|
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
|
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
|
|
|
// First row (position 0) with causal mask can only see position 0.
|
|
// So attention weight for position 0 is 1.0 for token 0 only.
|
|
// output[0] should be exactly V[0] = [1, 1, 1, ...1]
|
|
let result = out.as_slice::<f32>();
|
|
for i in 0..d {
|
|
assert!(
|
|
(result[i] - 1.0).abs() < 1e-5,
|
|
"first row should equal V[0], got {} at dim {}",
|
|
result[i],
|
|
i
|
|
);
|
|
}
|
|
}
|