Kernel additions: - add_f32/bf16, mul_f32/bf16 CUDA kernels (element-wise, on GPU) - Refactored activation.rs with dispatch_unary/dispatch_binary helpers - Qwen3 and GPT-2 now use GPU add/mul instead of CPU round-trips GPT-2 add_bias also moved to GPU (broadcast via tile + GPU add) BF16 precision analysis (docs/benchmarks/phase10-qwen3.md): - Root cause: separate attention kernels materialize BF16 intermediates (QK^T→BF16→scale→BF16→mask→BF16→softmax→BF16 vs HF's fused FP32 path) - HF itself SDPA vs Eager also differs by ~0.125 logit - xserv vs HF: ~1-2 logit systematic offset, but same top-1 in 84% cases - Industry standard for BF16: top-5 overlap (we achieve 100%) - Fix path: Flash Attention (Phase 14) to fuse attention in FP32 Performance: TTFT 138→119ms, TBT 144→137ms (GPU ops faster than CPU) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
73 lines
3.8 KiB
Rust
73 lines
3.8 KiB
Rust
use std::ffi::c_void;
|
|
use xserv_tensor::{DType, Device, Tensor};
|
|
|
|
unsafe extern "C" {
|
|
fn launch_gelu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
|
fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
|
fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
|
fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
|
fn launch_scale_f32(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
|
|
fn launch_scale_bf16(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
|
|
fn launch_add_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
|
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
|
fn launch_mul_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
|
fn launch_mul_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
|
}
|
|
|
|
fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
|
|
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
|
|
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
|
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
|
let n = x.numel() as i32;
|
|
unsafe {
|
|
match x.dtype() {
|
|
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
|
DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
|
_ => panic!("unsupported dtype"),
|
|
}
|
|
}
|
|
xserv_cuda::device::synchronize().unwrap();
|
|
out
|
|
}
|
|
|
|
fn dispatch_binary(a: &Tensor, b: &Tensor,
|
|
f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
|
|
bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
|
|
assert_eq!(a.shape(), b.shape());
|
|
assert!(a.is_contiguous() && b.is_contiguous());
|
|
assert!(matches!(a.device(), Device::Cuda(_)));
|
|
assert_eq!(a.dtype(), b.dtype());
|
|
let out = Tensor::zeros(a.shape(), a.dtype(), a.device());
|
|
let n = a.numel() as i32;
|
|
unsafe {
|
|
match a.dtype() {
|
|
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
|
DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
|
_ => panic!("unsupported dtype"),
|
|
}
|
|
}
|
|
xserv_cuda::device::synchronize().unwrap();
|
|
out
|
|
}
|
|
|
|
pub fn gelu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16) }
|
|
pub fn silu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_silu_f32, launch_silu_bf16) }
|
|
|
|
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
|
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
|
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
|
let n = x.numel() as i32;
|
|
unsafe {
|
|
match x.dtype() {
|
|
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
|
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
|
_ => panic!("unsupported dtype for scale"),
|
|
}
|
|
}
|
|
xserv_cuda::device::synchronize().unwrap();
|
|
out
|
|
}
|
|
|
|
pub fn add(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_add_f32, launch_add_bf16) }
|
|
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) }
|