phase 4: transformer core kernels
CUDA kernels (csrc/): - common.cuh: shared warp_reduce_sum/max, block_reduce_sum/max - normalization/rmsnorm.cu: RMSNorm (F32 + BF16) - normalization/layernorm.cu: LayerNorm with Welford (F32 + BF16) - activation/activations.cu: GELU tanh-approx + SiLU (F32 + BF16) - reduce/softmax.cu: safe softmax, 3-pass (F32 + BF16) - embedding/embedding.cu: gather lookup (F32 + BF16) - embedding/rope.cu: RoPE in-place + precomputed cos/sin cache (F32 + BF16) Rust wrappers (xserv-kernels/src/): - rmsnorm.rs, layernorm.rs, activation.rs, softmax.rs, embedding.rs, rope.rs - RopeCache struct with GPU-side precomputation Tests: 12 new tests (ops_test.rs), all passing with good precision: - F32: max_err 1e-6 ~ 1e-9 - BF16: max_err 2e-3 ~ 7e-3 Total: 29 kernel tests + 27 prior = 56 tests passing Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -13,9 +13,16 @@ fn main() {
|
||||
.cuda(true)
|
||||
.cudart("shared")
|
||||
.flag("-gencode=arch=compute_120,code=sm_120")
|
||||
.include("../../csrc")
|
||||
.file("../../csrc/gemm/naive.cu")
|
||||
.file("../../csrc/gemm/tiled.cu")
|
||||
.compile("xserv_gemm_kernels");
|
||||
.file("../../csrc/normalization/rmsnorm.cu")
|
||||
.file("../../csrc/normalization/layernorm.cu")
|
||||
.file("../../csrc/activation/activations.cu")
|
||||
.file("../../csrc/reduce/softmax.cu")
|
||||
.file("../../csrc/embedding/embedding.cu")
|
||||
.file("../../csrc/embedding/rope.cu")
|
||||
.compile("xserv_kernels");
|
||||
|
||||
println!("cargo:rerun-if-changed=../../csrc/gemm/");
|
||||
println!("cargo:rerun-if-changed=../../csrc/");
|
||||
}
|
||||
|
||||
41
crates/xserv-kernels/src/activation.rs
Normal file
41
crates/xserv-kernels/src/activation.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_gelu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
pub fn gelu(x: &Tensor) -> Tensor {
|
||||
assert!(x.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel() as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_gelu_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
DType::BF16 => launch_gelu_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
_ => panic!("unsupported dtype for gelu"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
pub fn silu(x: &Tensor) -> Tensor {
|
||||
assert!(x.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel() as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_silu_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
DType::BF16 => launch_silu_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
_ => panic!("unsupported dtype for silu"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
51
crates/xserv-kernels/src/embedding.rs
Normal file
51
crates/xserv-kernels/src/embedding.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_embedding_f32(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, stream: *mut c_void);
|
||||
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
/// Embedding lookup: table[token_ids[i]] for each i.
|
||||
/// table: [vocab_size, hidden_size], token_ids: [num_tokens] (i32 on CPU)
|
||||
pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
||||
assert_eq!(table.ndim(), 2);
|
||||
assert!(table.is_contiguous());
|
||||
assert!(matches!(table.device(), Device::Cuda(_)));
|
||||
|
||||
let hidden_size = table.shape()[1];
|
||||
let num_tokens = token_ids.len();
|
||||
|
||||
// Upload token_ids to GPU
|
||||
let ids_bytes = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
token_ids.as_ptr() as *const u8,
|
||||
num_tokens * std::mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
let mut ids_gpu = GpuBuffer::alloc(ids_bytes.len()).expect("alloc token_ids");
|
||||
ids_gpu.copy_from_host(ids_bytes).unwrap();
|
||||
|
||||
let out = Tensor::zeros(&[num_tokens, hidden_size], table.dtype(), table.device());
|
||||
|
||||
unsafe {
|
||||
match table.dtype() {
|
||||
DType::F32 => launch_embedding_f32(
|
||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden_size as i32, std::ptr::null_mut(),
|
||||
),
|
||||
DType::BF16 => launch_embedding_bf16(
|
||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden_size as i32, std::ptr::null_mut(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for embedding"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
39
crates/xserv-kernels/src/layernorm.rs
Normal file
39
crates/xserv-kernels/src/layernorm.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_layernorm_f32(x: *const c_void, gamma: *const c_void, beta: *const c_void,
|
||||
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_layernorm_bf16(x: *const c_void, gamma: *const c_void, beta: *const c_void,
|
||||
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor {
|
||||
assert!(x.ndim() >= 1);
|
||||
assert!(x.is_contiguous() && gamma.is_contiguous() && beta.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let hidden_size = *x.shape().last().unwrap();
|
||||
assert_eq!(gamma.shape(), &[hidden_size]);
|
||||
assert_eq!(beta.shape(), &[hidden_size]);
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_layernorm_f32(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
),
|
||||
DType::BF16 => launch_layernorm_bf16(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for layernorm"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
@@ -1,3 +1,15 @@
|
||||
pub mod activation;
|
||||
pub mod embedding;
|
||||
pub mod gemm;
|
||||
pub mod layernorm;
|
||||
pub mod rmsnorm;
|
||||
pub mod rope;
|
||||
pub mod softmax;
|
||||
|
||||
pub use gemm::{GemmBackend, matmul};
|
||||
pub use activation::{gelu, silu};
|
||||
pub use embedding::embedding;
|
||||
pub use gemm::{matmul, GemmBackend};
|
||||
pub use layernorm::layernorm;
|
||||
pub use rmsnorm::rmsnorm;
|
||||
pub use rope::{rope_inplace, RopeCache};
|
||||
pub use softmax::softmax;
|
||||
|
||||
37
crates/xserv-kernels/src/rmsnorm.rs
Normal file
37
crates/xserv-kernels/src/rmsnorm.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_rmsnorm_f32(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
||||
assert!(x.ndim() >= 1);
|
||||
assert!(x.is_contiguous() && gamma.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let hidden_size = *x.shape().last().unwrap();
|
||||
assert_eq!(gamma.shape(), &[hidden_size]);
|
||||
assert_eq!(x.dtype(), gamma.dtype());
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_rmsnorm_f32(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
),
|
||||
DType::BF16 => launch_rmsnorm_bf16(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for rmsnorm"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
85
crates/xserv-kernels/src/rope.rs
Normal file
85
crates/xserv-kernels/src/rope.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_rope_f32(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void);
|
||||
fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void);
|
||||
fn launch_compute_rope_cache(cos_cache: *mut c_void, sin_cache: *mut c_void,
|
||||
max_seq_len: i32, half_dim: i32, theta: f32,
|
||||
stream: *mut c_void);
|
||||
}
|
||||
|
||||
pub struct RopeCache {
|
||||
pub cos: GpuBuffer,
|
||||
pub sin: GpuBuffer,
|
||||
pub max_seq_len: usize,
|
||||
pub half_dim: usize,
|
||||
}
|
||||
|
||||
impl RopeCache {
|
||||
pub fn new(max_seq_len: usize, head_dim: usize, theta: f32) -> Self {
|
||||
let half_dim = head_dim / 2;
|
||||
let nbytes = max_seq_len * half_dim * std::mem::size_of::<f32>();
|
||||
let mut cos = GpuBuffer::alloc(nbytes).expect("alloc cos_cache");
|
||||
let mut sin = GpuBuffer::alloc(nbytes).expect("alloc sin_cache");
|
||||
|
||||
unsafe {
|
||||
launch_compute_rope_cache(
|
||||
cos.as_mut_ptr() as _, sin.as_mut_ptr() as _,
|
||||
max_seq_len as i32, half_dim as i32, theta, std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
|
||||
Self { cos, sin, max_seq_len, half_dim }
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply RoPE in-place to x.
|
||||
/// x: [num_tokens, num_heads, head_dim] on GPU
|
||||
/// positions: [num_tokens] (u32 on CPU, will be uploaded)
|
||||
pub fn rope_inplace(x: &Tensor, cache: &RopeCache, positions: &[u32]) {
|
||||
assert_eq!(x.ndim(), 3);
|
||||
assert!(x.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let num_tokens = x.shape()[0];
|
||||
let num_heads = x.shape()[1];
|
||||
let head_dim = x.shape()[2];
|
||||
assert_eq!(head_dim / 2, cache.half_dim);
|
||||
assert_eq!(positions.len(), num_tokens);
|
||||
|
||||
let pos_bytes = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
positions.as_ptr() as *const u8,
|
||||
num_tokens * std::mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
let mut pos_gpu = GpuBuffer::alloc(pos_bytes.len()).expect("alloc positions");
|
||||
pos_gpu.copy_from_host(pos_bytes).unwrap();
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_rope_f32(
|
||||
x.data_ptr() as *mut c_void,
|
||||
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
||||
pos_gpu.as_ptr() as _,
|
||||
num_tokens as i32, num_heads as i32, head_dim as i32,
|
||||
std::ptr::null_mut(),
|
||||
),
|
||||
DType::BF16 => launch_rope_bf16(
|
||||
x.data_ptr() as *mut c_void,
|
||||
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
||||
pos_gpu.as_ptr() as _,
|
||||
num_tokens as i32, num_heads as i32, head_dim as i32,
|
||||
std::ptr::null_mut(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for rope"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
34
crates/xserv-kernels/src/softmax.rs
Normal file
34
crates/xserv-kernels/src/softmax.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_softmax_f32(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
|
||||
fn launch_softmax_bf16(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
/// Softmax along the last dimension.
|
||||
pub fn softmax(x: &Tensor) -> Tensor {
|
||||
assert!(x.ndim() >= 1);
|
||||
assert!(x.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
|
||||
let cols = *x.shape().last().unwrap();
|
||||
let rows = x.numel() / cols;
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_softmax_f32(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32, std::ptr::null_mut(),
|
||||
),
|
||||
DType::BF16 => launch_softmax_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32, std::ptr::null_mut(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for softmax"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
302
crates/xserv-kernels/tests/ops_test.rs
Normal file
302
crates/xserv-kernels/tests/ops_test.rs
Normal file
@@ -0,0 +1,302 @@
|
||||
use half::bf16;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
|
||||
|
||||
// --- CPU reference implementations ---
|
||||
|
||||
fn cpu_rmsnorm(x: &[f32], gamma: &[f32], eps: f32, hidden: usize) -> Vec<f32> {
|
||||
let rows = x.len() / hidden;
|
||||
let mut out = vec![0.0f32; x.len()];
|
||||
for r in 0..rows {
|
||||
let row = &x[r * hidden..(r + 1) * hidden];
|
||||
let sum_sq: f32 = row.iter().map(|v| v * v).sum();
|
||||
let rms_inv = 1.0 / (sum_sq / hidden as f32 + eps).sqrt();
|
||||
for i in 0..hidden {
|
||||
out[r * hidden + i] = row[i] * rms_inv * gamma[i];
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn cpu_layernorm(x: &[f32], gamma: &[f32], beta: &[f32], eps: f32, hidden: usize) -> Vec<f32> {
|
||||
let rows = x.len() / hidden;
|
||||
let mut out = vec![0.0f32; x.len()];
|
||||
for r in 0..rows {
|
||||
let row = &x[r * hidden..(r + 1) * hidden];
|
||||
let mean: f32 = row.iter().sum::<f32>() / hidden as f32;
|
||||
let var: f32 = row.iter().map(|v| (v - mean) * (v - mean)).sum::<f32>() / hidden as f32;
|
||||
let inv_std = 1.0 / (var + eps).sqrt();
|
||||
for i in 0..hidden {
|
||||
out[r * hidden + i] = gamma[i] * (row[i] - mean) * inv_std + beta[i];
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn cpu_gelu(x: &[f32]) -> Vec<f32> {
|
||||
let sqrt_2_over_pi = 0.7978845608f32;
|
||||
x.iter().map(|&v| {
|
||||
let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
|
||||
0.5 * v * (1.0 + inner.tanh())
|
||||
}).collect()
|
||||
}
|
||||
|
||||
fn cpu_silu(x: &[f32]) -> Vec<f32> {
|
||||
x.iter().map(|&v| v / (1.0 + (-v).exp())).collect()
|
||||
}
|
||||
|
||||
fn cpu_softmax(x: &[f32], cols: usize) -> Vec<f32> {
|
||||
let rows = x.len() / cols;
|
||||
let mut out = vec![0.0f32; x.len()];
|
||||
for r in 0..rows {
|
||||
let row = &x[r * cols..(r + 1) * cols];
|
||||
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exps: Vec<f32> = row.iter().map(|v| (v - max).exp()).collect();
|
||||
let sum: f32 = exps.iter().sum();
|
||||
for i in 0..cols {
|
||||
out[r * cols + i] = exps[i] / sum;
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn cpu_rope(x: &mut [f32], positions: &[u32], num_heads: usize, head_dim: usize, theta: f32) {
|
||||
let half_dim = head_dim / 2;
|
||||
let num_tokens = positions.len();
|
||||
for t in 0..num_tokens {
|
||||
let pos = positions[t] as f32;
|
||||
for h in 0..num_heads {
|
||||
for i in 0..half_dim {
|
||||
let freq = 1.0 / theta.powf(2.0 * i as f32 / head_dim as f32);
|
||||
let angle = pos * freq;
|
||||
let cos_val = angle.cos();
|
||||
let sin_val = angle.sin();
|
||||
let base = (t * num_heads + h) * head_dim;
|
||||
let x0 = x[base + 2 * i];
|
||||
let x1 = x[base + 2 * i + 1];
|
||||
x[base + 2 * i] = x0 * cos_val - x1 * sin_val;
|
||||
x[base + 2 * i + 1] = x0 * sin_val + x1 * cos_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn check_close(result: &[f32], expected: &[f32], atol: f32, name: &str) {
|
||||
assert_eq!(result.len(), expected.len(), "{name}: length mismatch");
|
||||
let mut max_err = 0.0f32;
|
||||
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
|
||||
let err = (r - e).abs();
|
||||
if err > max_err { max_err = err; }
|
||||
assert!(err <= atol, "{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}");
|
||||
}
|
||||
println!("{name}: max_err = {max_err:.6e}");
|
||||
}
|
||||
|
||||
fn make_data(n: usize) -> Vec<f32> {
|
||||
(0..n).map(|i| ((i % 17) as f32 - 8.0) * 0.1).collect()
|
||||
}
|
||||
|
||||
// === RMSNorm ===
|
||||
|
||||
#[test]
|
||||
fn test_rmsnorm_f32() {
|
||||
init();
|
||||
let hidden = 768;
|
||||
let rows = 4;
|
||||
let x_data = make_data(rows * hidden);
|
||||
let gamma_data: Vec<f32> = (0..hidden).map(|i| 0.5 + (i % 3) as f32 * 0.2).collect();
|
||||
let expected = cpu_rmsnorm(&x_data, &gamma_data, 1e-5, hidden);
|
||||
|
||||
let x = Tensor::from_slice(&x_data, &[rows, hidden]).to_device(Device::Cuda(0));
|
||||
let gamma = Tensor::from_slice(&gamma_data, &[hidden]).to_device(Device::Cuda(0));
|
||||
let out = rmsnorm(&x, &gamma, 1e-5).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-4, "rmsnorm_f32");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rmsnorm_bf16() {
|
||||
init();
|
||||
let hidden = 768;
|
||||
let rows = 4;
|
||||
let x_f32 = make_data(rows * hidden);
|
||||
let gamma_f32: Vec<f32> = (0..hidden).map(|i| 0.5 + (i % 3) as f32 * 0.2).collect();
|
||||
let expected = cpu_rmsnorm(&x_f32, &gamma_f32, 1e-5, hidden);
|
||||
|
||||
let x_bf16: Vec<bf16> = x_f32.iter().map(|&v| bf16::from_f32(v)).collect();
|
||||
let gamma_bf16: Vec<bf16> = gamma_f32.iter().map(|&v| bf16::from_f32(v)).collect();
|
||||
let x = Tensor::from_slice(&x_bf16, &[rows, hidden]).to_device(Device::Cuda(0));
|
||||
let gamma = Tensor::from_slice(&gamma_bf16, &[hidden]).to_device(Device::Cuda(0));
|
||||
let out = rmsnorm(&x, &gamma, 1e-5).to_device(Device::Cpu);
|
||||
|
||||
let result: Vec<f32> = out.as_slice::<bf16>().iter().map(|v| v.to_f32()).collect();
|
||||
check_close(&result, &expected, 0.05, "rmsnorm_bf16");
|
||||
}
|
||||
|
||||
// === LayerNorm ===
|
||||
|
||||
#[test]
|
||||
fn test_layernorm_f32() {
|
||||
init();
|
||||
let hidden = 768;
|
||||
let rows = 4;
|
||||
let x_data = make_data(rows * hidden);
|
||||
let gamma_data: Vec<f32> = (0..hidden).map(|i| 0.8 + (i % 5) as f32 * 0.1).collect();
|
||||
let beta_data: Vec<f32> = (0..hidden).map(|i| ((i % 7) as f32 - 3.0) * 0.01).collect();
|
||||
let expected = cpu_layernorm(&x_data, &gamma_data, &beta_data, 1e-5, hidden);
|
||||
|
||||
let x = Tensor::from_slice(&x_data, &[rows, hidden]).to_device(Device::Cuda(0));
|
||||
let gamma = Tensor::from_slice(&gamma_data, &[hidden]).to_device(Device::Cuda(0));
|
||||
let beta = Tensor::from_slice(&beta_data, &[hidden]).to_device(Device::Cuda(0));
|
||||
let out = layernorm(&x, &gamma, &beta, 1e-5).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-4, "layernorm_f32");
|
||||
}
|
||||
|
||||
// === GELU ===
|
||||
|
||||
#[test]
|
||||
fn test_gelu_f32() {
|
||||
init();
|
||||
let data = make_data(10000);
|
||||
let expected = cpu_gelu(&data);
|
||||
let x = Tensor::from_slice(&data, &[10000]).to_device(Device::Cuda(0));
|
||||
let out = gelu(&x).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-5, "gelu_f32");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gelu_bf16() {
|
||||
init();
|
||||
let data_f32 = make_data(10000);
|
||||
let expected = cpu_gelu(&data_f32);
|
||||
let data_bf16: Vec<bf16> = data_f32.iter().map(|&v| bf16::from_f32(v)).collect();
|
||||
let x = Tensor::from_slice(&data_bf16, &[10000]).to_device(Device::Cuda(0));
|
||||
let out = gelu(&x).to_device(Device::Cpu);
|
||||
let result: Vec<f32> = out.as_slice::<bf16>().iter().map(|v| v.to_f32()).collect();
|
||||
check_close(&result, &expected, 0.02, "gelu_bf16");
|
||||
}
|
||||
|
||||
// === SiLU ===
|
||||
|
||||
#[test]
|
||||
fn test_silu_f32() {
|
||||
init();
|
||||
let data = make_data(10000);
|
||||
let expected = cpu_silu(&data);
|
||||
let x = Tensor::from_slice(&data, &[10000]).to_device(Device::Cuda(0));
|
||||
let out = silu(&x).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-5, "silu_f32");
|
||||
}
|
||||
|
||||
// === Softmax ===
|
||||
|
||||
#[test]
|
||||
fn test_softmax_f32() {
|
||||
init();
|
||||
let rows = 8;
|
||||
let cols = 256;
|
||||
let data = make_data(rows * cols);
|
||||
let expected = cpu_softmax(&data, cols);
|
||||
let x = Tensor::from_slice(&data, &[rows, cols]).to_device(Device::Cuda(0));
|
||||
let out = softmax(&x).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-5, "softmax_f32");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_sum_to_one() {
|
||||
init();
|
||||
let rows = 4;
|
||||
let cols = 2048;
|
||||
let data: Vec<f32> = (0..rows * cols).map(|i| ((i % 31) as f32 - 15.0) * 0.5).collect();
|
||||
let x = Tensor::from_slice(&data, &[rows, cols]).to_device(Device::Cuda(0));
|
||||
let out = softmax(&x).to_device(Device::Cpu);
|
||||
let result = out.as_slice::<f32>();
|
||||
for r in 0..rows {
|
||||
let row_sum: f32 = result[r * cols..(r + 1) * cols].iter().sum();
|
||||
assert!((row_sum - 1.0).abs() < 1e-5, "softmax row {r} sum = {row_sum}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_large_values() {
|
||||
init();
|
||||
let data = vec![1000.0f32, 1001.0, 999.0, 1000.5];
|
||||
let expected = cpu_softmax(&data, 4);
|
||||
let x = Tensor::from_slice(&data, &[1, 4]).to_device(Device::Cuda(0));
|
||||
let out = softmax(&x).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-5, "softmax_large");
|
||||
}
|
||||
|
||||
// === Embedding ===
|
||||
|
||||
#[test]
|
||||
fn test_embedding_f32() {
|
||||
init();
|
||||
let vocab_size = 100;
|
||||
let hidden = 64;
|
||||
let table_data: Vec<f32> = (0..vocab_size * hidden).map(|i| i as f32 * 0.01).collect();
|
||||
let token_ids: Vec<u32> = vec![0, 5, 99, 42, 1];
|
||||
|
||||
let table = Tensor::from_slice(&table_data, &[vocab_size, hidden]).to_device(Device::Cuda(0));
|
||||
let out = embedding(&table, &token_ids).to_device(Device::Cpu);
|
||||
|
||||
assert_eq!(out.shape(), &[5, hidden]);
|
||||
let result = out.as_slice::<f32>();
|
||||
for (seq_idx, &tid) in token_ids.iter().enumerate() {
|
||||
for i in 0..hidden {
|
||||
let expected = table_data[tid as usize * hidden + i];
|
||||
let got = result[seq_idx * hidden + i];
|
||||
assert!((got - expected).abs() < 1e-6,
|
||||
"embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === RoPE ===
|
||||
|
||||
#[test]
|
||||
fn test_rope_f32() {
|
||||
init();
|
||||
let num_tokens = 4;
|
||||
let num_heads = 2;
|
||||
let head_dim = 8;
|
||||
let theta = 10000.0f32;
|
||||
let positions: Vec<u32> = vec![0, 1, 2, 3];
|
||||
|
||||
let x_data: Vec<f32> = (0..num_tokens * num_heads * head_dim)
|
||||
.map(|i| ((i % 13) as f32 - 6.0) * 0.1)
|
||||
.collect();
|
||||
let mut expected = x_data.clone();
|
||||
cpu_rope(&mut expected, &positions, num_heads, head_dim, theta);
|
||||
|
||||
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
|
||||
.to_device(Device::Cuda(0));
|
||||
let cache = RopeCache::new(64, head_dim, theta);
|
||||
rope_inplace(&x, &cache, &positions);
|
||||
|
||||
let out = x.to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-4, "rope_f32");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rope_position_0_identity() {
|
||||
init();
|
||||
// At position 0, all angles are 0, so cos=1, sin=0 → identity transform
|
||||
let num_tokens = 1;
|
||||
let num_heads = 2;
|
||||
let head_dim = 8;
|
||||
let positions: Vec<u32> = vec![0];
|
||||
|
||||
let x_data: Vec<f32> = (0..num_tokens * num_heads * head_dim)
|
||||
.map(|i| (i as f32 + 1.0) * 0.1)
|
||||
.collect();
|
||||
|
||||
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
|
||||
.to_device(Device::Cuda(0));
|
||||
let cache = RopeCache::new(64, head_dim, 10000.0);
|
||||
rope_inplace(&x, &cache, &positions);
|
||||
|
||||
let out = x.to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &x_data, 1e-6, "rope_pos0");
|
||||
}
|
||||
Reference in New Issue
Block a user