phase 4: transformer core kernels

CUDA kernels (csrc/):
- common.cuh: shared warp_reduce_sum/max, block_reduce_sum/max
- normalization/rmsnorm.cu: RMSNorm (F32 + BF16)
- normalization/layernorm.cu: LayerNorm with Welford (F32 + BF16)
- activation/activations.cu: GELU tanh-approx + SiLU (F32 + BF16)
- reduce/softmax.cu: safe softmax, 3-pass (F32 + BF16)
- embedding/embedding.cu: gather lookup (F32 + BF16)
- embedding/rope.cu: RoPE in-place + precomputed cos/sin cache (F32 + BF16)

Rust wrappers (xserv-kernels/src/):
- rmsnorm.rs, layernorm.rs, activation.rs, softmax.rs, embedding.rs, rope.rs
- RopeCache struct with GPU-side precomputation

Tests: 12 new tests (ops_test.rs), all passing with good precision:
- F32: max_err 1e-6 ~ 1e-9
- BF16: max_err 2e-3 ~ 7e-3
Total: 29 kernel tests + 27 prior = 56 tests passing

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 21:07:24 +08:00
parent 51a0f2eb14
commit c8e8153702
17 changed files with 1402 additions and 3 deletions

View File

@@ -13,9 +13,16 @@ fn main() {
.cuda(true)
.cudart("shared")
.flag("-gencode=arch=compute_120,code=sm_120")
.include("../../csrc")
.file("../../csrc/gemm/naive.cu")
.file("../../csrc/gemm/tiled.cu")
.compile("xserv_gemm_kernels");
.file("../../csrc/normalization/rmsnorm.cu")
.file("../../csrc/normalization/layernorm.cu")
.file("../../csrc/activation/activations.cu")
.file("../../csrc/reduce/softmax.cu")
.file("../../csrc/embedding/embedding.cu")
.file("../../csrc/embedding/rope.cu")
.compile("xserv_kernels");
println!("cargo:rerun-if-changed=../../csrc/gemm/");
println!("cargo:rerun-if-changed=../../csrc/");
}

View File

@@ -0,0 +1,41 @@
use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_gelu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
}
pub fn gelu(x: &Tensor) -> Tensor {
assert!(x.is_contiguous());
assert!(matches!(x.device(), Device::Cuda(_)));
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
let n = x.numel() as i32;
unsafe {
match x.dtype() {
DType::F32 => launch_gelu_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
DType::BF16 => launch_gelu_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
_ => panic!("unsupported dtype for gelu"),
}
}
xserv_cuda::device::synchronize().unwrap();
out
}
pub fn silu(x: &Tensor) -> Tensor {
assert!(x.is_contiguous());
assert!(matches!(x.device(), Device::Cuda(_)));
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
let n = x.numel() as i32;
unsafe {
match x.dtype() {
DType::F32 => launch_silu_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
DType::BF16 => launch_silu_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
_ => panic!("unsupported dtype for silu"),
}
}
xserv_cuda::device::synchronize().unwrap();
out
}

View File

@@ -0,0 +1,51 @@
use std::ffi::c_void;
use xserv_cuda::GpuBuffer;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_embedding_f32(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
num_tokens: i32, hidden_size: i32, stream: *mut c_void);
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
num_tokens: i32, hidden_size: i32, stream: *mut c_void);
}
/// Embedding lookup: table[token_ids[i]] for each i.
/// table: [vocab_size, hidden_size], token_ids: [num_tokens] (i32 on CPU)
pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
assert_eq!(table.ndim(), 2);
assert!(table.is_contiguous());
assert!(matches!(table.device(), Device::Cuda(_)));
let hidden_size = table.shape()[1];
let num_tokens = token_ids.len();
// Upload token_ids to GPU
let ids_bytes = unsafe {
std::slice::from_raw_parts(
token_ids.as_ptr() as *const u8,
num_tokens * std::mem::size_of::<u32>(),
)
};
let mut ids_gpu = GpuBuffer::alloc(ids_bytes.len()).expect("alloc token_ids");
ids_gpu.copy_from_host(ids_bytes).unwrap();
let out = Tensor::zeros(&[num_tokens, hidden_size], table.dtype(), table.device());
unsafe {
match table.dtype() {
DType::F32 => launch_embedding_f32(
table.data_ptr() as _, ids_gpu.as_ptr() as _,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden_size as i32, std::ptr::null_mut(),
),
DType::BF16 => launch_embedding_bf16(
table.data_ptr() as _, ids_gpu.as_ptr() as _,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden_size as i32, std::ptr::null_mut(),
),
_ => panic!("unsupported dtype for embedding"),
}
}
xserv_cuda::device::synchronize().unwrap();
out
}

View File

@@ -0,0 +1,39 @@
use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_layernorm_f32(x: *const c_void, gamma: *const c_void, beta: *const c_void,
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_layernorm_bf16(x: *const c_void, gamma: *const c_void, beta: *const c_void,
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
}
pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor {
assert!(x.ndim() >= 1);
assert!(x.is_contiguous() && gamma.is_contiguous() && beta.is_contiguous());
assert!(matches!(x.device(), Device::Cuda(_)));
let hidden_size = *x.shape().last().unwrap();
assert_eq!(gamma.shape(), &[hidden_size]);
assert_eq!(beta.shape(), &[hidden_size]);
let rows = x.numel() / hidden_size;
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
unsafe {
match x.dtype() {
DType::F32 => launch_layernorm_f32(
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
),
DType::BF16 => launch_layernorm_bf16(
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
),
_ => panic!("unsupported dtype for layernorm"),
}
}
xserv_cuda::device::synchronize().unwrap();
out
}

View File

@@ -1,3 +1,15 @@
pub mod activation;
pub mod embedding;
pub mod gemm;
pub mod layernorm;
pub mod rmsnorm;
pub mod rope;
pub mod softmax;
pub use gemm::{GemmBackend, matmul};
pub use activation::{gelu, silu};
pub use embedding::embedding;
pub use gemm::{matmul, GemmBackend};
pub use layernorm::layernorm;
pub use rmsnorm::rmsnorm;
pub use rope::{rope_inplace, RopeCache};
pub use softmax::softmax;

View File

@@ -0,0 +1,37 @@
use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_rmsnorm_f32(x: *const c_void, gamma: *const c_void, out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
}
pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
assert!(x.ndim() >= 1);
assert!(x.is_contiguous() && gamma.is_contiguous());
assert!(matches!(x.device(), Device::Cuda(_)));
let hidden_size = *x.shape().last().unwrap();
assert_eq!(gamma.shape(), &[hidden_size]);
assert_eq!(x.dtype(), gamma.dtype());
let rows = x.numel() / hidden_size;
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
unsafe {
match x.dtype() {
DType::F32 => launch_rmsnorm_f32(
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
),
DType::BF16 => launch_rmsnorm_bf16(
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
),
_ => panic!("unsupported dtype for rmsnorm"),
}
}
xserv_cuda::device::synchronize().unwrap();
out
}

View File

@@ -0,0 +1,85 @@
use std::ffi::c_void;
use xserv_cuda::GpuBuffer;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_rope_f32(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
positions: *const c_void, num_tokens: i32, num_heads: i32,
head_dim: i32, stream: *mut c_void);
fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
positions: *const c_void, num_tokens: i32, num_heads: i32,
head_dim: i32, stream: *mut c_void);
fn launch_compute_rope_cache(cos_cache: *mut c_void, sin_cache: *mut c_void,
max_seq_len: i32, half_dim: i32, theta: f32,
stream: *mut c_void);
}
pub struct RopeCache {
pub cos: GpuBuffer,
pub sin: GpuBuffer,
pub max_seq_len: usize,
pub half_dim: usize,
}
impl RopeCache {
pub fn new(max_seq_len: usize, head_dim: usize, theta: f32) -> Self {
let half_dim = head_dim / 2;
let nbytes = max_seq_len * half_dim * std::mem::size_of::<f32>();
let mut cos = GpuBuffer::alloc(nbytes).expect("alloc cos_cache");
let mut sin = GpuBuffer::alloc(nbytes).expect("alloc sin_cache");
unsafe {
launch_compute_rope_cache(
cos.as_mut_ptr() as _, sin.as_mut_ptr() as _,
max_seq_len as i32, half_dim as i32, theta, std::ptr::null_mut(),
);
}
xserv_cuda::device::synchronize().unwrap();
Self { cos, sin, max_seq_len, half_dim }
}
}
/// Apply RoPE in-place to x.
/// x: [num_tokens, num_heads, head_dim] on GPU
/// positions: [num_tokens] (u32 on CPU, will be uploaded)
pub fn rope_inplace(x: &Tensor, cache: &RopeCache, positions: &[u32]) {
assert_eq!(x.ndim(), 3);
assert!(x.is_contiguous());
assert!(matches!(x.device(), Device::Cuda(_)));
let num_tokens = x.shape()[0];
let num_heads = x.shape()[1];
let head_dim = x.shape()[2];
assert_eq!(head_dim / 2, cache.half_dim);
assert_eq!(positions.len(), num_tokens);
let pos_bytes = unsafe {
std::slice::from_raw_parts(
positions.as_ptr() as *const u8,
num_tokens * std::mem::size_of::<u32>(),
)
};
let mut pos_gpu = GpuBuffer::alloc(pos_bytes.len()).expect("alloc positions");
pos_gpu.copy_from_host(pos_bytes).unwrap();
unsafe {
match x.dtype() {
DType::F32 => launch_rope_f32(
x.data_ptr() as *mut c_void,
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
pos_gpu.as_ptr() as _,
num_tokens as i32, num_heads as i32, head_dim as i32,
std::ptr::null_mut(),
),
DType::BF16 => launch_rope_bf16(
x.data_ptr() as *mut c_void,
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
pos_gpu.as_ptr() as _,
num_tokens as i32, num_heads as i32, head_dim as i32,
std::ptr::null_mut(),
),
_ => panic!("unsupported dtype for rope"),
}
}
xserv_cuda::device::synchronize().unwrap();
}

View File

@@ -0,0 +1,34 @@
use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_softmax_f32(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
fn launch_softmax_bf16(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
}
/// Softmax along the last dimension.
pub fn softmax(x: &Tensor) -> Tensor {
assert!(x.ndim() >= 1);
assert!(x.is_contiguous());
assert!(matches!(x.device(), Device::Cuda(_)));
let cols = *x.shape().last().unwrap();
let rows = x.numel() / cols;
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
unsafe {
match x.dtype() {
DType::F32 => launch_softmax_f32(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, cols as i32, std::ptr::null_mut(),
),
DType::BF16 => launch_softmax_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, cols as i32, std::ptr::null_mut(),
),
_ => panic!("unsupported dtype for softmax"),
}
}
xserv_cuda::device::synchronize().unwrap();
out
}

View File

@@ -0,0 +1,302 @@
use half::bf16;
use xserv_kernels::*;
use xserv_tensor::{Device, Tensor};
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
// --- CPU reference implementations ---
fn cpu_rmsnorm(x: &[f32], gamma: &[f32], eps: f32, hidden: usize) -> Vec<f32> {
let rows = x.len() / hidden;
let mut out = vec![0.0f32; x.len()];
for r in 0..rows {
let row = &x[r * hidden..(r + 1) * hidden];
let sum_sq: f32 = row.iter().map(|v| v * v).sum();
let rms_inv = 1.0 / (sum_sq / hidden as f32 + eps).sqrt();
for i in 0..hidden {
out[r * hidden + i] = row[i] * rms_inv * gamma[i];
}
}
out
}
fn cpu_layernorm(x: &[f32], gamma: &[f32], beta: &[f32], eps: f32, hidden: usize) -> Vec<f32> {
let rows = x.len() / hidden;
let mut out = vec![0.0f32; x.len()];
for r in 0..rows {
let row = &x[r * hidden..(r + 1) * hidden];
let mean: f32 = row.iter().sum::<f32>() / hidden as f32;
let var: f32 = row.iter().map(|v| (v - mean) * (v - mean)).sum::<f32>() / hidden as f32;
let inv_std = 1.0 / (var + eps).sqrt();
for i in 0..hidden {
out[r * hidden + i] = gamma[i] * (row[i] - mean) * inv_std + beta[i];
}
}
out
}
fn cpu_gelu(x: &[f32]) -> Vec<f32> {
let sqrt_2_over_pi = 0.7978845608f32;
x.iter().map(|&v| {
let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
0.5 * v * (1.0 + inner.tanh())
}).collect()
}
fn cpu_silu(x: &[f32]) -> Vec<f32> {
x.iter().map(|&v| v / (1.0 + (-v).exp())).collect()
}
fn cpu_softmax(x: &[f32], cols: usize) -> Vec<f32> {
let rows = x.len() / cols;
let mut out = vec![0.0f32; x.len()];
for r in 0..rows {
let row = &x[r * cols..(r + 1) * cols];
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = row.iter().map(|v| (v - max).exp()).collect();
let sum: f32 = exps.iter().sum();
for i in 0..cols {
out[r * cols + i] = exps[i] / sum;
}
}
out
}
fn cpu_rope(x: &mut [f32], positions: &[u32], num_heads: usize, head_dim: usize, theta: f32) {
let half_dim = head_dim / 2;
let num_tokens = positions.len();
for t in 0..num_tokens {
let pos = positions[t] as f32;
for h in 0..num_heads {
for i in 0..half_dim {
let freq = 1.0 / theta.powf(2.0 * i as f32 / head_dim as f32);
let angle = pos * freq;
let cos_val = angle.cos();
let sin_val = angle.sin();
let base = (t * num_heads + h) * head_dim;
let x0 = x[base + 2 * i];
let x1 = x[base + 2 * i + 1];
x[base + 2 * i] = x0 * cos_val - x1 * sin_val;
x[base + 2 * i + 1] = x0 * sin_val + x1 * cos_val;
}
}
}
}
fn check_close(result: &[f32], expected: &[f32], atol: f32, name: &str) {
assert_eq!(result.len(), expected.len(), "{name}: length mismatch");
let mut max_err = 0.0f32;
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
let err = (r - e).abs();
if err > max_err { max_err = err; }
assert!(err <= atol, "{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}");
}
println!("{name}: max_err = {max_err:.6e}");
}
fn make_data(n: usize) -> Vec<f32> {
(0..n).map(|i| ((i % 17) as f32 - 8.0) * 0.1).collect()
}
// === RMSNorm ===
#[test]
fn test_rmsnorm_f32() {
init();
let hidden = 768;
let rows = 4;
let x_data = make_data(rows * hidden);
let gamma_data: Vec<f32> = (0..hidden).map(|i| 0.5 + (i % 3) as f32 * 0.2).collect();
let expected = cpu_rmsnorm(&x_data, &gamma_data, 1e-5, hidden);
let x = Tensor::from_slice(&x_data, &[rows, hidden]).to_device(Device::Cuda(0));
let gamma = Tensor::from_slice(&gamma_data, &[hidden]).to_device(Device::Cuda(0));
let out = rmsnorm(&x, &gamma, 1e-5).to_device(Device::Cpu);
check_close(out.as_slice::<f32>(), &expected, 1e-4, "rmsnorm_f32");
}
#[test]
fn test_rmsnorm_bf16() {
init();
let hidden = 768;
let rows = 4;
let x_f32 = make_data(rows * hidden);
let gamma_f32: Vec<f32> = (0..hidden).map(|i| 0.5 + (i % 3) as f32 * 0.2).collect();
let expected = cpu_rmsnorm(&x_f32, &gamma_f32, 1e-5, hidden);
let x_bf16: Vec<bf16> = x_f32.iter().map(|&v| bf16::from_f32(v)).collect();
let gamma_bf16: Vec<bf16> = gamma_f32.iter().map(|&v| bf16::from_f32(v)).collect();
let x = Tensor::from_slice(&x_bf16, &[rows, hidden]).to_device(Device::Cuda(0));
let gamma = Tensor::from_slice(&gamma_bf16, &[hidden]).to_device(Device::Cuda(0));
let out = rmsnorm(&x, &gamma, 1e-5).to_device(Device::Cpu);
let result: Vec<f32> = out.as_slice::<bf16>().iter().map(|v| v.to_f32()).collect();
check_close(&result, &expected, 0.05, "rmsnorm_bf16");
}
// === LayerNorm ===
#[test]
fn test_layernorm_f32() {
init();
let hidden = 768;
let rows = 4;
let x_data = make_data(rows * hidden);
let gamma_data: Vec<f32> = (0..hidden).map(|i| 0.8 + (i % 5) as f32 * 0.1).collect();
let beta_data: Vec<f32> = (0..hidden).map(|i| ((i % 7) as f32 - 3.0) * 0.01).collect();
let expected = cpu_layernorm(&x_data, &gamma_data, &beta_data, 1e-5, hidden);
let x = Tensor::from_slice(&x_data, &[rows, hidden]).to_device(Device::Cuda(0));
let gamma = Tensor::from_slice(&gamma_data, &[hidden]).to_device(Device::Cuda(0));
let beta = Tensor::from_slice(&beta_data, &[hidden]).to_device(Device::Cuda(0));
let out = layernorm(&x, &gamma, &beta, 1e-5).to_device(Device::Cpu);
check_close(out.as_slice::<f32>(), &expected, 1e-4, "layernorm_f32");
}
// === GELU ===
#[test]
fn test_gelu_f32() {
init();
let data = make_data(10000);
let expected = cpu_gelu(&data);
let x = Tensor::from_slice(&data, &[10000]).to_device(Device::Cuda(0));
let out = gelu(&x).to_device(Device::Cpu);
check_close(out.as_slice::<f32>(), &expected, 1e-5, "gelu_f32");
}
#[test]
fn test_gelu_bf16() {
init();
let data_f32 = make_data(10000);
let expected = cpu_gelu(&data_f32);
let data_bf16: Vec<bf16> = data_f32.iter().map(|&v| bf16::from_f32(v)).collect();
let x = Tensor::from_slice(&data_bf16, &[10000]).to_device(Device::Cuda(0));
let out = gelu(&x).to_device(Device::Cpu);
let result: Vec<f32> = out.as_slice::<bf16>().iter().map(|v| v.to_f32()).collect();
check_close(&result, &expected, 0.02, "gelu_bf16");
}
// === SiLU ===
#[test]
fn test_silu_f32() {
init();
let data = make_data(10000);
let expected = cpu_silu(&data);
let x = Tensor::from_slice(&data, &[10000]).to_device(Device::Cuda(0));
let out = silu(&x).to_device(Device::Cpu);
check_close(out.as_slice::<f32>(), &expected, 1e-5, "silu_f32");
}
// === Softmax ===
#[test]
fn test_softmax_f32() {
init();
let rows = 8;
let cols = 256;
let data = make_data(rows * cols);
let expected = cpu_softmax(&data, cols);
let x = Tensor::from_slice(&data, &[rows, cols]).to_device(Device::Cuda(0));
let out = softmax(&x).to_device(Device::Cpu);
check_close(out.as_slice::<f32>(), &expected, 1e-5, "softmax_f32");
}
#[test]
fn test_softmax_sum_to_one() {
init();
let rows = 4;
let cols = 2048;
let data: Vec<f32> = (0..rows * cols).map(|i| ((i % 31) as f32 - 15.0) * 0.5).collect();
let x = Tensor::from_slice(&data, &[rows, cols]).to_device(Device::Cuda(0));
let out = softmax(&x).to_device(Device::Cpu);
let result = out.as_slice::<f32>();
for r in 0..rows {
let row_sum: f32 = result[r * cols..(r + 1) * cols].iter().sum();
assert!((row_sum - 1.0).abs() < 1e-5, "softmax row {r} sum = {row_sum}");
}
}
#[test]
fn test_softmax_large_values() {
init();
let data = vec![1000.0f32, 1001.0, 999.0, 1000.5];
let expected = cpu_softmax(&data, 4);
let x = Tensor::from_slice(&data, &[1, 4]).to_device(Device::Cuda(0));
let out = softmax(&x).to_device(Device::Cpu);
check_close(out.as_slice::<f32>(), &expected, 1e-5, "softmax_large");
}
// === Embedding ===
#[test]
fn test_embedding_f32() {
init();
let vocab_size = 100;
let hidden = 64;
let table_data: Vec<f32> = (0..vocab_size * hidden).map(|i| i as f32 * 0.01).collect();
let token_ids: Vec<u32> = vec![0, 5, 99, 42, 1];
let table = Tensor::from_slice(&table_data, &[vocab_size, hidden]).to_device(Device::Cuda(0));
let out = embedding(&table, &token_ids).to_device(Device::Cpu);
assert_eq!(out.shape(), &[5, hidden]);
let result = out.as_slice::<f32>();
for (seq_idx, &tid) in token_ids.iter().enumerate() {
for i in 0..hidden {
let expected = table_data[tid as usize * hidden + i];
let got = result[seq_idx * hidden + i];
assert!((got - expected).abs() < 1e-6,
"embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}");
}
}
}
// === RoPE ===
#[test]
fn test_rope_f32() {
init();
let num_tokens = 4;
let num_heads = 2;
let head_dim = 8;
let theta = 10000.0f32;
let positions: Vec<u32> = vec![0, 1, 2, 3];
let x_data: Vec<f32> = (0..num_tokens * num_heads * head_dim)
.map(|i| ((i % 13) as f32 - 6.0) * 0.1)
.collect();
let mut expected = x_data.clone();
cpu_rope(&mut expected, &positions, num_heads, head_dim, theta);
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
.to_device(Device::Cuda(0));
let cache = RopeCache::new(64, head_dim, theta);
rope_inplace(&x, &cache, &positions);
let out = x.to_device(Device::Cpu);
check_close(out.as_slice::<f32>(), &expected, 1e-4, "rope_f32");
}
#[test]
fn test_rope_position_0_identity() {
init();
// At position 0, all angles are 0, so cos=1, sin=0 → identity transform
let num_tokens = 1;
let num_heads = 2;
let head_dim = 8;
let positions: Vec<u32> = vec![0];
let x_data: Vec<f32> = (0..num_tokens * num_heads * head_dim)
.map(|i| (i as f32 + 1.0) * 0.1)
.collect();
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
.to_device(Device::Cuda(0));
let cache = RopeCache::new(64, head_dim, 10000.0);
rope_inplace(&x, &cache, &positions);
let out = x.to_device(Device::Cpu);
check_close(out.as_slice::<f32>(), &x_data, 1e-6, "rope_pos0");
}