model: fuse QKV/gate_up projections, batched decode ops
Weight fusion at load time: - q/k/v_proj → single qkv_proj_wt, GEMV once then narrow() to split - gate/up_proj → single gate_up_proj_wt, same pattern - Reduces GEMV calls from 7 to 4 per layer (36 layers → 108 fewer launches) Batched decode refactor (forward_decode_paged): - Per-head RMSNorm: reshape to [B*H, D], one rmsnorm call - Batched RoPE: one call for all sequences - Batched KV scatter: one reshape_and_cache kernel per layer - Eliminates the per-sequence loop entirely Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -24,18 +24,31 @@ pub struct Qwen3 {
|
||||
|
||||
struct Qwen3Block {
|
||||
input_norm: Tensor, // [hidden]
|
||||
q_proj_wt: Tensor, // TRANSPOSED: [hidden, num_heads*head_dim]
|
||||
k_proj_wt: Tensor, // TRANSPOSED: [hidden, num_kv_heads*head_dim]
|
||||
v_proj_wt: Tensor,
|
||||
qkv_proj_wt: Tensor, // FUSED: [hidden, (H+2*KV)*D] — Q|K|V columns
|
||||
q_dim: usize, // num_heads * head_dim (Q slice boundary)
|
||||
kv_dim: usize, // num_kv_heads * head_dim (K/V slice size)
|
||||
o_proj_wt: Tensor, // TRANSPOSED: [num_heads*head_dim, hidden]
|
||||
q_norm: Tensor, // [head_dim]
|
||||
k_norm: Tensor, // [head_dim]
|
||||
post_norm: Tensor, // [hidden]
|
||||
gate_proj_wt: Tensor, // TRANSPOSED: [hidden, intermediate]
|
||||
up_proj_wt: Tensor,
|
||||
gate_up_proj_wt: Tensor, // FUSED: [hidden, 2*intermediate]
|
||||
down_proj_wt: Tensor, // TRANSPOSED: [intermediate, hidden]
|
||||
}
|
||||
|
||||
impl Qwen3Block {
|
||||
fn q_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, 0, self.q_dim) }
|
||||
fn k_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, self.q_dim, self.kv_dim) }
|
||||
fn v_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, self.q_dim + self.kv_dim, self.kv_dim) }
|
||||
fn gate_proj_wt(&self) -> Tensor {
|
||||
let half = self.gate_up_proj_wt.shape()[1] / 2;
|
||||
self.gate_up_proj_wt.narrow(1, 0, half)
|
||||
}
|
||||
fn up_proj_wt(&self) -> Tensor {
|
||||
let half = self.gate_up_proj_wt.shape()[1] / 2;
|
||||
self.gate_up_proj_wt.narrow(1, half, half)
|
||||
}
|
||||
}
|
||||
|
||||
impl Qwen3 {
|
||||
/// Single-GPU load (weights already on the target GPU). Equivalent to
|
||||
/// `from_weights_tp(.., rank=0, world=1, device=0, tp=None)`.
|
||||
@@ -88,17 +101,28 @@ impl Qwen3 {
|
||||
}
|
||||
for i in 0..num_layers {
|
||||
let p = format!("model.layers.{i}");
|
||||
let q_proj_wt = col(take(&mut w, &format!("{p}.self_attn.q_proj.weight")));
|
||||
let k_proj_wt = col(take(&mut w, &format!("{p}.self_attn.k_proj.weight")));
|
||||
let v_proj_wt = col(take(&mut w, &format!("{p}.self_attn.v_proj.weight")));
|
||||
let q_dim = q_proj_wt.shape()[1];
|
||||
let kv_dim = k_proj_wt.shape()[1];
|
||||
let qkv_proj_wt = cat_cols(&[&q_proj_wt, &k_proj_wt, &v_proj_wt]);
|
||||
drop((q_proj_wt, k_proj_wt, v_proj_wt));
|
||||
let gate_proj_wt = col(take(&mut w, &format!("{p}.mlp.gate_proj.weight")));
|
||||
let up_proj_wt = col(take(&mut w, &format!("{p}.mlp.up_proj.weight")));
|
||||
let gate_up_proj_wt = cat_cols(&[&gate_proj_wt, &up_proj_wt]);
|
||||
drop((gate_proj_wt, up_proj_wt));
|
||||
xserv_cuda::allocator::cached_trim();
|
||||
layers.push(Qwen3Block {
|
||||
input_norm: repl(take(&mut w, &format!("{p}.input_layernorm.weight"))),
|
||||
q_proj_wt: col(take(&mut w, &format!("{p}.self_attn.q_proj.weight"))),
|
||||
k_proj_wt: col(take(&mut w, &format!("{p}.self_attn.k_proj.weight"))),
|
||||
v_proj_wt: col(take(&mut w, &format!("{p}.self_attn.v_proj.weight"))),
|
||||
qkv_proj_wt,
|
||||
q_dim,
|
||||
kv_dim,
|
||||
o_proj_wt: row(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
|
||||
q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))),
|
||||
k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))),
|
||||
post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))),
|
||||
gate_proj_wt: col(take(&mut w, &format!("{p}.mlp.gate_proj.weight"))),
|
||||
up_proj_wt: col(take(&mut w, &format!("{p}.mlp.up_proj.weight"))),
|
||||
gate_up_proj_wt,
|
||||
down_proj_wt: row(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
|
||||
});
|
||||
}
|
||||
@@ -144,52 +168,45 @@ impl Qwen3 {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
// Q/K/V projections (pre-transposed weights, x @ wt)
|
||||
let q = matmul_2d(&normed, &layer.q_proj_wt);
|
||||
let k = matmul_2d(&normed, &layer.k_proj_wt);
|
||||
let v = matmul_2d(&normed, &layer.v_proj_wt);
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
|
||||
// Reshape to [1, heads, seq, head_dim]
|
||||
let q = reshape_heads(&q, new_tokens, num_heads, head_dim);
|
||||
let k = reshape_heads(&k, new_tokens, num_kv_heads, head_dim);
|
||||
let v = reshape_heads(&v, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
// QK normalization (per-head RMSNorm)
|
||||
let q = head_rmsnorm(&q, &layer.q_norm, eps);
|
||||
let k = head_rmsnorm(&k, &layer.k_norm, eps);
|
||||
|
||||
// RoPE — kernel expects [S, H, D], our tensors are [1, H, S, D]
|
||||
// Transpose to [1, S, H, D] → reshape to [S, H, D] for RoPE
|
||||
let q = transpose_for_rope(&q, new_tokens, num_heads, head_dim);
|
||||
let k = transpose_for_rope(&k, new_tokens, num_kv_heads, head_dim);
|
||||
rope_inplace(&q, &self.rope_cache, &positions);
|
||||
rope_inplace(&k, &self.rope_cache, &positions);
|
||||
// Transpose back to [1, H, S, D]
|
||||
let q = transpose_from_rope(&q, new_tokens, num_heads, head_dim);
|
||||
let k = transpose_from_rope(&k, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
// KV cache
|
||||
let k_cpu = k.to_device(Device::Cpu);
|
||||
let v_cpu = v.to_device(Device::Cpu);
|
||||
cache.append_kv_tensor(layer_idx, &k_cpu, &v_cpu, new_tokens);
|
||||
let (k_full, v_full) = cache.get_kv_tensors(layer_idx);
|
||||
|
||||
// GQA: repeat K/V
|
||||
let n_rep = num_heads / num_kv_heads;
|
||||
let k_full = repeat_kv(&k_full, n_rep);
|
||||
let v_full = repeat_kv(&v_full, n_rep);
|
||||
|
||||
// Attention
|
||||
let attn_out = attention(&q, &k_full, &v_full, true);
|
||||
let attn_merged = merge_heads_any(&attn_out, new_tokens, hidden);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
x = add_any(&residual, &attn_proj);
|
||||
|
||||
// SwiGLU FFN
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.post_norm, eps);
|
||||
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
|
||||
let up = matmul_2d(&normed, &layer.up_proj_wt);
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
let ffn_dim = gate_up.shape()[1] / 2;
|
||||
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
||||
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
||||
let gate_activated = silu(&gate);
|
||||
let hidden_states = mul_any(&gate_activated, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
@@ -232,10 +249,10 @@ impl Qwen3 {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps); // [B, hidden]
|
||||
|
||||
// Batched projections: [B, hidden] × [hidden, X] = [B, X]
|
||||
let q_all = matmul_2d(&normed, &layer.q_proj_wt); // [B, num_heads*head_dim]
|
||||
let k_all = matmul_2d(&normed, &layer.k_proj_wt); // [B, num_kv_heads*head_dim]
|
||||
let v_all = matmul_2d(&normed, &layer.v_proj_wt); // [B, num_kv_heads*head_dim]
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q_all = qkv.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k_all = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v_all = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
|
||||
// Per-sequence: reshape, qk-norm, RoPE, KV cache, attention, merge
|
||||
let mut attn_outputs: Vec<Tensor> = Vec::with_capacity(batch);
|
||||
@@ -290,9 +307,10 @@ impl Qwen3 {
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
// Batched FFN: all projections on [B, hidden]
|
||||
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
|
||||
let up = matmul_2d(&normed, &layer.up_proj_wt);
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
let ffn_dim = gate_up.shape()[1] / 2;
|
||||
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
||||
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
x = add_any(&residual, &down);
|
||||
@@ -312,6 +330,13 @@ impl Qwen3 {
|
||||
/// tokens: [B] one token per sequence
|
||||
/// positions: [B] current logical position (BEFORE this step) per sequence
|
||||
/// seq_slots: [B] slot ids in `paged_cache`
|
||||
///
|
||||
/// Layout note: for S=1 decode the memory of `[B, H, 1, D]`,
|
||||
/// `[B, H, D]`, and `[B, H*D]` is the same — only shape/strides differ.
|
||||
/// We exploit this to drop every per-sequence kernel: head_rmsnorm and
|
||||
/// RoPE both natively accept the `[B*H, D]` / `[B, H, D]` layouts that
|
||||
/// fall out of the projection matmuls, and the new-token KV scatter is
|
||||
/// one batched `reshape_and_cache` kernel.
|
||||
pub fn forward_decode_paged(
|
||||
&self,
|
||||
tokens: &[u32],
|
||||
@@ -343,6 +368,10 @@ impl Qwen3 {
|
||||
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
|
||||
let max_blocks = paged_cache.max_blocks_per_seq();
|
||||
|
||||
// RoPE expects `[num_tokens, H, D]` with `num_tokens` positions —
|
||||
// matches our `[B, H, D]` exactly, so we upload once here.
|
||||
let positions_u32: Vec<u32> = positions.iter().map(|&p| p as u32).collect();
|
||||
|
||||
// Batched embedding: [B, hidden]
|
||||
let mut x = embedding(&self.embed_tokens, tokens);
|
||||
|
||||
@@ -350,62 +379,41 @@ impl Qwen3 {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
let q_all = matmul_2d(&normed, &layer.q_proj_wt);
|
||||
let k_all = matmul_2d(&normed, &layer.k_proj_wt);
|
||||
let v_all = matmul_2d(&normed, &layer.v_proj_wt);
|
||||
// Fused QKV projection: one GEMV instead of three.
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); // [B, (H+2*KV)*D]
|
||||
let q_dim = num_heads * head_dim;
|
||||
let kv_dim = num_kv_heads * head_dim;
|
||||
let q_all = qkv.narrow(1, 0, q_dim); // [B, H*D] (view)
|
||||
let k_all = qkv.narrow(1, q_dim, kv_dim); // [B, KV*D] (view)
|
||||
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
|
||||
|
||||
let mut q_rows: Vec<Tensor> = Vec::with_capacity(batch);
|
||||
for b in 0..batch {
|
||||
let q_row = row_view(&q_all, b);
|
||||
let k_row = row_view(&k_all, b);
|
||||
let v_row = row_view(&v_all, b);
|
||||
// Per-head RMSNorm on contiguous copies (narrow views are strided).
|
||||
let q_flat = q_all.contiguous().reshape(&[batch * num_heads, head_dim]);
|
||||
let k_flat = k_all.contiguous().reshape(&[batch * num_kv_heads, head_dim]);
|
||||
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
|
||||
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
|
||||
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q_row, 1, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k_row, 1, num_kv_heads, head_dim);
|
||||
let v = xserv_kernels::reshape_heads_gpu(&v_row, 1, num_kv_heads, head_dim);
|
||||
let q_3d = q_normed.reshape(&[batch, num_heads, head_dim]);
|
||||
let k_3d = k_normed.reshape(&[batch, num_kv_heads, head_dim]);
|
||||
rope_inplace(&q_3d, &self.rope_cache, &positions_u32);
|
||||
rope_inplace(&k_3d, &self.rope_cache, &positions_u32);
|
||||
|
||||
let q = head_rmsnorm(&q, &layer.q_norm, eps);
|
||||
let k = head_rmsnorm(&k, &layer.k_norm, eps);
|
||||
let v_3d = v_all.contiguous().reshape(&[batch, num_kv_heads, head_dim]);
|
||||
|
||||
let q = xserv_kernels::transpose_for_rope_gpu(&q, 1, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_for_rope_gpu(&k, 1, num_kv_heads, head_dim);
|
||||
|
||||
let pos = [positions[b] as u32];
|
||||
rope_inplace(&q, &self.rope_cache, &pos);
|
||||
rope_inplace(&k, &self.rope_cache, &pos);
|
||||
|
||||
let q = xserv_kernels::transpose_from_rope_gpu(&q, 1, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_from_rope_gpu(&k, 1, num_kv_heads, head_dim);
|
||||
|
||||
paged_cache.append_tokens(seq_slots[b], layer_idx, &k, &v, 1, positions[b]);
|
||||
|
||||
let q_flat = xserv_kernels::merge_heads_gpu(&q, 1, num_heads, head_dim);
|
||||
q_rows.push(q_flat);
|
||||
}
|
||||
|
||||
let q_batched_2d = concat_rows(&q_rows);
|
||||
// q_batched_2d: [B, num_heads * head_dim]. Memory is [B, H, D] —
|
||||
// a plain reshape view to [B, H, 1, D] is what the paged kernel expects.
|
||||
let q_4d = q_batched_2d.reshape(&[batch, num_heads, 1, head_dim]);
|
||||
// Single batched scatter for all sequences in the batch.
|
||||
paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, batch);
|
||||
|
||||
// Paged attention reads Q as [B, H, 1, D] — a contiguous view
|
||||
// of [B, H, D].
|
||||
let q_4d = q_3d.reshape(&[batch, num_heads, 1, head_dim]);
|
||||
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
|
||||
let attn_out = xserv_kernels::paged_decode_attention(
|
||||
&q_4d,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
batch,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_blocks,
|
||||
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
|
||||
batch, num_heads, num_kv_heads, head_dim, max_blocks,
|
||||
);
|
||||
|
||||
// attn_out shape [B, H, 1, D] is contiguous-equivalent to [B, H*D].
|
||||
// Plain reshape is a view; merge_heads_gpu would incorrectly swap B<->H.
|
||||
let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj); // TP: sum partial attention outputs
|
||||
@@ -413,8 +421,11 @@ impl Qwen3 {
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
|
||||
let up = matmul_2d(&normed, &layer.up_proj_wt);
|
||||
// Fused gate+up projection: one GEMV instead of two.
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt); // [B, 2*ffn]
|
||||
let ffn_dim = gate_up.shape()[1] / 2;
|
||||
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
||||
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
self.all_reduce(&down); // TP: sum partial MLP outputs
|
||||
@@ -459,9 +470,10 @@ impl Qwen3 {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
let q = matmul_2d(&normed, &layer.q_proj_wt);
|
||||
let k = matmul_2d(&normed, &layer.k_proj_wt);
|
||||
let v = matmul_2d(&normed, &layer.v_proj_wt);
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
@@ -477,25 +489,25 @@ impl Qwen3 {
|
||||
let q = xserv_kernels::transpose_from_rope_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_from_rope_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
// Write into paged pool at the original (pre-advance) position.
|
||||
paged_cache.append_tokens(slot, layer_idx, &k, &v, new_tokens, pos_offset);
|
||||
|
||||
// Gather contiguous K/V for the full sequence (seq_len already includes new_tokens).
|
||||
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
|
||||
let attn_out = flash_attention(&q, &k_full, &v_full, true);
|
||||
|
||||
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj); // TP: sum partial attention outputs
|
||||
self.all_reduce(&attn_proj);
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
|
||||
let up = matmul_2d(&normed, &layer.up_proj_wt);
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
let ffn_dim = gate_up.shape()[1] / 2;
|
||||
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
||||
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
self.all_reduce(&down); // TP: sum partial MLP outputs
|
||||
self.all_reduce(&down);
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
|
||||
@@ -520,45 +532,39 @@ impl Qwen3 {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
let q = matmul_2d(&normed, &layer.q_proj_wt);
|
||||
let k = matmul_2d(&normed, &layer.k_proj_wt);
|
||||
let v = matmul_2d(&normed, &layer.v_proj_wt);
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
|
||||
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
|
||||
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
|
||||
|
||||
// GPU reshape: [S, H*D] → [1, H, S, D]
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
let v = xserv_kernels::reshape_heads_gpu(&v, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
// QK norm (reshape to [H*S, D], rmsnorm, reshape back — stays on GPU)
|
||||
let q = head_rmsnorm(&q, &layer.q_norm, eps);
|
||||
let k = head_rmsnorm(&k, &layer.k_norm, eps);
|
||||
|
||||
// GPU transpose for RoPE: [1, H, S, D] → [S, H, D]
|
||||
let q = xserv_kernels::transpose_for_rope_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_for_rope_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
rope_inplace(&q, &self.rope_cache, &positions);
|
||||
rope_inplace(&k, &self.rope_cache, &positions);
|
||||
// GPU transpose back: [S, H, D] → [1, H, S, D]
|
||||
let q = xserv_kernels::transpose_from_rope_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_from_rope_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
// GPU KV cache
|
||||
cache.append(layer_idx, &k, &v, new_tokens, pos_offset);
|
||||
let (k_full, v_full) = cache.get_kv_len(layer_idx, pos_offset + new_tokens);
|
||||
|
||||
// Flash Attention with native GQA (no repeat_kv needed)
|
||||
let attn_out = flash_attention(&q, &k_full, &v_full, true);
|
||||
// GPU merge_heads: [1, H, S, D] → [S, H*D]
|
||||
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
// Fused add + rmsnorm: (normed, x) where x = residual + attn_proj
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
// Fused SiLU×Mul
|
||||
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
|
||||
let up = matmul_2d(&normed, &layer.up_proj_wt);
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
let ffn_dim = gate_up.shape()[1] / 2;
|
||||
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
||||
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
x = add_any(&residual, &down);
|
||||
@@ -573,15 +579,15 @@ impl Qwen3 {
|
||||
pub fn layer_weight_ptrs(&self) -> Vec<crate::decode_graph::LayerWeightPtrs> {
|
||||
self.layers.iter().map(|l| crate::decode_graph::LayerWeightPtrs {
|
||||
input_norm: l.input_norm.data_ptr() as *const std::ffi::c_void,
|
||||
q_proj_wt: l.q_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
k_proj_wt: l.k_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
v_proj_wt: l.v_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
q_proj_wt: l.q_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
k_proj_wt: l.k_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
v_proj_wt: l.v_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
o_proj_wt: l.o_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
q_norm: l.q_norm.data_ptr() as *const std::ffi::c_void,
|
||||
k_norm: l.k_norm.data_ptr() as *const std::ffi::c_void,
|
||||
post_norm: l.post_norm.data_ptr() as *const std::ffi::c_void,
|
||||
gate_proj_wt: l.gate_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
up_proj_wt: l.up_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
gate_proj_wt: l.gate_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
up_proj_wt: l.up_proj_wt().data_ptr() as *const std::ffi::c_void,
|
||||
down_proj_wt: l.down_proj_wt.data_ptr() as *const std::ffi::c_void,
|
||||
}).collect()
|
||||
}
|
||||
@@ -790,6 +796,42 @@ fn concat_rows(rows: &[Tensor]) -> Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Concatenate 2D GPU tensors along dim=1 (columns). All must share dim 0.
|
||||
fn cat_cols(tensors: &[&Tensor]) -> Tensor {
|
||||
assert!(!tensors.is_empty());
|
||||
let rows = tensors[0].shape()[0];
|
||||
let dtype = tensors[0].dtype();
|
||||
let device = tensors[0].device();
|
||||
let elem = dtype.size_bytes();
|
||||
let total_cols: usize = tensors.iter().map(|t| {
|
||||
assert_eq!(t.ndim(), 2);
|
||||
assert_eq!(t.shape()[0], rows);
|
||||
assert!(t.is_contiguous());
|
||||
t.shape()[1]
|
||||
}).sum();
|
||||
let out = Tensor::empty(&[rows, total_cols], dtype, device);
|
||||
let dst_base = out.data_ptr() as *mut u8;
|
||||
for r in 0..rows {
|
||||
let mut col_off = 0usize;
|
||||
for t in tensors {
|
||||
let cols = t.shape()[1];
|
||||
let src = unsafe { t.data_ptr().add(r * cols * elem) };
|
||||
let dst = unsafe { dst_base.add((r * total_cols + col_off) * elem) };
|
||||
let count = cols * elem;
|
||||
unsafe {
|
||||
xserv_cuda::ffi::cudaMemcpy(
|
||||
dst as *mut u8,
|
||||
src as *const u8,
|
||||
count,
|
||||
2, // cudaMemcpyDeviceToDevice
|
||||
);
|
||||
}
|
||||
col_off += cols;
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn add_any(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
xserv_kernels::add(a, b)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user