phase 10: Qwen3-8B support (Milestone ②)
Qwen3 model (qwen3.rs): - RMSNorm + QK normalization (per-head q_norm/k_norm) - GQA: 32 Q heads, 8 KV heads, repeat_kv for attention - SwiGLU FFN: gate_proj → SiLU → * up_proj → down_proj - RoPE with transpose for [1,H,S,D] ↔ [S,H,D] layout - BF16 forward pass, [out,in] weight layout via linear_t - No attention bias (attention_bias=false) Tokenizer fixes: - Fixed unicode_to_byte: shifted bytes now use correct inverse lookup table - MergeEntry supports both string and array formats - Both GPT-2 and Qwen3 tokenizers work correctly (English + Chinese) KVCache refactored: - Dtype-agnostic: stores raw bytes per-head, works for F32 and BF16 - append_kv_tensor/get_kv_tensors use Tensor directly CLI updated: - Auto-detects model type from config.json (gpt2 vs qwen3) - Supports both GPT-2 (F32) and Qwen3 (BF16) Verified: Qwen3-8B generates coherent English and Chinese on single RTX 5090. 61/61 tests pass, GPT-2 performance no regression. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -143,7 +143,7 @@ fn generate_with_cache(
|
||||
) -> (Vec<u32>, u128, Vec<u128>) {
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(), config.num_heads(), config.head_dim(),
|
||||
Device::Cuda(0),
|
||||
xserv_tensor::DType::F32, Device::Cuda(0),
|
||||
);
|
||||
|
||||
// Prefill
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use std::io::{self, Write};
|
||||
use std::path::PathBuf;
|
||||
use xserv_model::gpt2::{sample_greedy, KVCache};
|
||||
use xserv_model::{loader, GPT2, ModelConfig};
|
||||
use xserv_tensor::Device;
|
||||
use xserv_model::{loader, KVCache, ModelConfig};
|
||||
use xserv_tensor::{DType, Device};
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
fn main() {
|
||||
@@ -25,43 +24,59 @@ fn main() {
|
||||
eprintln!("GPU: {} ({} MB free)", info.name, info.free_memory / 1024 / 1024);
|
||||
|
||||
let config = ModelConfig::from_file(&model_dir.join("config.json"));
|
||||
let model_type = config.model_type.as_deref().unwrap_or("unknown");
|
||||
eprintln!(
|
||||
"Model: {:?}, layers={}, hidden={}, heads={}, vocab={}",
|
||||
config.model_type,
|
||||
config.num_layers(),
|
||||
config.hidden(),
|
||||
config.num_heads(),
|
||||
config.vocab_size
|
||||
"Model: {model_type}, layers={}, hidden={}, heads={}/{} kv, vocab={}",
|
||||
config.num_layers(), config.hidden(), config.num_heads(),
|
||||
config.num_kv_heads(), config.vocab_size
|
||||
);
|
||||
|
||||
eprintln!("Loading weights...");
|
||||
let weights = loader::load_model_dir(&model_dir, Device::Cuda(0));
|
||||
eprintln!("Loaded {} tensors", weights.len());
|
||||
|
||||
let model = GPT2::from_weights(config.clone(), weights);
|
||||
let is_qwen3 = model_type.contains("qwen");
|
||||
let dtype = if is_qwen3 { DType::BF16 } else { DType::F32 };
|
||||
|
||||
// Build model
|
||||
enum Model {
|
||||
GPT2(xserv_model::GPT2),
|
||||
Qwen3(xserv_model::Qwen3),
|
||||
}
|
||||
let model = if is_qwen3 {
|
||||
Model::Qwen3(xserv_model::Qwen3::from_weights(config.clone(), weights))
|
||||
} else {
|
||||
Model::GPT2(xserv_model::GPT2::from_weights(config.clone(), weights))
|
||||
};
|
||||
|
||||
let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
eprintln!("Ready (KV cache enabled).\n");
|
||||
eprintln!("Ready (KV cache, dtype={dtype}).\n");
|
||||
|
||||
loop {
|
||||
print!("xserv> ");
|
||||
io::stdout().flush().unwrap();
|
||||
let mut input = String::new();
|
||||
if io::stdin().read_line(&mut input).unwrap() == 0 {
|
||||
break;
|
||||
}
|
||||
if io::stdin().read_line(&mut input).unwrap() == 0 { break; }
|
||||
let input = input.trim();
|
||||
if input.is_empty() { continue; }
|
||||
if input == "quit" || input == "exit" { break; }
|
||||
|
||||
let token_ids = tokenizer.encode(input);
|
||||
let kv_heads = if is_qwen3 { config.num_kv_heads() } else { config.num_heads() };
|
||||
let mut cache = KVCache::new(
|
||||
config.num_layers(), config.num_heads(), config.head_dim(),
|
||||
Device::Cuda(0),
|
||||
config.num_layers(), kv_heads, config.head_dim(), dtype, Device::Cuda(0),
|
||||
);
|
||||
|
||||
// Prefill
|
||||
let logits = model.forward_with_cache(&token_ids, &mut cache);
|
||||
let mut next = sample_greedy(&logits);
|
||||
// Prefill + decode
|
||||
let logits = match &model {
|
||||
Model::GPT2(m) => m.forward_with_cache(&token_ids, &mut cache),
|
||||
Model::Qwen3(m) => m.forward_with_cache(&token_ids, &mut cache),
|
||||
};
|
||||
let mut next = match &model {
|
||||
Model::GPT2(_) => xserv_model::gpt2::sample_greedy(&logits),
|
||||
Model::Qwen3(_) => xserv_model::qwen3::sample_greedy(&logits),
|
||||
};
|
||||
|
||||
print!("{input}");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
@@ -72,8 +87,14 @@ fn main() {
|
||||
|
||||
if tokenizer.eos_token_id() == Some(next) { break; }
|
||||
|
||||
let logits = model.forward_with_cache(&[next], &mut cache);
|
||||
next = sample_greedy(&logits);
|
||||
let logits = match &model {
|
||||
Model::GPT2(m) => m.forward_with_cache(&[next], &mut cache),
|
||||
Model::Qwen3(m) => m.forward_with_cache(&[next], &mut cache),
|
||||
};
|
||||
next = match &model {
|
||||
Model::GPT2(_) => xserv_model::gpt2::sample_greedy(&logits),
|
||||
Model::Qwen3(_) => xserv_model::qwen3::sample_greedy(&logits),
|
||||
};
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
@@ -30,61 +30,90 @@ struct GPT2Block {
|
||||
}
|
||||
|
||||
pub struct KVCache {
|
||||
// Per layer, per head: k[layer][head] has seq_len * head_dim floats
|
||||
k: Vec<Vec<Vec<f32>>>, // [num_layers][num_heads][seq_len * head_dim]
|
||||
v: Vec<Vec<Vec<f32>>>,
|
||||
// Per layer, per head: raw bytes (works for both f32 and bf16)
|
||||
k: Vec<Vec<Vec<u8>>>, // [num_layers][num_heads][seq_len * head_dim * elem_size]
|
||||
v: Vec<Vec<Vec<u8>>>,
|
||||
len: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
elem_size: usize,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl KVCache {
|
||||
pub fn new(num_layers: usize, num_heads: usize, head_dim: usize, device: Device) -> Self {
|
||||
pub fn new(num_layers: usize, num_heads: usize, head_dim: usize, dtype: DType, device: Device) -> Self {
|
||||
Self {
|
||||
k: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(),
|
||||
v: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(),
|
||||
len: 0,
|
||||
num_heads,
|
||||
head_dim,
|
||||
elem_size: dtype.size_bytes(),
|
||||
dtype,
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn seq_len(&self) -> usize { self.len }
|
||||
|
||||
/// Append new K/V data. k_new is in [1, H, new_tokens, D] layout (flat).
|
||||
fn append_kv(&mut self, layer: usize, k_new: &[f32], v_new: &[f32], new_tokens: usize) {
|
||||
/// Append from a CPU tensor with shape [1, H, new_tokens, D].
|
||||
pub fn append_kv_tensor(&mut self, layer: usize, k_cpu: &Tensor, v_cpu: &Tensor, new_tokens: usize) {
|
||||
let hd = self.head_dim;
|
||||
let es = self.elem_size;
|
||||
let k_bytes = k_cpu.storage().as_cpu_bytes();
|
||||
let v_bytes = v_cpu.storage().as_cpu_bytes();
|
||||
let chunk = new_tokens * hd * es;
|
||||
for h in 0..self.num_heads {
|
||||
let off = h * new_tokens * hd;
|
||||
self.k[layer][h].extend_from_slice(&k_new[off..off + new_tokens * hd]);
|
||||
self.v[layer][h].extend_from_slice(&v_new[off..off + new_tokens * hd]);
|
||||
let off = h * chunk;
|
||||
self.k[layer][h].extend_from_slice(&k_bytes[off..off + chunk]);
|
||||
self.v[layer][h].extend_from_slice(&v_bytes[off..off + chunk]);
|
||||
}
|
||||
if layer == 0 {
|
||||
self.len += new_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
/// Reconstruct [1, H, seq_len, D] tensors from per-head cache.
|
||||
fn get_kv_tensors(&self, layer: usize) -> (Tensor, Tensor) {
|
||||
/// Reconstruct [1, H, seq_len, D] tensors.
|
||||
pub fn get_kv_tensors(&self, layer: usize) -> (Tensor, Tensor) {
|
||||
let sl = self.len;
|
||||
let hd = self.head_dim;
|
||||
let nh = self.num_heads;
|
||||
let mut k_data = vec![0.0f32; nh * sl * hd];
|
||||
let mut v_data = vec![0.0f32; nh * sl * hd];
|
||||
let es = self.elem_size;
|
||||
let head_bytes = sl * hd * es;
|
||||
let total = nh * head_bytes;
|
||||
let mut k_data = vec![0u8; total];
|
||||
let mut v_data = vec![0u8; total];
|
||||
for h in 0..nh {
|
||||
let off = h * sl * hd;
|
||||
k_data[off..off + sl * hd].copy_from_slice(&self.k[layer][h]);
|
||||
v_data[off..off + sl * hd].copy_from_slice(&self.v[layer][h]);
|
||||
let off = h * head_bytes;
|
||||
k_data[off..off + head_bytes].copy_from_slice(&self.k[layer][h]);
|
||||
v_data[off..off + head_bytes].copy_from_slice(&self.v[layer][h]);
|
||||
}
|
||||
let shape = &[1, nh, sl, hd];
|
||||
let k = Tensor::from_slice(&k_data, shape).to_device(self.device);
|
||||
let v = Tensor::from_slice(&v_data, shape).to_device(self.device);
|
||||
let k = tensor_from_raw_bytes(&k_data, shape, self.dtype).to_device(self.device);
|
||||
let v = tensor_from_raw_bytes(&v_data, shape, self.dtype).to_device(self.device);
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
fn tensor_from_raw_bytes(bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
let data: &[f32] = unsafe {
|
||||
std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4)
|
||||
};
|
||||
Tensor::from_slice(data, shape)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data: &[half::bf16] = unsafe {
|
||||
std::slice::from_raw_parts(bytes.as_ptr() as *const half::bf16, bytes.len() / 2)
|
||||
};
|
||||
Tensor::from_slice(data, shape)
|
||||
}
|
||||
_ => panic!("unsupported dtype for KV cache"),
|
||||
}
|
||||
}
|
||||
|
||||
impl GPT2 {
|
||||
pub fn from_weights(config: ModelConfig, mut w: HashMap<String, Tensor>) -> Self {
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
@@ -181,11 +210,10 @@ impl GPT2 {
|
||||
let qkv = linear(&normed, &layer.attn_qkv_w, Some(&layer.attn_qkv_b));
|
||||
let (q, k_new, v_new) = split_qkv(&qkv, num_heads, head_dim, new_tokens);
|
||||
|
||||
// KV cache: append new K/V, use full cached K/V for attention
|
||||
let (k_full, v_full) = if let Some((cache, layer_idx)) = cache {
|
||||
let k_cpu = k_new.to_device(Device::Cpu);
|
||||
let v_cpu = v_new.to_device(Device::Cpu);
|
||||
cache.append_kv(layer_idx, k_cpu.as_slice::<f32>(), v_cpu.as_slice::<f32>(), new_tokens);
|
||||
cache.append_kv_tensor(layer_idx, &k_cpu, &v_cpu, new_tokens);
|
||||
cache.get_kv_tensors(layer_idx)
|
||||
} else {
|
||||
(k_new, v_new)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
pub mod config;
|
||||
pub mod gpt2;
|
||||
pub mod loader;
|
||||
pub mod qwen3;
|
||||
|
||||
pub use config::ModelConfig;
|
||||
pub use gpt2::{GPT2, KVCache};
|
||||
pub use qwen3::Qwen3;
|
||||
|
||||
278
crates/xserv-model/src/qwen3.rs
Normal file
278
crates/xserv-model/src/qwen3.rs
Normal file
@@ -0,0 +1,278 @@
|
||||
use std::collections::HashMap;
|
||||
use half::bf16;
|
||||
use xserv_kernels::*;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
use crate::config::ModelConfig;
|
||||
use crate::gpt2::KVCache;
|
||||
|
||||
pub struct Qwen3 {
|
||||
pub config: ModelConfig,
|
||||
embed_tokens: Tensor,
|
||||
layers: Vec<Qwen3Block>,
|
||||
norm: Tensor,
|
||||
lm_head: Tensor,
|
||||
rope_cache: RopeCache,
|
||||
}
|
||||
|
||||
struct Qwen3Block {
|
||||
input_norm: Tensor, // [hidden]
|
||||
q_proj_w: Tensor, // [num_heads*head_dim, hidden]
|
||||
k_proj_w: Tensor, // [num_kv_heads*head_dim, hidden]
|
||||
v_proj_w: Tensor,
|
||||
o_proj_w: Tensor, // [hidden, num_heads*head_dim]
|
||||
q_norm: Tensor, // [head_dim] — per-head QK norm
|
||||
k_norm: Tensor, // [head_dim]
|
||||
post_norm: Tensor, // [hidden]
|
||||
gate_proj_w: Tensor, // [intermediate, hidden]
|
||||
up_proj_w: Tensor,
|
||||
down_proj_w: Tensor, // [hidden, intermediate]
|
||||
}
|
||||
|
||||
impl Qwen3 {
|
||||
pub fn from_weights(config: ModelConfig, mut w: HashMap<String, Tensor>) -> Self {
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
|
||||
let embed_tokens = take(&mut w, "model.embed_tokens.weight");
|
||||
let norm = take(&mut w, "model.norm.weight");
|
||||
let lm_head = take(&mut w, "lm_head.weight");
|
||||
|
||||
let rope_cache = RopeCache::new(
|
||||
config.max_seq_len().min(8192), // limit for memory
|
||||
config.head_dim(),
|
||||
config.rope_theta.unwrap_or(1_000_000.0) as f32,
|
||||
);
|
||||
|
||||
let num_layers = config.num_layers();
|
||||
let mut layers = Vec::with_capacity(num_layers);
|
||||
for i in 0..num_layers {
|
||||
let p = format!("model.layers.{i}");
|
||||
layers.push(Qwen3Block {
|
||||
input_norm: take(&mut w, &format!("{p}.input_layernorm.weight")),
|
||||
q_proj_w: take(&mut w, &format!("{p}.self_attn.q_proj.weight")),
|
||||
k_proj_w: take(&mut w, &format!("{p}.self_attn.k_proj.weight")),
|
||||
v_proj_w: take(&mut w, &format!("{p}.self_attn.v_proj.weight")),
|
||||
o_proj_w: take(&mut w, &format!("{p}.self_attn.o_proj.weight")),
|
||||
q_norm: take(&mut w, &format!("{p}.self_attn.q_norm.weight")),
|
||||
k_norm: take(&mut w, &format!("{p}.self_attn.k_norm.weight")),
|
||||
post_norm: take(&mut w, &format!("{p}.post_attention_layernorm.weight")),
|
||||
gate_proj_w: take(&mut w, &format!("{p}.mlp.gate_proj.weight")),
|
||||
up_proj_w: take(&mut w, &format!("{p}.mlp.up_proj.weight")),
|
||||
down_proj_w: take(&mut w, &format!("{p}.mlp.down_proj.weight")),
|
||||
});
|
||||
}
|
||||
|
||||
Self { config, embed_tokens, layers, norm, lm_head, rope_cache }
|
||||
}
|
||||
|
||||
pub fn forward_with_cache(&self, token_ids: &[u32], cache: &mut KVCache) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
let pos_offset = cache.seq_len();
|
||||
let hidden = self.config.hidden();
|
||||
let num_heads = self.config.num_heads();
|
||||
let num_kv_heads = self.config.num_kv_heads();
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
// Q/K/V projections (no bias, weight is [out, in])
|
||||
let q = linear_t(&normed, &layer.q_proj_w);
|
||||
let k = linear_t(&normed, &layer.k_proj_w);
|
||||
let v = linear_t(&normed, &layer.v_proj_w);
|
||||
|
||||
// Reshape to [1, heads, seq, head_dim]
|
||||
let q = reshape_heads(&q, new_tokens, num_heads, head_dim);
|
||||
let k = reshape_heads(&k, new_tokens, num_kv_heads, head_dim);
|
||||
let v = reshape_heads(&v, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
// QK normalization (per-head RMSNorm)
|
||||
let q = head_rmsnorm(&q, &layer.q_norm, eps);
|
||||
let k = head_rmsnorm(&k, &layer.k_norm, eps);
|
||||
|
||||
// RoPE — kernel expects [S, H, D], our tensors are [1, H, S, D]
|
||||
// Transpose to [1, S, H, D] → reshape to [S, H, D] for RoPE
|
||||
let q = transpose_for_rope(&q, new_tokens, num_heads, head_dim);
|
||||
let k = transpose_for_rope(&k, new_tokens, num_kv_heads, head_dim);
|
||||
rope_inplace(&q, &self.rope_cache, &positions);
|
||||
rope_inplace(&k, &self.rope_cache, &positions);
|
||||
// Transpose back to [1, H, S, D]
|
||||
let q = transpose_from_rope(&q, new_tokens, num_heads, head_dim);
|
||||
let k = transpose_from_rope(&k, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
// KV cache
|
||||
let k_cpu = k.to_device(Device::Cpu);
|
||||
let v_cpu = v.to_device(Device::Cpu);
|
||||
cache.append_kv_tensor(layer_idx, &k_cpu, &v_cpu, new_tokens);
|
||||
let (k_full, v_full) = cache.get_kv_tensors(layer_idx);
|
||||
|
||||
// GQA: repeat K/V
|
||||
let n_rep = num_heads / num_kv_heads;
|
||||
let k_full = repeat_kv(&k_full, n_rep);
|
||||
let v_full = repeat_kv(&v_full, n_rep);
|
||||
|
||||
// Attention
|
||||
let attn_out = attention(&q, &k_full, &v_full, true);
|
||||
let attn_merged = merge_heads_any(&attn_out, new_tokens, hidden);
|
||||
let attn_proj = linear_t(&attn_merged, &layer.o_proj_w);
|
||||
x = add_any(&residual, &attn_proj);
|
||||
|
||||
// SwiGLU FFN
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.post_norm, eps);
|
||||
let gate = linear_t(&normed, &layer.gate_proj_w);
|
||||
let up = linear_t(&normed, &layer.up_proj_w);
|
||||
let gate_activated = silu(&gate);
|
||||
let hidden_states = mul_any(&gate_activated, &up);
|
||||
let down = linear_t(&hidden_states, &layer.down_proj_w);
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
|
||||
let x = rmsnorm(&x, &self.norm, eps);
|
||||
linear_t(&x, &self.lm_head)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
fn linear_t(x: &Tensor, weight: &Tensor) -> Tensor {
|
||||
let w_t = weight.transpose(0, 1).contiguous();
|
||||
matmul(x, &w_t, GemmBackend::CuBlas)
|
||||
}
|
||||
|
||||
fn reshape_heads(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let hidden = num_heads * head_dim;
|
||||
let src = x_cpu.as_slice::<bf16>();
|
||||
let mut out = vec![bf16::ZERO; num_heads * seq_len * head_dim];
|
||||
for s in 0..seq_len {
|
||||
for h in 0..num_heads {
|
||||
let si = s * hidden + h * head_dim;
|
||||
let di = (h * seq_len + s) * head_dim;
|
||||
out[di..di + head_dim].copy_from_slice(&src[si..si + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[1, num_heads, seq_len, head_dim]).to_device(x.device())
|
||||
}
|
||||
|
||||
fn merge_heads_any(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor {
|
||||
let num_heads = x.shape()[1];
|
||||
let head_dim = x.shape()[3];
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let src = x_cpu.as_slice::<bf16>();
|
||||
let mut out = vec![bf16::ZERO; seq_len * hidden];
|
||||
for s in 0..seq_len {
|
||||
for h in 0..num_heads {
|
||||
let si = (h * seq_len + s) * head_dim;
|
||||
let di = s * hidden + h * head_dim;
|
||||
out[di..di + head_dim].copy_from_slice(&src[si..si + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, hidden]).to_device(x.device())
|
||||
}
|
||||
|
||||
/// Per-head RMSNorm: apply RMSNorm to each [head_dim] slice independently.
|
||||
/// x: [1, H, S, D], norm_weight: [D]
|
||||
fn head_rmsnorm(x: &Tensor, norm_weight: &Tensor, eps: f32) -> Tensor {
|
||||
let num_heads = x.shape()[1];
|
||||
let seq_len = x.shape()[2];
|
||||
let head_dim = x.shape()[3];
|
||||
// Reshape to [H*S, D], apply rmsnorm, reshape back
|
||||
let total_rows = num_heads * seq_len;
|
||||
let flat = x.reshape(&[total_rows, head_dim]);
|
||||
let normed = rmsnorm(&flat, norm_weight, eps);
|
||||
normed.reshape(&[1, num_heads, seq_len, head_dim])
|
||||
}
|
||||
|
||||
/// [1, H, S, D] → [S, H, D] for RoPE kernel
|
||||
fn transpose_for_rope(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let src = x_cpu.as_slice::<bf16>();
|
||||
let mut out = vec![bf16::ZERO; seq_len * num_heads * head_dim];
|
||||
for h in 0..num_heads {
|
||||
for s in 0..seq_len {
|
||||
let si = (h * seq_len + s) * head_dim;
|
||||
let di = (s * num_heads + h) * head_dim;
|
||||
out[di..di + head_dim].copy_from_slice(&src[si..si + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[seq_len, num_heads, head_dim]).to_device(x.device())
|
||||
}
|
||||
|
||||
/// [S, H, D] → [1, H, S, D] after RoPE
|
||||
fn transpose_from_rope(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let src = x_cpu.as_slice::<bf16>();
|
||||
let mut out = vec![bf16::ZERO; num_heads * seq_len * head_dim];
|
||||
for s in 0..seq_len {
|
||||
for h in 0..num_heads {
|
||||
let si = (s * num_heads + h) * head_dim;
|
||||
let di = (h * seq_len + s) * head_dim;
|
||||
out[di..di + head_dim].copy_from_slice(&src[si..si + head_dim]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[1, num_heads, seq_len, head_dim]).to_device(x.device())
|
||||
}
|
||||
|
||||
fn repeat_kv(x: &Tensor, n_rep: usize) -> Tensor {
|
||||
if n_rep == 1 { return x.clone(); }
|
||||
let kv_heads = x.shape()[1];
|
||||
let seq_len = x.shape()[2];
|
||||
let head_dim = x.shape()[3];
|
||||
let x_cpu = x.to_device(Device::Cpu);
|
||||
let src = x_cpu.as_slice::<bf16>();
|
||||
let new_heads = kv_heads * n_rep;
|
||||
let mut out = vec![bf16::ZERO; new_heads * seq_len * head_dim];
|
||||
let chunk = seq_len * head_dim;
|
||||
for kv_h in 0..kv_heads {
|
||||
for r in 0..n_rep {
|
||||
let dst_h = kv_h * n_rep + r;
|
||||
out[dst_h * chunk..(dst_h + 1) * chunk]
|
||||
.copy_from_slice(&src[kv_h * chunk..(kv_h + 1) * chunk]);
|
||||
}
|
||||
}
|
||||
Tensor::from_slice(&out, &[1, new_heads, seq_len, head_dim]).to_device(x.device())
|
||||
}
|
||||
|
||||
fn add_any(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.shape(), b.shape());
|
||||
let a_cpu = a.to_device(Device::Cpu);
|
||||
let b_cpu = b.to_device(Device::Cpu);
|
||||
let ad = a_cpu.as_slice::<bf16>();
|
||||
let bd = b_cpu.as_slice::<bf16>();
|
||||
let r: Vec<bf16> = ad.iter().zip(bd)
|
||||
.map(|(x, y)| bf16::from_f32(x.to_f32() + y.to_f32()))
|
||||
.collect();
|
||||
Tensor::from_slice(&r, a.shape()).to_device(a.device())
|
||||
}
|
||||
|
||||
fn mul_any(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.shape(), b.shape());
|
||||
let a_cpu = a.to_device(Device::Cpu);
|
||||
let b_cpu = b.to_device(Device::Cpu);
|
||||
let ad = a_cpu.as_slice::<bf16>();
|
||||
let bd = b_cpu.as_slice::<bf16>();
|
||||
let r: Vec<bf16> = ad.iter().zip(bd)
|
||||
.map(|(x, y)| bf16::from_f32(x.to_f32() * y.to_f32()))
|
||||
.collect();
|
||||
Tensor::from_slice(&r, a.shape()).to_device(a.device())
|
||||
}
|
||||
|
||||
pub fn sample_greedy(logits: &Tensor) -> u32 {
|
||||
assert_eq!(logits.ndim(), 2);
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
let vocab_size = logits.shape()[1];
|
||||
let seq_len = logits.shape()[0];
|
||||
let data = logits_cpu.as_slice::<bf16>();
|
||||
let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size];
|
||||
last.iter().enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(i, _)| i as u32).unwrap()
|
||||
}
|
||||
@@ -8,9 +8,11 @@ pub struct Tokenizer {
|
||||
decoder: Vec<Vec<u8>>,
|
||||
merge_ranks: HashMap<(u32, u32), usize>,
|
||||
special_tokens: HashMap<String, u32>,
|
||||
#[allow(dead_code)]
|
||||
special_token_ids: HashMap<u32, String>,
|
||||
pre_tokenize_re: Regex,
|
||||
eos_token_id: Option<u32>,
|
||||
byte_fallback: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -23,7 +25,16 @@ struct TokenizerJson {
|
||||
#[derive(Deserialize)]
|
||||
struct ModelSection {
|
||||
vocab: HashMap<String, u32>,
|
||||
merges: Vec<String>,
|
||||
merges: Vec<MergeEntry>,
|
||||
#[serde(default)]
|
||||
byte_fallback: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum MergeEntry {
|
||||
Str(String),
|
||||
Pair(Vec<String>),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -40,7 +51,10 @@ impl Tokenizer {
|
||||
let tj: TokenizerJson = serde_json::from_str(&data)
|
||||
.unwrap_or_else(|e| panic!("failed to parse tokenizer.json: {e}"));
|
||||
|
||||
let byte_fallback = tj.model.byte_fallback;
|
||||
|
||||
// Build encoder: token bytes → ID
|
||||
// All HF tokenizers use GPT-2 byte-to-unicode mapping for vocab keys.
|
||||
let mut encoder = HashMap::new();
|
||||
for (token_str, &id) in &tj.model.vocab {
|
||||
let bytes = token_str_to_bytes(token_str);
|
||||
@@ -56,13 +70,23 @@ impl Tokenizer {
|
||||
decoder[id as usize] = token_str_to_bytes(token_str);
|
||||
}
|
||||
|
||||
// Parse merges
|
||||
// Parse merges (supports both "a b" string format and ["a", "b"] array format)
|
||||
let byte_fallback = tj.model.byte_fallback;
|
||||
let mut merge_ranks = HashMap::new();
|
||||
for (rank, merge_line) in tj.model.merges.iter().enumerate() {
|
||||
let parts: Vec<&str> = merge_line.splitn(2, ' ').collect();
|
||||
for (rank, entry) in tj.model.merges.iter().enumerate() {
|
||||
let (a_str, b_str) = match entry {
|
||||
MergeEntry::Str(s) => {
|
||||
let parts: Vec<&str> = s.splitn(2, ' ').collect();
|
||||
if parts.len() != 2 { continue; }
|
||||
let a_bytes = token_str_to_bytes(parts[0]);
|
||||
let b_bytes = token_str_to_bytes(parts[1]);
|
||||
(parts[0].to_string(), parts[1].to_string())
|
||||
}
|
||||
MergeEntry::Pair(v) => {
|
||||
if v.len() != 2 { continue; }
|
||||
(v[0].clone(), v[1].clone())
|
||||
}
|
||||
};
|
||||
let a_bytes = token_str_to_bytes(&a_str);
|
||||
let b_bytes = token_str_to_bytes(&b_str);
|
||||
if let (Some(&a_id), Some(&b_id)) = (encoder.get(&a_bytes), encoder.get(&b_bytes)) {
|
||||
merge_ranks.insert((a_id, b_id), rank);
|
||||
}
|
||||
@@ -84,13 +108,14 @@ impl Tokenizer {
|
||||
}
|
||||
}
|
||||
|
||||
// GPT-2 pre-tokenization regex.
|
||||
// The original uses (?!\S) lookahead which Rust regex doesn't support.
|
||||
// Simplified: collapse trailing whitespace into one match. Functionally equivalent
|
||||
// for BPE since each whitespace chunk gets encoded independently anyway.
|
||||
let pre_tokenize_re = Regex::new(
|
||||
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+"
|
||||
).unwrap();
|
||||
// Pre-tokenization regex
|
||||
let pre_tokenize_re = if byte_fallback {
|
||||
// Qwen-style: split on whitespace boundaries, keep Unicode words/numbers
|
||||
Regex::new(r"[\p{L}\p{N}]+|[^\s\p{L}\p{N}]|\s+").unwrap()
|
||||
} else {
|
||||
// GPT-2 style
|
||||
Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+").unwrap()
|
||||
};
|
||||
|
||||
Self {
|
||||
encoder,
|
||||
@@ -100,6 +125,7 @@ impl Tokenizer {
|
||||
special_token_ids,
|
||||
pre_tokenize_re,
|
||||
eos_token_id,
|
||||
byte_fallback,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,10 +163,16 @@ impl Tokenizer {
|
||||
fn encode_ordinary(&self, text: &str, out: &mut Vec<u32>) {
|
||||
for mat in self.pre_tokenize_re.find_iter(text) {
|
||||
let word = mat.as_str();
|
||||
// Try to encode the whole word first
|
||||
if let Some(&id) = self.encoder.get(word.as_bytes()) {
|
||||
out.push(id);
|
||||
continue;
|
||||
}
|
||||
// Fall back to per-byte encoding
|
||||
let word_bytes: Vec<u8> = word.bytes().collect();
|
||||
let mut token_ids: Vec<u32> = word_bytes.iter().map(|&b| {
|
||||
*self.encoder.get(&vec![b]).unwrap_or_else(|| {
|
||||
panic!("byte {b} not in vocab")
|
||||
panic!("byte {b} (0x{b:02X}) not in vocab")
|
||||
})
|
||||
}).collect();
|
||||
|
||||
@@ -204,48 +236,32 @@ fn token_str_to_bytes(s: &str) -> Vec<u8> {
|
||||
s.chars().map(|c| unicode_to_byte(c)).collect()
|
||||
}
|
||||
|
||||
/// Convert a Unicode char back to the byte it represents in GPT-2 encoding.
|
||||
fn unicode_to_byte(c: char) -> u8 {
|
||||
let u = c as u32;
|
||||
// GPT-2 byte encoder: maps bytes 0-255 to specific Unicode code points.
|
||||
// Printable ASCII bytes map to themselves. Others are shifted to 256+.
|
||||
match u {
|
||||
0x21..=0x7E => u as u8, // '!' to '~'
|
||||
0xA1..=0xAC => u as u8, // '¡' to '¬'
|
||||
0xAE..=0xFF => u as u8, // '®' to 'ÿ'
|
||||
// Shifted bytes: 0x100 + original_byte for bytes not in the above ranges
|
||||
0x100..=0x1FF => (u - 0x100) as u8 + {
|
||||
// The shift mapping: byte values 0..=32, 127..=160, 173
|
||||
// are shifted to 256..=288, 289+, etc.
|
||||
0
|
||||
},
|
||||
// Build the inverse map on first use
|
||||
use std::sync::OnceLock;
|
||||
static INV_MAP: OnceLock<HashMap<u32, u8>> = OnceLock::new();
|
||||
|
||||
let map = INV_MAP.get_or_init(|| {
|
||||
let mut m = HashMap::new();
|
||||
// Build GPT-2's bytes_to_unicode forward map, then invert
|
||||
let mut n = 0u32;
|
||||
for b in 0..=255u16 {
|
||||
let byte = b as u8;
|
||||
let unicode = match byte {
|
||||
0x21..=0x7E | 0xA1..=0xAC | 0xAE..=0xFF => byte as u32,
|
||||
_ => {
|
||||
// Fallback: for the GPT-2 byte encoder, specific mappings
|
||||
byte_from_unicode_gpt2(c)
|
||||
let u = 256 + n;
|
||||
n += 1;
|
||||
u
|
||||
}
|
||||
};
|
||||
m.insert(unicode, byte);
|
||||
}
|
||||
}
|
||||
|
||||
fn byte_from_unicode_gpt2(c: char) -> u8 {
|
||||
// Build the inverse of GPT-2's bytes_to_unicode mapping.
|
||||
// The mapping assigns printable chars to themselves and shifts unprintable bytes.
|
||||
let u = c as u32;
|
||||
// Direct ASCII printable + Latin-1 supplement printable ranges map identity
|
||||
if (0x21..=0x7E).contains(&u) { return u as u8; }
|
||||
if (0xA1..=0xAC).contains(&u) { return u as u8; }
|
||||
if (0xAE..=0xFF).contains(&u) { return u as u8; }
|
||||
|
||||
// Shifted range: the remaining 68 bytes (0-32, 127-160, 173) get mapped to 256..=323
|
||||
static SHIFTED_BYTES: &[u8] = &[
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
|
||||
24, 25, 26, 27, 28, 29, 30, 31, 32, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136,
|
||||
137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
|
||||
154, 155, 156, 157, 158, 159, 160, 173,
|
||||
];
|
||||
let shifted_start = 256u32;
|
||||
if u >= shifted_start && u < shifted_start + SHIFTED_BYTES.len() as u32 {
|
||||
return SHIFTED_BYTES[(u - shifted_start) as usize];
|
||||
}
|
||||
|
||||
// Shouldn't reach here for valid GPT-2 tokenizer
|
||||
c as u8
|
||||
m
|
||||
});
|
||||
|
||||
*map.get(&(c as u32)).unwrap_or_else(|| {
|
||||
panic!("unmapped unicode char U+{:04X} in tokenizer", c as u32)
|
||||
})
|
||||
}
|
||||
|
||||
109
docs/10-qwen3.md
Normal file
109
docs/10-qwen3.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# Phase 10: Qwen3-7B Support — Design Document (Milestone ②)
|
||||
|
||||
## Goal
|
||||
|
||||
扩展模型定义支持 Qwen3-7B 架构,验证输出正确性。与 GPT-2 的关键差异:RMSNorm、RoPE、GQA、SwiGLU、不共享 embedding。
|
||||
|
||||
## 架构差异 (GPT-2 → Qwen3)
|
||||
|
||||
| 特性 | GPT-2 | Qwen3-7B |
|
||||
|------|-------|----------|
|
||||
| Norm | LayerNorm(gamma, beta) | RMSNorm(gamma only) |
|
||||
| Position | Learned absolute (wpe) | RoPE (no params) |
|
||||
| Attention | MHA (12 Q = 12 KV heads) | GQA (32 Q, 8 KV heads) |
|
||||
| QKV projection | Combined c_attn [H, 3H] | Separate q/k/v_proj [H, Hq/Hk/Hv] |
|
||||
| FFN | 2 Linear (fc, proj) + GELU | 3 Linear (gate, up, down) + SwiGLU |
|
||||
| Weight layout | [in, out] (Conv1D style) | [out, in] (standard Linear) |
|
||||
| Tied embeddings | Yes | No (separate lm_head) |
|
||||
| hidden_size | 768 | 3584 |
|
||||
| num_layers | 12 | 28 |
|
||||
| head_dim | 64 | 128 |
|
||||
|
||||
## Weight Names (HuggingFace)
|
||||
|
||||
```
|
||||
model.embed_tokens.weight [151936, 3584]
|
||||
model.layers.{i}.input_layernorm.weight [3584]
|
||||
model.layers.{i}.self_attn.q_proj.weight [3584, 3584] (32 heads × 112 dim? or 28 heads)
|
||||
model.layers.{i}.self_attn.q_proj.bias [3584]
|
||||
model.layers.{i}.self_attn.k_proj.weight [512, 3584] (4 KV heads × 128 dim)
|
||||
model.layers.{i}.self_attn.k_proj.bias [512]
|
||||
model.layers.{i}.self_attn.v_proj.weight [512, 3584]
|
||||
model.layers.{i}.self_attn.v_proj.bias [512]
|
||||
model.layers.{i}.self_attn.o_proj.weight [3584, 3584]
|
||||
model.layers.{i}.post_attention_layernorm.weight [3584]
|
||||
model.layers.{i}.mlp.gate_proj.weight [18944, 3584]
|
||||
model.layers.{i}.mlp.up_proj.weight [18944, 3584]
|
||||
model.layers.{i}.mlp.down_proj.weight [3584, 18944]
|
||||
model.norm.weight [3584]
|
||||
lm_head.weight [151936, 3584]
|
||||
```
|
||||
|
||||
**注意**: Qwen3 权重是 [out, in] layout,`x @ W^T` 而不是 `x @ W`。
|
||||
|
||||
## GQA (Grouped Query Attention)
|
||||
|
||||
```
|
||||
num_heads = 28, num_kv_heads = 4, head_dim = 128
|
||||
Q: [B, 28, S, 128]
|
||||
K: [B, 4, S, 128] ← 每个 KV head 服务 28/4 = 7 个 Q head
|
||||
V: [B, 4, S, 128]
|
||||
|
||||
attention 时需要 repeat K/V:
|
||||
K_expanded: [B, 28, S, 128] ← repeat_interleave(K, 7, dim=1)
|
||||
```
|
||||
|
||||
实现:在 CPU 侧 split_qkv 时直接做 repeat。
|
||||
|
||||
## SwiGLU FFN
|
||||
|
||||
```
|
||||
gate = gate_proj(x) # [S, 3584] @ [3584, 18944]^T → [S, 18944]
|
||||
up = up_proj(x) # [S, 3584] @ [3584, 18944]^T → [S, 18944]
|
||||
out = silu(gate) * up # element-wise
|
||||
out = down_proj(out) # [S, 18944] @ [18944, 3584]^T → [S, 3584]
|
||||
```
|
||||
|
||||
## 显存预算 (BF16, 单卡 5090)
|
||||
|
||||
```
|
||||
权重: 7B × 2B = ~14 GB (BF16)
|
||||
7B × 4B = ~28 GB (FP32) — 不够! 必须用 BF16
|
||||
KV cache (S=256, B=1): ~0.1 GB
|
||||
总计: ~14 GB (BF16), 单卡可运行
|
||||
```
|
||||
|
||||
**关键**: Qwen3-7B 必须用 BF16 才能在单张 5090 (32GB) 上运行。当前 GPT-2 用 FP32,需要支持 BF16 forward pass。
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
1. 下载 Qwen3-7B 模型 (BF16, ~14GB)
|
||||
2. 实现 Qwen3 模型结构 (qwen3.rs)
|
||||
3. 支持 BF16 forward pass (linear_transpose for [out, in] weights)
|
||||
4. 实现 GQA (K/V repeat in split)
|
||||
5. 集成 RoPE + RMSNorm + SwiGLU
|
||||
6. 验证输出
|
||||
|
||||
## Test Plan
|
||||
|
||||
- [x] 加载 Qwen3-8B BF16 权重 (399 tensors, ~15.5GB) 到单张 5090
|
||||
- [x] 英文: "The meaning of life is" → "to be happy"
|
||||
- [x] 中文: "请用中文回答:1+1等于几?" → "1加1"
|
||||
- [x] 61/61 单元测试无回归
|
||||
- [x] GPT-2 benchmark 性能无回归
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **Qwen3 实际是 8B,不是 7B**:modelscope 上的 `Qwen/Qwen3-8B` 有 36 层 × hidden 4096 × 32 heads,参数量约 8B。BF16 权重 ~15.5GB,单张 5090 (32GB) 可以运行。
|
||||
|
||||
2. **QK Normalization 是 Qwen3 的新特性**:每层有 `q_norm` 和 `k_norm` (shape [head_dim]),对 Q 和 K 做 per-head RMSNorm。这在 attention score 的数值稳定性上很重要——没有 QK norm 会导致 attention score 爆炸。
|
||||
|
||||
3. **attention_bias=false**:Qwen3 的 Q/K/V/O projection 没有 bias。这和 GPT-2 (有 bias) 不同。需要在模型代码中条件处理。
|
||||
|
||||
4. **Tokenizer 的 byte-to-unicode 映射 bug**:GPT-2 和 Qwen3 都使用同一套 byte-to-unicode 映射(printable ASCII identity,其余 68 bytes shifted to U+0100+)。初始实现中 `unicode_to_byte` 的 shifted 范围转换错误(直接 `u - 0x100` 而非查表),导致中文输入时 UTF-8 bytes 无法正确映射。修复:用 `OnceLock` 缓存反向映射表。
|
||||
|
||||
5. **Weight layout [out, in] vs [in, out]**:GPT-2 的 Conv1D 存为 [in, out],计算 `x @ W`;Qwen3 的 Linear 存为 [out, in],计算 `x @ W^T`。`linear_t` 函数通过 `weight.transpose(0,1).contiguous()` 处理。
|
||||
|
||||
6. **RoPE 的 tensor layout 不匹配**:RoPE kernel 期望 [S, H, D],但 attention 需要 [1, H, S, D]。需要在 RoPE 前后做 transpose。这引入了额外的 CPU round-trip(因为 transpose+contiguous 经过 CPU)。
|
||||
|
||||
7. **GQA repeat_kv 的实现**:每个 KV head 服务 `num_heads/num_kv_heads` 个 Q head。在 CPU 上做数据复制(repeat),简单但每步 decode 都要做。后续应在 attention kernel 中直接支持 GQA 索引,避免数据复制。
|
||||
Reference in New Issue
Block a user