Files
xserv/crates/xserv-kernels/tests/gemm_test.rs

207 lines
5.6 KiB
Rust

use half::bf16;
use xserv_kernels::{GemmBackend, matmul};
use xserv_tensor::{Device, Tensor};
fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
let mut c = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for kk in 0..k {
sum += a[i * k + kk] * b[kk * n + j];
}
c[i * n + j] = sum;
}
}
c
}
fn check_close_f32(result: &[f32], expected: &[f32], atol: f32) {
assert_eq!(result.len(), expected.len());
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
assert!(
(r - e).abs() <= atol,
"mismatch at index {i}: got {r}, expected {e}, diff {}",
(r - e).abs()
);
}
}
fn check_close_bf16(result: &[bf16], expected: &[f32], atol: f32) {
assert_eq!(result.len(), expected.len());
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
let rv = r.to_f32();
assert!(
(rv - e).abs() <= atol,
"mismatch at index {i}: got {rv}, expected {e}, diff {}",
(rv - e).abs()
);
}
}
fn run_gemm_test_f32(backend: GemmBackend, m: usize, n: usize, k: usize) {
xserv_cuda::device::set_device(0).unwrap();
let a_data: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
let b_data: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
let expected = cpu_matmul_f32(&a_data, &b_data, m, n, k);
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
let c = matmul(&a, &b, backend);
let c_cpu = c.to_device(Device::Cpu);
check_close_f32(c_cpu.as_slice::<f32>(), &expected, 1e-4);
}
fn run_gemm_test_bf16(backend: GemmBackend, m: usize, n: usize, k: usize) {
xserv_cuda::device::set_device(0).unwrap();
let a_f32: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
let b_f32: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
let expected = cpu_matmul_f32(&a_f32, &b_f32, m, n, k);
let a_data: Vec<bf16> = a_f32.iter().map(|&v| bf16::from_f32(v)).collect();
let b_data: Vec<bf16> = b_f32.iter().map(|&v| bf16::from_f32(v)).collect();
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
let c = matmul(&a, &b, backend);
let c_cpu = c.to_device(Device::Cpu);
check_close_bf16(c_cpu.as_slice::<bf16>(), &expected, 0.1);
}
// --- F32 tests ---
#[test]
fn test_gemm_naive_f32_small() {
run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4);
}
#[test]
fn test_gemm_naive_f32_medium() {
run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64);
}
#[test]
fn test_gemm_naive_f32_rect() {
run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48);
}
#[test]
fn test_gemm_tiled_f32_small() {
run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4);
}
#[test]
fn test_gemm_tiled_f32_medium() {
run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128);
}
#[test]
fn test_gemm_tiled_f32_rect() {
run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97);
}
#[test]
fn test_gemm_cublas_f32_small() {
run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4);
}
#[test]
fn test_gemm_cublas_f32_medium() {
run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256);
}
#[test]
fn test_gemm_cublas_f32_rect() {
run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97);
}
// --- BF16 tests ---
#[test]
fn test_gemm_naive_bf16_small() {
run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4);
}
#[test]
fn test_gemm_naive_bf16_medium() {
run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64);
}
#[test]
fn test_gemm_tiled_bf16_small() {
run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4);
}
#[test]
fn test_gemm_tiled_bf16_medium() {
run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128);
}
#[test]
fn test_gemm_cublas_bf16_small() {
run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4);
}
#[test]
fn test_gemm_cublas_bf16_medium() {
run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256);
}
// --- Custom GEMV tests (M=1, BF16 fast path) ---
#[test]
fn test_gemv_bf16_small() {
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 64, 64);
}
#[test]
fn test_gemv_bf16_medium() {
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 256, 256);
}
#[test]
fn test_gemv_bf16_4096() {
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 4096, 4096);
}
#[test]
fn test_gemv_bf16_rect() {
run_gemm_test_bf16(GemmBackend::CuBlas, 1, 512, 4096);
}
// --- Larger benchmark-style tests ---
#[test]
fn test_gemm_cublas_f32_1024() {
run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024);
}
#[test]
fn test_gemm_consistency_all_backends() {
xserv_cuda::device::set_device(0).unwrap();
let m = 64;
let n = 64;
let k = 64;
let a_data: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
let b_data: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
let c_naive = matmul(&a, &b, GemmBackend::Naive).to_device(Device::Cpu);
let c_tiled = matmul(&a, &b, GemmBackend::Tiled).to_device(Device::Cpu);
let c_cublas = matmul(&a, &b, GemmBackend::CuBlas).to_device(Device::Cpu);
let naive = c_naive.as_slice::<f32>();
let tiled = c_tiled.as_slice::<f32>();
let cublas = c_cublas.as_slice::<f32>();
check_close_f32(naive, cublas, 1e-4);
check_close_f32(tiled, cublas, 1e-4);
}