Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e1e75fc7f6 | |||
| 6035ffdc0b |
@@ -4,6 +4,8 @@ members = [
|
||||
"crates/xserv-cuda",
|
||||
"crates/xserv-tensor",
|
||||
"crates/xserv-kernels",
|
||||
"crates/xserv-model",
|
||||
"crates/xserv-tokenizer",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -14,3 +16,7 @@ license = "MIT"
|
||||
[workspace.dependencies]
|
||||
half = "2"
|
||||
smallvec = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
safetensors = "0.5"
|
||||
regex = "1"
|
||||
|
||||
@@ -22,6 +22,7 @@ fn main() {
|
||||
.file("../../csrc/reduce/softmax.cu")
|
||||
.file("../../csrc/embedding/embedding.cu")
|
||||
.file("../../csrc/embedding/rope.cu")
|
||||
.file("../../csrc/attention/causal_mask.cu")
|
||||
.compile("xserv_kernels");
|
||||
|
||||
println!("cargo:rerun-if-changed=../../csrc/");
|
||||
|
||||
@@ -6,6 +6,8 @@ unsafe extern "C" {
|
||||
fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_scale_f32(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
|
||||
fn launch_scale_bf16(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
pub fn gelu(x: &Tensor) -> Tensor {
|
||||
@@ -39,3 +41,19 @@ pub fn silu(x: &Tensor) -> Tensor {
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
||||
assert!(x.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel() as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
||||
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
||||
_ => panic!("unsupported dtype for scale"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
out
|
||||
}
|
||||
|
||||
77
crates/xserv-kernels/src/attention.rs
Normal file
77
crates/xserv-kernels/src/attention.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_tensor::{DType, Tensor};
|
||||
|
||||
use crate::activation::scale;
|
||||
use crate::gemm::batched_matmul;
|
||||
use crate::softmax::softmax;
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_causal_mask_f32(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
|
||||
offset: i32, stream: *mut c_void);
|
||||
fn launch_causal_mask_bf16(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
|
||||
offset: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
||||
let ndim = scores.ndim();
|
||||
let rows = scores.shape()[ndim - 2];
|
||||
let cols = scores.shape()[ndim - 1];
|
||||
let batch: usize = scores.shape()[..ndim - 2].iter().product();
|
||||
|
||||
unsafe {
|
||||
match scores.dtype() {
|
||||
DType::F32 => launch_causal_mask_f32(
|
||||
scores.data_ptr() as *mut c_void,
|
||||
batch as i32, rows as i32, cols as i32, offset as i32,
|
||||
std::ptr::null_mut(),
|
||||
),
|
||||
DType::BF16 => launch_causal_mask_bf16(
|
||||
scores.data_ptr() as *mut c_void,
|
||||
batch as i32, rows as i32, cols as i32, offset as i32,
|
||||
std::ptr::null_mut(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for causal mask"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
|
||||
/// Multi-head attention (naive, materializes S×S score matrix).
|
||||
///
|
||||
/// q, k, v: [batch, num_heads, seq_len, head_dim] — contiguous, on GPU
|
||||
/// Returns: [batch, num_heads, seq_len, head_dim]
|
||||
pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(k.ndim(), 4);
|
||||
assert_eq!(v.ndim(), 4);
|
||||
assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous());
|
||||
|
||||
let batch = q.shape()[0];
|
||||
let num_heads = q.shape()[1];
|
||||
let q_len = q.shape()[2];
|
||||
let head_dim = q.shape()[3];
|
||||
let kv_len = k.shape()[2];
|
||||
|
||||
assert_eq!(k.shape(), &[batch, num_heads, kv_len, head_dim]);
|
||||
assert_eq!(v.shape(), &[batch, num_heads, kv_len, head_dim]);
|
||||
|
||||
// scores = Q @ K^T → [B, H, q_len, kv_len]
|
||||
let k_t = k.transpose(2, 3).contiguous();
|
||||
let scores = batched_matmul(q, &k_t);
|
||||
|
||||
// Scale by 1/sqrt(head_dim)
|
||||
let scale_factor = 1.0 / (head_dim as f32).sqrt();
|
||||
let scaled_scores = scale(&scores, scale_factor);
|
||||
|
||||
// Causal mask
|
||||
if causal {
|
||||
let offset = kv_len - q_len;
|
||||
apply_causal_mask(&scaled_scores, offset);
|
||||
}
|
||||
|
||||
// Softmax
|
||||
let weights = softmax(&scaled_scores);
|
||||
|
||||
// output = weights @ V → [B, H, q_len, head_dim]
|
||||
batched_matmul(&weights, v)
|
||||
}
|
||||
@@ -46,6 +46,19 @@ unsafe extern "C" {
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
) -> i32;
|
||||
fn cublasGemmStridedBatchedEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32, transb: i32,
|
||||
m: i32, n: i32, k: i32,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void, a_type: i32, lda: i32, stride_a: i64,
|
||||
b: *const c_void, b_type: i32, ldb: i32, stride_b: i64,
|
||||
beta: *const c_void,
|
||||
c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64,
|
||||
batch_count: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
) -> i32;
|
||||
}
|
||||
|
||||
pub struct CublasContext {
|
||||
@@ -149,3 +162,68 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
|
||||
c
|
||||
}
|
||||
|
||||
/// Batched matrix multiplication via cuBLAS: C[b] = A[b] @ B[b]
|
||||
/// a: [..., M, K], b: [..., K, N] → [..., M, N]
|
||||
/// Leading dimensions must match and tensors must be contiguous.
|
||||
pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert!(a.ndim() >= 2 && b.ndim() >= 2);
|
||||
assert_eq!(a.ndim(), b.ndim());
|
||||
assert!(a.is_contiguous() && b.is_contiguous());
|
||||
assert!(matches!(a.device(), Device::Cuda(_)));
|
||||
assert_eq!(a.dtype(), b.dtype());
|
||||
|
||||
let ndim = a.ndim();
|
||||
let m = a.shape()[ndim - 2];
|
||||
let k = a.shape()[ndim - 1];
|
||||
let n = b.shape()[ndim - 1];
|
||||
assert_eq!(b.shape()[ndim - 2], k, "inner dimension mismatch");
|
||||
|
||||
// Compute batch count from leading dimensions
|
||||
let batch: usize = a.shape()[..ndim - 2].iter().product();
|
||||
assert_eq!(
|
||||
b.shape()[..ndim - 2].iter().product::<usize>(),
|
||||
batch,
|
||||
"batch dimensions mismatch"
|
||||
);
|
||||
|
||||
let mut out_shape: Vec<usize> = a.shape()[..ndim - 2].to_vec();
|
||||
out_shape.push(m);
|
||||
out_shape.push(n);
|
||||
let c = Tensor::zeros(&out_shape, a.dtype(), a.device());
|
||||
|
||||
let dtype = a.dtype();
|
||||
let (a_type, b_type, c_type) = match dtype {
|
||||
DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F),
|
||||
DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF),
|
||||
_ => panic!("unsupported dtype for batched matmul"),
|
||||
};
|
||||
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
// cuBLAS strides are in elements (not bytes)
|
||||
let stride_a = (m * k) as i64;
|
||||
let stride_b = (k * n) as i64;
|
||||
let stride_c = (m * n) as i64;
|
||||
|
||||
let ctx = CublasContext::new().unwrap();
|
||||
unsafe {
|
||||
cublasSetStream_v2(ctx.handle, std::ptr::null_mut());
|
||||
// Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major)
|
||||
error::check(cublasGemmStridedBatchedEx(
|
||||
ctx.handle,
|
||||
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
n as i32, m as i32, k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b.data_ptr() as _, b_type, n as i32, stride_b,
|
||||
a.data_ptr() as _, a_type, k as i32, stride_a,
|
||||
&beta as *const f32 as *const c_void,
|
||||
c.data_ptr() as *mut c_void, c_type, n as i32, stride_c,
|
||||
batch as i32,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
-1,
|
||||
)).expect("cuBLAS batched GEMM failed");
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
c
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod activation;
|
||||
pub mod attention;
|
||||
pub mod embedding;
|
||||
pub mod gemm;
|
||||
pub mod layernorm;
|
||||
@@ -6,9 +7,10 @@ pub mod rmsnorm;
|
||||
pub mod rope;
|
||||
pub mod softmax;
|
||||
|
||||
pub use activation::{gelu, silu};
|
||||
pub use activation::{gelu, scale, silu};
|
||||
pub use attention::attention;
|
||||
pub use embedding::embedding;
|
||||
pub use gemm::{matmul, GemmBackend};
|
||||
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
||||
pub use layernorm::layernorm;
|
||||
pub use rmsnorm::rmsnorm;
|
||||
pub use rope::{rope_inplace, RopeCache};
|
||||
|
||||
187
crates/xserv-kernels/tests/attention_test.rs
Normal file
187
crates/xserv-kernels/tests/attention_test.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
|
||||
|
||||
fn cpu_attention(q: &[f32], k: &[f32], v: &[f32],
|
||||
batch: usize, heads: usize, q_len: usize, kv_len: usize, head_dim: usize,
|
||||
causal: bool) -> Vec<f32> {
|
||||
let mut out = vec![0.0f32; batch * heads * q_len * head_dim];
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
for b in 0..batch {
|
||||
for h in 0..heads {
|
||||
// scores = Q @ K^T, scaled
|
||||
let mut scores = vec![0.0f32; q_len * kv_len];
|
||||
for i in 0..q_len {
|
||||
for j in 0..kv_len {
|
||||
let mut s = 0.0f32;
|
||||
for d in 0..head_dim {
|
||||
let qi = q[((b * heads + h) * q_len + i) * head_dim + d];
|
||||
let ki = k[((b * heads + h) * kv_len + j) * head_dim + d];
|
||||
s += qi * ki;
|
||||
}
|
||||
scores[i * kv_len + j] = s * scale;
|
||||
}
|
||||
}
|
||||
// causal mask
|
||||
if causal {
|
||||
let offset = kv_len - q_len;
|
||||
for i in 0..q_len {
|
||||
for j in 0..kv_len {
|
||||
if j > i + offset {
|
||||
scores[i * kv_len + j] = f32::NEG_INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// softmax per row
|
||||
for i in 0..q_len {
|
||||
let row = &mut scores[i * kv_len..(i + 1) * kv_len];
|
||||
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut sum = 0.0f32;
|
||||
for v in row.iter_mut() {
|
||||
*v = (*v - max).exp();
|
||||
sum += *v;
|
||||
}
|
||||
for v in row.iter_mut() {
|
||||
*v /= sum;
|
||||
}
|
||||
}
|
||||
// output = weights @ V
|
||||
for i in 0..q_len {
|
||||
for d in 0..head_dim {
|
||||
let mut s = 0.0f32;
|
||||
for j in 0..kv_len {
|
||||
let w = scores[i * kv_len + j];
|
||||
let vi = v[((b * heads + h) * kv_len + j) * head_dim + d];
|
||||
s += w * vi;
|
||||
}
|
||||
out[((b * heads + h) * q_len + i) * head_dim + d] = s;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn check_close(a: &[f32], b: &[f32], atol: f32, name: &str) {
|
||||
assert_eq!(a.len(), b.len(), "{name}: length mismatch");
|
||||
let mut max_err = 0.0f32;
|
||||
for (i, (x, y)) in a.iter().zip(b).enumerate() {
|
||||
let err = (x - y).abs();
|
||||
if err > max_err { max_err = err; }
|
||||
assert!(err <= atol, "{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}");
|
||||
}
|
||||
println!("{name}: max_err = {max_err:.6e}");
|
||||
}
|
||||
|
||||
fn make_data(n: usize) -> Vec<f32> {
|
||||
(0..n).map(|i| ((i % 17) as f32 - 8.0) * 0.05).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batched_matmul() {
|
||||
init();
|
||||
let batch = 4;
|
||||
let heads = 8;
|
||||
let m = 32;
|
||||
let k = 64;
|
||||
let n = 32;
|
||||
|
||||
let a_data = make_data(batch * heads * m * k);
|
||||
let b_data = make_data(batch * heads * k * n);
|
||||
|
||||
let a = Tensor::from_slice(&a_data, &[batch, heads, m, k]).to_device(Device::Cuda(0));
|
||||
let b = Tensor::from_slice(&b_data, &[batch, heads, k, n]).to_device(Device::Cuda(0));
|
||||
let c = batched_matmul(&a, &b).to_device(Device::Cpu);
|
||||
|
||||
assert_eq!(c.shape(), &[batch, heads, m, n]);
|
||||
|
||||
// Verify one batch element
|
||||
let a_cpu = &a_data[0..m * k];
|
||||
let b_cpu = &b_data[0..k * n];
|
||||
let mut expected = vec![0.0f32; m * n];
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
let mut s = 0.0f32;
|
||||
for kk in 0..k { s += a_cpu[i * k + kk] * b_cpu[kk * n + j]; }
|
||||
expected[i * n + j] = s;
|
||||
}
|
||||
}
|
||||
let result = c.as_slice::<f32>();
|
||||
check_close(&result[0..m * n], &expected, 1e-3, "batched_matmul[0]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_no_causal() {
|
||||
init();
|
||||
let b = 1; let h = 2; let s = 8; let d = 16;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, false);
|
||||
|
||||
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, false).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-4, "attention_no_causal");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_causal() {
|
||||
init();
|
||||
let b = 1; let h = 2; let s = 16; let d = 32;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true);
|
||||
|
||||
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-3, "attention_causal");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_causal_larger() {
|
||||
init();
|
||||
let b = 2; let h = 4; let s = 64; let d = 64;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data = make_data(b * h * s * d);
|
||||
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true);
|
||||
|
||||
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
||||
check_close(out.as_slice::<f32>(), &expected, 1e-2, "attention_causal_larger");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_causal_first_row_sees_only_first_token() {
|
||||
init();
|
||||
let b = 1; let h = 1; let s = 4; let d = 8;
|
||||
let q_data = make_data(b * h * s * d);
|
||||
let k_data = make_data(b * h * s * d);
|
||||
let v_data: Vec<f32> = (0..s * d).map(|i| {
|
||||
if i < d { 1.0 } else { 0.0 } // only first V row is nonzero
|
||||
}).collect();
|
||||
|
||||
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
||||
|
||||
// First row (position 0) with causal mask can only see position 0.
|
||||
// So attention weight for position 0 is 1.0 for token 0 only.
|
||||
// output[0] should be exactly V[0] = [1, 1, 1, ...1]
|
||||
let result = out.as_slice::<f32>();
|
||||
for i in 0..d {
|
||||
assert!((result[i] - 1.0).abs() < 1e-5,
|
||||
"first row should equal V[0], got {} at dim {}", result[i], i);
|
||||
}
|
||||
}
|
||||
14
crates/xserv-model/Cargo.toml
Normal file
14
crates/xserv-model/Cargo.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "xserv-model"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
xserv-cuda = { path = "../xserv-cuda" }
|
||||
xserv-tensor = { path = "../xserv-tensor" }
|
||||
xserv-kernels = { path = "../xserv-kernels" }
|
||||
xserv-tokenizer = { path = "../xserv-tokenizer" }
|
||||
half.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
safetensors.workspace = true
|
||||
78
crates/xserv-model/src/bin/xserv-cli.rs
Normal file
78
crates/xserv-model/src/bin/xserv-cli.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use std::io::{self, Write};
|
||||
use std::path::PathBuf;
|
||||
use xserv_model::{GPT2, ModelConfig};
|
||||
use xserv_model::loader;
|
||||
use xserv_model::gpt2::sample_greedy;
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
use xserv_tensor::Device;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: xserv-cli <model-dir> [--max-tokens N]");
|
||||
eprintln!(" model-dir: path to HF model directory (containing model.safetensors, config.json, tokenizer.json)");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let model_dir = PathBuf::from(&args[1]);
|
||||
let max_tokens: usize = args.iter()
|
||||
.position(|a| a == "--max-tokens")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(100);
|
||||
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
let info = xserv_cuda::device::device_info(0).unwrap();
|
||||
eprintln!("GPU: {} ({} MB free)", info.name, info.free_memory / 1024 / 1024);
|
||||
|
||||
// Load config
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
eprintln!("Model: {:?}, layers={}, hidden={}, heads={}, vocab={}",
|
||||
config.model_type, config.num_layers(), config.hidden(),
|
||||
config.num_heads(), config.vocab_size);
|
||||
|
||||
// Load weights
|
||||
eprintln!("Loading weights...");
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cuda(0));
|
||||
eprintln!("Loaded {} tensors", weights.len());
|
||||
|
||||
// GPT-2 uses weight names without "model." prefix
|
||||
let model = GPT2::from_weights(config, weights);
|
||||
|
||||
// Load tokenizer
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
eprintln!("Tokenizer loaded (vocab_size={})", tokenizer.vocab_size());
|
||||
eprintln!("Ready.\n");
|
||||
|
||||
// Interactive loop
|
||||
loop {
|
||||
print!("xserv> ");
|
||||
io::stdout().flush().unwrap();
|
||||
let mut input = String::new();
|
||||
if io::stdin().read_line(&mut input).unwrap() == 0 {
|
||||
break;
|
||||
}
|
||||
let input = input.trim();
|
||||
if input.is_empty() { continue; }
|
||||
if input == "quit" || input == "exit" { break; }
|
||||
|
||||
let mut token_ids = tokenizer.encode(input);
|
||||
print!("{input}");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
for _ in 0..max_tokens {
|
||||
let logits = model.forward(&token_ids);
|
||||
let next = sample_greedy(&logits);
|
||||
token_ids.push(next);
|
||||
|
||||
let text = tokenizer.decode(&[next]);
|
||||
print!("{text}");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
if tokenizer.eos_token_id() == Some(next) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
96
crates/xserv-model/src/config.rs
Normal file
96
crates/xserv-model/src/config.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
use serde::Deserialize;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelConfig {
|
||||
pub architectures: Option<Vec<String>>,
|
||||
pub model_type: Option<String>,
|
||||
|
||||
// Modern HF naming
|
||||
#[serde(default)]
|
||||
pub hidden_size: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub intermediate_size: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub num_attention_heads: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub num_key_value_heads: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub num_hidden_layers: Option<usize>,
|
||||
pub vocab_size: usize,
|
||||
#[serde(default)]
|
||||
pub max_position_embeddings: Option<usize>,
|
||||
|
||||
// GPT-2 naming
|
||||
#[serde(default)]
|
||||
pub n_embd: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub n_head: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub n_layer: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub n_positions: Option<usize>,
|
||||
#[serde(default)]
|
||||
pub n_inner: Option<usize>,
|
||||
|
||||
// Normalization
|
||||
#[serde(default)]
|
||||
pub layer_norm_eps: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub layer_norm_epsilon: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub rms_norm_eps: Option<f64>,
|
||||
|
||||
// Other
|
||||
#[serde(default)]
|
||||
pub rope_theta: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub tie_word_embeddings: Option<bool>,
|
||||
}
|
||||
|
||||
impl ModelConfig {
|
||||
pub fn from_file(path: &Path) -> Self {
|
||||
let data = std::fs::read_to_string(path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
|
||||
serde_json::from_str(&data)
|
||||
.unwrap_or_else(|e| panic!("failed to parse {}: {e}", path.display()))
|
||||
}
|
||||
|
||||
pub fn hidden(&self) -> usize {
|
||||
self.hidden_size.or(self.n_embd).expect("hidden_size or n_embd required")
|
||||
}
|
||||
|
||||
pub fn num_heads(&self) -> usize {
|
||||
self.num_attention_heads.or(self.n_head).expect("num_attention_heads or n_head required")
|
||||
}
|
||||
|
||||
pub fn num_layers(&self) -> usize {
|
||||
self.num_hidden_layers.or(self.n_layer).expect("num_hidden_layers or n_layer required")
|
||||
}
|
||||
|
||||
pub fn max_seq_len(&self) -> usize {
|
||||
self.max_position_embeddings.or(self.n_positions).unwrap_or(2048)
|
||||
}
|
||||
|
||||
pub fn ffn_hidden(&self) -> usize {
|
||||
self.intermediate_size.or(self.n_inner).unwrap_or(self.hidden() * 4)
|
||||
}
|
||||
|
||||
pub fn num_kv_heads(&self) -> usize {
|
||||
self.num_key_value_heads.unwrap_or(self.num_heads())
|
||||
}
|
||||
|
||||
pub fn head_dim(&self) -> usize {
|
||||
self.hidden() / self.num_heads()
|
||||
}
|
||||
|
||||
pub fn ln_eps(&self) -> f32 {
|
||||
self.layer_norm_eps
|
||||
.or(self.layer_norm_epsilon)
|
||||
.unwrap_or(1e-5) as f32
|
||||
}
|
||||
|
||||
pub fn tied_embeddings(&self) -> bool {
|
||||
self.tie_word_embeddings.unwrap_or(true)
|
||||
}
|
||||
}
|
||||
224
crates/xserv-model/src/gpt2.rs
Normal file
224
crates/xserv-model/src/gpt2.rs
Normal file
@@ -0,0 +1,224 @@
|
||||
use std::collections::HashMap;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
use crate::config::ModelConfig;
|
||||
|
||||
pub struct GPT2 {
|
||||
pub config: ModelConfig,
|
||||
wte: Tensor, // [vocab_size, hidden]
|
||||
wpe: Tensor, // [max_pos, hidden]
|
||||
layers: Vec<GPT2Block>,
|
||||
ln_f_g: Tensor, // [hidden]
|
||||
ln_f_b: Tensor, // [hidden]
|
||||
}
|
||||
|
||||
struct GPT2Block {
|
||||
ln_1_g: Tensor,
|
||||
ln_1_b: Tensor,
|
||||
// Attention: combined QKV weight + bias, output weight + bias
|
||||
attn_qkv_w: Tensor, // [hidden, 3*hidden]
|
||||
attn_qkv_b: Tensor, // [3*hidden]
|
||||
attn_out_w: Tensor, // [hidden, hidden]
|
||||
attn_out_b: Tensor, // [hidden]
|
||||
ln_2_g: Tensor,
|
||||
ln_2_b: Tensor,
|
||||
mlp_fc_w: Tensor, // [hidden, 4*hidden]
|
||||
mlp_fc_b: Tensor, // [4*hidden]
|
||||
mlp_proj_w: Tensor, // [4*hidden, hidden]
|
||||
mlp_proj_b: Tensor, // [hidden]
|
||||
}
|
||||
|
||||
impl GPT2 {
|
||||
pub fn from_weights(config: ModelConfig, mut w: HashMap<String, Tensor>) -> Self {
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
|
||||
let wte = take(&mut w, "wte.weight");
|
||||
let wpe = take(&mut w, "wpe.weight");
|
||||
let ln_f_g = take(&mut w, "ln_f.weight");
|
||||
let ln_f_b = take(&mut w, "ln_f.bias");
|
||||
|
||||
let num_layers = config.num_layers();
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
for i in 0..num_layers {
|
||||
let p = format!("h.{i}");
|
||||
layers.push(GPT2Block {
|
||||
ln_1_g: take(&mut w, &format!("{p}.ln_1.weight")),
|
||||
ln_1_b: take(&mut w, &format!("{p}.ln_1.bias")),
|
||||
attn_qkv_w: take(&mut w, &format!("{p}.attn.c_attn.weight")),
|
||||
attn_qkv_b: take(&mut w, &format!("{p}.attn.c_attn.bias")),
|
||||
attn_out_w: take(&mut w, &format!("{p}.attn.c_proj.weight")),
|
||||
attn_out_b: take(&mut w, &format!("{p}.attn.c_proj.bias")),
|
||||
ln_2_g: take(&mut w, &format!("{p}.ln_2.weight")),
|
||||
ln_2_b: take(&mut w, &format!("{p}.ln_2.bias")),
|
||||
mlp_fc_w: take(&mut w, &format!("{p}.mlp.c_fc.weight")),
|
||||
mlp_fc_b: take(&mut w, &format!("{p}.mlp.c_fc.bias")),
|
||||
mlp_proj_w: take(&mut w, &format!("{p}.mlp.c_proj.weight")),
|
||||
mlp_proj_b: take(&mut w, &format!("{p}.mlp.c_proj.bias")),
|
||||
});
|
||||
}
|
||||
|
||||
Self { config, wte, wpe, layers, ln_f_g, ln_f_b }
|
||||
}
|
||||
|
||||
/// Full forward pass, returns logits [seq_len, vocab_size].
|
||||
pub fn forward(&self, token_ids: &[u32]) -> Tensor {
|
||||
let seq_len = token_ids.len();
|
||||
let hidden = self.config.hidden();
|
||||
let num_heads = self.config.num_heads();
|
||||
let head_dim = self.config.head_dim();
|
||||
|
||||
// Token + position embedding
|
||||
let tok_emb = embedding(&self.wte, token_ids);
|
||||
let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
|
||||
let pos_emb = embedding(&self.wpe, &pos_ids);
|
||||
let mut x = add_tensors(&tok_emb, &pos_emb);
|
||||
|
||||
// Transformer layers
|
||||
for layer in &self.layers {
|
||||
// Pre-LN attention
|
||||
let residual = x.clone();
|
||||
let normed = layernorm(&x, &layer.ln_1_g, &layer.ln_1_b, self.config.ln_eps());
|
||||
|
||||
// QKV projection: [S, H] @ [H, 3H] + [3H] → [S, 3H]
|
||||
let qkv = linear(&normed, &layer.attn_qkv_w, Some(&layer.attn_qkv_b));
|
||||
// Split into Q, K, V and reshape for multi-head
|
||||
let (q, k, v) = split_qkv(&qkv, num_heads, head_dim, seq_len);
|
||||
// Attention: [1, H, S, D]
|
||||
let attn_out = attention(&q, &k, &v, true);
|
||||
// Merge heads: [1, H, S, D] → [S, hidden]
|
||||
let attn_out = merge_heads(&attn_out, seq_len, hidden);
|
||||
// Output projection
|
||||
let attn_out = linear(&attn_out, &layer.attn_out_w, Some(&layer.attn_out_b));
|
||||
x = add_tensors(&residual, &attn_out);
|
||||
|
||||
// Pre-LN MLP
|
||||
let residual = x.clone();
|
||||
let normed = layernorm(&x, &layer.ln_2_g, &layer.ln_2_b, self.config.ln_eps());
|
||||
let fc = linear(&normed, &layer.mlp_fc_w, Some(&layer.mlp_fc_b));
|
||||
let activated = gelu(&fc);
|
||||
let proj = linear(&activated, &layer.mlp_proj_w, Some(&layer.mlp_proj_b));
|
||||
x = add_tensors(&residual, &proj);
|
||||
}
|
||||
|
||||
// Final layer norm
|
||||
let x = layernorm(&x, &self.ln_f_g, &self.ln_f_b, self.config.ln_eps());
|
||||
|
||||
// LM head (tied with wte): [S, H] @ [H, V] → [S, V]
|
||||
// wte is [V, H], so we need wte^T
|
||||
let lm_head = self.wte.transpose(0, 1).contiguous();
|
||||
matmul_2d(&x, &lm_head)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper ops ---
|
||||
|
||||
fn linear(x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Tensor {
|
||||
// GPT-2 stores weights as [in, out] (not transposed), so x @ w
|
||||
let out = matmul_2d(x, weight);
|
||||
if let Some(b) = bias {
|
||||
add_bias(&out, b)
|
||||
} else {
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
// a: [S, K], b: [K, N] → [S, N]
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert_eq!(b.ndim(), 2);
|
||||
matmul(a, b, GemmBackend::CuBlas)
|
||||
}
|
||||
|
||||
fn add_tensors(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
// Element-wise add on GPU via a simple approach: scale(a, 1.0) + scale(b, 1.0)
|
||||
// TODO: proper add kernel. For now, go through CPU.
|
||||
assert_eq!(a.shape(), b.shape());
|
||||
assert_eq!(a.dtype(), DType::F32);
|
||||
let a_cpu = a.to_device(Device::Cpu);
|
||||
let b_cpu = b.to_device(Device::Cpu);
|
||||
let a_data = a_cpu.as_slice::<f32>();
|
||||
let b_data = b_cpu.as_slice::<f32>();
|
||||
let sum: Vec<f32> = a_data.iter().zip(b_data).map(|(x, y)| x + y).collect();
|
||||
Tensor::from_slice(&sum, a.shape()).to_device(a.device())
|
||||
}
|
||||
|
||||
fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
|
||||
// x: [S, N], bias: [N] → broadcast add
|
||||
assert_eq!(x.ndim(), 2);
|
||||
assert_eq!(bias.ndim(), 1);
|
||||
assert_eq!(x.shape()[1], bias.shape()[0]);
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let b_cpu = bias.to_device(Device::Cpu);
|
||||
let x_data = x_cpu.as_slice::<f32>();
|
||||
let b_data = b_cpu.as_slice::<f32>();
|
||||
let n = bias.shape()[0];
|
||||
let result: Vec<f32> = x_data.iter().enumerate().map(|(i, &v)| v + b_data[i % n]).collect();
|
||||
Tensor::from_slice(&result, x.shape()).to_device(x.device())
|
||||
}
|
||||
|
||||
fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) -> (Tensor, Tensor, Tensor) {
|
||||
// qkv: [S, 3*H] → Q, K, V each [1, num_heads, S, head_dim]
|
||||
let hidden = num_heads * head_dim;
|
||||
let qkv_cpu = qkv.to_device(Device::Cpu);
|
||||
let data = qkv_cpu.as_slice::<f32>();
|
||||
|
||||
// Split into Q, K, V and directly write in [1, num_heads, S, head_dim] layout
|
||||
let mut q_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
let mut k_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
let mut v_data = vec![0.0f32; num_heads * seq_len * head_dim];
|
||||
|
||||
for s in 0..seq_len {
|
||||
let row = &data[s * 3 * hidden..(s + 1) * 3 * hidden];
|
||||
for h in 0..num_heads {
|
||||
let src_off = h * head_dim;
|
||||
let dst_off = (h * seq_len + s) * head_dim;
|
||||
q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]);
|
||||
k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]);
|
||||
v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
|
||||
let device = qkv.device();
|
||||
let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device);
|
||||
(q, k, v)
|
||||
}
|
||||
|
||||
fn merge_heads(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor {
|
||||
// [1, num_heads, S, head_dim] → [S, hidden]
|
||||
let num_heads = x.shape()[1];
|
||||
let head_dim = x.shape()[3];
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let src = x_cpu.as_slice::<f32>();
|
||||
|
||||
// src layout: [1][num_heads][seq_len][head_dim]
|
||||
// dst layout: [seq_len][hidden] where hidden = num_heads * head_dim
|
||||
let mut out = vec![0.0f32; seq_len * hidden];
|
||||
for s in 0..seq_len {
|
||||
for h in 0..num_heads {
|
||||
let src_off = (h * seq_len + s) * head_dim;
|
||||
let dst_off = s * hidden + h * head_dim;
|
||||
out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(x.device())
|
||||
}
|
||||
|
||||
/// Greedy sampling: return the argmax token ID from the last position's logits.
|
||||
pub fn sample_greedy(logits: &Tensor) -> u32 {
|
||||
assert_eq!(logits.ndim(), 2); // [S, V]
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let data = logits_cpu.as_slice::<f32>();
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
last_row.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.map(|(idx, _)| idx as u32)
|
||||
.unwrap()
|
||||
}
|
||||
6
crates/xserv-model/src/lib.rs
Normal file
6
crates/xserv-model/src/lib.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod config;
|
||||
pub mod gpt2;
|
||||
pub mod loader;
|
||||
|
||||
pub use config::ModelConfig;
|
||||
pub use gpt2::GPT2;
|
||||
87
crates/xserv-model/src/loader.rs
Normal file
87
crates/xserv-model/src/loader.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use half::{bf16, f16};
|
||||
use safetensors::SafeTensors;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
pub fn load_safetensors(path: &Path, device: Device) -> HashMap<String, Tensor> {
|
||||
let data = std::fs::read(path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
|
||||
let st = SafeTensors::deserialize(&data)
|
||||
.unwrap_or_else(|e| panic!("failed to parse safetensors {}: {e}", path.display()));
|
||||
|
||||
let mut tensors = HashMap::new();
|
||||
|
||||
for (name, view) in st.tensors() {
|
||||
let shape: Vec<usize> = view.shape().to_vec();
|
||||
let raw_bytes = view.data();
|
||||
let dtype = match view.dtype() {
|
||||
safetensors::Dtype::F32 => DType::F32,
|
||||
safetensors::Dtype::F16 => DType::F16,
|
||||
safetensors::Dtype::BF16 => DType::BF16,
|
||||
other => {
|
||||
eprintln!("skipping tensor {name}: unsupported dtype {other:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let tensor = make_tensor(raw_bytes, &shape, dtype);
|
||||
let tensor = tensor.to_device(device);
|
||||
tensors.insert(name.to_string(), tensor);
|
||||
}
|
||||
|
||||
tensors
|
||||
}
|
||||
|
||||
/// Load from a directory containing model.safetensors (or sharded files) + config.json.
|
||||
pub fn load_model_dir(dir: &Path, device: Device) -> HashMap<String, Tensor> {
|
||||
let single = dir.join("model.safetensors");
|
||||
if single.exists() {
|
||||
return load_safetensors(&single, device);
|
||||
}
|
||||
|
||||
// Try sharded: model-00001-of-NNNNN.safetensors
|
||||
let mut all_tensors = HashMap::new();
|
||||
let mut entries: Vec<_> = std::fs::read_dir(dir)
|
||||
.unwrap()
|
||||
.filter_map(|e| e.ok())
|
||||
.filter(|e| {
|
||||
e.path()
|
||||
.file_name()
|
||||
.map(|f| f.to_string_lossy().ends_with(".safetensors"))
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect();
|
||||
entries.sort_by_key(|e| e.file_name());
|
||||
|
||||
for entry in entries {
|
||||
let tensors = load_safetensors(&entry.path(), device);
|
||||
all_tensors.extend(tensors);
|
||||
}
|
||||
|
||||
assert!(!all_tensors.is_empty(), "no safetensors files found in {}", dir.display());
|
||||
all_tensors
|
||||
}
|
||||
|
||||
fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
let floats: &[f32] = unsafe {
|
||||
std::slice::from_raw_parts(raw_bytes.as_ptr() as *const f32, raw_bytes.len() / 4)
|
||||
};
|
||||
Tensor::from_slice(floats, shape)
|
||||
}
|
||||
DType::F16 => {
|
||||
let halfs: &[f16] = unsafe {
|
||||
std::slice::from_raw_parts(raw_bytes.as_ptr() as *const f16, raw_bytes.len() / 2)
|
||||
};
|
||||
Tensor::from_slice(halfs, shape)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let bfs: &[bf16] = unsafe {
|
||||
std::slice::from_raw_parts(raw_bytes.as_ptr() as *const bf16, raw_bytes.len() / 2)
|
||||
};
|
||||
Tensor::from_slice(bfs, shape)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -137,8 +137,13 @@ impl Tensor {
|
||||
if self.is_contiguous() {
|
||||
return self.clone();
|
||||
}
|
||||
// Copy to contiguous layout on CPU
|
||||
assert_eq!(self.device(), Device::Cpu, "contiguous() on GPU not yet supported");
|
||||
// For GPU tensors: round-trip through CPU (correct but slow).
|
||||
// TODO: write a GPU contiguous-copy kernel for performance.
|
||||
if matches!(self.device(), Device::Cuda(_)) {
|
||||
let cpu = self.to_device(Device::Cpu);
|
||||
let contig = cpu.contiguous();
|
||||
return contig.to_device(self.device());
|
||||
}
|
||||
let numel = self.numel();
|
||||
let elem_size = self.dtype.size_bytes();
|
||||
let src_bytes = self.storage.as_cpu_bytes();
|
||||
@@ -173,17 +178,18 @@ impl Tensor {
|
||||
// --- Device transfer ---
|
||||
|
||||
pub fn to_device(&self, device: Device) -> Self {
|
||||
let t = if self.is_contiguous() { self.clone() } else { self.contiguous() };
|
||||
if t.device() == device {
|
||||
return t;
|
||||
if self.device() == device {
|
||||
return self.clone();
|
||||
}
|
||||
let new_storage = t.storage.to_device(device).expect("device transfer failed");
|
||||
// Transfer the raw storage (preserving strides/offset).
|
||||
// Non-contiguous layout is preserved — the user can call contiguous() after.
|
||||
let new_storage = self.storage.to_device(device).expect("device transfer failed");
|
||||
Self {
|
||||
storage: new_storage,
|
||||
shape: t.shape,
|
||||
strides: t.strides,
|
||||
offset: 0,
|
||||
dtype: t.dtype,
|
||||
shape: self.shape.clone(),
|
||||
strides: self.strides.clone(),
|
||||
offset: self.offset,
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
9
crates/xserv-tokenizer/Cargo.toml
Normal file
9
crates/xserv-tokenizer/Cargo.toml
Normal file
@@ -0,0 +1,9 @@
|
||||
[package]
|
||||
name = "xserv-tokenizer"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
regex.workspace = true
|
||||
251
crates/xserv-tokenizer/src/bpe.rs
Normal file
251
crates/xserv-tokenizer/src/bpe.rs
Normal file
@@ -0,0 +1,251 @@
|
||||
use regex::Regex;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
pub struct Tokenizer {
|
||||
encoder: HashMap<Vec<u8>, u32>,
|
||||
decoder: Vec<Vec<u8>>,
|
||||
merge_ranks: HashMap<(u32, u32), usize>,
|
||||
special_tokens: HashMap<String, u32>,
|
||||
special_token_ids: HashMap<u32, String>,
|
||||
pre_tokenize_re: Regex,
|
||||
eos_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TokenizerJson {
|
||||
model: ModelSection,
|
||||
#[serde(default)]
|
||||
added_tokens: Vec<AddedToken>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ModelSection {
|
||||
vocab: HashMap<String, u32>,
|
||||
merges: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AddedToken {
|
||||
id: u32,
|
||||
content: String,
|
||||
special: bool,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
pub fn from_file(path: &Path) -> Self {
|
||||
let data = std::fs::read_to_string(path)
|
||||
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
|
||||
let tj: TokenizerJson = serde_json::from_str(&data)
|
||||
.unwrap_or_else(|e| panic!("failed to parse tokenizer.json: {e}"));
|
||||
|
||||
// Build encoder: token bytes → ID
|
||||
let mut encoder = HashMap::new();
|
||||
for (token_str, &id) in &tj.model.vocab {
|
||||
let bytes = token_str_to_bytes(token_str);
|
||||
encoder.insert(bytes, id);
|
||||
}
|
||||
|
||||
// Build decoder: ID → token bytes
|
||||
let max_id = tj.model.vocab.values().copied().max().unwrap_or(0);
|
||||
let added_max = tj.added_tokens.iter().map(|t| t.id).max().unwrap_or(0);
|
||||
let vocab_size = (max_id.max(added_max) + 1) as usize;
|
||||
let mut decoder = vec![vec![]; vocab_size];
|
||||
for (token_str, &id) in &tj.model.vocab {
|
||||
decoder[id as usize] = token_str_to_bytes(token_str);
|
||||
}
|
||||
|
||||
// Parse merges
|
||||
let mut merge_ranks = HashMap::new();
|
||||
for (rank, merge_line) in tj.model.merges.iter().enumerate() {
|
||||
let parts: Vec<&str> = merge_line.splitn(2, ' ').collect();
|
||||
if parts.len() != 2 { continue; }
|
||||
let a_bytes = token_str_to_bytes(parts[0]);
|
||||
let b_bytes = token_str_to_bytes(parts[1]);
|
||||
if let (Some(&a_id), Some(&b_id)) = (encoder.get(&a_bytes), encoder.get(&b_bytes)) {
|
||||
merge_ranks.insert((a_id, b_id), rank);
|
||||
}
|
||||
}
|
||||
|
||||
// Special tokens
|
||||
let mut special_tokens = HashMap::new();
|
||||
let mut special_token_ids = HashMap::new();
|
||||
let mut eos_token_id = None;
|
||||
for at in &tj.added_tokens {
|
||||
if at.special {
|
||||
special_tokens.insert(at.content.clone(), at.id);
|
||||
special_token_ids.insert(at.id, at.content.clone());
|
||||
decoder.resize(decoder.len().max(at.id as usize + 1), vec![]);
|
||||
decoder[at.id as usize] = at.content.as_bytes().to_vec();
|
||||
if at.content == "<|endoftext|>" || at.content == "<|end_of_text|>" {
|
||||
eos_token_id = Some(at.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GPT-2 pre-tokenization regex.
|
||||
// The original uses (?!\S) lookahead which Rust regex doesn't support.
|
||||
// Simplified: collapse trailing whitespace into one match. Functionally equivalent
|
||||
// for BPE since each whitespace chunk gets encoded independently anyway.
|
||||
let pre_tokenize_re = Regex::new(
|
||||
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+"
|
||||
).unwrap();
|
||||
|
||||
Self {
|
||||
encoder,
|
||||
decoder,
|
||||
merge_ranks,
|
||||
special_tokens,
|
||||
special_token_ids,
|
||||
pre_tokenize_re,
|
||||
eos_token_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encode(&self, text: &str) -> Vec<u32> {
|
||||
let mut tokens = Vec::new();
|
||||
|
||||
// Check for special tokens first (split around them)
|
||||
let mut remaining = text;
|
||||
while !remaining.is_empty() {
|
||||
// Find earliest special token
|
||||
let mut earliest: Option<(usize, &str, u32)> = None;
|
||||
for (st, &id) in &self.special_tokens {
|
||||
if let Some(pos) = remaining.find(st.as_str()) {
|
||||
if earliest.is_none() || pos < earliest.unwrap().0 {
|
||||
earliest = Some((pos, st, id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((pos, st, id)) = earliest {
|
||||
if pos > 0 {
|
||||
self.encode_ordinary(&remaining[..pos], &mut tokens);
|
||||
}
|
||||
tokens.push(id);
|
||||
remaining = &remaining[pos + st.len()..];
|
||||
} else {
|
||||
self.encode_ordinary(remaining, &mut tokens);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tokens
|
||||
}
|
||||
|
||||
fn encode_ordinary(&self, text: &str, out: &mut Vec<u32>) {
|
||||
for mat in self.pre_tokenize_re.find_iter(text) {
|
||||
let word = mat.as_str();
|
||||
let word_bytes: Vec<u8> = word.bytes().collect();
|
||||
let mut token_ids: Vec<u32> = word_bytes.iter().map(|&b| {
|
||||
*self.encoder.get(&vec![b]).unwrap_or_else(|| {
|
||||
panic!("byte {b} not in vocab")
|
||||
})
|
||||
}).collect();
|
||||
|
||||
// BPE merges
|
||||
loop {
|
||||
if token_ids.len() < 2 { break; }
|
||||
let mut best_rank = usize::MAX;
|
||||
let mut best_idx = 0;
|
||||
for i in 0..token_ids.len() - 1 {
|
||||
if let Some(&rank) = self.merge_ranks.get(&(token_ids[i], token_ids[i + 1])) {
|
||||
if rank < best_rank {
|
||||
best_rank = rank;
|
||||
best_idx = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
if best_rank == usize::MAX { break; }
|
||||
|
||||
let merged_bytes = [
|
||||
self.decoder[token_ids[best_idx] as usize].as_slice(),
|
||||
self.decoder[token_ids[best_idx + 1] as usize].as_slice(),
|
||||
].concat();
|
||||
let merged_id = *self.encoder.get(&merged_bytes).unwrap_or_else(|| {
|
||||
panic!("merged token not in vocab");
|
||||
});
|
||||
token_ids[best_idx] = merged_id;
|
||||
token_ids.remove(best_idx + 1);
|
||||
}
|
||||
|
||||
out.extend_from_slice(&token_ids);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode(&self, token_ids: &[u32]) -> String {
|
||||
let mut bytes = Vec::new();
|
||||
for &id in token_ids {
|
||||
if let Some(b) = self.decoder.get(id as usize) {
|
||||
bytes.extend_from_slice(b);
|
||||
}
|
||||
}
|
||||
String::from_utf8_lossy(&bytes).into_owned()
|
||||
}
|
||||
|
||||
pub fn eos_token_id(&self) -> Option<u32> {
|
||||
self.eos_token_id
|
||||
}
|
||||
|
||||
pub fn vocab_size(&self) -> usize {
|
||||
self.decoder.len()
|
||||
}
|
||||
|
||||
pub fn special_token_id(&self, name: &str) -> Option<u32> {
|
||||
self.special_tokens.get(name).copied()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a token string from HF vocab (which uses Unicode replacements for bytes)
|
||||
/// back to raw bytes. GPT-2 uses a byte-to-unicode mapping where e.g. byte 0x20 (space)
|
||||
/// is represented as 'Ġ' (U+0120).
|
||||
fn token_str_to_bytes(s: &str) -> Vec<u8> {
|
||||
s.chars().map(|c| unicode_to_byte(c)).collect()
|
||||
}
|
||||
|
||||
fn unicode_to_byte(c: char) -> u8 {
|
||||
let u = c as u32;
|
||||
// GPT-2 byte encoder: maps bytes 0-255 to specific Unicode code points.
|
||||
// Printable ASCII bytes map to themselves. Others are shifted to 256+.
|
||||
match u {
|
||||
0x21..=0x7E => u as u8, // '!' to '~'
|
||||
0xA1..=0xAC => u as u8, // '¡' to '¬'
|
||||
0xAE..=0xFF => u as u8, // '®' to 'ÿ'
|
||||
// Shifted bytes: 0x100 + original_byte for bytes not in the above ranges
|
||||
0x100..=0x1FF => (u - 0x100) as u8 + {
|
||||
// The shift mapping: byte values 0..=32, 127..=160, 173
|
||||
// are shifted to 256..=288, 289+, etc.
|
||||
0
|
||||
},
|
||||
_ => {
|
||||
// Fallback: for the GPT-2 byte encoder, specific mappings
|
||||
byte_from_unicode_gpt2(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn byte_from_unicode_gpt2(c: char) -> u8 {
|
||||
// Build the inverse of GPT-2's bytes_to_unicode mapping.
|
||||
// The mapping assigns printable chars to themselves and shifts unprintable bytes.
|
||||
let u = c as u32;
|
||||
// Direct ASCII printable + Latin-1 supplement printable ranges map identity
|
||||
if (0x21..=0x7E).contains(&u) { return u as u8; }
|
||||
if (0xA1..=0xAC).contains(&u) { return u as u8; }
|
||||
if (0xAE..=0xFF).contains(&u) { return u as u8; }
|
||||
|
||||
// Shifted range: the remaining 68 bytes (0-32, 127-160, 173) get mapped to 256..=323
|
||||
static SHIFTED_BYTES: &[u8] = &[
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31, 32, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136,
|
||||
137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
|
||||
154, 155, 156, 157, 158, 159, 160, 173,
|
||||
];
|
||||
let shifted_start = 256u32;
|
||||
if u >= shifted_start && u < shifted_start + SHIFTED_BYTES.len() as u32 {
|
||||
return SHIFTED_BYTES[(u - shifted_start) as usize];
|
||||
}
|
||||
|
||||
// Shouldn't reach here for valid GPT-2 tokenizer
|
||||
c as u8
|
||||
}
|
||||
3
crates/xserv-tokenizer/src/lib.rs
Normal file
3
crates/xserv-tokenizer/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod bpe;
|
||||
|
||||
pub use bpe::Tokenizer;
|
||||
@@ -35,6 +35,16 @@ __global__ void silu_bf16(const __nv_bfloat16* x, __nv_bfloat16* out, int n) {
|
||||
if (idx < n) out[idx] = __float2bfloat16(silu_f(__bfloat162float(x[idx])));
|
||||
}
|
||||
|
||||
__global__ void scale_f32_kernel(const float* x, float* out, float scale, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = x[idx] * scale;
|
||||
}
|
||||
|
||||
__global__ void scale_bf16_kernel(const __nv_bfloat16* x, __nv_bfloat16* out, float scale, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(x[idx]) * scale);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_gelu_f32(const void* x, void* out, int n, void* stream) {
|
||||
@@ -63,4 +73,18 @@ void launch_silu_bf16(const void* x, void* out, int n, void* stream) {
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
||||
}
|
||||
|
||||
void launch_scale_f32(const void* x, void* out, float scale, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
scale_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)x, (float*)out, scale, n);
|
||||
}
|
||||
|
||||
void launch_scale_bf16(const void* x, void* out, float scale, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
scale_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, scale, n);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
53
csrc/attention/causal_mask.cu
Normal file
53
csrc/attention/causal_mask.cu
Normal file
@@ -0,0 +1,53 @@
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
// Apply causal mask: set scores[row][col] = -inf where col > row + offset.
|
||||
// offset is used for KV cache: when query starts at position `offset`,
|
||||
// we allow attending to positions [0, offset + row].
|
||||
// scores: [batch, rows, cols] (flattened batch×heads)
|
||||
|
||||
__global__ void causal_mask_f32(
|
||||
float* __restrict__ scores,
|
||||
int rows, int cols, int offset
|
||||
) {
|
||||
int batch_idx = blockIdx.z;
|
||||
int row = blockIdx.y;
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (col < cols && col > row + offset) {
|
||||
scores[batch_idx * rows * cols + row * cols + col] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void causal_mask_bf16(
|
||||
__nv_bfloat16* __restrict__ scores,
|
||||
int rows, int cols, int offset
|
||||
) {
|
||||
int batch_idx = blockIdx.z;
|
||||
int row = blockIdx.y;
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (col < cols && col > row + offset) {
|
||||
// BF16 doesn't have proper -inf literal, use a very large negative
|
||||
scores[batch_idx * rows * cols + row * cols + col] = __float2bfloat16(-1e9f);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_causal_mask_f32(void* scores, int batch, int rows, int cols,
|
||||
int offset, void* stream) {
|
||||
int block = 256;
|
||||
dim3 grid((cols + block - 1) / block, rows, batch);
|
||||
causal_mask_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(float*)scores, rows, cols, offset);
|
||||
}
|
||||
|
||||
void launch_causal_mask_bf16(void* scores, int batch, int rows, int cols,
|
||||
int offset, void* stream) {
|
||||
int block = 256;
|
||||
dim3 grid((cols + block - 1) / block, rows, batch);
|
||||
causal_mask_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)scores, rows, cols, offset);
|
||||
}
|
||||
|
||||
}
|
||||
92
docs/05-attention.md
Normal file
92
docs/05-attention.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# Phase 5: Naive Attention Kernel — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
实现标准 Multi-Head Attention(不做 Flash/Paged 优化),用组合式方法(GEMM + Softmax)完成。这是理解 attention 计算流程的基础,也是后续 Flash Attention 的 baseline。
|
||||
|
||||
## 计算流程
|
||||
|
||||
```
|
||||
Input: Q [B, H, S, D], K [B, H, S, D], V [B, H, S, D]
|
||||
B=batch, H=num_heads, S=seq_len, D=head_dim
|
||||
|
||||
1. scores = Q @ K^T / sqrt(D) → [B, H, S, S]
|
||||
2. scores += causal_mask → 上三角置为 -inf
|
||||
3. weights = softmax(scores, dim=-1) → [B, H, S, S]
|
||||
4. output = weights @ V → [B, H, S, D]
|
||||
```
|
||||
|
||||
## 设计选择
|
||||
|
||||
### 组合式实现(Phase 3 GEMM + Phase 4 Softmax)
|
||||
|
||||
不写新的 fused CUDA kernel,而是复用已有的 matmul 和 softmax:
|
||||
- `scores = batched_matmul(Q, K^T)` — 需要支持 batched GEMM
|
||||
- `masked_fill(scores, causal_mask, -inf)` — 新的逐元素 kernel
|
||||
- `softmax(scores)` — 复用 Phase 4
|
||||
- `output = batched_matmul(weights, V)` — 复用 batched GEMM
|
||||
|
||||
这意味着需要先扩展 matmul 支持 batched GEMM(cublasGemmStridedBatchedEx)。
|
||||
|
||||
### Causal Mask
|
||||
|
||||
不显式构造 mask 矩阵。写一个 kernel:
|
||||
```
|
||||
if (col > row + offset) score = -infinity
|
||||
```
|
||||
其中 offset 用于支持 KV cache 场景(decode 时 query 的 row 偏移)。
|
||||
|
||||
### Batched GEMM via cuBLAS
|
||||
|
||||
`cublasGemmStridedBatchedEx` 在一个 batch 维度上并行执行多个 GEMM:
|
||||
```
|
||||
C[b] = A[b] @ B[b] for b = 0..batch_count
|
||||
stride_a = M * K, stride_b = K * N, stride_c = M * N
|
||||
```
|
||||
|
||||
Attention 中 batch 维度 = B * H(batch_size × num_heads)。
|
||||
|
||||
## 文件布局
|
||||
|
||||
```
|
||||
csrc/attention/
|
||||
└── causal_mask.cu # causal mask fill kernel
|
||||
|
||||
crates/xserv-kernels/src/
|
||||
├── gemm.rs # 扩展: batched_matmul
|
||||
├── attention.rs # NEW: multi_head_attention()
|
||||
└── causal_mask.rs # NEW: causal mask apply
|
||||
```
|
||||
|
||||
## API 设计
|
||||
|
||||
```rust
|
||||
/// Multi-head attention (naive, materializes S×S scores).
|
||||
/// q, k, v: [batch, num_heads, seq_len, head_dim]
|
||||
/// Returns: [batch, num_heads, seq_len, head_dim]
|
||||
pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor;
|
||||
|
||||
/// Batched matmul: A[b] @ B[b] for all b.
|
||||
/// a: [..., M, K], b: [..., K, N] → [..., M, N]
|
||||
pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor;
|
||||
```
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] batched_matmul: [4,8,32,64]×[4,8,64,32] → max_err 2.7e-7
|
||||
- [x] attention (non-causal): B=1,H=2,S=8,D=16 → max_err 4.5e-8
|
||||
- [x] attention (causal): B=1,H=2,S=16,D=32 → max_err 3.0e-8
|
||||
- [x] attention (causal, larger): B=2,H=4,S=64,D=64 → max_err 6.0e-8
|
||||
- [x] causal mask 语义: position 0 只能看到 token 0,output[0] == V[0] → exact
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **`to_device` 不应强制 contiguous**:最初 `to_device()` 会先调 `contiguous()`,而 GPU 的 `contiguous()` 又调 `to_device(Cpu)`,导致无限递归栈溢出。修复:`to_device()` 直接传输 raw storage,保留 strides/offset,用户需要时自己调 `contiguous()`。GPU `contiguous()` 现在走 GPU→CPU→CPU contiguous→CPU→GPU 路径——正确但低效,Phase 15 需要写 GPU contiguous kernel。
|
||||
|
||||
2. **Batched GEMM via `cublasGemmStridedBatchedEx`**:row-major trick 同 Phase 3,额外参数是 stride(元素数,不是字节)。stride_a = M×K, stride_b = K×N, stride_c = M×N。注意初始版本错误地乘了 `elem_size`,cuBLAS 的 stride 单位是元素。
|
||||
|
||||
3. **Attention 的组合式实现足够验证正确性**:没有写 fused kernel,而是复用 `batched_matmul` + `scale` + `causal_mask` + `softmax`。精度极好(max_err < 1e-7),因为每步都在 FP32 中完成。缺点是 S×S score 矩阵完全 materialize(O(S²) 显存),Flash Attention 会解决。
|
||||
|
||||
4. **Scale kernel 的必要性**:原本想在 CPU 上做 scale(round-trip),但那太慢了。加了 `scale_f32/bf16` 逐元素 CUDA kernel。未来可以把 scale 合进 GEMM 的 alpha 参数,省一次 kernel launch。
|
||||
|
||||
5. **Causal mask 的 offset 设计**:`col > row + offset` 中的 offset 为 KV cache 场景预留。Decode 时 Q 只有 1 行但 KV cache 有前 S 行,offset = kv_len - q_len 确保 decode query 能看到所有 cached tokens。
|
||||
69
docs/06-model-loading.md
Normal file
69
docs/06-model-loading.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# Phase 6: Model Loading — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
从 HuggingFace safetensors 文件加载模型权重到 GPU Tensor。解析 config.json 获取模型结构参数。
|
||||
|
||||
## Crate: `xserv-model`
|
||||
|
||||
```
|
||||
crates/xserv-model/src/
|
||||
├── lib.rs
|
||||
├── config.rs # ModelConfig from config.json
|
||||
├── loader.rs # safetensors weight loading
|
||||
└── gpt2.rs # (Phase 8) GPT-2 model definition
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
- `safetensors` crate: parse safetensors format
|
||||
- `serde` + `serde_json`: deserialize config.json
|
||||
- `memmap2`: mmap for zero-copy file access (safetensors uses this internally)
|
||||
|
||||
## Weight Loading Flow
|
||||
|
||||
```
|
||||
safetensors file (disk)
|
||||
→ safetensors crate parses header (tensor names, shapes, dtypes, offsets)
|
||||
→ mmap raw data
|
||||
→ for each tensor:
|
||||
→ read bytes at offset
|
||||
→ create CPU Tensor from raw bytes
|
||||
→ .to_device(Cuda(0)) → GPU Tensor
|
||||
→ return HashMap<String, Tensor>
|
||||
```
|
||||
|
||||
## Config Parsing
|
||||
|
||||
```rust
|
||||
#[derive(Deserialize)]
|
||||
pub struct ModelConfig {
|
||||
pub architectures: Option<Vec<String>>,
|
||||
pub model_type: Option<String>,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: Option<usize>,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: Option<usize>,
|
||||
pub num_hidden_layers: usize,
|
||||
pub vocab_size: usize,
|
||||
pub max_position_embeddings: Option<usize>,
|
||||
pub layer_norm_eps: Option<f64>,
|
||||
pub rms_norm_eps: Option<f64>,
|
||||
pub rope_theta: Option<f64>,
|
||||
pub tie_word_embeddings: Option<bool>,
|
||||
}
|
||||
```
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] Load GPT-2 124M: 160 tensors loaded successfully
|
||||
- [x] Parse GPT-2 config.json: hidden=768, layers=12, heads=12, vocab=50257
|
||||
- [x] Sharded loading path implemented (for larger models)
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **GPT-2 vs modern HF config naming**:GPT-2 uses `n_embd`/`n_head`/`n_layer`/`n_positions`,而不是 `hidden_size`/`num_attention_heads` 等。ModelConfig 需要支持两套命名并提供统一的 accessor methods(`hidden()`, `num_heads()` 等)。
|
||||
|
||||
2. **safetensors 零拷贝读取**:`safetensors` crate 直接 mmap 文件,解析 header 得到 tensor 的 offset 和 shape,然后 zero-copy 读取 raw bytes。对于 GPT-2 的 500MB 权重文件,加载速度很快。
|
||||
|
||||
3. **模型下载的网络问题**:HuggingFace 在中国网络下不可达。使用 modelscope.cn 或 hf-mirror.com 作为替代。大文件(>100MB)的 redirect 到 CDN 可能也会失败,modelscope 的 snapshot_download 更可靠。
|
||||
57
docs/07-tokenizer.md
Normal file
57
docs/07-tokenizer.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# Phase 7: BPE Tokenizer — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
从零实现 Byte-Pair Encoding tokenizer,兼容 HuggingFace `tokenizer.json` 格式。支持 GPT-2 和 Qwen3。
|
||||
|
||||
## Crate: `xserv-tokenizer`
|
||||
|
||||
```
|
||||
crates/xserv-tokenizer/src/
|
||||
├── lib.rs
|
||||
├── bpe.rs # BPE encode/decode core algorithm
|
||||
└── chat.rs # Chat template formatting
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
- `serde` + `serde_json`: parse tokenizer.json
|
||||
- `regex`: pre-tokenization patterns
|
||||
|
||||
## BPE Algorithm
|
||||
|
||||
### Encode
|
||||
1. Pre-tokenize: split text by regex (GPT-2 pattern)
|
||||
2. Each word → byte sequence → initial token list (one token per byte)
|
||||
3. Repeatedly merge highest-priority pair until no more merges
|
||||
4. Map merged tokens to IDs via vocab
|
||||
|
||||
### Decode
|
||||
Token IDs → lookup vocab → concatenate bytes → UTF-8 decode
|
||||
|
||||
## Key Data Structures
|
||||
|
||||
```rust
|
||||
pub struct Tokenizer {
|
||||
vocab: HashMap<Vec<u8>, u32>, // token bytes → ID
|
||||
vocab_rev: Vec<Vec<u8>>, // ID → token bytes
|
||||
merges: Vec<(Vec<u8>, Vec<u8>)>, // ordered merge rules
|
||||
merge_ranks: HashMap<(u32, u32), usize>, // (id_a, id_b) → priority
|
||||
special_tokens: HashMap<String, u32>,
|
||||
pre_tokenize_regex: Regex,
|
||||
}
|
||||
```
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] Encode + decode roundtrip verified (GPT-2 tokenizer, English text)
|
||||
- [x] Special tokens handled (endoftext)
|
||||
- [x] Integrated into GPT-2 inference pipeline, generates coherent text
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **GPT-2 byte-to-unicode 映射**:GPT-2 的 vocab 中,每个 byte 都映射到一个 Unicode 字符。可打印 ASCII (0x21-0x7E) 映射到自身,其余字节(空格、控制字符等)映射到 U+0100 以上的 Unicode 码点。解码时需要反向映射。这个映射表是 BPE tokenizer 正确性的关键。
|
||||
|
||||
2. **Rust regex 不支持 lookahead**:GPT-2 的 pre-tokenization regex 使用了 `(?!\S)` lookahead,Rust 的 `regex` crate 不支持。简化为去掉 lookahead 后功能等价(whitespace 仍然被正确分词)。如果需要精确匹配 Python 行为,需要 `fancy-regex` crate。
|
||||
|
||||
3. **BPE merge 的 O(n²) 复杂度**:当前实现每次 merge 扫描整个 token 序列找最高优先级 pair,复杂度 O(n² × |merges|)。对于短文本够用,长文本需要 priority queue 优化。推理场景中 prompt 通常 < 10K tokens,暂时可接受。
|
||||
71
docs/08-gpt2.md
Normal file
71
docs/08-gpt2.md
Normal file
@@ -0,0 +1,71 @@
|
||||
# Phase 8: GPT-2 Complete Inference — Design Document (Milestone ①)
|
||||
|
||||
## Goal
|
||||
|
||||
Wire everything together: load GPT-2 124M, tokenize input, run forward pass, sample tokens, decode output. First time seeing the model "speak".
|
||||
|
||||
## Model Architecture (GPT-2 124M)
|
||||
|
||||
```
|
||||
hidden_size = 768
|
||||
num_heads = 12
|
||||
num_layers = 12
|
||||
vocab_size = 50257
|
||||
max_position_embeddings = 1024
|
||||
activation = GELU
|
||||
normalization = LayerNorm (pre-LN)
|
||||
tied embeddings (lm_head == wte)
|
||||
```
|
||||
|
||||
## Forward Pass
|
||||
|
||||
```
|
||||
tokens [S]
|
||||
→ wte[tokens] + wpe[0..S] → [S, 768]
|
||||
→ for each layer:
|
||||
residual = x
|
||||
x = layernorm(x, ln_1)
|
||||
x = attention(x) # Q,K,V from linear, MHA, output linear
|
||||
x = x + residual
|
||||
residual = x
|
||||
x = layernorm(x, ln_2)
|
||||
x = mlp(x) # linear→GELU→linear
|
||||
x = x + residual
|
||||
→ layernorm(x, ln_f)
|
||||
→ logits = x @ wte.T → [S, 50257]
|
||||
→ sample(logits[-1]) → next token
|
||||
```
|
||||
|
||||
## Sampling
|
||||
|
||||
- Greedy: argmax
|
||||
- Temperature: logits / T → softmax → sample
|
||||
- Top-K: keep top-k logits, rest = -inf
|
||||
- Top-P: sorted by prob, cumsum ≤ p
|
||||
|
||||
## CLI Binary
|
||||
|
||||
```
|
||||
$ cargo run --release --bin xserv-cli -- --model path/to/gpt2
|
||||
|
||||
xserv> The future of AI is
|
||||
GPT-2> ...generated text...
|
||||
```
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] Greedy generation produces coherent English text
|
||||
- [x] Interactive CLI works (pipe and interactive mode)
|
||||
- [x] Multiple prompts verified: "The future of AI is", "Once upon a time"
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **QKV split + head reshape 的 layout 陷阱(最关键的 bug)**:GPT-2 的 `c_attn` 输出 `[S, 3H]` 需要 split 成 Q/K/V 再 reshape 成 `[1, num_heads, S, head_dim]`。关键错误:从 `[S, num_heads, head_dim]` 直接 `reshape` 到 `[1, num_heads, S, head_dim]` 不等于 transpose!Reshape 只是重新解释 flat data 的 shape,不会重排数据。必须手动按 `[batch, head, seq, dim]` 的目标 layout 写入数据。同理 merge_heads 也需要手动重排。
|
||||
|
||||
2. **CPU round-trip 作为 correctness first 策略**:`add_tensors`、`add_bias`、`split_qkv`、`merge_heads` 都通过 CPU round-trip 实现。虽然慢(每次都有 GPU→CPU→GPU 拷贝),但确保了正确性。Phase 15 会写专门的 CUDA kernel 替换这些操作。
|
||||
|
||||
3. **GPT-2 的 Conv1D 权重布局**:GPT-2 用 `Conv1D` 而非 `Linear`,权重存为 `[in, out]`(不是标准 Linear 的 `[out, in]`)。计算方式是 `x @ weight`(不需要转置)。这和 Qwen3/LLaMA 的 `[out, in]` 布局不同——Phase 10 需要注意。
|
||||
|
||||
4. **Greedy decoding 的重复问题**:GPT-2 124M 在 greedy decoding 下极易陷入循环("The world was a place of great danger, and...")。这是已知行为,temperature + top-k/top-p sampling 可以缓解。当前实现只有 greedy,sampling 将在后续添加。
|
||||
|
||||
5. **无 KV Cache 的性能代价**:每生成一个 token 都要重新跑完整 forward pass(O(S²) attention)。50 tokens 的生成需要 50 次 full forward,每次的 attention 复杂度还在增长。Phase 9 的 KV Cache 会将 decode 降到 O(S) per token。
|
||||
Reference in New Issue
Block a user