use std::ffi::c_void; use xserv_tensor::{DType, Device, Tensor}; unsafe extern "C" { fn launch_rmsnorm_f32(x: *const c_void, gamma: *const c_void, out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void); fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void); } pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor { assert!(x.ndim() >= 1); assert!(x.is_contiguous() && gamma.is_contiguous()); assert!(matches!(x.device(), Device::Cuda(_))); let hidden_size = *x.shape().last().unwrap(); assert_eq!(gamma.shape(), &[hidden_size]); assert_eq!(x.dtype(), gamma.dtype()); let rows = x.numel() / hidden_size; let out = Tensor::zeros(x.shape(), x.dtype(), x.device()); unsafe { match x.dtype() { DType::F32 => launch_rmsnorm_f32( x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void, rows as i32, hidden_size as i32, eps, std::ptr::null_mut(), ), DType::BF16 => launch_rmsnorm_bf16( x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void, rows as i32, hidden_size as i32, eps, std::ptr::null_mut(), ), _ => panic!("unsupported dtype for rmsnorm"), } } out }