model: fuse QKV/gate_up projections, batched decode ops

Weight fusion at load time:
- q/k/v_proj → single qkv_proj_wt, GEMV once then narrow() to split
- gate/up_proj → single gate_up_proj_wt, same pattern
- Reduces GEMV calls from 7 to 4 per layer (36 layers → 108 fewer launches)

Batched decode refactor (forward_decode_paged):
- Per-head RMSNorm: reshape to [B*H, D], one rmsnorm call
- Batched RoPE: one call for all sequences
- Batched KV scatter: one reshape_and_cache kernel per layer
- Eliminates the per-sequence loop entirely

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Gahow Wang
2026-05-30 12:50:39 +08:00
parent cc4bd4cfe5
commit c679f618fd

View File

@@ -24,18 +24,31 @@ pub struct Qwen3 {
struct Qwen3Block {
input_norm: Tensor, // [hidden]
q_proj_wt: Tensor, // TRANSPOSED: [hidden, num_heads*head_dim]
k_proj_wt: Tensor, // TRANSPOSED: [hidden, num_kv_heads*head_dim]
v_proj_wt: Tensor,
qkv_proj_wt: Tensor, // FUSED: [hidden, (H+2*KV)*D] — Q|K|V columns
q_dim: usize, // num_heads * head_dim (Q slice boundary)
kv_dim: usize, // num_kv_heads * head_dim (K/V slice size)
o_proj_wt: Tensor, // TRANSPOSED: [num_heads*head_dim, hidden]
q_norm: Tensor, // [head_dim]
k_norm: Tensor, // [head_dim]
post_norm: Tensor, // [hidden]
gate_proj_wt: Tensor, // TRANSPOSED: [hidden, intermediate]
up_proj_wt: Tensor,
gate_up_proj_wt: Tensor, // FUSED: [hidden, 2*intermediate]
down_proj_wt: Tensor, // TRANSPOSED: [intermediate, hidden]
}
impl Qwen3Block {
fn q_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, 0, self.q_dim) }
fn k_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, self.q_dim, self.kv_dim) }
fn v_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, self.q_dim + self.kv_dim, self.kv_dim) }
fn gate_proj_wt(&self) -> Tensor {
let half = self.gate_up_proj_wt.shape()[1] / 2;
self.gate_up_proj_wt.narrow(1, 0, half)
}
fn up_proj_wt(&self) -> Tensor {
let half = self.gate_up_proj_wt.shape()[1] / 2;
self.gate_up_proj_wt.narrow(1, half, half)
}
}
impl Qwen3 {
/// Single-GPU load (weights already on the target GPU). Equivalent to
/// `from_weights_tp(.., rank=0, world=1, device=0, tp=None)`.
@@ -88,17 +101,28 @@ impl Qwen3 {
}
for i in 0..num_layers {
let p = format!("model.layers.{i}");
let q_proj_wt = col(take(&mut w, &format!("{p}.self_attn.q_proj.weight")));
let k_proj_wt = col(take(&mut w, &format!("{p}.self_attn.k_proj.weight")));
let v_proj_wt = col(take(&mut w, &format!("{p}.self_attn.v_proj.weight")));
let q_dim = q_proj_wt.shape()[1];
let kv_dim = k_proj_wt.shape()[1];
let qkv_proj_wt = cat_cols(&[&q_proj_wt, &k_proj_wt, &v_proj_wt]);
drop((q_proj_wt, k_proj_wt, v_proj_wt));
let gate_proj_wt = col(take(&mut w, &format!("{p}.mlp.gate_proj.weight")));
let up_proj_wt = col(take(&mut w, &format!("{p}.mlp.up_proj.weight")));
let gate_up_proj_wt = cat_cols(&[&gate_proj_wt, &up_proj_wt]);
drop((gate_proj_wt, up_proj_wt));
xserv_cuda::allocator::cached_trim();
layers.push(Qwen3Block {
input_norm: repl(take(&mut w, &format!("{p}.input_layernorm.weight"))),
q_proj_wt: col(take(&mut w, &format!("{p}.self_attn.q_proj.weight"))),
k_proj_wt: col(take(&mut w, &format!("{p}.self_attn.k_proj.weight"))),
v_proj_wt: col(take(&mut w, &format!("{p}.self_attn.v_proj.weight"))),
qkv_proj_wt,
q_dim,
kv_dim,
o_proj_wt: row(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))),
k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))),
post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))),
gate_proj_wt: col(take(&mut w, &format!("{p}.mlp.gate_proj.weight"))),
up_proj_wt: col(take(&mut w, &format!("{p}.mlp.up_proj.weight"))),
gate_up_proj_wt,
down_proj_wt: row(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
});
}
@@ -144,52 +168,45 @@ impl Qwen3 {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps);
// Q/K/V projections (pre-transposed weights, x @ wt)
let q = matmul_2d(&normed, &layer.q_proj_wt);
let k = matmul_2d(&normed, &layer.k_proj_wt);
let v = matmul_2d(&normed, &layer.v_proj_wt);
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
// Reshape to [1, heads, seq, head_dim]
let q = reshape_heads(&q, new_tokens, num_heads, head_dim);
let k = reshape_heads(&k, new_tokens, num_kv_heads, head_dim);
let v = reshape_heads(&v, new_tokens, num_kv_heads, head_dim);
// QK normalization (per-head RMSNorm)
let q = head_rmsnorm(&q, &layer.q_norm, eps);
let k = head_rmsnorm(&k, &layer.k_norm, eps);
// RoPE — kernel expects [S, H, D], our tensors are [1, H, S, D]
// Transpose to [1, S, H, D] → reshape to [S, H, D] for RoPE
let q = transpose_for_rope(&q, new_tokens, num_heads, head_dim);
let k = transpose_for_rope(&k, new_tokens, num_kv_heads, head_dim);
rope_inplace(&q, &self.rope_cache, &positions);
rope_inplace(&k, &self.rope_cache, &positions);
// Transpose back to [1, H, S, D]
let q = transpose_from_rope(&q, new_tokens, num_heads, head_dim);
let k = transpose_from_rope(&k, new_tokens, num_kv_heads, head_dim);
// KV cache
let k_cpu = k.to_device(Device::Cpu);
let v_cpu = v.to_device(Device::Cpu);
cache.append_kv_tensor(layer_idx, &k_cpu, &v_cpu, new_tokens);
let (k_full, v_full) = cache.get_kv_tensors(layer_idx);
// GQA: repeat K/V
let n_rep = num_heads / num_kv_heads;
let k_full = repeat_kv(&k_full, n_rep);
let v_full = repeat_kv(&v_full, n_rep);
// Attention
let attn_out = attention(&q, &k_full, &v_full, true);
let attn_merged = merge_heads_any(&attn_out, new_tokens, hidden);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
x = add_any(&residual, &attn_proj);
// SwiGLU FFN
let residual = x.clone();
let normed = rmsnorm(&x, &layer.post_norm, eps);
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
let up = matmul_2d(&normed, &layer.up_proj_wt);
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
let ffn_dim = gate_up.shape()[1] / 2;
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
let gate_activated = silu(&gate);
let hidden_states = mul_any(&gate_activated, &up);
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
@@ -232,10 +249,10 @@ impl Qwen3 {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps); // [B, hidden]
// Batched projections: [B, hidden] × [hidden, X] = [B, X]
let q_all = matmul_2d(&normed, &layer.q_proj_wt); // [B, num_heads*head_dim]
let k_all = matmul_2d(&normed, &layer.k_proj_wt); // [B, num_kv_heads*head_dim]
let v_all = matmul_2d(&normed, &layer.v_proj_wt); // [B, num_kv_heads*head_dim]
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q_all = qkv.narrow(1, 0, layer.q_dim).contiguous();
let k_all = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v_all = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
// Per-sequence: reshape, qk-norm, RoPE, KV cache, attention, merge
let mut attn_outputs: Vec<Tensor> = Vec::with_capacity(batch);
@@ -290,9 +307,10 @@ impl Qwen3 {
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
// Batched FFN: all projections on [B, hidden]
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
let up = matmul_2d(&normed, &layer.up_proj_wt);
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
let ffn_dim = gate_up.shape()[1] / 2;
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
x = add_any(&residual, &down);
@@ -312,6 +330,13 @@ impl Qwen3 {
/// tokens: [B] one token per sequence
/// positions: [B] current logical position (BEFORE this step) per sequence
/// seq_slots: [B] slot ids in `paged_cache`
///
/// Layout note: for S=1 decode the memory of `[B, H, 1, D]`,
/// `[B, H, D]`, and `[B, H*D]` is the same — only shape/strides differ.
/// We exploit this to drop every per-sequence kernel: head_rmsnorm and
/// RoPE both natively accept the `[B*H, D]` / `[B, H, D]` layouts that
/// fall out of the projection matmuls, and the new-token KV scatter is
/// one batched `reshape_and_cache` kernel.
pub fn forward_decode_paged(
&self,
tokens: &[u32],
@@ -343,6 +368,10 @@ impl Qwen3 {
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
let max_blocks = paged_cache.max_blocks_per_seq();
// RoPE expects `[num_tokens, H, D]` with `num_tokens` positions —
// matches our `[B, H, D]` exactly, so we upload once here.
let positions_u32: Vec<u32> = positions.iter().map(|&p| p as u32).collect();
// Batched embedding: [B, hidden]
let mut x = embedding(&self.embed_tokens, tokens);
@@ -350,62 +379,41 @@ impl Qwen3 {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps);
let q_all = matmul_2d(&normed, &layer.q_proj_wt);
let k_all = matmul_2d(&normed, &layer.k_proj_wt);
let v_all = matmul_2d(&normed, &layer.v_proj_wt);
// Fused QKV projection: one GEMV instead of three.
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); // [B, (H+2*KV)*D]
let q_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
let q_all = qkv.narrow(1, 0, q_dim); // [B, H*D] (view)
let k_all = qkv.narrow(1, q_dim, kv_dim); // [B, KV*D] (view)
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
let mut q_rows: Vec<Tensor> = Vec::with_capacity(batch);
for b in 0..batch {
let q_row = row_view(&q_all, b);
let k_row = row_view(&k_all, b);
let v_row = row_view(&v_all, b);
// Per-head RMSNorm on contiguous copies (narrow views are strided).
let q_flat = q_all.contiguous().reshape(&[batch * num_heads, head_dim]);
let k_flat = k_all.contiguous().reshape(&[batch * num_kv_heads, head_dim]);
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
let q = xserv_kernels::reshape_heads_gpu(&q_row, 1, num_heads, head_dim);
let k = xserv_kernels::reshape_heads_gpu(&k_row, 1, num_kv_heads, head_dim);
let v = xserv_kernels::reshape_heads_gpu(&v_row, 1, num_kv_heads, head_dim);
let q_3d = q_normed.reshape(&[batch, num_heads, head_dim]);
let k_3d = k_normed.reshape(&[batch, num_kv_heads, head_dim]);
rope_inplace(&q_3d, &self.rope_cache, &positions_u32);
rope_inplace(&k_3d, &self.rope_cache, &positions_u32);
let q = head_rmsnorm(&q, &layer.q_norm, eps);
let k = head_rmsnorm(&k, &layer.k_norm, eps);
let v_3d = v_all.contiguous().reshape(&[batch, num_kv_heads, head_dim]);
let q = xserv_kernels::transpose_for_rope_gpu(&q, 1, num_heads, head_dim);
let k = xserv_kernels::transpose_for_rope_gpu(&k, 1, num_kv_heads, head_dim);
let pos = [positions[b] as u32];
rope_inplace(&q, &self.rope_cache, &pos);
rope_inplace(&k, &self.rope_cache, &pos);
let q = xserv_kernels::transpose_from_rope_gpu(&q, 1, num_heads, head_dim);
let k = xserv_kernels::transpose_from_rope_gpu(&k, 1, num_kv_heads, head_dim);
paged_cache.append_tokens(seq_slots[b], layer_idx, &k, &v, 1, positions[b]);
let q_flat = xserv_kernels::merge_heads_gpu(&q, 1, num_heads, head_dim);
q_rows.push(q_flat);
}
let q_batched_2d = concat_rows(&q_rows);
// q_batched_2d: [B, num_heads * head_dim]. Memory is [B, H, D] —
// a plain reshape view to [B, H, 1, D] is what the paged kernel expects.
let q_4d = q_batched_2d.reshape(&[batch, num_heads, 1, head_dim]);
// Single batched scatter for all sequences in the batch.
paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, batch);
// Paged attention reads Q as [B, H, 1, D] — a contiguous view
// of [B, H, D].
let q_4d = q_3d.reshape(&[batch, num_heads, 1, head_dim]);
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let attn_out = xserv_kernels::paged_decode_attention(
&q_4d,
k_pool_ptr,
v_pool_ptr,
bt_ptr,
cl_ptr,
batch,
num_heads,
num_kv_heads,
head_dim,
max_blocks,
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
batch, num_heads, num_kv_heads, head_dim, max_blocks,
);
// attn_out shape [B, H, 1, D] is contiguous-equivalent to [B, H*D].
// Plain reshape is a view; merge_heads_gpu would incorrectly swap B<->H.
let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
self.all_reduce(&attn_proj); // TP: sum partial attention outputs
@@ -413,8 +421,11 @@ impl Qwen3 {
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
let up = matmul_2d(&normed, &layer.up_proj_wt);
// Fused gate+up projection: one GEMV instead of two.
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt); // [B, 2*ffn]
let ffn_dim = gate_up.shape()[1] / 2;
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
self.all_reduce(&down); // TP: sum partial MLP outputs
@@ -459,9 +470,10 @@ impl Qwen3 {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps);
let q = matmul_2d(&normed, &layer.q_proj_wt);
let k = matmul_2d(&normed, &layer.k_proj_wt);
let v = matmul_2d(&normed, &layer.v_proj_wt);
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
@@ -477,25 +489,25 @@ impl Qwen3 {
let q = xserv_kernels::transpose_from_rope_gpu(&q, new_tokens, num_heads, head_dim);
let k = xserv_kernels::transpose_from_rope_gpu(&k, new_tokens, num_kv_heads, head_dim);
// Write into paged pool at the original (pre-advance) position.
paged_cache.append_tokens(slot, layer_idx, &k, &v, new_tokens, pos_offset);
// Gather contiguous K/V for the full sequence (seq_len already includes new_tokens).
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
let attn_out = flash_attention(&q, &k_full, &v_full, true);
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
self.all_reduce(&attn_proj); // TP: sum partial attention outputs
self.all_reduce(&attn_proj);
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
let up = matmul_2d(&normed, &layer.up_proj_wt);
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
let ffn_dim = gate_up.shape()[1] / 2;
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
self.all_reduce(&down); // TP: sum partial MLP outputs
self.all_reduce(&down);
x = add_any(&residual, &down);
}
@@ -520,45 +532,39 @@ impl Qwen3 {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps);
let q = matmul_2d(&normed, &layer.q_proj_wt);
let k = matmul_2d(&normed, &layer.k_proj_wt);
let v = matmul_2d(&normed, &layer.v_proj_wt);
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q = qkv.narrow(1, 0, layer.q_dim).contiguous();
let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous();
let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous();
// GPU reshape: [S, H*D] → [1, H, S, D]
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
let v = xserv_kernels::reshape_heads_gpu(&v, new_tokens, num_kv_heads, head_dim);
// QK norm (reshape to [H*S, D], rmsnorm, reshape back — stays on GPU)
let q = head_rmsnorm(&q, &layer.q_norm, eps);
let k = head_rmsnorm(&k, &layer.k_norm, eps);
// GPU transpose for RoPE: [1, H, S, D] → [S, H, D]
let q = xserv_kernels::transpose_for_rope_gpu(&q, new_tokens, num_heads, head_dim);
let k = xserv_kernels::transpose_for_rope_gpu(&k, new_tokens, num_kv_heads, head_dim);
rope_inplace(&q, &self.rope_cache, &positions);
rope_inplace(&k, &self.rope_cache, &positions);
// GPU transpose back: [S, H, D] → [1, H, S, D]
let q = xserv_kernels::transpose_from_rope_gpu(&q, new_tokens, num_heads, head_dim);
let k = xserv_kernels::transpose_from_rope_gpu(&k, new_tokens, num_kv_heads, head_dim);
// GPU KV cache
cache.append(layer_idx, &k, &v, new_tokens, pos_offset);
let (k_full, v_full) = cache.get_kv_len(layer_idx, pos_offset + new_tokens);
// Flash Attention with native GQA (no repeat_kv needed)
let attn_out = flash_attention(&q, &k_full, &v_full, true);
// GPU merge_heads: [1, H, S, D] → [S, H*D]
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
// Fused add + rmsnorm: (normed, x) where x = residual + attn_proj
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
// Fused SiLU×Mul
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
let up = matmul_2d(&normed, &layer.up_proj_wt);
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
let ffn_dim = gate_up.shape()[1] / 2;
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
x = add_any(&residual, &down);
@@ -573,15 +579,15 @@ impl Qwen3 {
pub fn layer_weight_ptrs(&self) -> Vec<crate::decode_graph::LayerWeightPtrs> {
self.layers.iter().map(|l| crate::decode_graph::LayerWeightPtrs {
input_norm: l.input_norm.data_ptr() as *const std::ffi::c_void,
q_proj_wt: l.q_proj_wt.data_ptr() as *const std::ffi::c_void,
k_proj_wt: l.k_proj_wt.data_ptr() as *const std::ffi::c_void,
v_proj_wt: l.v_proj_wt.data_ptr() as *const std::ffi::c_void,
q_proj_wt: l.q_proj_wt().data_ptr() as *const std::ffi::c_void,
k_proj_wt: l.k_proj_wt().data_ptr() as *const std::ffi::c_void,
v_proj_wt: l.v_proj_wt().data_ptr() as *const std::ffi::c_void,
o_proj_wt: l.o_proj_wt.data_ptr() as *const std::ffi::c_void,
q_norm: l.q_norm.data_ptr() as *const std::ffi::c_void,
k_norm: l.k_norm.data_ptr() as *const std::ffi::c_void,
post_norm: l.post_norm.data_ptr() as *const std::ffi::c_void,
gate_proj_wt: l.gate_proj_wt.data_ptr() as *const std::ffi::c_void,
up_proj_wt: l.up_proj_wt.data_ptr() as *const std::ffi::c_void,
gate_proj_wt: l.gate_proj_wt().data_ptr() as *const std::ffi::c_void,
up_proj_wt: l.up_proj_wt().data_ptr() as *const std::ffi::c_void,
down_proj_wt: l.down_proj_wt.data_ptr() as *const std::ffi::c_void,
}).collect()
}
@@ -790,6 +796,42 @@ fn concat_rows(rows: &[Tensor]) -> Tensor {
}
}
/// Concatenate 2D GPU tensors along dim=1 (columns). All must share dim 0.
fn cat_cols(tensors: &[&Tensor]) -> Tensor {
assert!(!tensors.is_empty());
let rows = tensors[0].shape()[0];
let dtype = tensors[0].dtype();
let device = tensors[0].device();
let elem = dtype.size_bytes();
let total_cols: usize = tensors.iter().map(|t| {
assert_eq!(t.ndim(), 2);
assert_eq!(t.shape()[0], rows);
assert!(t.is_contiguous());
t.shape()[1]
}).sum();
let out = Tensor::empty(&[rows, total_cols], dtype, device);
let dst_base = out.data_ptr() as *mut u8;
for r in 0..rows {
let mut col_off = 0usize;
for t in tensors {
let cols = t.shape()[1];
let src = unsafe { t.data_ptr().add(r * cols * elem) };
let dst = unsafe { dst_base.add((r * total_cols + col_off) * elem) };
let count = cols * elem;
unsafe {
xserv_cuda::ffi::cudaMemcpy(
dst as *mut u8,
src as *const u8,
count,
2, // cudaMemcpyDeviceToDevice
);
}
col_off += cols;
}
}
out
}
fn add_any(a: &Tensor, b: &Tensor) -> Tensor {
xserv_kernels::add(a, b)
}