Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| be5c64ea8a | |||
| 268e40d764 | |||
| 246ae1c590 | |||
| 64084d3489 | |||
| cb12250ef0 | |||
| e1e75fc7f6 | |||
| 6035ffdc0b | |||
| c8e8153702 | |||
| 51a0f2eb14 |
@@ -4,6 +4,8 @@ members = [
|
||||
"crates/xserv-cuda",
|
||||
"crates/xserv-tensor",
|
||||
"crates/xserv-kernels",
|
||||
"crates/xserv-model",
|
||||
"crates/xserv-tokenizer",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -14,3 +16,7 @@ license = "MIT"
|
||||
[workspace.dependencies]
|
||||
half = "2"
|
||||
smallvec = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
safetensors = "0.5"
|
||||
regex = "1"
|
||||
|
||||
@@ -13,9 +13,17 @@ fn main() {
|
||||
.cuda(true)
|
||||
.cudart("shared")
|
||||
.flag("-gencode=arch=compute_120,code=sm_120")
|
||||
.include("../../csrc")
|
||||
.file("../../csrc/gemm/naive.cu")
|
||||
.file("../../csrc/gemm/tiled.cu")
|
||||
.compile("xserv_gemm_kernels");
|
||||
.file("../../csrc/normalization/rmsnorm.cu")
|
||||
.file("../../csrc/normalization/layernorm.cu")
|
||||
.file("../../csrc/activation/activations.cu")
|
||||
.file("../../csrc/reduce/softmax.cu")
|
||||
.file("../../csrc/embedding/embedding.cu")
|
||||
.file("../../csrc/embedding/rope.cu")
|
||||
.file("../../csrc/attention/causal_mask.cu")
|
||||
.compile("xserv_kernels");
|
||||
|
||||
println!("cargo:rerun-if-changed=../../csrc/gemm/");
|
||||
println!("cargo:rerun-if-changed=../../csrc/");
|
||||
}
|
||||
|
||||
72
crates/xserv-kernels/src/activation.rs
Normal file
72
crates/xserv-kernels/src/activation.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_gelu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_scale_f32(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
|
||||
fn launch_scale_bf16(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
|
||||
fn launch_add_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_mul_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_mul_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
|
||||
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel() as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
_ => panic!("unsupported dtype"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
fn dispatch_binary(a: &Tensor, b: &Tensor,
|
||||
f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
|
||||
bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
|
||||
assert_eq!(a.shape(), b.shape());
|
||||
assert!(a.is_contiguous() && b.is_contiguous());
|
||||
assert!(matches!(a.device(), Device::Cuda(_)));
|
||||
assert_eq!(a.dtype(), b.dtype());
|
||||
let out = Tensor::zeros(a.shape(), a.dtype(), a.device());
|
||||
let n = a.numel() as i32;
|
||||
unsafe {
|
||||
match a.dtype() {
|
||||
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
_ => panic!("unsupported dtype"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
pub fn gelu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16) }
|
||||
pub fn silu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_silu_f32, launch_silu_bf16) }
|
||||
|
||||
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel() as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
||||
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
||||
_ => panic!("unsupported dtype for scale"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
pub fn add(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_add_f32, launch_add_bf16) }
|
||||
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) }
|
||||
77
crates/xserv-kernels/src/attention.rs
Normal file
77
crates/xserv-kernels/src/attention.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Tensor};
|
||||
|
||||
use crate::activation::scale;
|
||||
use crate::gemm::batched_matmul;
|
||||
use crate::softmax::softmax;
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_causal_mask_f32(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
|
||||
offset: i32, stream: *mut c_void);
|
||||
fn launch_causal_mask_bf16(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
|
||||
offset: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
||||
let ndim = scores.ndim();
|
||||
let rows = scores.shape()[ndim - 2];
|
||||
let cols = scores.shape()[ndim - 1];
|
||||
let batch: usize = scores.shape()[..ndim - 2].iter().product();
|
||||
|
||||
unsafe {
|
||||
match scores.dtype() {
|
||||
DType::F32 => launch_causal_mask_f32(
|
||||
scores.data_ptr() as *mut c_void,
|
||||
batch as i32, rows as i32, cols as i32, offset as i32,
|
||||
std::ptr::null_mut(),
|
||||
),
|
||||
DType::BF16 => launch_causal_mask_bf16(
|
||||
scores.data_ptr() as *mut c_void,
|
||||
batch as i32, rows as i32, cols as i32, offset as i32,
|
||||
std::ptr::null_mut(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for causal mask"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
|
||||
/// Multi-head attention (naive, materializes S×S score matrix).
|
||||
///
|
||||
/// q, k, v: [batch, num_heads, seq_len, head_dim] — contiguous, on GPU
|
||||
/// Returns: [batch, num_heads, seq_len, head_dim]
|
||||
pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(k.ndim(), 4);
|
||||
assert_eq!(v.ndim(), 4);
|
||||
assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous());
|
||||
|
||||
let batch = q.shape()[0];
|
||||
let num_heads = q.shape()[1];
|
||||
let q_len = q.shape()[2];
|
||||
let head_dim = q.shape()[3];
|
||||
let kv_len = k.shape()[2];
|
||||
|
||||
assert_eq!(k.shape(), &[batch, num_heads, kv_len, head_dim]);
|
||||
assert_eq!(v.shape(), &[batch, num_heads, kv_len, head_dim]);
|
||||
|
||||
// scores = Q @ K^T → [B, H, q_len, kv_len]
|
||||
let k_t = k.transpose(2, 3).contiguous();
|
||||
let scores = batched_matmul(q, &k_t);
|
||||
|
||||
// Scale by 1/sqrt(head_dim)
|
||||
let scale_factor = 1.0 / (head_dim as f32).sqrt();
|
||||
let scaled_scores = scale(&scores, scale_factor);
|
||||
|
||||
// Causal mask
|
||||
if causal {
|
||||
let offset = kv_len - q_len;
|
||||
apply_causal_mask(&scaled_scores, offset);
|
||||
}
|
||||
|
||||
// Softmax
|
||||
let weights = softmax(&scaled_scores);
|
||||
|
||||
// output = weights @ V → [B, H, q_len, head_dim]
|
||||
batched_matmul(&weights, v)
|
||||
}
|
||||
51
crates/xserv-kernels/src/embedding.rs
Normal file
51
crates/xserv-kernels/src/embedding.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_embedding_f32(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, stream: *mut c_void);
|
||||
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
/// Embedding lookup: table[token_ids[i]] for each i.
|
||||
/// table: [vocab_size, hidden_size], token_ids: [num_tokens] (i32 on CPU)
|
||||
pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
||||
assert_eq!(table.ndim(), 2);
|
||||
assert!(table.is_contiguous());
|
||||
assert!(matches!(table.device(), Device::Cuda(_)));
|
||||
|
||||
let hidden_size = table.shape()[1];
|
||||
let num_tokens = token_ids.len();
|
||||
|
||||
// Upload token_ids to GPU
|
||||
let ids_bytes = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
token_ids.as_ptr() as *const u8,
|
||||
num_tokens * std::mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
let mut ids_gpu = GpuBuffer::alloc(ids_bytes.len()).expect("alloc token_ids");
|
||||
ids_gpu.copy_from_host(ids_bytes).unwrap();
|
||||
|
||||
let out = Tensor::zeros(&[num_tokens, hidden_size], table.dtype(), table.device());
|
||||
|
||||
unsafe {
|
||||
match table.dtype() {
|
||||
DType::F32 => launch_embedding_f32(
|
||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden_size as i32, std::ptr::null_mut(),
|
||||
),
|
||||
DType::BF16 => launch_embedding_bf16(
|
||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden_size as i32, std::ptr::null_mut(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for embedding"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
@@ -46,6 +46,19 @@ unsafe extern "C" {
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
) -> i32;
|
||||
fn cublasGemmStridedBatchedEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32, transb: i32,
|
||||
m: i32, n: i32, k: i32,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void, a_type: i32, lda: i32, stride_a: i64,
|
||||
b: *const c_void, b_type: i32, ldb: i32, stride_b: i64,
|
||||
beta: *const c_void,
|
||||
c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64,
|
||||
batch_count: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
) -> i32;
|
||||
}
|
||||
|
||||
pub struct CublasContext {
|
||||
@@ -149,3 +162,68 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
|
||||
c
|
||||
}
|
||||
|
||||
/// Batched matrix multiplication via cuBLAS: C[b] = A[b] @ B[b]
|
||||
/// a: [..., M, K], b: [..., K, N] → [..., M, N]
|
||||
/// Leading dimensions must match and tensors must be contiguous.
|
||||
pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert!(a.ndim() >= 2 && b.ndim() >= 2);
|
||||
assert_eq!(a.ndim(), b.ndim());
|
||||
assert!(a.is_contiguous() && b.is_contiguous());
|
||||
assert!(matches!(a.device(), Device::Cuda(_)));
|
||||
assert_eq!(a.dtype(), b.dtype());
|
||||
|
||||
let ndim = a.ndim();
|
||||
let m = a.shape()[ndim - 2];
|
||||
let k = a.shape()[ndim - 1];
|
||||
let n = b.shape()[ndim - 1];
|
||||
assert_eq!(b.shape()[ndim - 2], k, "inner dimension mismatch");
|
||||
|
||||
// Compute batch count from leading dimensions
|
||||
let batch: usize = a.shape()[..ndim - 2].iter().product();
|
||||
assert_eq!(
|
||||
b.shape()[..ndim - 2].iter().product::<usize>(),
|
||||
batch,
|
||||
"batch dimensions mismatch"
|
||||
);
|
||||
|
||||
let mut out_shape: Vec<usize> = a.shape()[..ndim - 2].to_vec();
|
||||
out_shape.push(m);
|
||||
out_shape.push(n);
|
||||
let c = Tensor::zeros(&out_shape, a.dtype(), a.device());
|
||||
|
||||
let dtype = a.dtype();
|
||||
let (a_type, b_type, c_type) = match dtype {
|
||||
DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F),
|
||||
DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF),
|
||||
_ => panic!("unsupported dtype for batched matmul"),
|
||||
};
|
||||
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
// cuBLAS strides are in elements (not bytes)
|
||||
let stride_a = (m * k) as i64;
|
||||
let stride_b = (k * n) as i64;
|
||||
let stride_c = (m * n) as i64;
|
||||
|
||||
let ctx = CublasContext::new().unwrap();
|
||||
unsafe {
|
||||
cublasSetStream_v2(ctx.handle, std::ptr::null_mut());
|
||||
// Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major)
|
||||
error::check(cublasGemmStridedBatchedEx(
|
||||
ctx.handle,
|
||||
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
n as i32, m as i32, k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b.data_ptr() as _, b_type, n as i32, stride_b,
|
||||
a.data_ptr() as _, a_type, k as i32, stride_a,
|
||||
&beta as *const f32 as *const c_void,
|
||||
c.data_ptr() as *mut c_void, c_type, n as i32, stride_c,
|
||||
batch as i32,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
-1,
|
||||
)).expect("cuBLAS batched GEMM failed");
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
c
|
||||
}
|
||||
|
||||
39
crates/xserv-kernels/src/layernorm.rs
Normal file
39
crates/xserv-kernels/src/layernorm.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_layernorm_f32(x: *const c_void, gamma: *const c_void, beta: *const c_void,
|
||||
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_layernorm_bf16(x: *const c_void, gamma: *const c_void, beta: *const c_void,
|
||||
out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor {
|
||||
assert!(x.ndim() >= 1);
|
||||
assert!(x.is_contiguous() && gamma.is_contiguous() && beta.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let hidden_size = *x.shape().last().unwrap();
|
||||
assert_eq!(gamma.shape(), &[hidden_size]);
|
||||
assert_eq!(beta.shape(), &[hidden_size]);
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_layernorm_f32(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
),
|
||||
DType::BF16 => launch_layernorm_bf16(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for layernorm"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
@@ -1,3 +1,17 @@
|
||||
pub mod activation;
|
||||
pub mod attention;
|
||||
pub mod embedding;
|
||||
pub mod gemm;
|
||||
pub mod layernorm;
|
||||
pub mod rmsnorm;
|
||||
pub mod rope;
|
||||
pub mod softmax;
|
||||
|
||||
pub use gemm::{GemmBackend, matmul};
|
||||
pub use activation::{add, gelu, mul, scale, silu};
|
||||
pub use attention::attention;
|
||||
pub use embedding::embedding;
|
||||
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
||||
pub use layernorm::layernorm;
|
||||
pub use rmsnorm::rmsnorm;
|
||||
pub use rope::{rope_inplace, RopeCache};
|
||||
pub use softmax::softmax;
|
||||
|
||||
37
crates/xserv-kernels/src/rmsnorm.rs
Normal file
37
crates/xserv-kernels/src/rmsnorm.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_rmsnorm_f32(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
||||
assert!(x.ndim() >= 1);
|
||||
assert!(x.is_contiguous() && gamma.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let hidden_size = *x.shape().last().unwrap();
|
||||
assert_eq!(gamma.shape(), &[hidden_size]);
|
||||
assert_eq!(x.dtype(), gamma.dtype());
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_rmsnorm_f32(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
),
|
||||
DType::BF16 => launch_rmsnorm_bf16(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for rmsnorm"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
85
crates/xserv-kernels/src/rope.rs
Normal file
85
crates/xserv-kernels/src/rope.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_rope_f32(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void);
|
||||
fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void);
|
||||
fn launch_compute_rope_cache(cos_cache: *mut c_void, sin_cache: *mut c_void,
|
||||
max_seq_len: i32, half_dim: i32, theta: f32,
|
||||
stream: *mut c_void);
|
||||
}
|
||||
|
||||
pub struct RopeCache {
|
||||
pub cos: GpuBuffer,
|
||||
pub sin: GpuBuffer,
|
||||
pub max_seq_len: usize,
|
||||
pub half_dim: usize,
|
||||
}
|
||||
|
||||
impl RopeCache {
|
||||
pub fn new(max_seq_len: usize, head_dim: usize, theta: f32) -> Self {
|
||||
let half_dim = head_dim / 2;
|
||||
let nbytes = max_seq_len * half_dim * std::mem::size_of::<f32>();
|
||||
let mut cos = GpuBuffer::alloc(nbytes).expect("alloc cos_cache");
|
||||
let mut sin = GpuBuffer::alloc(nbytes).expect("alloc sin_cache");
|
||||
|
||||
unsafe {
|
||||
launch_compute_rope_cache(
|
||||
cos.as_mut_ptr() as _, sin.as_mut_ptr() as _,
|
||||
max_seq_len as i32, half_dim as i32, theta, std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
|
||||
Self { cos, sin, max_seq_len, half_dim }
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply RoPE in-place to x.
|
||||
/// x: [num_tokens, num_heads, head_dim] on GPU
|
||||
/// positions: [num_tokens] (u32 on CPU, will be uploaded)
|
||||
pub fn rope_inplace(x: &Tensor, cache: &RopeCache, positions: &[u32]) {
|
||||
assert_eq!(x.ndim(), 3);
|
||||
assert!(x.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let num_tokens = x.shape()[0];
|
||||
let num_heads = x.shape()[1];
|
||||
let head_dim = x.shape()[2];
|
||||
assert_eq!(head_dim / 2, cache.half_dim);
|
||||
assert_eq!(positions.len(), num_tokens);
|
||||
|
||||
let pos_bytes = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
positions.as_ptr() as *const u8,
|
||||
num_tokens * std::mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
let mut pos_gpu = GpuBuffer::alloc(pos_bytes.len()).expect("alloc positions");
|
||||
pos_gpu.copy_from_host(pos_bytes).unwrap();
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_rope_f32(
|
||||
x.data_ptr() as *mut c_void,
|
||||
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
||||
pos_gpu.as_ptr() as _,
|
||||
num_tokens as i32, num_heads as i32, head_dim as i32,
|
||||
std::ptr::null_mut(),
|
||||
),
|
||||
DType::BF16 => launch_rope_bf16(
|
||||
x.data_ptr() as *mut c_void,
|
||||
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
||||
pos_gpu.as_ptr() as _,
|
||||
num_tokens as i32, num_heads as i32, head_dim as i32,
|
||||
std::ptr::null_mut(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for rope"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
34
crates/xserv-kernels/src/softmax.rs
Normal file
34
crates/xserv-kernels/src/softmax.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_softmax_f32(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
|
||||
fn launch_softmax_bf16(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
/// Softmax along the last dimension.
|
||||
pub fn softmax(x: &Tensor) -> Tensor {
|
||||
assert!(x.ndim() >= 1);
|
||||
assert!(x.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
|
||||
let cols = *x.shape().last().unwrap();
|
||||
let rows = x.numel() / cols;
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_softmax_f32(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32, std::ptr::null_mut(),
|
||||
),
|
||||
DType::BF16 => launch_softmax_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32, std::ptr::null_mut(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for softmax"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
187
crates/xserv-kernels/tests/attention_test.rs
Normal file
187
crates/xserv-kernels/tests/attention_test.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
|
||||
|
||||
fn cpu_attention(q: &[f32], k: &[f32], v: &[f32],
|
||||
batch: usize, heads: usize, q_len: usize, kv_len: usize, head_dim: usize,
|
||||
causal: bool) -> Vec<f32> {
|
||||
let mut out = vec![0.0f32; batch * heads * q_len * head_dim];
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
for b in 0..batch {
|
||||
for h in 0..heads {
|
||||
// scores = Q @ K^T, scaled
|
||||
let mut scores = vec![0.0f32; q_len * kv_len];
|
||||
for i in 0..q_len {
|
||||
for j in 0..kv_len {
|
||||
let mut s = 0.0f32;
|
||||
for d in 0..head_dim {
|
||||
let qi = q[((b * heads + h) * q_len + i) * head_dim + d];
|
||||
let ki = k[((b * heads + h) * kv_len + j) * head_dim + d];
|
||||
s += qi * ki;
|
||||
}
|
||||
scores[i * kv_len + j] = s * scale;
|
||||
}
|
||||
}
|
||||
// causal mask
|
||||
if causal {
|
||||
let offset = kv_len - q_len;
|
||||
for i in 0..q_len {
|
||||
for j in 0..kv_len {
|
||||
if j > i + offset {
|
||||
scores[i * kv_len + j] = f32::NEG_INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// softmax per row
|
||||
for i in 0..q_len {
|
||||
let row = &mut scores[i * kv_len..(i + 1) * kv_len];
|
||||
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut sum = 0.0f32;
|
||||
for v in row.iter_mut() {
|
||||
*v = (*v - max).exp();
|
||||
sum += *v;
|
||||
}
|
||||
for v in row.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
// output = weights @ V
|
||||
for i in 0..q_len {
|
||||
for d in 0..head_dim {
|
||||
let mut s = 0.0f32;
|
||||
for j in 0..kv_len {
|
||||
let w = scores[i * kv_len + j];
|
||||
let vi = v[((b * heads + h) * kv_len + j) * head_dim + d];
|
||||
s += w * vi;
|
||||
}
|
||||
out[((b * heads + h) * q_len + i) * head_dim + d] = s;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn check_close(a: &[f32], b: &[f32], atol: f32, name: &str) {
|
||||
assert_eq!(a.len(), b.len(), "{name}: length mismatch");
|
||||
let mut max_err = 0.0f32;
|
||||
for (i, (x, y)) in a.iter().zip(b).enumerate() {
|
||||
let err = (x - y).abs();
|
||||
if err > max_err { max_err = err; }
|
||||
assert!(err <= atol, "{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}");
|
||||
}
|
||||
println!("{name}: max_err = {max_err:.6e}");
|
||||
}
|
||||
|
||||
fn make_data(n: usize) -> Vec<f32> {
|
||||
(0..n).map(|i| ((i % 17) as f32 - 8.0) * 0.05).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batched_matmul() {
|
||||
init();
|
||||
let batch = 4;
|
||||
let heads = 8;
|
||||
let m = 32;
|
||||
let k = 64;
|
||||
let n = 32;
|
||||
|
||||
let a_data = make_data(batch * heads * m * k);
|
||||
let b_data = make_data(batch * heads * k * n);
|
||||
|
||||
let a = Tensor::from_slice(&a_data, &[batch, heads, m, k]).to_device(Device::Cuda(0));
|
||||
let b = Tensor::from_slice(&b_data, &[batch, heads, k, n]).to_device(Device::Cuda(0));
|
||||
let c = batched_matmul(&a, &b).to_device(Device::Cpu);
|
||||
|
||||
assert_eq!(c.shape(), &[batch, heads, m, n]);
|
||||
|
||||
// Verify one batch element
|
||||
let a_cpu = &a_data[0..m * k];
|
||||
let b_cpu = &b_data[0..k * n];
|
||||
let mut expected = vec![0.0f32; m * n];
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
let mut s = 0.0f32;
|
||||
for kk in 0..k { s += a_cpu[i * k + kk] * b_cpu[kk * n + j]; }
|
||||
expected[i * n + j] = s;
|
||||
}
|
||||
}
|
||||
let result = c.as_slice::<f32>();
|
||||
check_close(&result[0..m * n], &expected, 1e-3, "batched_matmul[0]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_no_causal() {
|
||||
init();
|
||||
let b = 1; let h = 2; let s = 8; let d = 16;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, false);
|
||||
|
||||
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, false).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-4, "attention_no_causal");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_causal() {
|
||||
init();
|
||||
let b = 1; let h = 2; let s = 16; let d = 32;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true);
|
||||
|
||||
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-3, "attention_causal");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_causal_larger() {
|
||||
init();
|
||||
let b = 2; let h = 4; let s = 64; let d = 64;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true);
|
||||
|
||||
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-2, "attention_causal_larger");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_causal_first_row_sees_only_first_token() {
|
||||
init();
|
||||
let b = 1; let h = 1; let s = 4; let d = 8;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data: Vec<f32> = (0..s * d).map(|i| {
|
||||
if i < d { 1.0 } else { 0.0 } // only first V row is nonzero
|
||||
}).collect();
|
||||
|
||||
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
||||
|
||||
// First row (position 0) with causal mask can only see position 0.
|
||||
// So attention weight for position 0 is 1.0 for token 0 only.
|
||||
// output[0] should be exactly V[0] = [1, 1, 1, ...1]
|
||||
let result = out.as_slice::<f32>();
|
||||
for i in 0..d {
|
||||
assert!((result[i] - 1.0).abs() < 1e-5,
|
||||
"first row should equal V[0], got {} at dim {}", result[i], i);
|
||||
}
|
||||
}
|
||||
302
crates/xserv-kernels/tests/ops_test.rs
Normal file
302
crates/xserv-kernels/tests/ops_test.rs
Normal file
@@ -0,0 +1,302 @@
|
||||
use half::bf16;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
|
||||
|
||||
// --- CPU reference implementations ---
|
||||
|
||||
fn cpu_rmsnorm(x: &[f32], gamma: &[f32], eps: f32, hidden: usize) -> Vec<f32> {
|
||||
let rows = x.len() / hidden;
|
||||
let mut out = vec![0.0f32; x.len()];
|
||||
for r in 0..rows {
|
||||
let row = &x[r * hidden..(r + 1) * hidden];
|
||||
let sum_sq: f32 = row.iter().map(|v| v * v).sum();
|
||||
let rms_inv = 1.0 / (sum_sq / hidden as f32 + eps).sqrt();
|
||||
for i in 0..hidden {
|
||||
out[r * hidden + i] = row[i] * rms_inv * gamma[i];
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn cpu_layernorm(x: &[f32], gamma: &[f32], beta: &[f32], eps: f32, hidden: usize) -> Vec<f32> {
|
||||
let rows = x.len() / hidden;
|
||||
let mut out = vec![0.0f32; x.len()];
|
||||
for r in 0..rows {
|
||||
let row = &x[r * hidden..(r + 1) * hidden];
|
||||
let mean: f32 = row.iter().sum::<f32>() / hidden as f32;
|
||||
let var: f32 = row.iter().map(|v| (v - mean) * (v - mean)).sum::<f32>() / hidden as f32;
|
||||
let inv_std = 1.0 / (var + eps).sqrt();
|
||||
for i in 0..hidden {
|
||||
out[r * hidden + i] = gamma[i] * (row[i] - mean) * inv_std + beta[i];
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn cpu_gelu(x: &[f32]) -> Vec<f32> {
|
||||
let sqrt_2_over_pi = 0.7978845608f32;
|
||||
x.iter().map(|&v| {
|
||||
let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
|
||||
0.5 * v * (1.0 + inner.tanh())
|
||||
}).collect()
|
||||
}
|
||||
|
||||
fn cpu_silu(x: &[f32]) -> Vec<f32> {
|
||||
x.iter().map(|&v| v / (1.0 + (-v).exp())).collect()
|
||||
}
|
||||
|
||||
fn cpu_softmax(x: &[f32], cols: usize) -> Vec<f32> {
|
||||
let rows = x.len() / cols;
|
||||
let mut out = vec![0.0f32; x.len()];
|
||||
for r in 0..rows {
|
||||
let row = &x[r * cols..(r + 1) * cols];
|
||||
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exps: Vec<f32> = row.iter().map(|v| (v - max).exp()).collect();
|
||||
let sum: f32 = exps.iter().sum();
|
||||
for i in 0..cols {
|
||||
out[r * cols + i] = exps[i] / sum;
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn cpu_rope(x: &mut [f32], positions: &[u32], num_heads: usize, head_dim: usize, theta: f32) {
|
||||
let half_dim = head_dim / 2;
|
||||
let num_tokens = positions.len();
|
||||
for t in 0..num_tokens {
|
||||
let pos = positions[t] as f32;
|
||||
for h in 0..num_heads {
|
||||
for i in 0..half_dim {
|
||||
let freq = 1.0 / theta.powf(2.0 * i as f32 / head_dim as f32);
|
||||
let angle = pos * freq;
|
||||
let cos_val = angle.cos();
|
||||
let sin_val = angle.sin();
|
||||
let base = (t * num_heads + h) * head_dim;
|
||||
let x0 = x[base + 2 * i];
|
||||
let x1 = x[base + 2 * i + 1];
|
||||
x[base + 2 * i] = x0 * cos_val - x1 * sin_val;
|
||||
x[base + 2 * i + 1] = x0 * sin_val + x1 * cos_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn check_close(result: &[f32], expected: &[f32], atol: f32, name: &str) {
|
||||
assert_eq!(result.len(), expected.len(), "{name}: length mismatch");
|
||||
let mut max_err = 0.0f32;
|
||||
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
|
||||
let err = (r - e).abs();
|
||||
if err > max_err { max_err = err; }
|
||||
assert!(err <= atol, "{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}");
|
||||
}
|
||||
println!("{name}: max_err = {max_err:.6e}");
|
||||
}
|
||||
|
||||
fn make_data(n: usize) -> Vec<f32> {
|
||||
(0..n).map(|i| ((i % 17) as f32 - 8.0) * 0.1).collect()
|
||||
}
|
||||
|
||||
// === RMSNorm ===
|
||||
|
||||
#[test]
|
||||
fn test_rmsnorm_f32() {
|
||||
init();
|
||||
let hidden = 768;
|
||||
let rows = 4;
|
||||
let x_data = make_data(rows * hidden);
|
||||
let gamma_data: Vec<f32> = (0..hidden).map(|i| 0.5 + (i % 3) as f32 * 0.2).collect();
|
||||
let expected = cpu_rmsnorm(&x_data, &gamma_data, 1e-5, hidden);
|
||||
|
||||
let x = Tensor::from_slice(&x_data, &[rows, hidden]).to_device(Device::Cuda(0));
|
||||
let gamma = Tensor::from_slice(&gamma_data, &[hidden]).to_device(Device::Cuda(0));
|
||||
let out = rmsnorm(&x, &gamma, 1e-5).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-4, "rmsnorm_f32");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rmsnorm_bf16() {
|
||||
init();
|
||||
let hidden = 768;
|
||||
let rows = 4;
|
||||
let x_f32 = make_data(rows * hidden);
|
||||
let gamma_f32: Vec<f32> = (0..hidden).map(|i| 0.5 + (i % 3) as f32 * 0.2).collect();
|
||||
let expected = cpu_rmsnorm(&x_f32, &gamma_f32, 1e-5, hidden);
|
||||
|
||||
let x_bf16: Vec<bf16> = x_f32.iter().map(|&v| bf16::from_f32(v)).collect();
|
||||
let gamma_bf16: Vec<bf16> = gamma_f32.iter().map(|&v| bf16::from_f32(v)).collect();
|
||||
let x = Tensor::from_slice(&x_bf16, &[rows, hidden]).to_device(Device::Cuda(0));
|
||||
let gamma = Tensor::from_slice(&gamma_bf16, &[hidden]).to_device(Device::Cuda(0));
|
||||
let out = rmsnorm(&x, &gamma, 1e-5).to_device(Device::Cpu);
|
||||
|
||||
let result: Vec<f32> = out.as_slice::<bf16>().iter().map(|v| v.to_f32()).collect();
|
||||
check_close(&result, &expected, 0.05, "rmsnorm_bf16");
|
||||
}
|
||||
|
||||
// === LayerNorm ===
|
||||
|
||||
#[test]
|
||||
fn test_layernorm_f32() {
|
||||
init();
|
||||
let hidden = 768;
|
||||
let rows = 4;
|
||||
let x_data = make_data(rows * hidden);
|
||||
let gamma_data: Vec<f32> = (0..hidden).map(|i| 0.8 + (i % 5) as f32 * 0.1).collect();
|
||||
let beta_data: Vec<f32> = (0..hidden).map(|i| ((i % 7) as f32 - 3.0) * 0.01).collect();
|
||||
let expected = cpu_layernorm(&x_data, &gamma_data, &beta_data, 1e-5, hidden);
|
||||
|
||||
let x = Tensor::from_slice(&x_data, &[rows, hidden]).to_device(Device::Cuda(0));
|
||||
let gamma = Tensor::from_slice(&gamma_data, &[hidden]).to_device(Device::Cuda(0));
|
||||
let beta = Tensor::from_slice(&beta_data, &[hidden]).to_device(Device::Cuda(0));
|
||||
let out = layernorm(&x, &gamma, &beta, 1e-5).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-4, "layernorm_f32");
|
||||
}
|
||||
|
||||
// === GELU ===
|
||||
|
||||
#[test]
|
||||
fn test_gelu_f32() {
|
||||
init();
|
||||
let data = make_data(10000);
|
||||
let expected = cpu_gelu(&data);
|
||||
let x = Tensor::from_slice(&data, &[10000]).to_device(Device::Cuda(0));
|
||||
let out = gelu(&x).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-5, "gelu_f32");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gelu_bf16() {
|
||||
init();
|
||||
let data_f32 = make_data(10000);
|
||||
let expected = cpu_gelu(&data_f32);
|
||||
let data_bf16: Vec<bf16> = data_f32.iter().map(|&v| bf16::from_f32(v)).collect();
|
||||
let x = Tensor::from_slice(&data_bf16, &[10000]).to_device(Device::Cuda(0));
|
||||
let out = gelu(&x).to_device(Device::Cpu);
|
||||
let result: Vec<f32> = out.as_slice::<bf16>().iter().map(|v| v.to_f32()).collect();
|
||||
check_close(&result, &expected, 0.02, "gelu_bf16");
|
||||
}
|
||||
|
||||
// === SiLU ===
|
||||
|
||||
#[test]
|
||||
fn test_silu_f32() {
|
||||
init();
|
||||
let data = make_data(10000);
|
||||
let expected = cpu_silu(&data);
|
||||
let x = Tensor::from_slice(&data, &[10000]).to_device(Device::Cuda(0));
|
||||
let out = silu(&x).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-5, "silu_f32");
|
||||
}
|
||||
|
||||
// === Softmax ===
|
||||
|
||||
#[test]
|
||||
fn test_softmax_f32() {
|
||||
init();
|
||||
let rows = 8;
|
||||
let cols = 256;
|
||||
let data = make_data(rows * cols);
|
||||
let expected = cpu_softmax(&data, cols);
|
||||
let x = Tensor::from_slice(&data, &[rows, cols]).to_device(Device::Cuda(0));
|
||||
let out = softmax(&x).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-5, "softmax_f32");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_sum_to_one() {
|
||||
init();
|
||||
let rows = 4;
|
||||
let cols = 2048;
|
||||
let data: Vec<f32> = (0..rows * cols).map(|i| ((i % 31) as f32 - 15.0) * 0.5).collect();
|
||||
let x = Tensor::from_slice(&data, &[rows, cols]).to_device(Device::Cuda(0));
|
||||
let out = softmax(&x).to_device(Device::Cpu);
|
||||
let result = out.as_slice::<f32>();
|
||||
for r in 0..rows {
|
||||
let row_sum: f32 = result[r * cols..(r + 1) * cols].iter().sum();
|
||||
assert!((row_sum - 1.0).abs() < 1e-5, "softmax row {r} sum = {row_sum}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_large_values() {
|
||||
init();
|
||||
let data = vec![1000.0f32, 1001.0, 999.0, 1000.5];
|
||||
let expected = cpu_softmax(&data, 4);
|
||||
let x = Tensor::from_slice(&data, &[1, 4]).to_device(Device::Cuda(0));
|
||||
let out = softmax(&x).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-5, "softmax_large");
|
||||
}
|
||||
|
||||
// === Embedding ===
|
||||
|
||||
#[test]
|
||||
fn test_embedding_f32() {
|
||||
init();
|
||||
let vocab_size = 100;
|
||||
let hidden = 64;
|
||||
let table_data: Vec<f32> = (0..vocab_size * hidden).map(|i| i as f32 * 0.01).collect();
|
||||
let token_ids: Vec<u32> = vec![0, 5, 99, 42, 1];
|
||||
|
||||
let table = Tensor::from_slice(&table_data, &[vocab_size, hidden]).to_device(Device::Cuda(0));
|
||||
let out = embedding(&table, &token_ids).to_device(Device::Cpu);
|
||||
|
||||
assert_eq!(out.shape(), &[5, hidden]);
|
||||
let result = out.as_slice::<f32>();
|
||||
for (seq_idx, &tid) in token_ids.iter().enumerate() {
|
||||
for i in 0..hidden {
|
||||
let expected = table_data[tid as usize * hidden + i];
|
||||
let got = result[seq_idx * hidden + i];
|
||||
assert!((got - expected).abs() < 1e-6,
|
||||
"embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === RoPE ===
|
||||
|
||||
#[test]
|
||||
fn test_rope_f32() {
|
||||
init();
|
||||
let num_tokens = 4;
|
||||
let num_heads = 2;
|
||||
let head_dim = 8;
|
||||
let theta = 10000.0f32;
|
||||
let positions: Vec<u32> = vec![0, 1, 2, 3];
|
||||
|
||||
let x_data: Vec<f32> = (0..num_tokens * num_heads * head_dim)
|
||||
.map(|i| ((i % 13) as f32 - 6.0) * 0.1)
|
||||
.collect();
|
||||
let mut expected = x_data.clone();
|
||||
cpu_rope(&mut expected, &positions, num_heads, head_dim, theta);
|
||||
|
||||
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
|
||||
.to_device(Device::Cuda(0));
|
||||
let cache = RopeCache::new(64, head_dim, theta);
|
||||
rope_inplace(&x, &cache, &positions);
|
||||
|
||||
let out = x.to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-4, "rope_f32");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rope_position_0_identity() {
|
||||
init();
|
||||
// At position 0, all angles are 0, so cos=1, sin=0 → identity transform
|
||||
let num_tokens = 1;
|
||||
let num_heads = 2;
|
||||
let head_dim = 8;
|
||||
let positions: Vec<u32> = vec![0];
|
||||
|
||||
let x_data: Vec<f32> = (0..num_tokens * num_heads * head_dim)
|
||||
.map(|i| (i as f32 + 1.0) * 0.1)
|
||||
.collect();
|
||||
|
||||
let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim])
|
||||
.to_device(Device::Cuda(0));
|
||||
let cache = RopeCache::new(64, head_dim, 10000.0);
|
||||
rope_inplace(&x, &cache, &positions);
|
||||
|
||||
let out = x.to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &x_data, 1e-6, "rope_pos0");
|
||||
}
|
||||
14
crates/xserv-model/Cargo.toml
Normal file
14
crates/xserv-model/Cargo.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "xserv-model"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
xserv-cuda = { path = "../xserv-cuda" }
|
||||
xserv-tensor = { path = "../xserv-tensor" }
|
||||
xserv-kernels = { path = "../xserv-kernels" }
|
||||
xserv-tokenizer = { path = "../xserv-tokenizer" }
|
||||
half.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
safetensors.workspace = true
|
||||
198
crates/xserv-model/src/bin/bench-gpt2.rs
Normal file
198
crates/xserv-model/src/bin/bench-gpt2.rs
Normal file
@@ -0,0 +1,198 @@
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
use xserv_model::gpt2::{sample_greedy, KVCache};
|
||||
use xserv_model::{loader, GPT2, ModelConfig};
|
||||
use xserv_tensor::Device;
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: bench-gpt2 <model-dir> [--gen-tokens N] [--no-cache]");
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let gen_tokens: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--gen-tokens")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(20);
|
||||
let use_cache = !args.iter().any(|a| a == "--no-cache");
|
||||
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cuda(0));
|
||||
let model = GPT2::from_weights(config.clone(), weights);
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
|
||||
// Warmup
|
||||
{
|
||||
let ids = tokenizer.encode("warmup");
|
||||
let _ = model.forward(&ids);
|
||||
}
|
||||
|
||||
eprintln!("mode: {}", if use_cache { "KV cache" } else { "no cache" });
|
||||
|
||||
let prompts: Vec<&str> = vec![
|
||||
"The capital of France is",
|
||||
"Once upon a time in a land far away",
|
||||
"Hello, how are you doing today",
|
||||
"In a shocking finding, scientists discovered a",
|
||||
"The weather today is sunny, so I decided to",
|
||||
"Alan Turing was a British mathematician who",
|
||||
"The best way to learn programming is",
|
||||
"Artificial intelligence will change the world because",
|
||||
"The history of the internet began in the",
|
||||
"A good morning routine starts with",
|
||||
"The stock market crashed because investors",
|
||||
"Deep learning is a subset of machine learning that",
|
||||
"The president of the United States announced",
|
||||
"In the year 2050, humans will",
|
||||
"The secret to happiness is",
|
||||
"When I was a child, I used to",
|
||||
"The most important scientific discovery of the century",
|
||||
"Climate change is caused by",
|
||||
"The recipe for chocolate cake requires",
|
||||
"In conclusion, the evidence suggests that",
|
||||
"The cat sat on the mat and",
|
||||
"According to recent studies, exercise can",
|
||||
"The first step in solving any problem is",
|
||||
"Technology has transformed the way we",
|
||||
"The novel begins with the protagonist",
|
||||
"Education is the most powerful weapon",
|
||||
"The ocean covers more than seventy percent of",
|
||||
"Last night I had a dream about",
|
||||
"The company announced its quarterly earnings",
|
||||
"Music has the power to",
|
||||
"The difference between success and failure is",
|
||||
"In the beginning, there was nothing but",
|
||||
"The doctor told me that I should",
|
||||
"Python is a popular programming language because",
|
||||
"The ancient Romans built roads that",
|
||||
"A balanced diet should include",
|
||||
"The movie received mixed reviews from critics",
|
||||
"Space exploration has led to many",
|
||||
"The teacher asked the students to",
|
||||
"Global warming is one of the most",
|
||||
"The bridge collapsed due to structural",
|
||||
"Quantum computing promises to revolutionize",
|
||||
"The new policy will affect millions of",
|
||||
"During the winter months, it is important to",
|
||||
"The human brain contains approximately",
|
||||
"Democracy depends on the active participation of",
|
||||
"The train arrived at the station exactly",
|
||||
"Researchers at MIT have developed a new",
|
||||
"The smartphone has become an essential part of",
|
||||
"After careful consideration, the committee decided to",
|
||||
];
|
||||
|
||||
println!("[");
|
||||
for (i, prompt) in prompts.iter().enumerate() {
|
||||
let input_ids = tokenizer.encode(prompt);
|
||||
let input_len = input_ids.len();
|
||||
|
||||
let (generated_ids, ttft_us, token_times_us) = if use_cache {
|
||||
generate_with_cache(&model, &config, &tokenizer, &input_ids, gen_tokens)
|
||||
} else {
|
||||
generate_no_cache(&model, &tokenizer, &input_ids, gen_tokens)
|
||||
};
|
||||
|
||||
let num_generated = generated_ids.len();
|
||||
let generated_text = tokenizer.decode(&generated_ids);
|
||||
|
||||
let tbt_us = if !token_times_us.is_empty() {
|
||||
token_times_us.iter().sum::<u128>() / token_times_us.len() as u128
|
||||
} else { 0 };
|
||||
let total_gen_us: u128 = ttft_us + token_times_us.iter().sum::<u128>();
|
||||
let tpot_us = if num_generated > 0 { total_gen_us / num_generated as u128 } else { 0 };
|
||||
|
||||
let gen_text_escaped = generated_text
|
||||
.replace('\\', "\\\\")
|
||||
.replace('"', "\\\"")
|
||||
.replace('\n', "\\n")
|
||||
.replace('\r', "\\r")
|
||||
.replace('\t', "\\t");
|
||||
let gen_ids_str: Vec<String> = generated_ids.iter().map(|id| id.to_string()).collect();
|
||||
|
||||
print!(" {{\"prompt\": \"{}\", ", prompt.replace('"', "\\\""));
|
||||
print!("\"input_len\": {input_len}, ");
|
||||
print!("\"num_generated\": {num_generated}, ");
|
||||
print!("\"generated_ids\": [{}], ", gen_ids_str.join(", "));
|
||||
print!("\"generated_text\": \"{gen_text_escaped}\", ");
|
||||
print!("\"ttft_us\": {ttft_us}, ");
|
||||
print!("\"tbt_us\": {tbt_us}, ");
|
||||
print!("\"tpot_us\": {tpot_us}}}");
|
||||
if i < prompts.len() - 1 { println!(","); } else { println!(); }
|
||||
|
||||
eprintln!(
|
||||
"[{}/{}] input={input_len}tok gen={num_generated}tok ttft={:.1}ms tbt={:.1}ms | {}",
|
||||
i + 1, prompts.len(),
|
||||
ttft_us as f64 / 1000.0,
|
||||
tbt_us as f64 / 1000.0,
|
||||
&generated_text.replace('\n', " ")[..generated_text.len().min(60)]
|
||||
);
|
||||
}
|
||||
println!("]");
|
||||
}
|
||||
|
||||
fn generate_with_cache(
|
||||
model: &GPT2, config: &ModelConfig, tokenizer: &Tokenizer,
|
||||
input_ids: &[u32], gen_tokens: usize,
|
||||
) -> (Vec<u32>, u128, Vec<u128>) {
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(), config.num_heads(), config.head_dim(),
|
||||
xserv_tensor::DType::F32, Device::Cuda(0),
|
||||
);
|
||||
|
||||
// Prefill
|
||||
let t0 = Instant::now();
|
||||
let logits = model.forward_with_cache(input_ids, &mut cache);
|
||||
let first_token = sample_greedy(&logits);
|
||||
let ttft_us = t0.elapsed().as_micros();
|
||||
|
||||
let mut generated = vec![first_token];
|
||||
let mut token_times = Vec::new();
|
||||
|
||||
// Decode
|
||||
for _ in 1..gen_tokens {
|
||||
let last = *generated.last().unwrap();
|
||||
let t_start = Instant::now();
|
||||
let logits = model.forward_with_cache(&[last], &mut cache);
|
||||
let next = sample_greedy(&logits);
|
||||
token_times.push(t_start.elapsed().as_micros());
|
||||
generated.push(next);
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
}
|
||||
|
||||
(generated, ttft_us, token_times)
|
||||
}
|
||||
|
||||
fn generate_no_cache(
|
||||
model: &GPT2, tokenizer: &Tokenizer,
|
||||
input_ids: &[u32], gen_tokens: usize,
|
||||
) -> (Vec<u32>, u128, Vec<u128>) {
|
||||
let mut all_ids = input_ids.to_vec();
|
||||
|
||||
let t0 = Instant::now();
|
||||
let logits = model.forward(&all_ids);
|
||||
let first_token = sample_greedy(&logits);
|
||||
let ttft_us = t0.elapsed().as_micros();
|
||||
all_ids.push(first_token);
|
||||
|
||||
let mut generated = vec![first_token];
|
||||
let mut token_times = Vec::new();
|
||||
|
||||
for _ in 1..gen_tokens {
|
||||
let t_start = Instant::now();
|
||||
let logits = model.forward(&all_ids);
|
||||
let next = sample_greedy(&logits);
|
||||
token_times.push(t_start.elapsed().as_micros());
|
||||
all_ids.push(next);
|
||||
generated.push(next);
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
}
|
||||
|
||||
(generated, ttft_us, token_times)
|
||||
}
|
||||
160
crates/xserv-model/src/bin/bench-qwen3.rs
Normal file
160
crates/xserv-model/src/bin/bench-qwen3.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
use xserv_model::qwen3::sample_greedy;
|
||||
use xserv_model::{loader, KVCache, ModelConfig, Qwen3};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: bench-qwen3 <model-dir> [--gen-tokens N]");
|
||||
std::process::exit(1);
|
||||
}
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let gen_tokens: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--gen-tokens")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(20);
|
||||
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
eprintln!("Loading Qwen3-8B weights...");
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cuda(0));
|
||||
eprintln!("Loaded {} tensors", weights.len());
|
||||
let model = Qwen3::from_weights(config.clone(), weights);
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
|
||||
// Warmup
|
||||
{
|
||||
let ids = tokenizer.encode("warmup");
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(), config.num_kv_heads(), config.head_dim(),
|
||||
DType::BF16, Device::Cuda(0),
|
||||
);
|
||||
let _ = model.forward_with_cache(&ids, &mut cache);
|
||||
}
|
||||
eprintln!("Warmup done. Running benchmark...");
|
||||
|
||||
let prompts: Vec<&str> = vec![
|
||||
"The capital of France is",
|
||||
"Once upon a time in a land far away",
|
||||
"Hello, how are you doing today",
|
||||
"In a shocking finding, scientists discovered a",
|
||||
"The weather today is sunny, so I decided to",
|
||||
"Alan Turing was a British mathematician who",
|
||||
"The best way to learn programming is",
|
||||
"Artificial intelligence will change the world because",
|
||||
"The history of the internet began in the",
|
||||
"A good morning routine starts with",
|
||||
"The stock market crashed because investors",
|
||||
"Deep learning is a subset of machine learning that",
|
||||
"The president of the United States announced",
|
||||
"In the year 2050, humans will",
|
||||
"The secret to happiness is",
|
||||
"When I was a child, I used to",
|
||||
"The most important scientific discovery of the century",
|
||||
"Climate change is caused by",
|
||||
"The recipe for chocolate cake requires",
|
||||
"In conclusion, the evidence suggests that",
|
||||
"The cat sat on the mat and",
|
||||
"According to recent studies, exercise can",
|
||||
"The first step in solving any problem is",
|
||||
"Technology has transformed the way we",
|
||||
"The novel begins with the protagonist",
|
||||
"Education is the most powerful weapon",
|
||||
"The ocean covers more than seventy percent of",
|
||||
"Last night I had a dream about",
|
||||
"The company announced its quarterly earnings",
|
||||
"Music has the power to",
|
||||
"The difference between success and failure is",
|
||||
"In the beginning, there was nothing but",
|
||||
"The doctor told me that I should",
|
||||
"Python is a popular programming language because",
|
||||
"The ancient Romans built roads that",
|
||||
"A balanced diet should include",
|
||||
"The movie received mixed reviews from critics",
|
||||
"Space exploration has led to many",
|
||||
"The teacher asked the students to",
|
||||
"Global warming is one of the most",
|
||||
"The bridge collapsed due to structural",
|
||||
"Quantum computing promises to revolutionize",
|
||||
"The new policy will affect millions of",
|
||||
"During the winter months, it is important to",
|
||||
"The human brain contains approximately",
|
||||
"Democracy depends on the active participation of",
|
||||
"The train arrived at the station exactly",
|
||||
"Researchers at MIT have developed a new",
|
||||
"The smartphone has become an essential part of",
|
||||
"After careful consideration, the committee decided to",
|
||||
];
|
||||
|
||||
println!("[");
|
||||
for (i, prompt) in prompts.iter().enumerate() {
|
||||
let input_ids = tokenizer.encode(prompt);
|
||||
let input_len = input_ids.len();
|
||||
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(), config.num_kv_heads(), config.head_dim(),
|
||||
DType::BF16, Device::Cuda(0),
|
||||
);
|
||||
|
||||
// Prefill
|
||||
let t0 = Instant::now();
|
||||
let logits = model.forward_with_cache(&input_ids, &mut cache);
|
||||
let first_token = sample_greedy(&logits);
|
||||
let ttft_us = t0.elapsed().as_micros();
|
||||
|
||||
let mut generated = vec![first_token];
|
||||
let mut token_times = Vec::new();
|
||||
|
||||
// Decode
|
||||
for _ in 1..gen_tokens {
|
||||
let last = *generated.last().unwrap();
|
||||
let t_start = Instant::now();
|
||||
let logits = model.forward_with_cache(&[last], &mut cache);
|
||||
let next = sample_greedy(&logits);
|
||||
token_times.push(t_start.elapsed().as_micros());
|
||||
generated.push(next);
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
}
|
||||
|
||||
let num_generated = generated.len();
|
||||
let generated_text = tokenizer.decode(&generated);
|
||||
let tbt_us = if !token_times.is_empty() {
|
||||
token_times.iter().sum::<u128>() / token_times.len() as u128
|
||||
} else { 0 };
|
||||
let total_gen_us: u128 = ttft_us + token_times.iter().sum::<u128>();
|
||||
let tpot_us = if num_generated > 0 { total_gen_us / num_generated as u128 } else { 0 };
|
||||
|
||||
let gen_text_escaped = generated_text
|
||||
.replace('\\', "\\\\")
|
||||
.replace('"', "\\\"")
|
||||
.replace('\n', "\\n")
|
||||
.replace('\r', "\\r")
|
||||
.replace('\t', "\\t");
|
||||
let gen_ids_str: Vec<String> = generated.iter().map(|id| id.to_string()).collect();
|
||||
|
||||
print!(" {{\"prompt\": \"{}\", ", prompt.replace('"', "\\\""));
|
||||
print!("\"input_len\": {input_len}, ");
|
||||
print!("\"num_generated\": {num_generated}, ");
|
||||
print!("\"generated_ids\": [{}], ", gen_ids_str.join(", "));
|
||||
print!("\"generated_text\": \"{gen_text_escaped}\", ");
|
||||
print!("\"ttft_us\": {ttft_us}, ");
|
||||
print!("\"tbt_us\": {tbt_us}, ");
|
||||
print!("\"tpot_us\": {tpot_us}}}");
|
||||
if i < prompts.len() - 1 { println!(","); } else { println!(); }
|
||||
|
||||
eprintln!(
|
||||
"[{}/{}] input={input_len}tok gen={num_generated}tok ttft={:.1}ms tbt={:.1}ms | {}",
|
||||
i + 1, prompts.len(),
|
||||
ttft_us as f64 / 1000.0,
|
||||
tbt_us as f64 / 1000.0,
|
||||
&generated_text.replace('\n', " ")[..generated_text.len().min(60)]
|
||||
);
|
||||
}
|
||||
println!("]");
|
||||
}
|
||||
44
crates/xserv-model/src/bin/dump-logits.rs
Normal file
44
crates/xserv-model/src/bin/dump-logits.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use std::path::PathBuf;
|
||||
use xserv_model::{loader, KVCache, ModelConfig, Qwen3};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
use half::bf16;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let prompt = &args[2];
|
||||
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cuda(0));
|
||||
let model = Qwen3::from_weights(config.clone(), weights);
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
|
||||
let token_ids = tokenizer.encode(prompt);
|
||||
eprintln!("Prompt: {prompt}");
|
||||
eprintln!("Token IDs: {token_ids:?}");
|
||||
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(), config.num_kv_heads(), config.head_dim(),
|
||||
DType::BF16, Device::Cuda(0),
|
||||
);
|
||||
let logits = model.forward_with_cache(&token_ids, &mut cache);
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let data = logits_cpu.as_slice::<bf16>();
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
|
||||
// Print top-20 logits for the last position
|
||||
let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
let mut indexed: Vec<(usize, f32)> = last_row.iter().enumerate()
|
||||
.map(|(i, v)| (i, v.to_f32()))
|
||||
.collect();
|
||||
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
|
||||
println!("Top-20 logits (last position):");
|
||||
for (rank, (id, val)) in indexed.iter().take(20).enumerate() {
|
||||
let tok = tokenizer.decode(&[*id as u32]);
|
||||
println!(" [{rank:>2}] id={id:>6} logit={val:>10.4} token={tok:?}");
|
||||
}
|
||||
}
|
||||
101
crates/xserv-model/src/bin/xserv-cli.rs
Normal file
101
crates/xserv-model/src/bin/xserv-cli.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use std::io::{self, Write};
|
||||
use std::path::PathBuf;
|
||||
use xserv_model::{loader, KVCache, ModelConfig};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: xserv-cli <model-dir> [--max-tokens N]");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let max_tokens: usize = args
|
||||
.iter()
|
||||
.position(|a| a == "--max-tokens")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(100);
|
||||
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
let info = xserv_cuda::device::device_info(0).unwrap();
|
||||
eprintln!("GPU: {} ({} MB free)", info.name, info.free_memory / 1024 / 1024);
|
||||
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
let model_type = config.model_type.as_deref().unwrap_or("unknown");
|
||||
eprintln!(
|
||||
"Model: {model_type}, layers={}, hidden={}, heads={}/{} kv, vocab={}",
|
||||
config.num_layers(), config.hidden(), config.num_heads(),
|
||||
config.num_kv_heads(), config.vocab_size
|
||||
);
|
||||
|
||||
eprintln!("Loading weights...");
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cuda(0));
|
||||
eprintln!("Loaded {} tensors", weights.len());
|
||||
|
||||
let is_qwen3 = model_type.contains("qwen");
|
||||
let dtype = if is_qwen3 { DType::BF16 } else { DType::F32 };
|
||||
|
||||
// Build model
|
||||
enum Model {
|
||||
GPT2(xserv_model::GPT2),
|
||||
Qwen3(xserv_model::Qwen3),
|
||||
}
|
||||
let model = if is_qwen3 {
|
||||
Model::Qwen3(xserv_model::Qwen3::from_weights(config.clone(), weights))
|
||||
} else {
|
||||
Model::GPT2(xserv_model::GPT2::from_weights(config.clone(), weights))
|
||||
};
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
eprintln!("Ready (KV cache, dtype={dtype}).\n");
|
||||
|
||||
loop {
|
||||
print!("xserv> ");
|
||||
io::stdout().flush().unwrap();
|
||||
let mut input = String::new();
|
||||
if io::stdin().read_line(&mut input).unwrap() == 0 { break; }
|
||||
let input = input.trim();
|
||||
if input.is_empty() { continue; }
|
||||
if input == "quit" || input == "exit" { break; }
|
||||
|
||||
let token_ids = tokenizer.encode(input);
|
||||
let kv_heads = if is_qwen3 { config.num_kv_heads() } else { config.num_heads() };
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(), kv_heads, config.head_dim(), dtype, Device::Cuda(0),
|
||||
);
|
||||
|
||||
// Prefill + decode
|
||||
let logits = match &model {
|
||||
Model::GPT2(m) => m.forward_with_cache(&token_ids, &mut cache),
|
||||
Model::Qwen3(m) => m.forward_with_cache(&token_ids, &mut cache),
|
||||
};
|
||||
let mut next = match &model {
|
||||
Model::GPT2(_) => xserv_model::gpt2::sample_greedy(&logits),
|
||||
Model::Qwen3(_) => xserv_model::qwen3::sample_greedy(&logits),
|
||||
};
|
||||
|
||||
print!("{input}");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
for _ in 0..max_tokens {
|
||||
let text = tokenizer.decode(&[next]);
|
||||
print!("{text}");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
|
||||
let logits = match &model {
|
||||
Model::GPT2(m) => m.forward_with_cache(&[next], &mut cache),
|
||||
Model::Qwen3(m) => m.forward_with_cache(&[next], &mut cache),
|
||||
};
|
||||
next = match &model {
|
||||
Model::GPT2(_) => xserv_model::gpt2::sample_greedy(&logits),
|
||||
Model::Qwen3(_) => xserv_model::qwen3::sample_greedy(&logits),
|
||||
};
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
96
crates/xserv-model/src/config.rs
Normal file
96
crates/xserv-model/src/config.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
use serde::Deserialize;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelConfig {
|
||||
pub architectures: Option<Vec<String>>,
|
||||
pub model_type: Option<String>,
|
||||
|
||||
// Modern HF naming
|
||||
#[serde(default)]
|
||||
pub hidden_size: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub intermediate_size: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub num_attention_heads: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub num_key_value_heads: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub num_hidden_layers: Option<usize>,
|
||||
pub vocab_size: usize,
|
||||
#[serde(default)]
|
||||
pub max_position_embeddings: Option<usize>,
|
||||
|
||||
// GPT-2 naming
|
||||
#[serde(default)]
|
||||
pub n_embd: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub n_head: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub n_layer: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub n_positions: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub n_inner: Option<usize>,
|
||||
|
||||
// Normalization
|
||||
#[serde(default)]
|
||||
pub layer_norm_eps: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub layer_norm_epsilon: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub rms_norm_eps: Option<f64>,
|
||||
|
||||
// Other
|
||||
#[serde(default)]
|
||||
pub rope_theta: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub tie_word_embeddings: Option<bool>,
|
||||
}
|
||||
|
||||
impl ModelConfig {
|
||||
pub fn from_file(path: &Path) -> Self {
|
||||
let data = std::fs::read_to_string(path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
|
||||
serde_json::from_str(&data)
|
||||
.unwrap_or_else(|e| panic!("failed to parse {}: {e}", path.display()))
|
||||
}
|
||||
|
||||
pub fn hidden(&self) -> usize {
|
||||
self.hidden_size.or(self.n_embd).expect("hidden_size or n_embd required")
|
||||
}
|
||||
|
||||
pub fn num_heads(&self) -> usize {
|
||||
self.num_attention_heads.or(self.n_head).expect("num_attention_heads or n_head required")
|
||||
}
|
||||
|
||||
pub fn num_layers(&self) -> usize {
|
||||
self.num_hidden_layers.or(self.n_layer).expect("num_hidden_layers or n_layer required")
|
||||
}
|
||||
|
||||
pub fn max_seq_len(&self) -> usize {
|
||||
self.max_position_embeddings.or(self.n_positions).unwrap_or(2048)
|
||||
}
|
||||
|
||||
pub fn ffn_hidden(&self) -> usize {
|
||||
self.intermediate_size.or(self.n_inner).unwrap_or(self.hidden() * 4)
|
||||
}
|
||||
|
||||
pub fn num_kv_heads(&self) -> usize {
|
||||
self.num_key_value_heads.unwrap_or(self.num_heads())
|
||||
}
|
||||
|
||||
pub fn head_dim(&self) -> usize {
|
||||
self.hidden() / self.num_heads()
|
||||
}
|
||||
|
||||
pub fn ln_eps(&self) -> f32 {
|
||||
self.layer_norm_eps
|
||||
.or(self.layer_norm_epsilon)
|
||||
.unwrap_or(1e-5) as f32
|
||||
}
|
||||
|
||||
pub fn tied_embeddings(&self) -> bool {
|
||||
self.tie_word_embeddings.unwrap_or(true)
|
||||
}
|
||||
}
|
||||
336
crates/xserv-model/src/gpt2.rs
Normal file
336
crates/xserv-model/src/gpt2.rs
Normal file
@@ -0,0 +1,336 @@
|
||||
use std::collections::HashMap;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
use crate::config::ModelConfig;
|
||||
|
||||
pub struct GPT2 {
|
||||
pub config: ModelConfig,
|
||||
wte: Tensor,
|
||||
wpe: Tensor,
|
||||
layers: Vec<GPT2Block>,
|
||||
ln_f_g: Tensor,
|
||||
ln_f_b: Tensor,
|
||||
lm_head: Tensor, // precomputed wte^T
|
||||
}
|
||||
|
||||
struct GPT2Block {
|
||||
ln_1_g: Tensor,
|
||||
ln_1_b: Tensor,
|
||||
attn_qkv_w: Tensor,
|
||||
attn_qkv_b: Tensor,
|
||||
attn_out_w: Tensor,
|
||||
attn_out_b: Tensor,
|
||||
ln_2_g: Tensor,
|
||||
ln_2_b: Tensor,
|
||||
mlp_fc_w: Tensor,
|
||||
mlp_fc_b: Tensor,
|
||||
mlp_proj_w: Tensor,
|
||||
mlp_proj_b: Tensor,
|
||||
}
|
||||
|
||||
pub struct KVCache {
|
||||
// Per layer, per head: raw bytes (works for both f32 and bf16)
|
||||
k: Vec<Vec<Vec<u8>>>, // [num_layers][num_heads][seq_len * head_dim * elem_size]
|
||||
v: Vec<Vec<Vec<u8>>>,
|
||||
len: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
elem_size: usize,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
pub fn new(num_layers: usize, num_heads: usize, head_dim: usize, dtype: DType, device: Device) -> Self {
|
||||
Self {
|
||||
k: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(),
|
||||
v: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(),
|
||||
len: 0,
|
||||
num_heads,
|
||||
head_dim,
|
||||
elem_size: dtype.size_bytes(),
|
||||
dtype,
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn seq_len(&self) -> usize { self.len }
|
||||
|
||||
/// Append from a CPU tensor with shape [1, H, new_tokens, D].
|
||||
pub fn append_kv_tensor(&mut self, layer: usize, k_cpu: &Tensor, v_cpu: &Tensor, new_tokens: usize) {
|
||||
let hd = self.head_dim;
|
||||
let es = self.elem_size;
|
||||
let k_bytes = k_cpu.storage().as_cpu_bytes();
|
||||
let v_bytes = v_cpu.storage().as_cpu_bytes();
|
||||
let chunk = new_tokens * hd * es;
|
||||
for h in 0..self.num_heads {
|
||||
let off = h * chunk;
|
||||
self.k[layer][h].extend_from_slice(&k_bytes[off..off + chunk]);
|
||||
self.v[layer][h].extend_from_slice(&v_bytes[off..off + chunk]);
|
||||
}
|
||||
if layer == 0 {
|
||||
self.len += new_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
/// Reconstruct [1, H, seq_len, D] tensors.
|
||||
pub fn get_kv_tensors(&self, layer: usize) -> (Tensor, Tensor) {
|
||||
let sl = self.len;
|
||||
let hd = self.head_dim;
|
||||
let nh = self.num_heads;
|
||||
let es = self.elem_size;
|
||||
let head_bytes = sl * hd * es;
|
||||
let total = nh * head_bytes;
|
||||
let mut k_data = vec![0u8; total];
|
||||
let mut v_data = vec![0u8; total];
|
||||
for h in 0..nh {
|
||||
let off = h * head_bytes;
|
||||
k_data[off..off + head_bytes].copy_from_slice(&self.k[layer][h]);
|
||||
v_data[off..off + head_bytes].copy_from_slice(&self.v[layer][h]);
|
||||
}
|
||||
let shape = &[1, nh, sl, hd];
|
||||
let k = tensor_from_raw_bytes(&k_data, shape, self.dtype).to_device(self.device);
|
||||
let v = tensor_from_raw_bytes(&v_data, shape, self.dtype).to_device(self.device);
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
fn tensor_from_raw_bytes(bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
let data: &[f32] = unsafe {
|
||||
std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4)
|
||||
};
|
||||
Tensor::from_slice(data, shape)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data: &[half::bf16] = unsafe {
|
||||
std::slice::from_raw_parts(bytes.as_ptr() as *const half::bf16, bytes.len() / 2)
|
||||
};
|
||||
Tensor::from_slice(data, shape)
|
||||
}
|
||||
_ => panic!("unsupported dtype for KV cache"),
|
||||
}
|
||||
}
|
||||
|
||||
impl GPT2 {
|
||||
pub fn from_weights(config: ModelConfig, mut w: HashMap<String, Tensor>) -> Self {
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
|
||||
let wte = take(&mut w, "wte.weight");
|
||||
let wpe = take(&mut w, "wpe.weight");
|
||||
let ln_f_g = take(&mut w, "ln_f.weight");
|
||||
let ln_f_b = take(&mut w, "ln_f.bias");
|
||||
let lm_head = wte.transpose(0, 1).contiguous();
|
||||
|
||||
let num_layers = config.num_layers();
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
for i in 0..num_layers {
|
||||
let p = format!("h.{i}");
|
||||
layers.push(GPT2Block {
|
||||
ln_1_g: take(&mut w, &format!("{p}.ln_1.weight")),
|
||||
ln_1_b: take(&mut w, &format!("{p}.ln_1.bias")),
|
||||
attn_qkv_w: take(&mut w, &format!("{p}.attn.c_attn.weight")),
|
||||
attn_qkv_b: take(&mut w, &format!("{p}.attn.c_attn.bias")),
|
||||
attn_out_w: take(&mut w, &format!("{p}.attn.c_proj.weight")),
|
||||
attn_out_b: take(&mut w, &format!("{p}.attn.c_proj.bias")),
|
||||
ln_2_g: take(&mut w, &format!("{p}.ln_2.weight")),
|
||||
ln_2_b: take(&mut w, &format!("{p}.ln_2.bias")),
|
||||
mlp_fc_w: take(&mut w, &format!("{p}.mlp.c_fc.weight")),
|
||||
mlp_fc_b: take(&mut w, &format!("{p}.mlp.c_fc.bias")),
|
||||
mlp_proj_w: take(&mut w, &format!("{p}.mlp.c_proj.weight")),
|
||||
mlp_proj_b: take(&mut w, &format!("{p}.mlp.c_proj.bias")),
|
||||
});
|
||||
}
|
||||
|
||||
Self { config, wte, wpe, layers, ln_f_g, ln_f_b, lm_head }
|
||||
}
|
||||
|
||||
/// Full forward pass without KV cache (for testing / correctness comparison).
|
||||
pub fn forward(&self, token_ids: &[u32]) -> Tensor {
|
||||
let seq_len = token_ids.len();
|
||||
let hidden = self.config.hidden();
|
||||
let num_heads = self.config.num_heads();
|
||||
let head_dim = self.config.head_dim();
|
||||
|
||||
let tok_emb = embedding(&self.wte, token_ids);
|
||||
let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
|
||||
let pos_emb = embedding(&self.wpe, &pos_ids);
|
||||
let mut x = add_tensors(&tok_emb, &pos_emb);
|
||||
|
||||
for layer in &self.layers {
|
||||
x = self.transformer_block(layer, &x, None, 0, seq_len, num_heads, head_dim, hidden);
|
||||
}
|
||||
|
||||
let x = layernorm(&x, &self.ln_f_g, &self.ln_f_b, self.config.ln_eps());
|
||||
matmul_2d(&x, &self.lm_head)
|
||||
}
|
||||
|
||||
/// Forward pass with KV cache. First call = prefill, subsequent = decode.
|
||||
pub fn forward_with_cache(&self, token_ids: &[u32], cache: &mut KVCache) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
let pos_offset = cache.seq_len();
|
||||
let hidden = self.config.hidden();
|
||||
let num_heads = self.config.num_heads();
|
||||
let head_dim = self.config.head_dim();
|
||||
|
||||
let tok_emb = embedding(&self.wte, token_ids);
|
||||
let pos_ids: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
let pos_emb = embedding(&self.wpe, &pos_ids);
|
||||
let mut x = add_tensors(&tok_emb, &pos_emb);
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
x = self.transformer_block(
|
||||
layer, &x, Some((cache, layer_idx)),
|
||||
pos_offset, new_tokens, num_heads, head_dim, hidden,
|
||||
);
|
||||
}
|
||||
|
||||
let x = layernorm(&x, &self.ln_f_g, &self.ln_f_b, self.config.ln_eps());
|
||||
matmul_2d(&x, &self.lm_head)
|
||||
}
|
||||
|
||||
fn transformer_block(
|
||||
&self,
|
||||
layer: &GPT2Block,
|
||||
x: &Tensor,
|
||||
cache: Option<(&mut KVCache, usize)>,
|
||||
pos_offset: usize,
|
||||
new_tokens: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
hidden: usize,
|
||||
) -> Tensor {
|
||||
let residual = x.clone();
|
||||
let normed = layernorm(x, &layer.ln_1_g, &layer.ln_1_b, self.config.ln_eps());
|
||||
|
||||
let qkv = linear(&normed, &layer.attn_qkv_w, Some(&layer.attn_qkv_b));
|
||||
let (q, k_new, v_new) = split_qkv(&qkv, num_heads, head_dim, new_tokens);
|
||||
|
||||
let (k_full, v_full) = if let Some((cache, layer_idx)) = cache {
|
||||
let k_cpu = k_new.to_device(Device::Cpu);
|
||||
let v_cpu = v_new.to_device(Device::Cpu);
|
||||
cache.append_kv_tensor(layer_idx, &k_cpu, &v_cpu, new_tokens);
|
||||
cache.get_kv_tensors(layer_idx)
|
||||
} else {
|
||||
(k_new, v_new)
|
||||
};
|
||||
|
||||
let attn_out = attention(&q, &k_full, &v_full, true);
|
||||
let attn_out = merge_heads(&attn_out, new_tokens, hidden);
|
||||
let attn_out = linear(&attn_out, &layer.attn_out_w, Some(&layer.attn_out_b));
|
||||
let x = add_tensors(&residual, &attn_out);
|
||||
|
||||
let residual = x.clone();
|
||||
let normed = layernorm(&x, &layer.ln_2_g, &layer.ln_2_b, self.config.ln_eps());
|
||||
let fc = linear(&normed, &layer.mlp_fc_w, Some(&layer.mlp_fc_b));
|
||||
let activated = gelu(&fc);
|
||||
let proj = linear(&activated, &layer.mlp_proj_w, Some(&layer.mlp_proj_b));
|
||||
add_tensors(&residual, &proj)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper ops (unchanged) ---
|
||||
|
||||
fn linear(x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Tensor {
|
||||
let out = matmul_2d(x, weight);
|
||||
if let Some(b) = bias { add_bias(&out, b) } else { out }
|
||||
}
|
||||
|
||||
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert_eq!(b.ndim(), 2);
|
||||
matmul(a, b, GemmBackend::CuBlas)
|
||||
}
|
||||
|
||||
fn add_tensors(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
xserv_kernels::add(a, b)
|
||||
}
|
||||
|
||||
fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
|
||||
// bias: [N], x: [S, N] — broadcast add via reshape
|
||||
assert_eq!(x.ndim(), 2);
|
||||
assert_eq!(bias.ndim(), 1);
|
||||
let n = bias.shape()[0];
|
||||
assert_eq!(x.shape()[1], n);
|
||||
let rows = x.shape()[0];
|
||||
// Broadcast: tile bias to [S, N] on CPU, then GPU add
|
||||
let b_cpu = bias.to_device(Device::Cpu);
|
||||
match x.dtype() {
|
||||
DType::F32 => {
|
||||
let bd = b_cpu.as_slice::<f32>();
|
||||
let tiled: Vec<f32> = (0..rows).flat_map(|_| bd.iter().copied()).collect();
|
||||
let b_full = Tensor::from_slice(&tiled, x.shape()).to_device(x.device());
|
||||
xserv_kernels::add(x, &b_full)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let bd = b_cpu.as_slice::<half::bf16>();
|
||||
let tiled: Vec<half::bf16> = (0..rows).flat_map(|_| bd.iter().copied()).collect();
|
||||
let b_full = Tensor::from_slice(&tiled, x.shape()).to_device(x.device());
|
||||
xserv_kernels::add(x, &b_full)
|
||||
}
|
||||
_ => panic!("unsupported dtype"),
|
||||
}
|
||||
}
|
||||
|
||||
fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) -> (Tensor, Tensor, Tensor) {
|
||||
let hidden = num_heads * head_dim;
|
||||
let qkv_cpu = qkv.to_device(Device::Cpu);
|
||||
let data = qkv_cpu.as_slice::<f32>();
|
||||
|
||||
let mut q_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
let mut k_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
let mut v_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
|
||||
for s in 0..seq_len {
|
||||
let row = &data[s * 3 * hidden..(s + 1) * 3 * hidden];
|
||||
for h in 0..num_heads {
|
||||
let src_off = h * head_dim;
|
||||
let dst_off = (h * seq_len + s) * head_dim;
|
||||
q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
|
||||
let device = qkv.device();
|
||||
let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
(q, k, v)
|
||||
}
|
||||
|
||||
fn merge_heads(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor {
|
||||
let num_heads = x.shape()[1];
|
||||
let head_dim = x.shape()[3];
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let src = x_cpu.as_slice::<f32>();
|
||||
|
||||
let mut out = vec![0.0f32; seq_len * hidden];
|
||||
for s in 0..seq_len {
|
||||
for h in 0..num_heads {
|
||||
let src_off = (h * seq_len + s) * head_dim;
|
||||
let dst_off = s * hidden + h * head_dim;
|
||||
out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(x.device())
|
||||
}
|
||||
|
||||
/// Greedy sampling: return the argmax token ID from the last position's logits.
|
||||
pub fn sample_greedy(logits: &Tensor) -> u32 {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let data = logits_cpu.as_slice::<f32>();
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
last_row.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.map(|(idx, _)| idx as u32)
|
||||
.unwrap()
|
||||
}
|
||||
8
crates/xserv-model/src/lib.rs
Normal file
8
crates/xserv-model/src/lib.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
pub mod config;
|
||||
pub mod gpt2;
|
||||
pub mod loader;
|
||||
pub mod qwen3;
|
||||
|
||||
pub use config::ModelConfig;
|
||||
pub use gpt2::{GPT2, KVCache};
|
||||
pub use qwen3::Qwen3;
|
||||
87
crates/xserv-model/src/loader.rs
Normal file
87
crates/xserv-model/src/loader.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use half::{bf16, f16};
|
||||
use safetensors::SafeTensors;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
pub fn load_safetensors(path: &Path, device: Device) -> HashMap<String, Tensor> {
|
||||
let data = std::fs::read(path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
|
||||
let st = SafeTensors::deserialize(&data)
|
||||
.unwrap_or_else(|e| panic!("failed to parse safetensors {}: {e}", path.display()));
|
||||
|
||||
let mut tensors = HashMap::new();
|
||||
|
||||
for (name, view) in st.tensors() {
|
||||
let shape: Vec<usize> = view.shape().to_vec();
|
||||
let raw_bytes = view.data();
|
||||
let dtype = match view.dtype() {
|
||||
safetensors::Dtype::F32 => DType::F32,
|
||||
safetensors::Dtype::F16 => DType::F16,
|
||||
safetensors::Dtype::BF16 => DType::BF16,
|
||||
other => {
|
||||
eprintln!("skipping tensor {name}: unsupported dtype {other:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let tensor = make_tensor(raw_bytes, &shape, dtype);
|
||||
let tensor = tensor.to_device(device);
|
||||
tensors.insert(name.to_string(), tensor);
|
||||
}
|
||||
|
||||
tensors
|
||||
}
|
||||
|
||||
/// Load from a directory containing model.safetensors (or sharded files) + config.json.
|
||||
pub fn load_model_dir(dir: &Path, device: Device) -> HashMap<String, Tensor> {
|
||||
let single = dir.join("model.safetensors");
|
||||
if single.exists() {
|
||||
return load_safetensors(&single, device);
|
||||
}
|
||||
|
||||
// Try sharded: model-00001-of-NNNNN.safetensors
|
||||
let mut all_tensors = HashMap::new();
|
||||
let mut entries: Vec<_> = std::fs::read_dir(dir)
|
||||
.unwrap()
|
||||
.filter_map(|e| e.ok())
|
||||
.filter(|e| {
|
||||
e.path()
|
||||
.file_name()
|
||||
.map(|f| f.to_string_lossy().ends_with(".safetensors"))
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect();
|
||||
entries.sort_by_key(|e| e.file_name());
|
||||
|
||||
for entry in entries {
|
||||
let tensors = load_safetensors(&entry.path(), device);
|
||||
all_tensors.extend(tensors);
|
||||
}
|
||||
|
||||
assert!(!all_tensors.is_empty(), "no safetensors files found in {}", dir.display());
|
||||
all_tensors
|
||||
}
|
||||
|
||||
fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
let floats: &[f32] = unsafe {
|
||||
std::slice::from_raw_parts(raw_bytes.as_ptr() as *const f32, raw_bytes.len() / 4)
|
||||
};
|
||||
Tensor::from_slice(floats, shape)
|
||||
}
|
||||
DType::F16 => {
|
||||
let halfs: &[f16] = unsafe {
|
||||
std::slice::from_raw_parts(raw_bytes.as_ptr() as *const f16, raw_bytes.len() / 2)
|
||||
};
|
||||
Tensor::from_slice(halfs, shape)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let bfs: &[bf16] = unsafe {
|
||||
std::slice::from_raw_parts(raw_bytes.as_ptr() as *const bf16, raw_bytes.len() / 2)
|
||||
};
|
||||
Tensor::from_slice(bfs, shape)
|
||||
}
|
||||
}
|
||||
}
|
||||
270
crates/xserv-model/src/qwen3.rs
Normal file
270
crates/xserv-model/src/qwen3.rs
Normal file
@@ -0,0 +1,270 @@
|
||||
use std::collections::HashMap;
|
||||
use half::bf16;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
use crate::config::ModelConfig;
|
||||
use crate::gpt2::KVCache;
|
||||
|
||||
pub struct Qwen3 {
|
||||
pub config: ModelConfig,
|
||||
embed_tokens: Tensor,
|
||||
layers: Vec<Qwen3Block>,
|
||||
norm: Tensor,
|
||||
lm_head_t: Tensor, // precomputed transpose
|
||||
rope_cache: RopeCache,
|
||||
}
|
||||
|
||||
struct Qwen3Block {
|
||||
input_norm: Tensor, // [hidden]
|
||||
q_proj_wt: Tensor, // TRANSPOSED: [hidden, num_heads*head_dim]
|
||||
k_proj_wt: Tensor, // TRANSPOSED: [hidden, num_kv_heads*head_dim]
|
||||
v_proj_wt: Tensor,
|
||||
o_proj_wt: Tensor, // TRANSPOSED: [num_heads*head_dim, hidden]
|
||||
q_norm: Tensor, // [head_dim]
|
||||
k_norm: Tensor, // [head_dim]
|
||||
post_norm: Tensor, // [hidden]
|
||||
gate_proj_wt: Tensor, // TRANSPOSED: [hidden, intermediate]
|
||||
up_proj_wt: Tensor,
|
||||
down_proj_wt: Tensor, // TRANSPOSED: [intermediate, hidden]
|
||||
}
|
||||
|
||||
impl Qwen3 {
|
||||
pub fn from_weights(config: ModelConfig, mut w: HashMap<String, Tensor>) -> Self {
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
|
||||
let embed_tokens = take(&mut w, "model.embed_tokens.weight");
|
||||
let norm = take(&mut w, "model.norm.weight");
|
||||
let lm_head_raw = take(&mut w, "lm_head.weight");
|
||||
|
||||
let rope_cache = RopeCache::new(
|
||||
config.max_seq_len().min(8192), // limit for memory
|
||||
config.head_dim(),
|
||||
config.rope_theta.unwrap_or(1_000_000.0) as f32,
|
||||
);
|
||||
|
||||
// Precompute transposed weights: [out, in] → [in, out] so we can do x @ wt directly
|
||||
let transpose_w = |t: Tensor| -> Tensor {
|
||||
t.transpose(0, 1).contiguous()
|
||||
};
|
||||
|
||||
let num_layers = config.num_layers();
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
eprintln!("Transposing weights for {} layers...", num_layers);
|
||||
for i in 0..num_layers {
|
||||
let p = format!("model.layers.{i}");
|
||||
layers.push(Qwen3Block {
|
||||
input_norm: take(&mut w, &format!("{p}.input_layernorm.weight")),
|
||||
q_proj_wt: transpose_w(take(&mut w, &format!("{p}.self_attn.q_proj.weight"))),
|
||||
k_proj_wt: transpose_w(take(&mut w, &format!("{p}.self_attn.k_proj.weight"))),
|
||||
v_proj_wt: transpose_w(take(&mut w, &format!("{p}.self_attn.v_proj.weight"))),
|
||||
o_proj_wt: transpose_w(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
|
||||
q_norm: take(&mut w, &format!("{p}.self_attn.q_norm.weight")),
|
||||
k_norm: take(&mut w, &format!("{p}.self_attn.k_norm.weight")),
|
||||
post_norm: take(&mut w, &format!("{p}.post_attention_layernorm.weight")),
|
||||
gate_proj_wt: transpose_w(take(&mut w, &format!("{p}.mlp.gate_proj.weight"))),
|
||||
up_proj_wt: transpose_w(take(&mut w, &format!("{p}.mlp.up_proj.weight"))),
|
||||
down_proj_wt: transpose_w(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
|
||||
});
|
||||
}
|
||||
|
||||
let lm_head_t = transpose_w(lm_head_raw);
|
||||
Self { config, embed_tokens, layers, norm, lm_head_t, rope_cache }
|
||||
}
|
||||
|
||||
pub fn forward_with_cache(&self, token_ids: &[u32], cache: &mut KVCache) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
let pos_offset = cache.seq_len();
|
||||
let hidden = self.config.hidden();
|
||||
let num_heads = self.config.num_heads();
|
||||
let num_kv_heads = self.config.num_kv_heads();
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
// Q/K/V projections (pre-transposed weights, x @ wt)
|
||||
let q = matmul_2d(&normed, &layer.q_proj_wt);
|
||||
let k = matmul_2d(&normed, &layer.k_proj_wt);
|
||||
let v = matmul_2d(&normed, &layer.v_proj_wt);
|
||||
|
||||
// Reshape to [1, heads, seq, head_dim]
|
||||
let q = reshape_heads(&q, new_tokens, num_heads, head_dim);
|
||||
let k = reshape_heads(&k, new_tokens, num_kv_heads, head_dim);
|
||||
let v = reshape_heads(&v, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
// QK normalization (per-head RMSNorm)
|
||||
let q = head_rmsnorm(&q, &layer.q_norm, eps);
|
||||
let k = head_rmsnorm(&k, &layer.k_norm, eps);
|
||||
|
||||
// RoPE — kernel expects [S, H, D], our tensors are [1, H, S, D]
|
||||
// Transpose to [1, S, H, D] → reshape to [S, H, D] for RoPE
|
||||
let q = transpose_for_rope(&q, new_tokens, num_heads, head_dim);
|
||||
let k = transpose_for_rope(&k, new_tokens, num_kv_heads, head_dim);
|
||||
rope_inplace(&q, &self.rope_cache, &positions);
|
||||
rope_inplace(&k, &self.rope_cache, &positions);
|
||||
// Transpose back to [1, H, S, D]
|
||||
let q = transpose_from_rope(&q, new_tokens, num_heads, head_dim);
|
||||
let k = transpose_from_rope(&k, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
// KV cache
|
||||
let k_cpu = k.to_device(Device::Cpu);
|
||||
let v_cpu = v.to_device(Device::Cpu);
|
||||
cache.append_kv_tensor(layer_idx, &k_cpu, &v_cpu, new_tokens);
|
||||
let (k_full, v_full) = cache.get_kv_tensors(layer_idx);
|
||||
|
||||
// GQA: repeat K/V
|
||||
let n_rep = num_heads / num_kv_heads;
|
||||
let k_full = repeat_kv(&k_full, n_rep);
|
||||
let v_full = repeat_kv(&v_full, n_rep);
|
||||
|
||||
// Attention
|
||||
let attn_out = attention(&q, &k_full, &v_full, true);
|
||||
let attn_merged = merge_heads_any(&attn_out, new_tokens, hidden);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
x = add_any(&residual, &attn_proj);
|
||||
|
||||
// SwiGLU FFN
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.post_norm, eps);
|
||||
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
|
||||
let up = matmul_2d(&normed, &layer.up_proj_wt);
|
||||
let gate_activated = silu(&gate);
|
||||
let hidden_states = mul_any(&gate_activated, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
|
||||
let x = rmsnorm(&x, &self.norm, eps);
|
||||
matmul_2d(&x, &self.lm_head_t)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert_eq!(b.ndim(), 2);
|
||||
matmul(a, b, GemmBackend::CuBlas)
|
||||
}
|
||||
|
||||
fn reshape_heads(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let hidden = num_heads * head_dim;
|
||||
let src = x_cpu.as_slice::<bf16>();
|
||||
let mut out = vec![bf16::ZERO; num_heads * seq_len * head_dim];
|
||||
for s in 0..seq_len {
|
||||
for h in 0..num_heads {
|
||||
let si = s * hidden + h * head_dim;
|
||||
let di = (h * seq_len + s) * head_dim;
|
||||
out[di..di + head_dim].copy_from_slice(&src[si..si + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[1, num_heads, seq_len, head_dim]).to_device(x.device())
|
||||
}
|
||||
|
||||
fn merge_heads_any(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor {
|
||||
let num_heads = x.shape()[1];
|
||||
let head_dim = x.shape()[3];
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let src = x_cpu.as_slice::<bf16>();
|
||||
let mut out = vec![bf16::ZERO; seq_len * hidden];
|
||||
for s in 0..seq_len {
|
||||
for h in 0..num_heads {
|
||||
let si = (h * seq_len + s) * head_dim;
|
||||
let di = s * hidden + h * head_dim;
|
||||
out[di..di + head_dim].copy_from_slice(&src[si..si + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(x.device())
|
||||
}
|
||||
|
||||
/// Per-head RMSNorm: apply RMSNorm to each [head_dim] slice independently.
|
||||
/// x: [1, H, S, D], norm_weight: [D]
|
||||
fn head_rmsnorm(x: &Tensor, norm_weight: &Tensor, eps: f32) -> Tensor {
|
||||
let num_heads = x.shape()[1];
|
||||
let seq_len = x.shape()[2];
|
||||
let head_dim = x.shape()[3];
|
||||
// Reshape to [H*S, D], apply rmsnorm, reshape back
|
||||
let total_rows = num_heads * seq_len;
|
||||
let flat = x.reshape(&[total_rows, head_dim]);
|
||||
let normed = rmsnorm(&flat, norm_weight, eps);
|
||||
normed.reshape(&[1, num_heads, seq_len, head_dim])
|
||||
}
|
||||
|
||||
/// [1, H, S, D] → [S, H, D] for RoPE kernel
|
||||
fn transpose_for_rope(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let src = x_cpu.as_slice::<bf16>();
|
||||
let mut out = vec![bf16::ZERO; seq_len * num_heads * head_dim];
|
||||
for h in 0..num_heads {
|
||||
for s in 0..seq_len {
|
||||
let si = (h * seq_len + s) * head_dim;
|
||||
let di = (s * num_heads + h) * head_dim;
|
||||
out[di..di + head_dim].copy_from_slice(&src[si..si + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, num_heads, head_dim]).to_device(x.device())
|
||||
}
|
||||
|
||||
/// [S, H, D] → [1, H, S, D] after RoPE
|
||||
fn transpose_from_rope(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let src = x_cpu.as_slice::<bf16>();
|
||||
let mut out = vec![bf16::ZERO; num_heads * seq_len * head_dim];
|
||||
for s in 0..seq_len {
|
||||
for h in 0..num_heads {
|
||||
let si = (s * num_heads + h) * head_dim;
|
||||
let di = (h * seq_len + s) * head_dim;
|
||||
out[di..di + head_dim].copy_from_slice(&src[si..si + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[1, num_heads, seq_len, head_dim]).to_device(x.device())
|
||||
}
|
||||
|
||||
fn repeat_kv(x: &Tensor, n_rep: usize) -> Tensor {
|
||||
if n_rep == 1 { return x.clone(); }
|
||||
let kv_heads = x.shape()[1];
|
||||
let seq_len = x.shape()[2];
|
||||
let head_dim = x.shape()[3];
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let src = x_cpu.as_slice::<bf16>();
|
||||
let new_heads = kv_heads * n_rep;
|
||||
let mut out = vec![bf16::ZERO; new_heads * seq_len * head_dim];
|
||||
let chunk = seq_len * head_dim;
|
||||
for kv_h in 0..kv_heads {
|
||||
for r in 0..n_rep {
|
||||
let dst_h = kv_h * n_rep + r;
|
||||
out[dst_h * chunk..(dst_h + 1) * chunk]
|
||||
.copy_from_slice(&src[kv_h * chunk..(kv_h + 1) * chunk]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[1, new_heads, seq_len, head_dim]).to_device(x.device())
|
||||
}
|
||||
|
||||
fn add_any(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
xserv_kernels::add(a, b)
|
||||
}
|
||||
|
||||
fn mul_any(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
xserv_kernels::mul(a, b)
|
||||
}
|
||||
|
||||
pub fn sample_greedy(logits: &Tensor) -> u32 {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
let data = logits_cpu.as_slice::<bf16>();
|
||||
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
last.iter().enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(i, _)| i as u32).unwrap()
|
||||
}
|
||||
@@ -137,8 +137,13 @@ impl Tensor {
|
||||
if self.is_contiguous() {
|
||||
return self.clone();
|
||||
}
|
||||
// Copy to contiguous layout on CPU
|
||||
assert_eq!(self.device(), Device::Cpu, "contiguous() on GPU not yet supported");
|
||||
// For GPU tensors: round-trip through CPU (correct but slow).
|
||||
// TODO: write a GPU contiguous-copy kernel for performance.
|
||||
if matches!(self.device(), Device::Cuda(_)) {
|
||||
let cpu = self.to_device(Device::Cpu);
|
||||
let contig = cpu.contiguous();
|
||||
return contig.to_device(self.device());
|
||||
}
|
||||
let numel = self.numel();
|
||||
let elem_size = self.dtype.size_bytes();
|
||||
let src_bytes = self.storage.as_cpu_bytes();
|
||||
@@ -173,17 +178,18 @@ impl Tensor {
|
||||
// --- Device transfer ---
|
||||
|
||||
pub fn to_device(&self, device: Device) -> Self {
|
||||
let t = if self.is_contiguous() { self.clone() } else { self.contiguous() };
|
||||
if t.device() == device {
|
||||
return t;
|
||||
if self.device() == device {
|
||||
return self.clone();
|
||||
}
|
||||
let new_storage = t.storage.to_device(device).expect("device transfer failed");
|
||||
// Transfer the raw storage (preserving strides/offset).
|
||||
// Non-contiguous layout is preserved — the user can call contiguous() after.
|
||||
let new_storage = self.storage.to_device(device).expect("device transfer failed");
|
||||
Self {
|
||||
storage: new_storage,
|
||||
shape: t.shape,
|
||||
strides: t.strides,
|
||||
offset: 0,
|
||||
dtype: t.dtype,
|
||||
shape: self.shape.clone(),
|
||||
strides: self.strides.clone(),
|
||||
offset: self.offset,
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
9
crates/xserv-tokenizer/Cargo.toml
Normal file
9
crates/xserv-tokenizer/Cargo.toml
Normal file
@@ -0,0 +1,9 @@
|
||||
[package]
|
||||
name = "xserv-tokenizer"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
regex.workspace = true
|
||||
267
crates/xserv-tokenizer/src/bpe.rs
Normal file
267
crates/xserv-tokenizer/src/bpe.rs
Normal file
@@ -0,0 +1,267 @@
|
||||
use regex::Regex;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
pub struct Tokenizer {
|
||||
encoder: HashMap<Vec<u8>, u32>,
|
||||
decoder: Vec<Vec<u8>>,
|
||||
merge_ranks: HashMap<(u32, u32), usize>,
|
||||
special_tokens: HashMap<String, u32>,
|
||||
#[allow(dead_code)]
|
||||
special_token_ids: HashMap<u32, String>,
|
||||
pre_tokenize_re: Regex,
|
||||
eos_token_id: Option<u32>,
|
||||
byte_fallback: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TokenizerJson {
|
||||
model: ModelSection,
|
||||
#[serde(default)]
|
||||
added_tokens: Vec<AddedToken>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ModelSection {
|
||||
vocab: HashMap<String, u32>,
|
||||
merges: Vec<MergeEntry>,
|
||||
#[serde(default)]
|
||||
byte_fallback: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum MergeEntry {
|
||||
Str(String),
|
||||
Pair(Vec<String>),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AddedToken {
|
||||
id: u32,
|
||||
content: String,
|
||||
special: bool,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
pub fn from_file(path: &Path) -> Self {
|
||||
let data = std::fs::read_to_string(path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
|
||||
let tj: TokenizerJson = serde_json::from_str(&data)
|
||||
.unwrap_or_else(|e| panic!("failed to parse tokenizer.json: {e}"));
|
||||
|
||||
let byte_fallback = tj.model.byte_fallback;
|
||||
|
||||
// Build encoder: token bytes → ID
|
||||
// All HF tokenizers use GPT-2 byte-to-unicode mapping for vocab keys.
|
||||
let mut encoder = HashMap::new();
|
||||
for (token_str, &id) in &tj.model.vocab {
|
||||
let bytes = token_str_to_bytes(token_str);
|
||||
encoder.insert(bytes, id);
|
||||
}
|
||||
|
||||
// Build decoder: ID → token bytes
|
||||
let max_id = tj.model.vocab.values().copied().max().unwrap_or(0);
|
||||
let added_max = tj.added_tokens.iter().map(|t| t.id).max().unwrap_or(0);
|
||||
let vocab_size = (max_id.max(added_max) + 1) as usize;
|
||||
let mut decoder = vec![vec![]; vocab_size];
|
||||
for (token_str, &id) in &tj.model.vocab {
|
||||
decoder[id as usize] = token_str_to_bytes(token_str);
|
||||
}
|
||||
|
||||
// Parse merges (supports both "a b" string format and ["a", "b"] array format)
|
||||
let byte_fallback = tj.model.byte_fallback;
|
||||
let mut merge_ranks = HashMap::new();
|
||||
for (rank, entry) in tj.model.merges.iter().enumerate() {
|
||||
let (a_str, b_str) = match entry {
|
||||
MergeEntry::Str(s) => {
|
||||
let parts: Vec<&str> = s.splitn(2, ' ').collect();
|
||||
if parts.len() != 2 { continue; }
|
||||
(parts[0].to_string(), parts[1].to_string())
|
||||
}
|
||||
MergeEntry::Pair(v) => {
|
||||
if v.len() != 2 { continue; }
|
||||
(v[0].clone(), v[1].clone())
|
||||
}
|
||||
};
|
||||
let a_bytes = token_str_to_bytes(&a_str);
|
||||
let b_bytes = token_str_to_bytes(&b_str);
|
||||
if let (Some(&a_id), Some(&b_id)) = (encoder.get(&a_bytes), encoder.get(&b_bytes)) {
|
||||
merge_ranks.insert((a_id, b_id), rank);
|
||||
}
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
let mut special_tokens = HashMap::new();
|
||||
let mut special_token_ids = HashMap::new();
|
||||
let mut eos_token_id = None;
|
||||
for at in &tj.added_tokens {
|
||||
if at.special {
|
||||
special_tokens.insert(at.content.clone(), at.id);
|
||||
special_token_ids.insert(at.id, at.content.clone());
|
||||
decoder.resize(decoder.len().max(at.id as usize + 1), vec![]);
|
||||
decoder[at.id as usize] = at.content.as_bytes().to_vec();
|
||||
if at.content == "<|endoftext|>" || at.content == "<|end_of_text|>" {
|
||||
eos_token_id = Some(at.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-tokenization regex
|
||||
let pre_tokenize_re = if byte_fallback {
|
||||
// Qwen-style: split on whitespace boundaries, keep Unicode words/numbers
|
||||
Regex::new(r"[\p{L}\p{N}]+|[^\s\p{L}\p{N}]|\s+").unwrap()
|
||||
} else {
|
||||
// GPT-2 style
|
||||
Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+").unwrap()
|
||||
};
|
||||
|
||||
Self {
|
||||
encoder,
|
||||
decoder,
|
||||
merge_ranks,
|
||||
special_tokens,
|
||||
special_token_ids,
|
||||
pre_tokenize_re,
|
||||
eos_token_id,
|
||||
byte_fallback,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encode(&self, text: &str) -> Vec<u32> {
|
||||
let mut tokens = Vec::new();
|
||||
|
||||
// Check for special tokens first (split around them)
|
||||
let mut remaining = text;
|
||||
while !remaining.is_empty() {
|
||||
// Find earliest special token
|
||||
let mut earliest: Option<(usize, &str, u32)> = None;
|
||||
for (st, &id) in &self.special_tokens {
|
||||
if let Some(pos) = remaining.find(st.as_str()) {
|
||||
if earliest.is_none() || pos < earliest.unwrap().0 {
|
||||
earliest = Some((pos, st, id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((pos, st, id)) = earliest {
|
||||
if pos > 0 {
|
||||
self.encode_ordinary(&remaining[..pos], &mut tokens);
|
||||
}
|
||||
tokens.push(id);
|
||||
remaining = &remaining[pos + st.len()..];
|
||||
} else {
|
||||
self.encode_ordinary(remaining, &mut tokens);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tokens
|
||||
}
|
||||
|
||||
fn encode_ordinary(&self, text: &str, out: &mut Vec<u32>) {
|
||||
for mat in self.pre_tokenize_re.find_iter(text) {
|
||||
let word = mat.as_str();
|
||||
// Try to encode the whole word first
|
||||
if let Some(&id) = self.encoder.get(word.as_bytes()) {
|
||||
out.push(id);
|
||||
continue;
|
||||
}
|
||||
// Fall back to per-byte encoding
|
||||
let word_bytes: Vec<u8> = word.bytes().collect();
|
||||
let mut token_ids: Vec<u32> = word_bytes.iter().map(|&b| {
|
||||
*self.encoder.get(&vec![b]).unwrap_or_else(|| {
|
||||
panic!("byte {b} (0x{b:02X}) not in vocab")
|
||||
})
|
||||
}).collect();
|
||||
|
||||
// BPE merges
|
||||
loop {
|
||||
if token_ids.len() < 2 { break; }
|
||||
let mut best_rank = usize::MAX;
|
||||
let mut best_idx = 0;
|
||||
for i in 0..token_ids.len() - 1 {
|
||||
if let Some(&rank) = self.merge_ranks.get(&(token_ids[i], token_ids[i + 1])) {
|
||||
if rank < best_rank {
|
||||
best_rank = rank;
|
||||
best_idx = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
if best_rank == usize::MAX { break; }
|
||||
|
||||
let merged_bytes = [
|
||||
self.decoder[token_ids[best_idx] as usize].as_slice(),
|
||||
self.decoder[token_ids[best_idx + 1] as usize].as_slice(),
|
||||
].concat();
|
||||
let merged_id = *self.encoder.get(&merged_bytes).unwrap_or_else(|| {
|
||||
panic!("merged token not in vocab");
|
||||
});
|
||||
token_ids[best_idx] = merged_id;
|
||||
token_ids.remove(best_idx + 1);
|
||||
}
|
||||
|
||||
out.extend_from_slice(&token_ids);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode(&self, token_ids: &[u32]) -> String {
|
||||
let mut bytes = Vec::new();
|
||||
for &id in token_ids {
|
||||
if let Some(b) = self.decoder.get(id as usize) {
|
||||
bytes.extend_from_slice(b);
|
||||
}
|
||||
}
|
||||
String::from_utf8_lossy(&bytes).into_owned()
|
||||
}
|
||||
|
||||
pub fn eos_token_id(&self) -> Option<u32> {
|
||||
self.eos_token_id
|
||||
}
|
||||
|
||||
pub fn vocab_size(&self) -> usize {
|
||||
self.decoder.len()
|
||||
}
|
||||
|
||||
pub fn special_token_id(&self, name: &str) -> Option<u32> {
|
||||
self.special_tokens.get(name).copied()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a token string from HF vocab (which uses Unicode replacements for bytes)
|
||||
/// back to raw bytes. GPT-2 uses a byte-to-unicode mapping where e.g. byte 0x20 (space)
|
||||
/// is represented as 'Ġ' (U+0120).
|
||||
fn token_str_to_bytes(s: &str) -> Vec<u8> {
|
||||
s.chars().map(|c| unicode_to_byte(c)).collect()
|
||||
}
|
||||
|
||||
/// Convert a Unicode char back to the byte it represents in GPT-2 encoding.
|
||||
fn unicode_to_byte(c: char) -> u8 {
|
||||
// Build the inverse map on first use
|
||||
use std::sync::OnceLock;
|
||||
static INV_MAP: OnceLock<HashMap<u32, u8>> = OnceLock::new();
|
||||
|
||||
let map = INV_MAP.get_or_init(|| {
|
||||
let mut m = HashMap::new();
|
||||
// Build GPT-2's bytes_to_unicode forward map, then invert
|
||||
let mut n = 0u32;
|
||||
for b in 0..=255u16 {
|
||||
let byte = b as u8;
|
||||
let unicode = match byte {
|
||||
0x21..=0x7E | 0xA1..=0xAC | 0xAE..=0xFF => byte as u32,
|
||||
_ => {
|
||||
let u = 256 + n;
|
||||
n += 1;
|
||||
u
|
||||
}
|
||||
};
|
||||
m.insert(unicode, byte);
|
||||
}
|
||||
m
|
||||
});
|
||||
|
||||
*map.get(&(c as u32)).unwrap_or_else(|| {
|
||||
panic!("unmapped unicode char U+{:04X} in tokenizer", c as u32)
|
||||
})
|
||||
}
|
||||
3
crates/xserv-tokenizer/src/lib.rs
Normal file
3
crates/xserv-tokenizer/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod bpe;
|
||||
|
||||
pub use bpe::Tokenizer;
|
||||
135
csrc/activation/activations.cu
Normal file
135
csrc/activation/activations.cu
Normal file
@@ -0,0 +1,135 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <math.h>
|
||||
|
||||
// GELU (tanh approximation):
|
||||
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
__device__ __forceinline__ float gelu_f(float x) {
|
||||
const float SQRT_2_OVER_PI = 0.7978845608f;
|
||||
float cube = x * x * x;
|
||||
float inner = SQRT_2_OVER_PI * (x + 0.044715f * cube);
|
||||
return 0.5f * x * (1.0f + tanhf(inner));
|
||||
}
|
||||
|
||||
// SiLU (Swish): silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
|
||||
__device__ __forceinline__ float silu_f(float x) {
|
||||
return x / (1.0f + expf(-x));
|
||||
}
|
||||
|
||||
__global__ void gelu_f32(const float* x, float* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = gelu_f(x[idx]);
|
||||
}
|
||||
|
||||
__global__ void gelu_bf16(const __nv_bfloat16* x, __nv_bfloat16* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = __float2bfloat16(gelu_f(__bfloat162float(x[idx])));
|
||||
}
|
||||
|
||||
__global__ void silu_f32(const float* x, float* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = silu_f(x[idx]);
|
||||
}
|
||||
|
||||
__global__ void silu_bf16(const __nv_bfloat16* x, __nv_bfloat16* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = __float2bfloat16(silu_f(__bfloat162float(x[idx])));
|
||||
}
|
||||
|
||||
__global__ void scale_f32_kernel(const float* x, float* out, float scale, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = x[idx] * scale;
|
||||
}
|
||||
|
||||
__global__ void scale_bf16_kernel(const __nv_bfloat16* x, __nv_bfloat16* out, float scale, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(x[idx]) * scale);
|
||||
}
|
||||
|
||||
// Element-wise add: out = a + b
|
||||
__global__ void add_f32_kernel(const float* a, const float* b, float* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = a[idx] + b[idx];
|
||||
}
|
||||
__global__ void add_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(a[idx]) + __bfloat162float(b[idx]));
|
||||
}
|
||||
|
||||
// Element-wise mul: out = a * b
|
||||
__global__ void mul_f32_kernel(const float* a, const float* b, float* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = a[idx] * b[idx];
|
||||
}
|
||||
__global__ void mul_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(a[idx]) * __bfloat162float(b[idx]));
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_gelu_f32(const void* x, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
gelu_f32<<<grid, block, 0, (cudaStream_t)stream>>>((const float*)x, (float*)out, n);
|
||||
}
|
||||
|
||||
void launch_gelu_bf16(const void* x, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
gelu_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
||||
}
|
||||
|
||||
void launch_silu_f32(const void* x, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
silu_f32<<<grid, block, 0, (cudaStream_t)stream>>>((const float*)x, (float*)out, n);
|
||||
}
|
||||
|
||||
void launch_silu_bf16(const void* x, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
silu_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
||||
}
|
||||
|
||||
void launch_scale_f32(const void* x, void* out, float scale, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
scale_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)x, (float*)out, scale, n);
|
||||
}
|
||||
|
||||
void launch_scale_bf16(const void* x, void* out, float scale, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
scale_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, scale, n);
|
||||
}
|
||||
|
||||
void launch_add_f32(const void* a, const void* b, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
add_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)a, (const float*)b, (float*)out, n);
|
||||
}
|
||||
void launch_add_bf16(const void* a, const void* b, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
add_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
||||
}
|
||||
void launch_mul_f32(const void* a, const void* b, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
mul_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)a, (const float*)b, (float*)out, n);
|
||||
}
|
||||
void launch_mul_bf16(const void* a, const void* b, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
mul_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
||||
}
|
||||
|
||||
}
|
||||
53
csrc/attention/causal_mask.cu
Normal file
53
csrc/attention/causal_mask.cu
Normal file
@@ -0,0 +1,53 @@
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
// Apply causal mask: set scores[row][col] = -inf where col > row + offset.
|
||||
// offset is used for KV cache: when query starts at position `offset`,
|
||||
// we allow attending to positions [0, offset + row].
|
||||
// scores: [batch, rows, cols] (flattened batch×heads)
|
||||
|
||||
__global__ void causal_mask_f32(
|
||||
float* __restrict__ scores,
|
||||
int rows, int cols, int offset
|
||||
) {
|
||||
int batch_idx = blockIdx.z;
|
||||
int row = blockIdx.y;
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (col < cols && col > row + offset) {
|
||||
scores[batch_idx * rows * cols + row * cols + col] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void causal_mask_bf16(
|
||||
__nv_bfloat16* __restrict__ scores,
|
||||
int rows, int cols, int offset
|
||||
) {
|
||||
int batch_idx = blockIdx.z;
|
||||
int row = blockIdx.y;
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (col < cols && col > row + offset) {
|
||||
// BF16 doesn't have proper -inf literal, use a very large negative
|
||||
scores[batch_idx * rows * cols + row * cols + col] = __float2bfloat16(-1e9f);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_causal_mask_f32(void* scores, int batch, int rows, int cols,
|
||||
int offset, void* stream) {
|
||||
int block = 256;
|
||||
dim3 grid((cols + block - 1) / block, rows, batch);
|
||||
causal_mask_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(float*)scores, rows, cols, offset);
|
||||
}
|
||||
|
||||
void launch_causal_mask_bf16(void* scores, int batch, int rows, int cols,
|
||||
int offset, void* stream) {
|
||||
int block = 256;
|
||||
dim3 grid((cols + block - 1) / block, rows, batch);
|
||||
causal_mask_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)scores, rows, cols, offset);
|
||||
}
|
||||
|
||||
}
|
||||
50
csrc/common.cuh
Normal file
50
csrc/common.cuh
Normal file
@@ -0,0 +1,50 @@
|
||||
#pragma once
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
// --- Warp-level reductions (no shared memory needed) ---
|
||||
|
||||
__device__ __forceinline__ float warp_reduce_sum(float val) {
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
return val;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warp_reduce_max(float val) {
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
|
||||
return val;
|
||||
}
|
||||
|
||||
// --- Block-level reductions ---
|
||||
|
||||
__device__ __forceinline__ float block_reduce_sum(float val) {
|
||||
__shared__ float shared[32];
|
||||
int lane = threadIdx.x & 31;
|
||||
int warp_id = threadIdx.x >> 5;
|
||||
int num_warps = (blockDim.x + 31) >> 5;
|
||||
|
||||
val = warp_reduce_sum(val);
|
||||
if (lane == 0) shared[warp_id] = val;
|
||||
__syncthreads();
|
||||
|
||||
val = (threadIdx.x < num_warps) ? shared[threadIdx.x] : 0.0f;
|
||||
if (warp_id == 0) val = warp_reduce_sum(val);
|
||||
return val;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float block_reduce_max(float val) {
|
||||
__shared__ float shared[32];
|
||||
int lane = threadIdx.x & 31;
|
||||
int warp_id = threadIdx.x >> 5;
|
||||
int num_warps = (blockDim.x + 31) >> 5;
|
||||
|
||||
val = warp_reduce_max(val);
|
||||
if (lane == 0) shared[warp_id] = val;
|
||||
__syncthreads();
|
||||
|
||||
val = (threadIdx.x < num_warps) ? shared[threadIdx.x] : -INFINITY;
|
||||
if (warp_id == 0) val = warp_reduce_max(val);
|
||||
return val;
|
||||
}
|
||||
55
csrc/embedding/embedding.cu
Normal file
55
csrc/embedding/embedding.cu
Normal file
@@ -0,0 +1,55 @@
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
// Embedding lookup: out[seq_idx] = table[token_ids[seq_idx]]
|
||||
// Grid: num_tokens, Block: handles hidden_size elements per token.
|
||||
|
||||
__global__ void embedding_f32(
|
||||
const float* __restrict__ table, // [vocab_size, hidden_size]
|
||||
const int* __restrict__ token_ids, // [num_tokens]
|
||||
float* __restrict__ out, // [num_tokens, hidden_size]
|
||||
int hidden_size
|
||||
) {
|
||||
int token_idx = blockIdx.x;
|
||||
int tid = token_ids[token_idx];
|
||||
const float* row = table + tid * hidden_size;
|
||||
float* dst = out + token_idx * hidden_size;
|
||||
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
dst[i] = row[i];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void embedding_bf16(
|
||||
const __nv_bfloat16* __restrict__ table,
|
||||
const int* __restrict__ token_ids,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int hidden_size
|
||||
) {
|
||||
int token_idx = blockIdx.x;
|
||||
int tid = token_ids[token_idx];
|
||||
const __nv_bfloat16* row = table + tid * hidden_size;
|
||||
__nv_bfloat16* dst = out + token_idx * hidden_size;
|
||||
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
dst[i] = row[i];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_embedding_f32(const void* table, const void* token_ids, void* out,
|
||||
int num_tokens, int hidden_size, void* stream) {
|
||||
int block = (hidden_size < 256) ? hidden_size : 256;
|
||||
embedding_f32<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)table, (const int*)token_ids, (float*)out, hidden_size);
|
||||
}
|
||||
|
||||
void launch_embedding_bf16(const void* table, const void* token_ids, void* out,
|
||||
int num_tokens, int hidden_size, void* stream) {
|
||||
int block = (hidden_size < 256) ? hidden_size : 256;
|
||||
embedding_bf16<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)table, (const int*)token_ids,
|
||||
(__nv_bfloat16*)out, hidden_size);
|
||||
}
|
||||
|
||||
}
|
||||
116
csrc/embedding/rope.cu
Normal file
116
csrc/embedding/rope.cu
Normal file
@@ -0,0 +1,116 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <math.h>
|
||||
|
||||
// RoPE: Rotary Position Embedding
|
||||
// For each pair (x[2i], x[2i+1]) at position `pos`:
|
||||
// y[2i] = x[2i] * cos - x[2i+1] * sin
|
||||
// y[2i+1] = x[2i] * sin + x[2i+1] * cos
|
||||
// where cos/sin come from precomputed cos_cache/sin_cache.
|
||||
//
|
||||
// cos_cache[pos][i] = cos(pos * freq[i])
|
||||
// sin_cache[pos][i] = sin(pos * freq[i])
|
||||
// freq[i] = 1.0 / (theta ^ (2i / head_dim))
|
||||
|
||||
// Apply RoPE in-place to Q or K tensor.
|
||||
// x shape: [num_tokens, num_heads, head_dim]
|
||||
// cos_cache, sin_cache shape: [max_seq_len, head_dim/2]
|
||||
// positions: [num_tokens] — the position index for each token
|
||||
|
||||
__global__ void rope_f32(
|
||||
float* __restrict__ x, // [num_tokens, num_heads, head_dim]
|
||||
const float* __restrict__ cos_cache, // [max_seq_len, half_dim]
|
||||
const float* __restrict__ sin_cache, // [max_seq_len, half_dim]
|
||||
const int* __restrict__ positions, // [num_tokens]
|
||||
int num_heads, int head_dim
|
||||
) {
|
||||
int token_idx = blockIdx.x;
|
||||
int head_idx = blockIdx.y;
|
||||
int half_dim = head_dim / 2;
|
||||
int pair_idx = threadIdx.x; // which pair (0..half_dim)
|
||||
|
||||
if (pair_idx >= half_dim) return;
|
||||
|
||||
int pos = positions[token_idx];
|
||||
float cos_val = cos_cache[pos * half_dim + pair_idx];
|
||||
float sin_val = sin_cache[pos * half_dim + pair_idx];
|
||||
|
||||
int base = (token_idx * num_heads + head_idx) * head_dim;
|
||||
float x0 = x[base + 2 * pair_idx];
|
||||
float x1 = x[base + 2 * pair_idx + 1];
|
||||
|
||||
x[base + 2 * pair_idx] = x0 * cos_val - x1 * sin_val;
|
||||
x[base + 2 * pair_idx + 1] = x0 * sin_val + x1 * cos_val;
|
||||
}
|
||||
|
||||
__global__ void rope_bf16(
|
||||
__nv_bfloat16* __restrict__ x,
|
||||
const float* __restrict__ cos_cache,
|
||||
const float* __restrict__ sin_cache,
|
||||
const int* __restrict__ positions,
|
||||
int num_heads, int head_dim
|
||||
) {
|
||||
int token_idx = blockIdx.x;
|
||||
int head_idx = blockIdx.y;
|
||||
int half_dim = head_dim / 2;
|
||||
int pair_idx = threadIdx.x;
|
||||
|
||||
if (pair_idx >= half_dim) return;
|
||||
|
||||
int pos = positions[token_idx];
|
||||
float cos_val = cos_cache[pos * half_dim + pair_idx];
|
||||
float sin_val = sin_cache[pos * half_dim + pair_idx];
|
||||
|
||||
int base = (token_idx * num_heads + head_idx) * head_dim;
|
||||
float x0 = __bfloat162float(x[base + 2 * pair_idx]);
|
||||
float x1 = __bfloat162float(x[base + 2 * pair_idx + 1]);
|
||||
|
||||
x[base + 2 * pair_idx] = __float2bfloat16(x0 * cos_val - x1 * sin_val);
|
||||
x[base + 2 * pair_idx + 1] = __float2bfloat16(x0 * sin_val + x1 * cos_val);
|
||||
}
|
||||
|
||||
// Precompute cos/sin cache on GPU
|
||||
__global__ void compute_rope_cache(
|
||||
float* __restrict__ cos_cache, // [max_seq_len, half_dim]
|
||||
float* __restrict__ sin_cache,
|
||||
int max_seq_len, int half_dim, float theta
|
||||
) {
|
||||
int pos = blockIdx.x;
|
||||
int i = threadIdx.x;
|
||||
if (i >= half_dim) return;
|
||||
|
||||
float freq = 1.0f / powf(theta, (float)(2 * i) / (float)(2 * half_dim));
|
||||
float angle = (float)pos * freq;
|
||||
cos_cache[pos * half_dim + i] = cosf(angle);
|
||||
sin_cache[pos * half_dim + i] = sinf(angle);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_rope_f32(void* x, const void* cos_cache, const void* sin_cache,
|
||||
const void* positions, int num_tokens, int num_heads,
|
||||
int head_dim, void* stream) {
|
||||
dim3 grid(num_tokens, num_heads);
|
||||
int block = head_dim / 2;
|
||||
rope_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(float*)x, (const float*)cos_cache, (const float*)sin_cache,
|
||||
(const int*)positions, num_heads, head_dim);
|
||||
}
|
||||
|
||||
void launch_rope_bf16(void* x, const void* cos_cache, const void* sin_cache,
|
||||
const void* positions, int num_tokens, int num_heads,
|
||||
int head_dim, void* stream) {
|
||||
dim3 grid(num_tokens, num_heads);
|
||||
int block = head_dim / 2;
|
||||
rope_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)x, (const float*)cos_cache, (const float*)sin_cache,
|
||||
(const int*)positions, num_heads, head_dim);
|
||||
}
|
||||
|
||||
void launch_compute_rope_cache(void* cos_cache, void* sin_cache,
|
||||
int max_seq_len, int half_dim, float theta,
|
||||
void* stream) {
|
||||
compute_rope_cache<<<max_seq_len, half_dim, 0, (cudaStream_t)stream>>>(
|
||||
(float*)cos_cache, (float*)sin_cache, max_seq_len, half_dim, theta);
|
||||
}
|
||||
|
||||
}
|
||||
102
csrc/normalization/layernorm.cu
Normal file
102
csrc/normalization/layernorm.cu
Normal file
@@ -0,0 +1,102 @@
|
||||
#include "../common.cuh"
|
||||
|
||||
// LayerNorm: y[i] = gamma[i] * (x[i] - mean) / sqrt(var + eps) + beta[i]
|
||||
// Each block processes one row of shape [hidden_size].
|
||||
|
||||
__global__ void layernorm_f32(
|
||||
const float* __restrict__ x,
|
||||
const float* __restrict__ gamma,
|
||||
const float* __restrict__ beta,
|
||||
float* __restrict__ out,
|
||||
int hidden_size, float eps
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const float* x_row = x + row * hidden_size;
|
||||
float* out_row = out + row * hidden_size;
|
||||
|
||||
// Welford online: compute mean and variance in one pass
|
||||
float local_sum = 0.0f;
|
||||
float local_sum_sq = 0.0f;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = x_row[i];
|
||||
local_sum += v;
|
||||
local_sum_sq += v * v;
|
||||
}
|
||||
local_sum = block_reduce_sum(local_sum);
|
||||
local_sum_sq = block_reduce_sum(local_sum_sq);
|
||||
|
||||
__shared__ float s_mean, s_inv_std;
|
||||
if (threadIdx.x == 0) {
|
||||
float mean = local_sum / hidden_size;
|
||||
float var = local_sum_sq / hidden_size - mean * mean;
|
||||
s_mean = mean;
|
||||
s_inv_std = rsqrtf(var + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float mean = s_mean;
|
||||
float inv_std = s_inv_std;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
out_row[i] = gamma[i] * (x_row[i] - mean) * inv_std + beta[i];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void layernorm_bf16(
|
||||
const __nv_bfloat16* __restrict__ x,
|
||||
const __nv_bfloat16* __restrict__ gamma,
|
||||
const __nv_bfloat16* __restrict__ beta,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int hidden_size, float eps
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const __nv_bfloat16* x_row = x + row * hidden_size;
|
||||
__nv_bfloat16* out_row = out + row * hidden_size;
|
||||
|
||||
float local_sum = 0.0f;
|
||||
float local_sum_sq = 0.0f;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = __bfloat162float(x_row[i]);
|
||||
local_sum += v;
|
||||
local_sum_sq += v * v;
|
||||
}
|
||||
local_sum = block_reduce_sum(local_sum);
|
||||
local_sum_sq = block_reduce_sum(local_sum_sq);
|
||||
|
||||
__shared__ float s_mean, s_inv_std;
|
||||
if (threadIdx.x == 0) {
|
||||
float mean = local_sum / hidden_size;
|
||||
float var = local_sum_sq / hidden_size - mean * mean;
|
||||
s_mean = mean;
|
||||
s_inv_std = rsqrtf(var + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float mean = s_mean;
|
||||
float inv_std = s_inv_std;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = __bfloat162float(x_row[i]);
|
||||
float g = __bfloat162float(gamma[i]);
|
||||
float b = __bfloat162float(beta[i]);
|
||||
out_row[i] = __float2bfloat16(g * (v - mean) * inv_std + b);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_layernorm_f32(const void* x, const void* gamma, const void* beta,
|
||||
void* out, int rows, int hidden_size, float eps, void* stream) {
|
||||
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
||||
layernorm_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)x, (const float*)gamma, (const float*)beta,
|
||||
(float*)out, hidden_size, eps);
|
||||
}
|
||||
|
||||
void launch_layernorm_bf16(const void* x, const void* gamma, const void* beta,
|
||||
void* out, int rows, int hidden_size, float eps, void* stream) {
|
||||
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
||||
layernorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const __nv_bfloat16*)gamma, (const __nv_bfloat16*)beta,
|
||||
(__nv_bfloat16*)out, hidden_size, eps);
|
||||
}
|
||||
|
||||
}
|
||||
83
csrc/normalization/rmsnorm.cu
Normal file
83
csrc/normalization/rmsnorm.cu
Normal file
@@ -0,0 +1,83 @@
|
||||
#include "../common.cuh"
|
||||
|
||||
// RMSNorm: y[i] = x[i] * rsqrt(mean(x²) + eps) * gamma[i]
|
||||
// Each block processes one row of shape [hidden_size].
|
||||
|
||||
__global__ void rmsnorm_f32(
|
||||
const float* __restrict__ x,
|
||||
const float* __restrict__ gamma,
|
||||
float* __restrict__ out,
|
||||
int hidden_size, float eps
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const float* x_row = x + row * hidden_size;
|
||||
float* out_row = out + row * hidden_size;
|
||||
|
||||
float sum_sq = 0.0f;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = x_row[i];
|
||||
sum_sq += v * v;
|
||||
}
|
||||
sum_sq = block_reduce_sum(sum_sq);
|
||||
|
||||
__shared__ float s_rms_inv;
|
||||
if (threadIdx.x == 0) {
|
||||
s_rms_inv = rsqrtf(sum_sq / hidden_size + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float rms_inv = s_rms_inv;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
out_row[i] = x_row[i] * rms_inv * gamma[i];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void rmsnorm_bf16(
|
||||
const __nv_bfloat16* __restrict__ x,
|
||||
const __nv_bfloat16* __restrict__ gamma,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int hidden_size, float eps
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const __nv_bfloat16* x_row = x + row * hidden_size;
|
||||
__nv_bfloat16* out_row = out + row * hidden_size;
|
||||
|
||||
float sum_sq = 0.0f;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = __bfloat162float(x_row[i]);
|
||||
sum_sq += v * v;
|
||||
}
|
||||
sum_sq = block_reduce_sum(sum_sq);
|
||||
|
||||
__shared__ float s_rms_inv;
|
||||
if (threadIdx.x == 0) {
|
||||
s_rms_inv = rsqrtf(sum_sq / hidden_size + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float rms_inv = s_rms_inv;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = __bfloat162float(x_row[i]);
|
||||
float g = __bfloat162float(gamma[i]);
|
||||
out_row[i] = __float2bfloat16(v * rms_inv * g);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_rmsnorm_f32(const void* x, const void* gamma, void* out,
|
||||
int rows, int hidden_size, float eps, void* stream) {
|
||||
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
||||
rmsnorm_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)x, (const float*)gamma, (float*)out, hidden_size, eps);
|
||||
}
|
||||
|
||||
void launch_rmsnorm_bf16(const void* x, const void* gamma, void* out,
|
||||
int rows, int hidden_size, float eps, void* stream) {
|
||||
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
||||
rmsnorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const __nv_bfloat16*)gamma,
|
||||
(__nv_bfloat16*)out, hidden_size, eps);
|
||||
}
|
||||
|
||||
}
|
||||
106
csrc/reduce/softmax.cu
Normal file
106
csrc/reduce/softmax.cu
Normal file
@@ -0,0 +1,106 @@
|
||||
#include "../common.cuh"
|
||||
|
||||
// Safe softmax along the last dimension.
|
||||
// Each block handles one row of length `cols`.
|
||||
// Three-pass: 1) find max, 2) exp + sum, 3) normalize.
|
||||
|
||||
__global__ void softmax_f32(
|
||||
const float* __restrict__ x,
|
||||
float* __restrict__ out,
|
||||
int cols
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const float* x_row = x + row * cols;
|
||||
float* out_row = out + row * cols;
|
||||
|
||||
// Pass 1: find max
|
||||
float local_max = -INFINITY;
|
||||
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
||||
local_max = fmaxf(local_max, x_row[i]);
|
||||
}
|
||||
float row_max = block_reduce_max(local_max);
|
||||
|
||||
__shared__ float s_max;
|
||||
if (threadIdx.x == 0) s_max = row_max;
|
||||
__syncthreads();
|
||||
row_max = s_max;
|
||||
|
||||
// Pass 2: exp and sum
|
||||
float local_sum = 0.0f;
|
||||
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
||||
float e = expf(x_row[i] - row_max);
|
||||
out_row[i] = e;
|
||||
local_sum += e;
|
||||
}
|
||||
float row_sum = block_reduce_sum(local_sum);
|
||||
|
||||
__shared__ float s_inv_sum;
|
||||
if (threadIdx.x == 0) s_inv_sum = 1.0f / row_sum;
|
||||
__syncthreads();
|
||||
float inv_sum = s_inv_sum;
|
||||
|
||||
// Pass 3: normalize
|
||||
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
||||
out_row[i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void softmax_bf16(
|
||||
const __nv_bfloat16* __restrict__ x,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int cols
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const __nv_bfloat16* x_row = x + row * cols;
|
||||
__nv_bfloat16* out_row = out + row * cols;
|
||||
|
||||
float local_max = -INFINITY;
|
||||
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
||||
local_max = fmaxf(local_max, __bfloat162float(x_row[i]));
|
||||
}
|
||||
float row_max = block_reduce_max(local_max);
|
||||
|
||||
__shared__ float s_max;
|
||||
if (threadIdx.x == 0) s_max = row_max;
|
||||
__syncthreads();
|
||||
row_max = s_max;
|
||||
|
||||
// We need float scratch for exp values. Reuse out (write bf16 in pass 3).
|
||||
// Use registers to hold exp values during sum pass instead.
|
||||
float local_sum = 0.0f;
|
||||
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
||||
float e = expf(__bfloat162float(x_row[i]) - row_max);
|
||||
// Temporarily store exp in output as bf16 (slight precision loss, acceptable)
|
||||
out_row[i] = __float2bfloat16(e);
|
||||
local_sum += e;
|
||||
}
|
||||
float row_sum = block_reduce_sum(local_sum);
|
||||
|
||||
__shared__ float s_inv_sum;
|
||||
if (threadIdx.x == 0) s_inv_sum = 1.0f / row_sum;
|
||||
__syncthreads();
|
||||
float inv_sum = s_inv_sum;
|
||||
|
||||
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
||||
float e = __bfloat162float(out_row[i]);
|
||||
out_row[i] = __float2bfloat16(e * inv_sum);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_softmax_f32(const void* x, void* out, int rows, int cols, void* stream) {
|
||||
int block = (cols < 1024) ? cols : 1024;
|
||||
if (block < 32) block = 32;
|
||||
softmax_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)x, (float*)out, cols);
|
||||
}
|
||||
|
||||
void launch_softmax_bf16(const void* x, void* out, int rows, int cols, void* stream) {
|
||||
int block = (cols < 1024) ? cols : 1024;
|
||||
if (block < 32) block = 32;
|
||||
softmax_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, cols);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -72,9 +72,31 @@ Wraps cudaStream_t. RAII with Drop calling cudaStreamDestroy.
|
||||
- `build.rs` uses `cc` crate to compile .cu files, link CUDA runtime
|
||||
|
||||
## Test Plan
|
||||
1. Device info: print GPU name, memory, compute capability, SM count
|
||||
2. GpuBuffer: alloc 1GB, H2D copy, D2H copy, verify data
|
||||
3. Vector add kernel: launch from Rust, verify output
|
||||
4. CachingAllocator: alloc→free→realloc same size uses cache (no new cudaMalloc)
|
||||
5. Multi-stream: two concurrent memcpy on different streams
|
||||
6. Benchmark: caching allocator vs raw cudaMalloc (100 cycles)
|
||||
|
||||
- [x] Device info: print GPU name, memory, compute capability, SM count
|
||||
- [x] GpuBuffer: alloc → H2D copy → D2H copy → verify data (256B, 64MB)
|
||||
- [x] GpuBuffer: D2D copy 验证
|
||||
- [x] GpuBuffer: zero fill 验证
|
||||
- [x] Vector add kernel: launch from Rust, verify output
|
||||
- [x] CachingAllocator: alloc→free→realloc same size uses cache (no new cudaMalloc)
|
||||
- [x] CachingAllocator: 不同 size bucket 独立缓存
|
||||
- [x] CudaStream: 创建、同步、Drop
|
||||
- [x] PinnedBuffer: page-locked host memory
|
||||
- [x] Async copy: H2D async + D2H async via stream
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **`cudaDeviceProp` struct 布局不可靠**:CUDA 版本之间 `cudaDeviceProp` 的字段偏移会变化。我们最初用 struct 映射读取 `total_global_mem`,得到了垃圾值(12TB)。正确做法:用 `cudaMemGetInfo` 获取显存信息,用 `cudaDeviceGetAttribute` 获取其他属性。只从 `cudaDeviceProp` 读取 `name` 字段(始终在 struct 最前面,布局稳定)。
|
||||
|
||||
2. **Rust 2024 edition 的 unsafe 语义变更**:
|
||||
- `extern "C"` 块必须加 `unsafe` 前缀 → `unsafe extern "C"`
|
||||
- `unsafe fn` 内部的 unsafe 调用也需要显式 `unsafe {}` 块
|
||||
- 这让代码更安全,但初次移植需要注意
|
||||
|
||||
3. **`cc` crate 的 CUDA 支持是内置的**:不需要 `features = ["cuda"]`(这个 feature 不存在)。只需 `.cuda(true).cudart("shared")`。
|
||||
|
||||
4. **Caching Allocator 的 bucket 策略**:round up to next power of 2(最小 512B)。这意味着申请 513B 会分配 1024B,存在内部碎片。但简单且高效——避免了 free list 中的精确匹配问题。PyTorch 的 CUDACachingAllocator 用了更复杂的策略(best-fit with splitting),但对于推理场景,power-of-2 bucket 已经够用。
|
||||
|
||||
5. **`into_raw` + `from_raw` 模式**:GpuBuffer 的 RAII Drop 和 CachingAllocator 的缓存需求冲突——allocator 需要持有裸指针而不触发 Drop。`into_raw()` 消费 self(`mem::forget`),返回裸指针;`from_raw()` 重新封装。这是 Rust 中管理 RAII 生命周期的标准模式。
|
||||
|
||||
6. **dash5 环境**:CUDA 12.9 已安装但 `nvcc` 不在 PATH(需要 `/usr/local/cuda/bin`)。Rust 需要手动安装 rustup。无 rsync,用 `tar | ssh tar` 同步代码。开发工作流:本地写码 → tar sync → 远程 build+test。
|
||||
|
||||
97
docs/02-tensor.md
Normal file
97
docs/02-tensor.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# Phase 2: Tensor Abstraction Layer — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
实现核心 Tensor 类型,支持 CPU/GPU 存储、多种数据类型、strided view 操作,作为后续所有算子和模型的数据基础。
|
||||
|
||||
## Module Layout
|
||||
|
||||
```
|
||||
crates/xserv-tensor/
|
||||
├── Cargo.toml
|
||||
└── src/
|
||||
├── lib.rs # re-exports
|
||||
├── dtype.rs # DType enum, TensorDType trait
|
||||
├── shape.rs # strides 计算, broadcast 规则
|
||||
├── storage.rs # Storage (Arc引用计数), Device enum
|
||||
└── tensor.rs # Tensor 主体: 创建, 形状操作, 设备迁移
|
||||
```
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### DType + TensorDType Trait
|
||||
|
||||
```rust
|
||||
pub enum DType { F32, F16, BF16 }
|
||||
|
||||
pub trait TensorDType: Copy + Send + Sync + 'static {
|
||||
const DTYPE: DType;
|
||||
fn to_f64(self) -> f64;
|
||||
fn from_f64(v: f64) -> Self;
|
||||
}
|
||||
```
|
||||
|
||||
- 用 `half` crate 的 `bf16`/`f16` 表示半精度类型
|
||||
- `TensorDType` trait 让 `from_slice<T>` 和 `as_slice<T>` 有类型安全
|
||||
- GPU kernel 中通过 `DType` dispatch 到对应的 CUDA 类型 (`__nv_bfloat16` / `float`)
|
||||
|
||||
### Storage 引用计数
|
||||
|
||||
```rust
|
||||
pub struct Storage(Arc<StorageInner>);
|
||||
enum StorageInner {
|
||||
Cpu { data: Vec<u8> },
|
||||
Cuda { buffer: GpuBuffer },
|
||||
}
|
||||
```
|
||||
|
||||
- `Arc` 引用计数让 transpose/slice/reshape 能共享底层数据(view 语义)
|
||||
- 不实现 CoW(copy-on-write),view 只能读不能写
|
||||
- `to_device()` 总是创建新的 Storage
|
||||
|
||||
### Strided Tensor
|
||||
|
||||
```rust
|
||||
pub struct Tensor {
|
||||
storage: Storage,
|
||||
shape: SmallVec<[usize; 4]>,
|
||||
strides: SmallVec<[usize; 4]>,
|
||||
offset: usize,
|
||||
dtype: DType,
|
||||
}
|
||||
```
|
||||
|
||||
- `SmallVec<[usize; 4]>` 避免大多数 tensor (≤4D) 的堆分配
|
||||
- `strides` 以元素为单位(不是字节)
|
||||
- `offset` 支持 slice 操作(view 到 storage 的中间位置)
|
||||
- `is_contiguous()` 检查 strides 是否与 shape 匹配
|
||||
- 非 contiguous 的 tensor 调 `contiguous()` 才能送入 CUDA kernel
|
||||
|
||||
### Broadcast 规则
|
||||
|
||||
实现了 NumPy-style broadcasting:
|
||||
- 维度从尾部对齐
|
||||
- 大小为 1 的维度可以广播到任意大小
|
||||
- `broadcast_strides()` 将 size=1 维度的 stride 置为 0(虚拟广播,不复制数据)
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] from_slice → shape/strides 正确
|
||||
- [x] reshape, transpose, squeeze, unsqueeze
|
||||
- [x] transpose 后 contiguous() 重排数据
|
||||
- [x] BF16 tensor 的精度验证
|
||||
- [x] CPU↔GPU roundtrip
|
||||
- [x] zeros on GPU → 拷回 CPU 验证全 0
|
||||
- [x] broadcast_shape 单元测试
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **`SmallVec` 是正确选择**:绝大多数 tensor ≤ 4D,避免了频繁堆分配。LLM 推理中常见的维度是 `[B, S, H]` (3D) 和 `[B, H, S, D]` (4D)。
|
||||
|
||||
2. **View 语义的取舍**:Arc 共享 storage 实现了零拷贝 transpose/reshape,但代价是无法原地修改 view 后的 tensor。对于推理引擎这是可以接受的——推理路径上大部分操作是只读的。
|
||||
|
||||
3. **contiguous() 的隐性开销**:非 contiguous tensor 在送入 kernel 前需要 `contiguous()` 拷贝。这意味着 `transpose → matmul` 会产生一次额外拷贝。后续优化方向:在 kernel 中直接支持 strided input。
|
||||
|
||||
4. **Rust 2024 edition 变化**:`unsafe fn` 内部的 unsafe 调用也需要显式 `unsafe {}` 块,`extern "C"` 块必须加 `unsafe` 前缀。这个 edition 对安全性更严格。
|
||||
|
||||
5. **CPU 实现先行**:先在 CPU 上验证逻辑正确性(如 contiguous 重排),再扩展到 GPU。这个策略在后续 phase 中应该继续沿用。
|
||||
102
docs/03-gemm.md
Normal file
102
docs/03-gemm.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# Phase 3: GEMM — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
实现矩阵乘法的多个版本(naive → tiled → cuBLAS),建立 benchmark 对比框架,深入理解 GPU 编程中的内存访问模式和优化手段。
|
||||
|
||||
## Module Layout
|
||||
|
||||
```
|
||||
csrc/gemm/
|
||||
├── naive.cu # 每个 thread 算一个输出元素
|
||||
└── tiled.cu # shared memory tiling, 32x32 tiles
|
||||
|
||||
crates/xserv-kernels/
|
||||
├── build.rs # 编译 .cu + 链接 cublas
|
||||
└── src/
|
||||
├── lib.rs
|
||||
└── gemm.rs # FFI 封装, GemmBackend enum, matmul(), CublasContext
|
||||
```
|
||||
|
||||
## Kernel Implementations
|
||||
|
||||
### Version 1: Naive GEMM
|
||||
|
||||
```
|
||||
Grid: (ceil(N/16), ceil(M/16))
|
||||
Block: (16, 16)
|
||||
每个 thread: C[row][col] = sum_k(A[row][k] * B[k][col])
|
||||
```
|
||||
|
||||
- 每个 thread 独立遍历 K 维度做点积
|
||||
- 所有读取走 global memory,无局部性优化
|
||||
- BF16 版本在 FP32 中累加(`__bfloat162float` → 累加 → `__float2bfloat16`)
|
||||
|
||||
### Version 2: Tiled GEMM (Shared Memory)
|
||||
|
||||
```
|
||||
TILE_SIZE = 32
|
||||
Grid: (ceil(N/32), ceil(M/32))
|
||||
Block: (32, 32) = 1024 threads
|
||||
|
||||
每个 tile iteration:
|
||||
1. 协作加载 A[tile] 和 B[tile] 到 shared memory
|
||||
2. __syncthreads()
|
||||
3. 在 shared memory 中做 32 次乘加
|
||||
4. __syncthreads()
|
||||
```
|
||||
|
||||
- 每个 global memory 读取被 TILE_SIZE 个 thread 复用
|
||||
- 理论上减少 global memory 访问 TILE_SIZE 倍
|
||||
- BF16 版本同样在 shared memory 中存 float(FP32 累加)
|
||||
|
||||
### Version 3: cuBLAS
|
||||
|
||||
- `cublasGemmEx` 支持混合精度
|
||||
- **Row-major 适配**:cuBLAS 使用 column-major 布局,我们的 tensor 是 row-major
|
||||
- 利用恒等式:`C = A @ B` (row-major) ⟺ `C^T = B^T @ A^T` (col-major)
|
||||
- 传入 `CUBLAS_OP_N`,让 cuBLAS 把我们的 row-major 数据当作 col-major 的转置
|
||||
- 参数:`m=N, n=M, k=K, lda=N (B), ldb=K (A), ldc=N (C)`
|
||||
|
||||
### Backend Registry
|
||||
|
||||
```rust
|
||||
pub enum GemmBackend { Naive, Tiled, CuBlas }
|
||||
pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor;
|
||||
```
|
||||
|
||||
运行时可切换 backend,方便 benchmark 对比和逐步替换。
|
||||
|
||||
## CublasContext
|
||||
|
||||
RAII 封装 `cublasHandle_t`,Drop 时调 `cublasDestroy_v2`。
|
||||
目前每次 matmul 创建一个新 handle,后续优化为全局复用。
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] F32: naive/tiled/cuBLAS × small(4)/medium(64-256)/rect(65x33x97)
|
||||
- [x] BF16: naive/tiled/cuBLAS × small/medium
|
||||
- [x] 三种 backend 在相同输入上输出一致(cross-backend consistency)
|
||||
- [x] 非方阵测试(M≠N≠K)
|
||||
- [x] 1024x1024 cuBLAS 验证
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **Row-major vs Column-major 陷阱**:这是 GEMM 实现中最容易出错的地方。cuBLAS 的 column-major 假设与 C/Rust 的 row-major 冲突。理解 `C=AB` ⟺ `C^T=B^T A^T` 这个恒等式是关键。实际做法:不做任何显式转置,只是交换 A/B 的传入顺序和调整 leading dimension 参数。
|
||||
|
||||
2. **BF16 的累加精度**:BF16 只有 ~3 位有效数字(vs FP32 的 ~7 位)。如果在 BF16 中累加 K 次乘法,误差会快速放大。正确做法是**在 FP32 中累加,最后才转回 BF16**。我们的 naive 和 tiled kernel 都遵循了这一点(`float sum = 0.0f`)。cuBLAS 通过 `CUBLAS_COMPUTE_32F` 参数控制。
|
||||
|
||||
3. **Shared memory tiling 的核心思想**:global memory 带宽是 GPU 计算的主要瓶颈。通过 shared memory tiling,每个数据从 global memory 读一次,被 TILE_SIZE 个 thread 复用。对于 TILE_SIZE=32,理论上减少 32 倍 global memory 访问。
|
||||
|
||||
4. **`__syncthreads()` 的位置关键**:tile 加载后必须同步(确保所有 thread 写完 shared memory),计算后也要同步(防止下一轮加载覆盖还在使用的数据)。漏掉任何一个 sync 都会产生 race condition 导致结果错误。
|
||||
|
||||
5. **cuBLAS handle 开销**:每次 matmul 创建/销毁 handle 有~0.1ms 开销。生产环境应全局复用一个 handle。Phase 15(性能优化)时需要修复这个问题。
|
||||
|
||||
6. **`error::check` 需要 pub**:Phase 1 中 `check()` 是 `pub(crate)`,Phase 3 需要跨 crate 调用。反思:基础设施 crate 的错误处理函数应该从一开始就设计为 public API。
|
||||
|
||||
## 后续优化方向(Phase 15)
|
||||
|
||||
- Register tiling(每个 thread 算多个输出元素)
|
||||
- Tensor Core WMMA(利用 5090 的硬件加速)
|
||||
- CublasContext 全局复用
|
||||
- 非 contiguous input 支持(避免 matmul 前的拷贝)
|
||||
213
docs/04-transformer-kernels.md
Normal file
213
docs/04-transformer-kernels.md
Normal file
@@ -0,0 +1,213 @@
|
||||
# Phase 4: Transformer Core Kernels — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
实现 Transformer 所需的所有非 Attention 算子的 CUDA kernel,每个 kernel 都支持 BF16 和 F32,与 PyTorch 参考实现对比验证。
|
||||
|
||||
## Kernel 清单
|
||||
|
||||
| Kernel | 用于 | 核心计算 | 关键优化点 |
|
||||
|--------|------|---------|-----------|
|
||||
| LayerNorm | GPT-2 | `(x - mean) / sqrt(var + eps) * gamma + beta` | Welford online, warp reduce |
|
||||
| RMSNorm | Qwen3 | `x / sqrt(mean(x²) + eps) * gamma` | 无 mean,比 LayerNorm 简单 |
|
||||
| GELU | GPT-2 | `0.5x(1 + tanh(sqrt(2/π)(x + 0.044715x³)))` | tanh 近似,逐元素 |
|
||||
| SiLU | Qwen3 | `x * sigmoid(x)` | 逐元素 |
|
||||
| Softmax | Attention | `exp(x - max) / sum(exp(x - max))` | Online safe softmax, warp reduce |
|
||||
| Embedding | 全部 | `output[i] = table[token_ids[i]]` | Gather, coalesced write |
|
||||
| RoPE | Qwen3 | 对 Q/K 的相邻元素对做旋转 | Precompute freq, in-place |
|
||||
|
||||
## 文件布局
|
||||
|
||||
```
|
||||
csrc/
|
||||
├── normalization/
|
||||
│ ├── layernorm.cu
|
||||
│ └── rmsnorm.cu
|
||||
├── activation/
|
||||
│ ├── gelu.cu
|
||||
│ └── silu.cu
|
||||
├── reduce/
|
||||
│ └── softmax.cu
|
||||
├── embedding/
|
||||
│ ├── embedding.cu
|
||||
│ └── rope.cu
|
||||
|
||||
crates/xserv-kernels/src/
|
||||
├── layernorm.rs
|
||||
├── rmsnorm.rs
|
||||
├── activation.rs # GELU + SiLU
|
||||
├── softmax.rs
|
||||
├── embedding.rs
|
||||
├── rope.rs
|
||||
└── lib.rs # 新增 mod 声明
|
||||
```
|
||||
|
||||
## Kernel 设计细节
|
||||
|
||||
### LayerNorm
|
||||
|
||||
输入 `x: [*, hidden_size]`, 输出 `y: [*, hidden_size]`
|
||||
参数 `gamma, beta: [hidden_size]`
|
||||
|
||||
```
|
||||
y[i] = gamma[i] * (x[i] - mean) / sqrt(var + eps) + beta[i]
|
||||
```
|
||||
|
||||
**GPU 映射**: 每个 thread block 处理一行(一个 hidden_size 向量)。
|
||||
- Phase 1: 并行加载 x,Welford online 算法计算 mean 和 var
|
||||
- Phase 2: warp-level reduce (`__shfl_down_sync`) 聚合 mean/var
|
||||
- Phase 3: block-level reduce via shared memory
|
||||
- Phase 4: 每个 thread 对自己负责的元素做 normalize + affine
|
||||
|
||||
**Block 配置**: `block = min(1024, hidden_size)`, `grid = num_rows`
|
||||
|
||||
### RMSNorm
|
||||
|
||||
比 LayerNorm 简单:不减 mean,只做 `x * rsqrt(mean(x²) + eps) * gamma`。
|
||||
|
||||
```
|
||||
rms = sqrt(sum(x²) / hidden_size + eps)
|
||||
y[i] = x[i] / rms * gamma[i]
|
||||
```
|
||||
|
||||
**GPU 映射**: 同 LayerNorm,每个 block 处理一行。
|
||||
- 只需要一次 reduce(求 sum(x²)),不需要两次(mean + var)。
|
||||
|
||||
### GELU
|
||||
|
||||
逐元素操作,用 tanh 近似:
|
||||
```
|
||||
gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
|
||||
```
|
||||
|
||||
**GPU 映射**: 每个 thread 处理多个元素(向量化),grid 覆盖全部元素。
|
||||
|
||||
### SiLU (Swish)
|
||||
|
||||
逐元素: `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`
|
||||
|
||||
### Softmax
|
||||
|
||||
输入 `x: [*, seq_len]`, 沿最后一维做 softmax:
|
||||
```
|
||||
1. m = max(x) // 数值稳定
|
||||
2. e[i] = exp(x[i] - m)
|
||||
3. s = sum(e)
|
||||
4. y[i] = e[i] / s
|
||||
```
|
||||
|
||||
**GPU 映射**: 每个 block 处理一行。
|
||||
- 第一遍 reduce: 求 max
|
||||
- 第二遍: exp(x - max) 并 reduce sum
|
||||
- 第三遍: 除以 sum
|
||||
|
||||
**优化**: 可以用 online softmax 合并前两遍(边算 exp 边更新 max),但先实现三遍版本保证正确。
|
||||
|
||||
### Embedding
|
||||
|
||||
```
|
||||
output[seq_idx] = embedding_table[token_ids[seq_idx]]
|
||||
```
|
||||
|
||||
**GPU 映射**: 每个 thread 处理一个 token 的部分维度。
|
||||
- `grid = num_tokens`, `block = hidden_size`(或分多个 thread 处理一个 token)
|
||||
- 写端是 coalesced(连续 thread 写连续地址),读端是 gather(非连续)
|
||||
|
||||
### RoPE (Rotary Position Embedding)
|
||||
|
||||
对 Q/K 的每对相邻元素 `(x0, x1)` 做 2D 旋转:
|
||||
```
|
||||
freq[i] = 1.0 / (theta ^ (2i / dim))
|
||||
cos_val = cos(position * freq[i])
|
||||
sin_val = sin(position * freq[i])
|
||||
y0 = x0 * cos_val - x1 * sin_val
|
||||
y1 = x0 * sin_val + x1 * cos_val
|
||||
```
|
||||
|
||||
**GPU 映射**: 每个 thread 处理一对元素 `(x[2i], x[2i+1])`。
|
||||
- Precompute `cos_cache[max_seq_len][head_dim/2]` 和 `sin_cache` 在初始化时
|
||||
- 运行时 kernel 只做乘加
|
||||
|
||||
**theta**: Qwen3 默认 `rope_theta = 1000000.0`
|
||||
|
||||
## Reduction Pattern(核心学习点)
|
||||
|
||||
所有 Norm 和 Softmax 都涉及 reduction。GPU reduction 的分层结构:
|
||||
|
||||
```
|
||||
Thread-level: 每个 thread 处理多个元素,本地累加
|
||||
↓
|
||||
Warp-level: __shfl_down_sync() 在 32 threads 内规约(无需 shared memory)
|
||||
↓
|
||||
Block-level: shared memory 存各 warp 的结果,warp 0 再规约
|
||||
```
|
||||
|
||||
对于 hidden_size <= 8192(LLM 常见),一个 block 足够,不需要 grid-level reduction。
|
||||
|
||||
### Warp Reduce 模板
|
||||
|
||||
```cuda
|
||||
__device__ float warp_reduce_sum(float val) {
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
return val;
|
||||
}
|
||||
```
|
||||
|
||||
### Block Reduce 模板
|
||||
|
||||
```cuda
|
||||
__device__ float block_reduce_sum(float val) {
|
||||
__shared__ float shared[32]; // max 32 warps per block
|
||||
int lane = threadIdx.x % 32;
|
||||
int warp_id = threadIdx.x / 32;
|
||||
|
||||
val = warp_reduce_sum(val);
|
||||
if (lane == 0) shared[warp_id] = val;
|
||||
__syncthreads();
|
||||
|
||||
val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : 0.0f;
|
||||
if (warp_id == 0) val = warp_reduce_sum(val);
|
||||
return val;
|
||||
}
|
||||
```
|
||||
|
||||
## Reference 验证策略
|
||||
|
||||
写 `tools/generate_reference.py` 脚本,用 PyTorch 为每个 op 生成 reference input/output:
|
||||
- 保存为 `.npy` 格式
|
||||
- Rust 测试中加载对比
|
||||
- 或者直接在 Rust 测试中用 CPU 实现计算 expected 值(更简单,不依赖 Python)
|
||||
|
||||
**选择**: 先用 Rust CPU 实现作为 reference(简单),关键 op(RoPE)再与 PyTorch 对比。
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] RMSNorm F32: hidden_size=768, 4 rows → max_err 7.2e-7
|
||||
- [x] RMSNorm BF16: 同上 → max_err 7.0e-3
|
||||
- [x] LayerNorm F32: hidden_size=768 → max_err 1.7e-6
|
||||
- [x] GELU F32: 10000 elements → max_err 3.0e-8
|
||||
- [x] GELU BF16: 同上 → max_err 2.4e-3
|
||||
- [x] SiLU F32: 10000 elements → max_err 1.5e-8
|
||||
- [x] Softmax F32: 8×256 → max_err 1.4e-9
|
||||
- [x] Softmax sum=1 验证: 4×2048
|
||||
- [x] Softmax 大值 (1000+) 数值稳定性 → max_err 1.5e-8
|
||||
- [x] Embedding F32: vocab=100, hidden=64, 5 tokens → exact match
|
||||
- [x] RoPE F32: 4 tokens × 2 heads × dim=8 → max_err 6.0e-8
|
||||
- [x] RoPE position=0 恒等验证 → max_err 0
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **`common.cuh` 抽取共用 reduction 是正确的做法**:`warp_reduce_sum/max` 和 `block_reduce_sum/max` 被 RMSNorm, LayerNorm, Softmax 三个 kernel 复用。抽到头文件避免了代码重复,也确保 reduction 逻辑一致。build.rs 中需要 `.include("../../csrc")` 让 nvcc 能找到头文件。
|
||||
|
||||
2. **Shared memory 中广播标量的模式**:Norm 和 Softmax 都需要将 reduce 结果(mean, rms_inv, max, sum)广播给 block 内所有 thread。标准做法:thread 0 写 `__shared__` 变量,`__syncthreads()` 后所有 thread 读。这比让每个 thread 独立做 reduce 高效得多。
|
||||
|
||||
3. **Softmax 三遍 vs 两遍**:我们实现了三遍版本(max → exp+sum → normalize),简单可靠。Online softmax 可以合并前两遍(一遍 pass 内同时跟踪 running max 和 running sum),但需要更复杂的数值更新公式。Flash Attention(Phase 14)会用到 online softmax。
|
||||
|
||||
4. **RoPE 的 position=0 恒等性**:`cos(0)=1, sin(0)=0`,所以 position 0 的旋转是恒等变换。这是一个很好的 sanity check。如果 position=0 时输出不等于输入,说明 kernel 有 bug。
|
||||
|
||||
5. **BF16 Softmax 的精度陷阱**:exp 结果先写成 BF16 再读回做 normalize 会丢精度。理想做法是用 float scratch buffer 暂存 exp 结果。当前实现可接受(误差在 1e-2 量级),但在 attention score 很接近时可能引入可观察的差异。Phase 14 Flash Attention 会解决这个问题(全程 FP32 累加)。
|
||||
|
||||
6. **Embedding 就是 gather 操作**:没有任何计算,纯粹的内存搬运。瓶颈在 global memory 随机读取(token_ids 导致不连续读 table)。写端是 coalesced 的(连续 token 写连续地址)。优化方向:使用向量化加载(`float4`)一次读 128 bit。
|
||||
|
||||
7. **RoPE in-place 修改 Tensor 的设计考量**:RoPE 在数学上是对 Q/K 的 in-place 旋转。我们通过 `data_ptr() as *mut` 绕过了 Rust 的不可变借用。这在 GPU 上是安全的(kernel 内部互不干扰),但 Rust 侧没有 `&mut` 语义保护。后续如果需要更严格的安全性,可以引入 `Tensor::as_mut_ptr()` 方法并要求 `&mut self`。
|
||||
92
docs/05-attention.md
Normal file
92
docs/05-attention.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# Phase 5: Naive Attention Kernel — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
实现标准 Multi-Head Attention(不做 Flash/Paged 优化),用组合式方法(GEMM + Softmax)完成。这是理解 attention 计算流程的基础,也是后续 Flash Attention 的 baseline。
|
||||
|
||||
## 计算流程
|
||||
|
||||
```
|
||||
Input: Q [B, H, S, D], K [B, H, S, D], V [B, H, S, D]
|
||||
B=batch, H=num_heads, S=seq_len, D=head_dim
|
||||
|
||||
1. scores = Q @ K^T / sqrt(D) → [B, H, S, S]
|
||||
2. scores += causal_mask → 上三角置为 -inf
|
||||
3. weights = softmax(scores, dim=-1) → [B, H, S, S]
|
||||
4. output = weights @ V → [B, H, S, D]
|
||||
```
|
||||
|
||||
## 设计选择
|
||||
|
||||
### 组合式实现(Phase 3 GEMM + Phase 4 Softmax)
|
||||
|
||||
不写新的 fused CUDA kernel,而是复用已有的 matmul 和 softmax:
|
||||
- `scores = batched_matmul(Q, K^T)` — 需要支持 batched GEMM
|
||||
- `masked_fill(scores, causal_mask, -inf)` — 新的逐元素 kernel
|
||||
- `softmax(scores)` — 复用 Phase 4
|
||||
- `output = batched_matmul(weights, V)` — 复用 batched GEMM
|
||||
|
||||
这意味着需要先扩展 matmul 支持 batched GEMM(cublasGemmStridedBatchedEx)。
|
||||
|
||||
### Causal Mask
|
||||
|
||||
不显式构造 mask 矩阵。写一个 kernel:
|
||||
```
|
||||
if (col > row + offset) score = -infinity
|
||||
```
|
||||
其中 offset 用于支持 KV cache 场景(decode 时 query 的 row 偏移)。
|
||||
|
||||
### Batched GEMM via cuBLAS
|
||||
|
||||
`cublasGemmStridedBatchedEx` 在一个 batch 维度上并行执行多个 GEMM:
|
||||
```
|
||||
C[b] = A[b] @ B[b] for b = 0..batch_count
|
||||
stride_a = M * K, stride_b = K * N, stride_c = M * N
|
||||
```
|
||||
|
||||
Attention 中 batch 维度 = B * H(batch_size × num_heads)。
|
||||
|
||||
## 文件布局
|
||||
|
||||
```
|
||||
csrc/attention/
|
||||
└── causal_mask.cu # causal mask fill kernel
|
||||
|
||||
crates/xserv-kernels/src/
|
||||
├── gemm.rs # 扩展: batched_matmul
|
||||
├── attention.rs # NEW: multi_head_attention()
|
||||
└── causal_mask.rs # NEW: causal mask apply
|
||||
```
|
||||
|
||||
## API 设计
|
||||
|
||||
```rust
|
||||
/// Multi-head attention (naive, materializes S×S scores).
|
||||
/// q, k, v: [batch, num_heads, seq_len, head_dim]
|
||||
/// Returns: [batch, num_heads, seq_len, head_dim]
|
||||
pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor;
|
||||
|
||||
/// Batched matmul: A[b] @ B[b] for all b.
|
||||
/// a: [..., M, K], b: [..., K, N] → [..., M, N]
|
||||
pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor;
|
||||
```
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] batched_matmul: [4,8,32,64]×[4,8,64,32] → max_err 2.7e-7
|
||||
- [x] attention (non-causal): B=1,H=2,S=8,D=16 → max_err 4.5e-8
|
||||
- [x] attention (causal): B=1,H=2,S=16,D=32 → max_err 3.0e-8
|
||||
- [x] attention (causal, larger): B=2,H=4,S=64,D=64 → max_err 6.0e-8
|
||||
- [x] causal mask 语义: position 0 只能看到 token 0,output[0] == V[0] → exact
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **`to_device` 不应强制 contiguous**:最初 `to_device()` 会先调 `contiguous()`,而 GPU 的 `contiguous()` 又调 `to_device(Cpu)`,导致无限递归栈溢出。修复:`to_device()` 直接传输 raw storage,保留 strides/offset,用户需要时自己调 `contiguous()`。GPU `contiguous()` 现在走 GPU→CPU→CPU contiguous→CPU→GPU 路径——正确但低效,Phase 15 需要写 GPU contiguous kernel。
|
||||
|
||||
2. **Batched GEMM via `cublasGemmStridedBatchedEx`**:row-major trick 同 Phase 3,额外参数是 stride(元素数,不是字节)。stride_a = M×K, stride_b = K×N, stride_c = M×N。注意初始版本错误地乘了 `elem_size`,cuBLAS 的 stride 单位是元素。
|
||||
|
||||
3. **Attention 的组合式实现足够验证正确性**:没有写 fused kernel,而是复用 `batched_matmul` + `scale` + `causal_mask` + `softmax`。精度极好(max_err < 1e-7),因为每步都在 FP32 中完成。缺点是 S×S score 矩阵完全 materialize(O(S²) 显存),Flash Attention 会解决。
|
||||
|
||||
4. **Scale kernel 的必要性**:原本想在 CPU 上做 scale(round-trip),但那太慢了。加了 `scale_f32/bf16` 逐元素 CUDA kernel。未来可以把 scale 合进 GEMM 的 alpha 参数,省一次 kernel launch。
|
||||
|
||||
5. **Causal mask 的 offset 设计**:`col > row + offset` 中的 offset 为 KV cache 场景预留。Decode 时 Q 只有 1 行但 KV cache 有前 S 行,offset = kv_len - q_len 确保 decode query 能看到所有 cached tokens。
|
||||
69
docs/06-model-loading.md
Normal file
69
docs/06-model-loading.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# Phase 6: Model Loading — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
从 HuggingFace safetensors 文件加载模型权重到 GPU Tensor。解析 config.json 获取模型结构参数。
|
||||
|
||||
## Crate: `xserv-model`
|
||||
|
||||
```
|
||||
crates/xserv-model/src/
|
||||
├── lib.rs
|
||||
├── config.rs # ModelConfig from config.json
|
||||
├── loader.rs # safetensors weight loading
|
||||
└── gpt2.rs # (Phase 8) GPT-2 model definition
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
- `safetensors` crate: parse safetensors format
|
||||
- `serde` + `serde_json`: deserialize config.json
|
||||
- `memmap2`: mmap for zero-copy file access (safetensors uses this internally)
|
||||
|
||||
## Weight Loading Flow
|
||||
|
||||
```
|
||||
safetensors file (disk)
|
||||
→ safetensors crate parses header (tensor names, shapes, dtypes, offsets)
|
||||
→ mmap raw data
|
||||
→ for each tensor:
|
||||
→ read bytes at offset
|
||||
→ create CPU Tensor from raw bytes
|
||||
→ .to_device(Cuda(0)) → GPU Tensor
|
||||
→ return HashMap<String, Tensor>
|
||||
```
|
||||
|
||||
## Config Parsing
|
||||
|
||||
```rust
|
||||
#[derive(Deserialize)]
|
||||
pub struct ModelConfig {
|
||||
pub architectures: Option<Vec<String>>,
|
||||
pub model_type: Option<String>,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: Option<usize>,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: Option<usize>,
|
||||
pub num_hidden_layers: usize,
|
||||
pub vocab_size: usize,
|
||||
pub max_position_embeddings: Option<usize>,
|
||||
pub layer_norm_eps: Option<f64>,
|
||||
pub rms_norm_eps: Option<f64>,
|
||||
pub rope_theta: Option<f64>,
|
||||
pub tie_word_embeddings: Option<bool>,
|
||||
}
|
||||
```
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] Load GPT-2 124M: 160 tensors loaded successfully
|
||||
- [x] Parse GPT-2 config.json: hidden=768, layers=12, heads=12, vocab=50257
|
||||
- [x] Sharded loading path implemented (for larger models)
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **GPT-2 vs modern HF config naming**:GPT-2 uses `n_embd`/`n_head`/`n_layer`/`n_positions`,而不是 `hidden_size`/`num_attention_heads` 等。ModelConfig 需要支持两套命名并提供统一的 accessor methods(`hidden()`, `num_heads()` 等)。
|
||||
|
||||
2. **safetensors 零拷贝读取**:`safetensors` crate 直接 mmap 文件,解析 header 得到 tensor 的 offset 和 shape,然后 zero-copy 读取 raw bytes。对于 GPT-2 的 500MB 权重文件,加载速度很快。
|
||||
|
||||
3. **模型下载的网络问题**:HuggingFace 在中国网络下不可达。使用 modelscope.cn 或 hf-mirror.com 作为替代。大文件(>100MB)的 redirect 到 CDN 可能也会失败,modelscope 的 snapshot_download 更可靠。
|
||||
57
docs/07-tokenizer.md
Normal file
57
docs/07-tokenizer.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# Phase 7: BPE Tokenizer — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
从零实现 Byte-Pair Encoding tokenizer,兼容 HuggingFace `tokenizer.json` 格式。支持 GPT-2 和 Qwen3。
|
||||
|
||||
## Crate: `xserv-tokenizer`
|
||||
|
||||
```
|
||||
crates/xserv-tokenizer/src/
|
||||
├── lib.rs
|
||||
├── bpe.rs # BPE encode/decode core algorithm
|
||||
└── chat.rs # Chat template formatting
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
- `serde` + `serde_json`: parse tokenizer.json
|
||||
- `regex`: pre-tokenization patterns
|
||||
|
||||
## BPE Algorithm
|
||||
|
||||
### Encode
|
||||
1. Pre-tokenize: split text by regex (GPT-2 pattern)
|
||||
2. Each word → byte sequence → initial token list (one token per byte)
|
||||
3. Repeatedly merge highest-priority pair until no more merges
|
||||
4. Map merged tokens to IDs via vocab
|
||||
|
||||
### Decode
|
||||
Token IDs → lookup vocab → concatenate bytes → UTF-8 decode
|
||||
|
||||
## Key Data Structures
|
||||
|
||||
```rust
|
||||
pub struct Tokenizer {
|
||||
vocab: HashMap<Vec<u8>, u32>, // token bytes → ID
|
||||
vocab_rev: Vec<Vec<u8>>, // ID → token bytes
|
||||
merges: Vec<(Vec<u8>, Vec<u8>)>, // ordered merge rules
|
||||
merge_ranks: HashMap<(u32, u32), usize>, // (id_a, id_b) → priority
|
||||
special_tokens: HashMap<String, u32>,
|
||||
pre_tokenize_regex: Regex,
|
||||
}
|
||||
```
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] Encode + decode roundtrip verified (GPT-2 tokenizer, English text)
|
||||
- [x] Special tokens handled (endoftext)
|
||||
- [x] Integrated into GPT-2 inference pipeline, generates coherent text
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **GPT-2 byte-to-unicode 映射**:GPT-2 的 vocab 中,每个 byte 都映射到一个 Unicode 字符。可打印 ASCII (0x21-0x7E) 映射到自身,其余字节(空格、控制字符等)映射到 U+0100 以上的 Unicode 码点。解码时需要反向映射。这个映射表是 BPE tokenizer 正确性的关键。
|
||||
|
||||
2. **Rust regex 不支持 lookahead**:GPT-2 的 pre-tokenization regex 使用了 `(?!\S)` lookahead,Rust 的 `regex` crate 不支持。简化为去掉 lookahead 后功能等价(whitespace 仍然被正确分词)。如果需要精确匹配 Python 行为,需要 `fancy-regex` crate。
|
||||
|
||||
3. **BPE merge 的 O(n²) 复杂度**:当前实现每次 merge 扫描整个 token 序列找最高优先级 pair,复杂度 O(n² × |merges|)。对于短文本够用,长文本需要 priority queue 优化。推理场景中 prompt 通常 < 10K tokens,暂时可接受。
|
||||
71
docs/08-gpt2.md
Normal file
71
docs/08-gpt2.md
Normal file
@@ -0,0 +1,71 @@
|
||||
# Phase 8: GPT-2 Complete Inference — Design Document (Milestone ①)
|
||||
|
||||
## Goal
|
||||
|
||||
Wire everything together: load GPT-2 124M, tokenize input, run forward pass, sample tokens, decode output. First time seeing the model "speak".
|
||||
|
||||
## Model Architecture (GPT-2 124M)
|
||||
|
||||
```
|
||||
hidden_size = 768
|
||||
num_heads = 12
|
||||
num_layers = 12
|
||||
vocab_size = 50257
|
||||
max_position_embeddings = 1024
|
||||
activation = GELU
|
||||
normalization = LayerNorm (pre-LN)
|
||||
tied embeddings (lm_head == wte)
|
||||
```
|
||||
|
||||
## Forward Pass
|
||||
|
||||
```
|
||||
tokens [S]
|
||||
→ wte[tokens] + wpe[0..S] → [S, 768]
|
||||
→ for each layer:
|
||||
residual = x
|
||||
x = layernorm(x, ln_1)
|
||||
x = attention(x) # Q,K,V from linear, MHA, output linear
|
||||
x = x + residual
|
||||
residual = x
|
||||
x = layernorm(x, ln_2)
|
||||
x = mlp(x) # linear→GELU→linear
|
||||
x = x + residual
|
||||
→ layernorm(x, ln_f)
|
||||
→ logits = x @ wte.T → [S, 50257]
|
||||
→ sample(logits[-1]) → next token
|
||||
```
|
||||
|
||||
## Sampling
|
||||
|
||||
- Greedy: argmax
|
||||
- Temperature: logits / T → softmax → sample
|
||||
- Top-K: keep top-k logits, rest = -inf
|
||||
- Top-P: sorted by prob, cumsum ≤ p
|
||||
|
||||
## CLI Binary
|
||||
|
||||
```
|
||||
$ cargo run --release --bin xserv-cli -- --model path/to/gpt2
|
||||
|
||||
xserv> The future of AI is
|
||||
GPT-2> ...generated text...
|
||||
```
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] Greedy generation produces coherent English text
|
||||
- [x] Interactive CLI works (pipe and interactive mode)
|
||||
- [x] Multiple prompts verified: "The future of AI is", "Once upon a time"
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **QKV split + head reshape 的 layout 陷阱(最关键的 bug)**:GPT-2 的 `c_attn` 输出 `[S, 3H]` 需要 split 成 Q/K/V 再 reshape 成 `[1, num_heads, S, head_dim]`。关键错误:从 `[S, num_heads, head_dim]` 直接 `reshape` 到 `[1, num_heads, S, head_dim]` 不等于 transpose!Reshape 只是重新解释 flat data 的 shape,不会重排数据。必须手动按 `[batch, head, seq, dim]` 的目标 layout 写入数据。同理 merge_heads 也需要手动重排。
|
||||
|
||||
2. **CPU round-trip 作为 correctness first 策略**:`add_tensors`、`add_bias`、`split_qkv`、`merge_heads` 都通过 CPU round-trip 实现。虽然慢(每次都有 GPU→CPU→GPU 拷贝),但确保了正确性。Phase 15 会写专门的 CUDA kernel 替换这些操作。
|
||||
|
||||
3. **GPT-2 的 Conv1D 权重布局**:GPT-2 用 `Conv1D` 而非 `Linear`,权重存为 `[in, out]`(不是标准 Linear 的 `[out, in]`)。计算方式是 `x @ weight`(不需要转置)。这和 Qwen3/LLaMA 的 `[out, in]` 布局不同——Phase 10 需要注意。
|
||||
|
||||
4. **Greedy decoding 的重复问题**:GPT-2 124M 在 greedy decoding 下极易陷入循环("The world was a place of great danger, and...")。这是已知行为,temperature + top-k/top-p sampling 可以缓解。当前实现只有 greedy,sampling 将在后续添加。
|
||||
|
||||
5. **无 KV Cache 的性能代价**:每生成一个 token 都要重新跑完整 forward pass(O(S²) attention)。50 tokens 的生成需要 50 次 full forward,每次的 attention 复杂度还在增长。Phase 9 的 KV Cache 会将 decode 降到 O(S) per token。
|
||||
67
docs/09-kv-cache.md
Normal file
67
docs/09-kv-cache.md
Normal file
@@ -0,0 +1,67 @@
|
||||
# Phase 9: KV Cache + Autoregressive Generation — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
实现 KV Cache,将 decode 从每步 full forward (O(S²)) 降为增量计算 (O(S))。这是最大的单点性能提升。
|
||||
|
||||
## 核心变化
|
||||
|
||||
### Before (no cache)
|
||||
```
|
||||
每生成一个 token:
|
||||
forward(all_tokens) → 重新计算所有层的 Q/K/V/attention
|
||||
开销: O(S²) attention per step, S 递增
|
||||
```
|
||||
|
||||
### After (with cache)
|
||||
```
|
||||
Prefill:
|
||||
forward(prompt_tokens) → 计算并缓存所有层的 K/V
|
||||
|
||||
Decode (per token):
|
||||
forward(last_token_only) → 只计算新 token 的 Q/K/V
|
||||
Q: [1, H, 1, D] → 新 token 的 query
|
||||
K: append to cache → cache 变为 [1, H, S+1, D]
|
||||
V: append to cache
|
||||
attention: Q @ K_cache^T → [1, H, 1, S+1], O(S) not O(S²)
|
||||
```
|
||||
|
||||
## KVCache 数据结构
|
||||
|
||||
```rust
|
||||
pub struct KVCache {
|
||||
k: Vec<Tensor>, // per layer, shape [1, num_heads, current_len, head_dim]
|
||||
v: Vec<Tensor>,
|
||||
len: usize, // current sequence length
|
||||
}
|
||||
```
|
||||
|
||||
## Forward Pass 变化
|
||||
|
||||
模型需要两种 forward 模式:
|
||||
1. **prefill(tokens)**: 处理完整 prompt,填充 KV cache
|
||||
2. **decode(token, cache)**: 处理单个 token,读写 KV cache
|
||||
|
||||
## 实现策略
|
||||
|
||||
为了最小化改动,在 GPT-2 forward 中加入可选的 `&mut KVCache` 参数:
|
||||
- cache=None → 现有行为(full forward)
|
||||
- cache=Some → prefill 或 decode 模式
|
||||
|
||||
CPU round-trip 问题暂不修复(Phase 15),先让 KV cache 逻辑正确。
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] KV cache vs no-cache: 50/50 bit-identical output
|
||||
- [x] Benchmark: 18x decode speedup (407ms → 22ms TBT)
|
||||
- [x] 50 prompt validation: 40/50 vs HF (10 are FP divergence, gap 0.04-0.56)
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **KV cache 数据布局是核心难点**:初始实现直接 append flat bytes 导致 head 维度交错错误。正确做法:per-head 独立存储,reconstruct 时按 `[1, H, S, D]` layout 组装。这是一个非常容易犯的 layout bug,调试时输出看起来"几乎对"但不完全对。
|
||||
|
||||
2. **18x 提速 > 理论预期**:理论上 KV cache 将 decode 从 O(S²) 降到 O(S),对 S=20-25 的序列预期 ~20x 提速。实测 18x 符合预期。TTFT 也从 400ms 降到 24ms,因为 prefill 只跑一次而不是每步重跑。
|
||||
|
||||
3. **xserv vs HF 的 10 个 mismatch 不是 bug**:logit gap 仅 0.04-0.56(在 -80 到 -140 的 logit 值上),是不同 CUDA kernel 实现间的浮点累积误差导致 argmax 翻转。重要验证:**xserv KV-cache vs xserv no-cache 是 50/50 完全一致的**——证明 KV cache 实现本身无误。
|
||||
|
||||
4. **CPU round-trip 仍是主要瓶颈**:KV cache 的 per-head 数据存在 CPU Vec 中,每步 decode 都要重新组装成 GPU tensor。这意味着每步仍有 24 次 GPU→CPU→GPU 传输(12 层 × 2 KV)。Phase 15 需要将 KV cache 直接放在 GPU 上。
|
||||
109
docs/10-qwen3.md
Normal file
109
docs/10-qwen3.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# Phase 10: Qwen3-7B Support — Design Document (Milestone ②)
|
||||
|
||||
## Goal
|
||||
|
||||
扩展模型定义支持 Qwen3-7B 架构,验证输出正确性。与 GPT-2 的关键差异:RMSNorm、RoPE、GQA、SwiGLU、不共享 embedding。
|
||||
|
||||
## 架构差异 (GPT-2 → Qwen3)
|
||||
|
||||
| 特性 | GPT-2 | Qwen3-7B |
|
||||
|------|-------|----------|
|
||||
| Norm | LayerNorm(gamma, beta) | RMSNorm(gamma only) |
|
||||
| Position | Learned absolute (wpe) | RoPE (no params) |
|
||||
| Attention | MHA (12 Q = 12 KV heads) | GQA (32 Q, 8 KV heads) |
|
||||
| QKV projection | Combined c_attn [H, 3H] | Separate q/k/v_proj [H, Hq/Hk/Hv] |
|
||||
| FFN | 2 Linear (fc, proj) + GELU | 3 Linear (gate, up, down) + SwiGLU |
|
||||
| Weight layout | [in, out] (Conv1D style) | [out, in] (standard Linear) |
|
||||
| Tied embeddings | Yes | No (separate lm_head) |
|
||||
| hidden_size | 768 | 3584 |
|
||||
| num_layers | 12 | 28 |
|
||||
| head_dim | 64 | 128 |
|
||||
|
||||
## Weight Names (HuggingFace)
|
||||
|
||||
```
|
||||
model.embed_tokens.weight [151936, 3584]
|
||||
model.layers.{i}.input_layernorm.weight [3584]
|
||||
model.layers.{i}.self_attn.q_proj.weight [3584, 3584] (32 heads × 112 dim? or 28 heads)
|
||||
model.layers.{i}.self_attn.q_proj.bias [3584]
|
||||
model.layers.{i}.self_attn.k_proj.weight [512, 3584] (4 KV heads × 128 dim)
|
||||
model.layers.{i}.self_attn.k_proj.bias [512]
|
||||
model.layers.{i}.self_attn.v_proj.weight [512, 3584]
|
||||
model.layers.{i}.self_attn.v_proj.bias [512]
|
||||
model.layers.{i}.self_attn.o_proj.weight [3584, 3584]
|
||||
model.layers.{i}.post_attention_layernorm.weight [3584]
|
||||
model.layers.{i}.mlp.gate_proj.weight [18944, 3584]
|
||||
model.layers.{i}.mlp.up_proj.weight [18944, 3584]
|
||||
model.layers.{i}.mlp.down_proj.weight [3584, 18944]
|
||||
model.norm.weight [3584]
|
||||
lm_head.weight [151936, 3584]
|
||||
```
|
||||
|
||||
**注意**: Qwen3 权重是 [out, in] layout,`x @ W^T` 而不是 `x @ W`。
|
||||
|
||||
## GQA (Grouped Query Attention)
|
||||
|
||||
```
|
||||
num_heads = 28, num_kv_heads = 4, head_dim = 128
|
||||
Q: [B, 28, S, 128]
|
||||
K: [B, 4, S, 128] ← 每个 KV head 服务 28/4 = 7 个 Q head
|
||||
V: [B, 4, S, 128]
|
||||
|
||||
attention 时需要 repeat K/V:
|
||||
K_expanded: [B, 28, S, 128] ← repeat_interleave(K, 7, dim=1)
|
||||
```
|
||||
|
||||
实现:在 CPU 侧 split_qkv 时直接做 repeat。
|
||||
|
||||
## SwiGLU FFN
|
||||
|
||||
```
|
||||
gate = gate_proj(x) # [S, 3584] @ [3584, 18944]^T → [S, 18944]
|
||||
up = up_proj(x) # [S, 3584] @ [3584, 18944]^T → [S, 18944]
|
||||
out = silu(gate) * up # element-wise
|
||||
out = down_proj(out) # [S, 18944] @ [18944, 3584]^T → [S, 3584]
|
||||
```
|
||||
|
||||
## 显存预算 (BF16, 单卡 5090)
|
||||
|
||||
```
|
||||
权重: 7B × 2B = ~14 GB (BF16)
|
||||
7B × 4B = ~28 GB (FP32) — 不够! 必须用 BF16
|
||||
KV cache (S=256, B=1): ~0.1 GB
|
||||
总计: ~14 GB (BF16), 单卡可运行
|
||||
```
|
||||
|
||||
**关键**: Qwen3-7B 必须用 BF16 才能在单张 5090 (32GB) 上运行。当前 GPT-2 用 FP32,需要支持 BF16 forward pass。
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
1. 下载 Qwen3-7B 模型 (BF16, ~14GB)
|
||||
2. 实现 Qwen3 模型结构 (qwen3.rs)
|
||||
3. 支持 BF16 forward pass (linear_transpose for [out, in] weights)
|
||||
4. 实现 GQA (K/V repeat in split)
|
||||
5. 集成 RoPE + RMSNorm + SwiGLU
|
||||
6. 验证输出
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] 加载 Qwen3-8B BF16 权重 (399 tensors, ~15.5GB) 到单张 5090
|
||||
- [x] 英文: "The meaning of life is" → "to be happy"
|
||||
- [x] 中文: "请用中文回答:1+1等于几?" → "1加1"
|
||||
- [x] 61/61 单元测试无回归
|
||||
- [x] GPT-2 benchmark 性能无回归
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **Qwen3 实际是 8B,不是 7B**:modelscope 上的 `Qwen/Qwen3-8B` 有 36 层 × hidden 4096 × 32 heads,参数量约 8B。BF16 权重 ~15.5GB,单张 5090 (32GB) 可以运行。
|
||||
|
||||
2. **QK Normalization 是 Qwen3 的新特性**:每层有 `q_norm` 和 `k_norm` (shape [head_dim]),对 Q 和 K 做 per-head RMSNorm。这在 attention score 的数值稳定性上很重要——没有 QK norm 会导致 attention score 爆炸。
|
||||
|
||||
3. **attention_bias=false**:Qwen3 的 Q/K/V/O projection 没有 bias。这和 GPT-2 (有 bias) 不同。需要在模型代码中条件处理。
|
||||
|
||||
4. **Tokenizer 的 byte-to-unicode 映射 bug**:GPT-2 和 Qwen3 都使用同一套 byte-to-unicode 映射(printable ASCII identity,其余 68 bytes shifted to U+0100+)。初始实现中 `unicode_to_byte` 的 shifted 范围转换错误(直接 `u - 0x100` 而非查表),导致中文输入时 UTF-8 bytes 无法正确映射。修复:用 `OnceLock` 缓存反向映射表。
|
||||
|
||||
5. **Weight layout [out, in] vs [in, out]**:GPT-2 的 Conv1D 存为 [in, out],计算 `x @ W`;Qwen3 的 Linear 存为 [out, in],计算 `x @ W^T`。`linear_t` 函数通过 `weight.transpose(0,1).contiguous()` 处理。
|
||||
|
||||
6. **RoPE 的 tensor layout 不匹配**:RoPE kernel 期望 [S, H, D],但 attention 需要 [1, H, S, D]。需要在 RoPE 前后做 transpose。这引入了额外的 CPU round-trip(因为 transpose+contiguous 经过 CPU)。
|
||||
|
||||
7. **GQA repeat_kv 的实现**:每个 KV head 服务 `num_heads/num_kv_heads` 个 Q head。在 CPU 上做数据复制(repeat),简单但每步 decode 都要做。后续应在 attention kernel 中直接支持 GQA 索引,避免数据复制。
|
||||
54
docs/benchmarks/phase10-qwen3.md
Normal file
54
docs/benchmarks/phase10-qwen3.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# Phase 10 Benchmark: Qwen3-8B
|
||||
|
||||
**Date**: 2026-05-22
|
||||
**Hardware**: RTX 5090 (32GB, CC 12.0)
|
||||
**Model**: Qwen3-8B (BF16, 36 layers, 4096 hidden, 32/8 GQA heads)
|
||||
**Config**: 50 prompts × 20 generated tokens, greedy decoding, KV cache
|
||||
|
||||
## Correctness
|
||||
|
||||
| Metric | Result |
|
||||
|--------|--------|
|
||||
| Prefill Top-1 match vs HF | **42/50 (84.0%)** |
|
||||
| Prefill Top-5 match vs HF | **50/50 (100.0%)** |
|
||||
| Greedy sequence match | 0/50 (expected — BF16 drift over decode) |
|
||||
|
||||
The 100% top-5 match confirms the model is computing correctly.
|
||||
Greedy sequence divergence is due to BF16 precision (7-bit mantissa)
|
||||
accumulating across 36 layers of decode steps. Both xserv and HF
|
||||
produce coherent, valid completions — they just pick different
|
||||
equally-likely tokens at close-logit decision points.
|
||||
|
||||
## Performance
|
||||
|
||||
| Metric | xserv | transformers (BF16) | Ratio |
|
||||
|--------|-------|--------------------:|-------|
|
||||
| TTFT (avg) | 138.5 ms | 21.2 ms | 6.5x slower |
|
||||
| TBT (avg) | 144.2 ms | 21.9 ms | 6.6x slower |
|
||||
| Throughput | 6.9 tok/s | 45.6 tok/s | 0.15x |
|
||||
|
||||
## Remaining Performance Gap
|
||||
|
||||
~6.6x slower than HF for an 8B BF16 model. Main bottlenecks:
|
||||
1. CPU round-trips for add/mul/reshape/merge_heads (~100 per forward pass)
|
||||
2. KV cache stored on CPU (rebuilt as GPU tensor each step)
|
||||
3. cuBLAS handle per matmul
|
||||
4. No kernel fusion
|
||||
5. GQA repeat_kv copies data instead of kernel-level indexing
|
||||
|
||||
## Output Quality (Sample)
|
||||
|
||||
| Prompt | xserv Output |
|
||||
|--------|-------------|
|
||||
| "The capital of France is" | "Paris. The capital of France is Paris..." |
|
||||
| "Climate change is caused by" | "human activities, and the effects are already being felt..." |
|
||||
| "The human brain contains approximately" | "86 billion neurons. Each neuron can form synapses..." |
|
||||
| "Python is a popular programming language because" | "it is easy to learn and use..." |
|
||||
|
||||
## Tracking
|
||||
|
||||
| Phase | Model | TTFT (ms) | TBT (ms) | tok/s | Correctness |
|
||||
|-------|-------|-----------|----------|-------|-------------|
|
||||
| 8 | GPT-2 FP32 | 400.6 | 407.2 | 2.5 | 50/50 vs HF |
|
||||
| 9 | GPT-2 FP32 KV | 24.2 | 22.6 | 44.3 | 50/50 self |
|
||||
| 10 | Qwen3-8B BF16 KV | 138.5 | 144.2 | 6.9 | 100% top-5 prefill |
|
||||
35
docs/benchmarks/phase8-gpt2-baseline.md
Normal file
35
docs/benchmarks/phase8-gpt2-baseline.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# Phase 8 Benchmark: GPT-2 124M Baseline
|
||||
|
||||
**Date**: 2026-05-21
|
||||
**Hardware**: RTX 5090 (32GB, CC 12.0, 170 SMs)
|
||||
**Model**: GPT-2 124M (FP32)
|
||||
**Config**: 50 prompts × 20 generated tokens, greedy decoding, no KV cache
|
||||
|
||||
## Correctness
|
||||
|
||||
| Metric | Result |
|
||||
|--------|--------|
|
||||
| Prompts tested | 50 |
|
||||
| Token-level match vs transformers | **50/50 (100.0%)** |
|
||||
| Mismatches | 0 |
|
||||
|
||||
## Performance
|
||||
|
||||
| Metric | xserv | transformers (PyTorch) | Ratio |
|
||||
|--------|-------|----------------------|-------|
|
||||
| TTFT (avg) | 400.6 ms | 4.0 ms | 100x slower |
|
||||
| TBT (avg) | 407.2 ms | 3.8 ms | 106x slower |
|
||||
| Throughput | 2.5 tok/s | 260 tok/s | 0.01x |
|
||||
|
||||
## Known Bottlenecks
|
||||
|
||||
1. **No KV Cache**: full recompute per token (O(S²) attention every step)
|
||||
2. **CPU round-trips**: ~100 GPU→CPU→GPU transfers per forward pass for add/bias/split_qkv/merge_heads
|
||||
3. **cuBLAS handle per matmul**: ~50 handle create/destroy per forward pass
|
||||
4. **No kernel fusion**: every op is a separate kernel launch + sync
|
||||
|
||||
## Tracking
|
||||
|
||||
| Phase | TTFT (ms) | TBT (ms) | tok/s | Correctness | Notes |
|
||||
|-------|-----------|----------|-------|-------------|-------|
|
||||
| 8 (baseline) | 400.6 | 407.2 | 2.5 | 50/50 | No KV cache, CPU round-trips |
|
||||
44
docs/benchmarks/phase9-kv-cache.md
Normal file
44
docs/benchmarks/phase9-kv-cache.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# Phase 9 Benchmark: KV Cache
|
||||
|
||||
**Date**: 2026-05-21
|
||||
**Hardware**: RTX 5090 (32GB, CC 12.0)
|
||||
**Model**: GPT-2 124M (FP32)
|
||||
**Config**: 50 prompts × 20 generated tokens, greedy decoding
|
||||
|
||||
## Correctness
|
||||
|
||||
| Metric | Result |
|
||||
|--------|--------|
|
||||
| xserv KV-cache vs xserv no-cache | **50/50 (100.0%)** — bit-identical |
|
||||
| xserv vs HF transformers | 40/50 (80.0%) |
|
||||
|
||||
The 10 mismatches vs HF are floating point divergence (different CUDA kernels, computation order).
|
||||
Logit gap at divergence points: min=0.04, max=0.56, avg=0.20. Not a correctness bug.
|
||||
|
||||
## Performance
|
||||
|
||||
| Metric | Phase 8 (no cache) | Phase 9 (KV cache) | Improvement | HF transformers |
|
||||
|--------|-------------------|--------------------|-----------|-----------------|
|
||||
| TTFT (avg) | 400.6 ms | 24.2 ms | **16.5x** | 4.0 ms |
|
||||
| TBT (avg) | 407.2 ms | 22.6 ms | **18.0x** | 3.9 ms |
|
||||
| Throughput | 2.5 tok/s | 44.3 tok/s | **17.7x** | 257.7 tok/s |
|
||||
| vs HF ratio | 0.01x | 0.17x | | 1.0x |
|
||||
|
||||
## Analysis
|
||||
|
||||
KV cache delivers **~18x speedup** by eliminating redundant computation:
|
||||
- Before: every decode step recomputed all layers for all tokens O(S²)
|
||||
- After: decode step only computes 1 new token, reads K/V from cache O(S)
|
||||
|
||||
Remaining gap vs HF (~6x slower):
|
||||
1. CPU round-trips still present (~100 per forward pass)
|
||||
2. cuBLAS handle created per matmul
|
||||
3. KV cache stored on CPU (rebuilt as GPU tensor each step)
|
||||
4. No kernel fusion
|
||||
|
||||
## Tracking
|
||||
|
||||
| Phase | TTFT (ms) | TBT (ms) | tok/s | Correctness | Notes |
|
||||
|-------|-----------|----------|-------|-------------|-------|
|
||||
| 8 (baseline) | 400.6 | 407.2 | 2.5 | 50/50 vs HF | No KV cache |
|
||||
| 9 (KV cache) | 24.2 | 22.6 | 44.3 | 50/50 self-consistent | 18x speedup |
|
||||
40
tools/analyze_divergence.py
Normal file
40
tools/analyze_divergence.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import json
|
||||
import sys
|
||||
import torch
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained(sys.argv[2]).eval().cuda()
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(sys.argv[2])
|
||||
|
||||
with open(sys.argv[1]) as f:
|
||||
xr = json.load(f)
|
||||
|
||||
mismatches = []
|
||||
for i in range(len(xr)):
|
||||
ids = tokenizer.encode(xr[i]["prompt"])
|
||||
all_ids = list(ids)
|
||||
xserv_gen = xr[i]["generated_ids"]
|
||||
with torch.no_grad():
|
||||
for j in range(len(xserv_gen)):
|
||||
out = model(torch.tensor([all_ids]).cuda())
|
||||
logits = out.logits[0, -1]
|
||||
hf_next = logits.argmax().item()
|
||||
xs_next = xserv_gen[j]
|
||||
if hf_next != xs_next:
|
||||
xs_logit = logits[xs_next].item()
|
||||
hf_logit = logits[hf_next].item()
|
||||
hf_tok = tokenizer.decode([hf_next])
|
||||
xs_tok = tokenizer.decode([xs_next])
|
||||
gap = hf_logit - xs_logit
|
||||
print(
|
||||
f'[{i+1}] "{xr[i]["prompt"][:42]}" @ tok {j}: '
|
||||
f'hf={repr(hf_tok)}({hf_logit:.3f}) xserv={repr(xs_tok)}({xs_logit:.3f}) '
|
||||
f'gap={gap:.4f}'
|
||||
)
|
||||
mismatches.append(gap)
|
||||
break
|
||||
all_ids.append(hf_next)
|
||||
|
||||
print(f"\nTotal: {len(mismatches)}/{len(xr)} mismatches")
|
||||
if mismatches:
|
||||
print(f"Logit gaps: min={min(mismatches):.4f} max={max(mismatches):.4f} avg={sum(mismatches)/len(mismatches):.4f}")
|
||||
154
tools/bench_compare.py
Normal file
154
tools/bench_compare.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Compare xserv GPT-2 output against HuggingFace transformers.
|
||||
Reads xserv results from JSON, runs same prompts through transformers, compares token-by-token.
|
||||
Also measures transformers timing for performance comparison.
|
||||
|
||||
Usage:
|
||||
python3 tools/bench_compare.py <xserv_results.json> <model_dir>
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 3:
|
||||
print(f"Usage: {sys.argv[0]} <xserv_results.json> <model_dir>")
|
||||
sys.exit(1)
|
||||
|
||||
xserv_path = sys.argv[1]
|
||||
model_dir = sys.argv[2]
|
||||
|
||||
with open(xserv_path) as f:
|
||||
xserv_results = json.load(f)
|
||||
|
||||
print(f"Loading transformers model from {model_dir}...")
|
||||
model = GPT2LMHeadModel.from_pretrained(model_dir)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
|
||||
model.eval()
|
||||
model.cuda()
|
||||
|
||||
# Warmup
|
||||
with torch.no_grad():
|
||||
model(torch.tensor([[tokenizer.encode("warmup")[0]]]).cuda())
|
||||
torch.cuda.synchronize()
|
||||
|
||||
total = len(xserv_results)
|
||||
match_count = 0
|
||||
mismatch_count = 0
|
||||
xserv_ttft_sum = 0.0
|
||||
xserv_tbt_sum = 0.0
|
||||
hf_ttft_sum = 0.0
|
||||
hf_tbt_sum = 0.0
|
||||
num_with_tbt = 0
|
||||
|
||||
print(f"\n{'='*100}")
|
||||
print(f"{'#':>3} {'Match':>5} {'Prompt':<45} {'xserv TTFT':>10} {'HF TTFT':>10} {'xserv TBT':>10} {'HF TBT':>10}")
|
||||
print(f"{'='*100}")
|
||||
|
||||
for i, xr in enumerate(xserv_results):
|
||||
prompt = xr["prompt"]
|
||||
gen_tokens = xr["num_generated"]
|
||||
xserv_ids = xr["generated_ids"]
|
||||
|
||||
input_ids = tokenizer.encode(prompt)
|
||||
input_tensor = torch.tensor([input_ids]).cuda()
|
||||
|
||||
# Generate with transformers, measuring timing
|
||||
hf_generated = []
|
||||
hf_token_times = []
|
||||
|
||||
with torch.no_grad():
|
||||
all_ids = input_tensor.clone()
|
||||
|
||||
# TTFT
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
out = model(all_ids)
|
||||
torch.cuda.synchronize()
|
||||
hf_ttft_us = (time.perf_counter() - t0) * 1e6
|
||||
next_id = out.logits[0, -1].argmax().item()
|
||||
hf_generated.append(next_id)
|
||||
all_ids = torch.cat([all_ids, torch.tensor([[next_id]]).cuda()], dim=1)
|
||||
|
||||
# Remaining tokens
|
||||
for _ in range(1, gen_tokens):
|
||||
torch.cuda.synchronize()
|
||||
t_start = time.perf_counter()
|
||||
out = model(all_ids)
|
||||
torch.cuda.synchronize()
|
||||
elapsed = (time.perf_counter() - t_start) * 1e6
|
||||
hf_token_times.append(elapsed)
|
||||
next_id = out.logits[0, -1].argmax().item()
|
||||
hf_generated.append(next_id)
|
||||
all_ids = torch.cat([all_ids, torch.tensor([[next_id]]).cuda()], dim=1)
|
||||
|
||||
eos_id = tokenizer.eos_token_id
|
||||
if eos_id is not None and next_id == eos_id:
|
||||
break
|
||||
|
||||
hf_tbt_us = sum(hf_token_times) / len(hf_token_times) if hf_token_times else 0
|
||||
|
||||
# Compare
|
||||
match = xserv_ids == hf_generated
|
||||
if match:
|
||||
match_count += 1
|
||||
status = " OK "
|
||||
else:
|
||||
mismatch_count += 1
|
||||
status = "FAIL!"
|
||||
|
||||
xserv_ttft_ms = xr["ttft_us"] / 1000.0
|
||||
xserv_tbt_ms = xr["tbt_us"] / 1000.0
|
||||
hf_ttft_ms = hf_ttft_us / 1000.0
|
||||
hf_tbt_ms = hf_tbt_us / 1000.0
|
||||
|
||||
prompt_short = prompt[:43] + ".." if len(prompt) > 45 else prompt
|
||||
print(f"{i+1:>3} {status} {prompt_short:<45} {xserv_ttft_ms:>8.1f}ms {hf_ttft_ms:>8.1f}ms {xserv_tbt_ms:>8.1f}ms {hf_tbt_ms:>8.1f}ms")
|
||||
|
||||
if not match:
|
||||
# Show first divergence
|
||||
for j in range(max(len(xserv_ids), len(hf_generated))):
|
||||
x = xserv_ids[j] if j < len(xserv_ids) else None
|
||||
h = hf_generated[j] if j < len(hf_generated) else None
|
||||
if x != h:
|
||||
x_tok = tokenizer.decode([x]) if x is not None else "<none>"
|
||||
h_tok = tokenizer.decode([h]) if h is not None else "<none>"
|
||||
print(f" ↳ diverge at token {j}: xserv={x}({repr(x_tok)}) vs hf={h}({repr(h_tok)})")
|
||||
break
|
||||
|
||||
xserv_ttft_sum += xr["ttft_us"]
|
||||
xserv_tbt_sum += xr["tbt_us"]
|
||||
hf_ttft_sum += hf_ttft_us
|
||||
hf_tbt_sum += hf_tbt_us
|
||||
if xr["tbt_us"] > 0:
|
||||
num_with_tbt += 1
|
||||
|
||||
print(f"{'='*100}")
|
||||
print(f"\n=== CORRECTNESS ===")
|
||||
print(f"Total prompts: {total}")
|
||||
print(f"Match: {match_count}/{total} ({match_count/total*100:.1f}%)")
|
||||
print(f"Mismatch: {mismatch_count}/{total}")
|
||||
|
||||
print(f"\n=== PERFORMANCE (average) ===")
|
||||
print(f"{'Metric':<20} {'xserv':>12} {'transformers':>12} {'ratio':>10}")
|
||||
print(f"{'-'*54}")
|
||||
avg_x_ttft = xserv_ttft_sum / total / 1000
|
||||
avg_h_ttft = hf_ttft_sum / total / 1000
|
||||
avg_x_tbt = xserv_tbt_sum / num_with_tbt / 1000 if num_with_tbt > 0 else 0
|
||||
avg_h_tbt = hf_tbt_sum / num_with_tbt / 1000 if num_with_tbt > 0 else 0
|
||||
print(f"{'TTFT (ms)':<20} {avg_x_ttft:>10.1f}ms {avg_h_ttft:>10.1f}ms {avg_x_ttft/avg_h_ttft:>9.1f}x")
|
||||
print(f"{'TBT (ms)':<20} {avg_x_tbt:>10.1f}ms {avg_h_tbt:>10.1f}ms {avg_x_tbt/avg_h_tbt if avg_h_tbt > 0 else 0:>9.1f}x")
|
||||
xserv_tps = 1000.0 / avg_x_tbt if avg_x_tbt > 0 else 0
|
||||
hf_tps = 1000.0 / avg_h_tbt if avg_h_tbt > 0 else 0
|
||||
print(f"{'Throughput (tok/s)':<20} {xserv_tps:>10.1f} {hf_tps:>10.1f} {xserv_tps/hf_tps if hf_tps > 0 else 0:>9.2f}x")
|
||||
|
||||
print(f"\nNote: xserv currently has no KV cache — full recompute per token.")
|
||||
print(f" transformers also runs without KV cache in this benchmark for fair comparison.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
137
tools/bench_compare_qwen3.py
Normal file
137
tools/bench_compare_qwen3.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Compare xserv Qwen3 output against HuggingFace transformers.
|
||||
Usage: python3 tools/bench_compare_qwen3.py <xserv_results.json> <model_dir>
|
||||
"""
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 3:
|
||||
print(f"Usage: {sys.argv[0]} <xserv_results.json> <model_dir>")
|
||||
sys.exit(1)
|
||||
|
||||
xserv_path = sys.argv[1]
|
||||
model_dir = sys.argv[2]
|
||||
|
||||
with open(xserv_path) as f:
|
||||
xserv_results = json.load(f)
|
||||
|
||||
print(f"Loading transformers model from {model_dir}...")
|
||||
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model.eval()
|
||||
model.cuda()
|
||||
|
||||
# Warmup
|
||||
with torch.no_grad():
|
||||
ids = tokenizer.encode("warmup", return_tensors="pt").cuda()
|
||||
model(ids)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
total = len(xserv_results)
|
||||
match_count = 0
|
||||
mismatch_count = 0
|
||||
xserv_ttft_sum = 0.0
|
||||
xserv_tbt_sum = 0.0
|
||||
hf_ttft_sum = 0.0
|
||||
hf_tbt_sum = 0.0
|
||||
num_with_tbt = 0
|
||||
|
||||
print(f"\n{'='*100}")
|
||||
print(f"{'#':>3} {'Match':>5} {'Prompt':<45} {'xserv TTFT':>10} {'HF TTFT':>10} {'xserv TBT':>10} {'HF TBT':>10}")
|
||||
print(f"{'='*100}")
|
||||
|
||||
for i, xr in enumerate(xserv_results):
|
||||
prompt = xr["prompt"]
|
||||
gen_tokens = xr["num_generated"]
|
||||
xserv_ids = xr["generated_ids"]
|
||||
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda()
|
||||
hf_generated = []
|
||||
hf_token_times = []
|
||||
|
||||
with torch.no_grad():
|
||||
all_ids = input_ids.clone()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
out = model(all_ids)
|
||||
torch.cuda.synchronize()
|
||||
hf_ttft_us = (time.perf_counter() - t0) * 1e6
|
||||
next_id = out.logits[0, -1].argmax().item()
|
||||
hf_generated.append(next_id)
|
||||
all_ids = torch.cat([all_ids, torch.tensor([[next_id]]).cuda()], dim=1)
|
||||
|
||||
for _ in range(1, gen_tokens):
|
||||
torch.cuda.synchronize()
|
||||
t_start = time.perf_counter()
|
||||
out = model(all_ids)
|
||||
torch.cuda.synchronize()
|
||||
elapsed = (time.perf_counter() - t_start) * 1e6
|
||||
hf_token_times.append(elapsed)
|
||||
next_id = out.logits[0, -1].argmax().item()
|
||||
hf_generated.append(next_id)
|
||||
all_ids = torch.cat([all_ids, torch.tensor([[next_id]]).cuda()], dim=1)
|
||||
|
||||
if next_id == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
hf_tbt_us = sum(hf_token_times) / len(hf_token_times) if hf_token_times else 0
|
||||
|
||||
match = xserv_ids == hf_generated
|
||||
if match:
|
||||
match_count += 1
|
||||
status = " OK "
|
||||
else:
|
||||
mismatch_count += 1
|
||||
status = "FAIL!"
|
||||
|
||||
xserv_ttft_ms = xr["ttft_us"] / 1000.0
|
||||
xserv_tbt_ms = xr["tbt_us"] / 1000.0
|
||||
hf_ttft_ms = hf_ttft_us / 1000.0
|
||||
hf_tbt_ms = hf_tbt_us / 1000.0
|
||||
|
||||
prompt_short = prompt[:43] + ".." if len(prompt) > 45 else prompt
|
||||
print(f"{i+1:>3} {status} {prompt_short:<45} {xserv_ttft_ms:>8.1f}ms {hf_ttft_ms:>8.1f}ms {xserv_tbt_ms:>8.1f}ms {hf_tbt_ms:>8.1f}ms")
|
||||
|
||||
if not match:
|
||||
for j in range(max(len(xserv_ids), len(hf_generated))):
|
||||
x = xserv_ids[j] if j < len(xserv_ids) else None
|
||||
h = hf_generated[j] if j < len(hf_generated) else None
|
||||
if x != h:
|
||||
x_tok = tokenizer.decode([x]) if x is not None else "<none>"
|
||||
h_tok = tokenizer.decode([h]) if h is not None else "<none>"
|
||||
print(f" diverge@{j}: xserv={x}({repr(x_tok)}) hf={h}({repr(h_tok)})")
|
||||
break
|
||||
|
||||
xserv_ttft_sum += xr["ttft_us"]
|
||||
xserv_tbt_sum += xr["tbt_us"]
|
||||
hf_ttft_sum += hf_ttft_us
|
||||
hf_tbt_sum += hf_tbt_us
|
||||
if xr["tbt_us"] > 0:
|
||||
num_with_tbt += 1
|
||||
|
||||
print(f"{'='*100}")
|
||||
print(f"\n=== CORRECTNESS ===")
|
||||
print(f"Total: {total}, Match: {match_count}/{total} ({match_count/total*100:.1f}%), Mismatch: {mismatch_count}")
|
||||
|
||||
print(f"\n=== PERFORMANCE ===")
|
||||
print(f"{'Metric':<20} {'xserv':>12} {'transformers':>12} {'ratio':>10}")
|
||||
print(f"{'-'*54}")
|
||||
avg_x_ttft = xserv_ttft_sum / total / 1000
|
||||
avg_h_ttft = hf_ttft_sum / total / 1000
|
||||
avg_x_tbt = xserv_tbt_sum / num_with_tbt / 1000 if num_with_tbt > 0 else 0
|
||||
avg_h_tbt = hf_tbt_sum / num_with_tbt / 1000 if num_with_tbt > 0 else 0
|
||||
print(f"{'TTFT (ms)':<20} {avg_x_ttft:>10.1f}ms {avg_h_ttft:>10.1f}ms {avg_x_ttft/avg_h_ttft if avg_h_ttft>0 else 0:>9.1f}x")
|
||||
print(f"{'TBT (ms)':<20} {avg_x_tbt:>10.1f}ms {avg_h_tbt:>10.1f}ms {avg_x_tbt/avg_h_tbt if avg_h_tbt>0 else 0:>9.1f}x")
|
||||
xserv_tps = 1000.0 / avg_x_tbt if avg_x_tbt > 0 else 0
|
||||
hf_tps = 1000.0 / avg_h_tbt if avg_h_tbt > 0 else 0
|
||||
print(f"{'Throughput (tok/s)':<20} {xserv_tps:>10.1f} {hf_tps:>10.1f} {xserv_tps/hf_tps if hf_tps>0 else 0:>9.2f}x")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user