model: pipeline-parallel Qwen3 (from_weights_pp + stage forward)

Layer-wise split: each stage loads only its contiguous layer range
[s*L, (s+1)*L); stage 0 keeps embed_tokens, the last stage keeps
norm/lm_head (others get a 1x1 placeholder). Heads are NOT split
(PP is orthogonal to TP). Adds embed/head and forward_layers_prefill/
forward_layers_decode that take and return the [tokens, hidden] hidden
state; per-stage PagedKVCache is indexed by local layer id.

sampling: derive Clone on SamplingParams (carried in the PP command enum).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-05-29 18:45:47 +08:00
parent 859c0cc0b6
commit da3aaa134a
2 changed files with 264 additions and 0 deletions

View File

@@ -20,6 +20,11 @@ pub struct Qwen3 {
tp: Option<std::sync::Arc<xserv_distributed::TpContext>>,
local_num_heads: usize, // = num_heads / world
local_num_kv_heads: usize, // = num_kv_heads / world
// Pipeline parallelism (Phase 18): this stage holds a contiguous slice of
// layers. `is_first_stage` owns `embed_tokens`; `is_last_stage` owns
// `norm`/`lm_head_t`. Both true for single-GPU / TP (the whole model).
is_first_stage: bool,
is_last_stage: bool,
}
struct Qwen3Block {
@@ -113,9 +118,267 @@ impl Qwen3 {
lm_head_t,
rope_cache,
tp,
is_first_stage: true,
is_last_stage: true,
}
}
/// Pipeline-parallel load (Phase 18). This stage holds the contiguous layer
/// range `[stage*L, (stage+1)*L)` with `L = num_layers / num_stages`; only
/// stage 0 keeps `embed_tokens` and only the last stage keeps `norm`/`lm_head`
/// (others get a 1x1 placeholder, guarded by the stage flags and never used).
/// Heads are NOT split (PP is orthogonal to TP), so each stage runs full
/// attention/MLP over its layers and hands off the `[tokens, hidden]` hidden
/// state to the next stage (the engine does the NCCL send/recv).
pub fn from_weights_pp(
config: ModelConfig,
mut w: HashMap<String, Tensor>,
stage: usize,
num_stages: usize,
device: u32,
) -> Self {
crate::init_kernels();
let dev = Device::Cuda(device);
assert!(num_stages >= 1);
let num_layers = config.num_layers();
assert!(num_layers % num_stages == 0, "num_layers {num_layers} not divisible by pp {num_stages}");
let per_stage = num_layers / num_stages;
let lo = stage * per_stage;
let hi = lo + per_stage;
let is_first_stage = stage == 0;
let is_last_stage = stage == num_stages - 1;
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
};
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
// Pre-transpose like the TP path's `col`/`row` do for world==1 (no shard).
let wt = |t: Tensor| -> Tensor { t.to_device(dev).transpose(0, 1).contiguous() };
let placeholder = || Tensor::from_slice(&[bf16::ZERO], &[1, 1]).to_device(dev);
let embed_tokens = if is_first_stage { repl(take(&mut w, "model.embed_tokens.weight")) } else { placeholder() };
let norm = if is_last_stage { repl(take(&mut w, "model.norm.weight")) } else { placeholder() };
let lm_head_t = if is_last_stage { wt(take(&mut w, "lm_head.weight")) } else { placeholder() };
let rope_cache = RopeCache::new(
config.max_seq_len(),
config.head_dim(),
config.rope_theta.unwrap_or(1_000_000.0) as f32,
);
let mut layers = Vec::with_capacity(per_stage);
eprintln!(
"[pp] stage {stage}/{num_stages}: layers [{lo}, {hi}) {}{}",
if is_first_stage { "+embed " } else { "" },
if is_last_stage { "+norm+lm_head" } else { "" }
);
for i in lo..hi {
let p = format!("model.layers.{i}");
layers.push(Qwen3Block {
input_norm: repl(take(&mut w, &format!("{p}.input_layernorm.weight"))),
q_proj_wt: wt(take(&mut w, &format!("{p}.self_attn.q_proj.weight"))),
k_proj_wt: wt(take(&mut w, &format!("{p}.self_attn.k_proj.weight"))),
v_proj_wt: wt(take(&mut w, &format!("{p}.self_attn.v_proj.weight"))),
o_proj_wt: wt(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))),
k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))),
post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))),
gate_proj_wt: wt(take(&mut w, &format!("{p}.mlp.gate_proj.weight"))),
up_proj_wt: wt(take(&mut w, &format!("{p}.mlp.up_proj.weight"))),
down_proj_wt: wt(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
});
}
Self {
local_num_heads: config.num_heads(),
local_num_kv_heads: config.num_kv_heads(),
config,
embed_tokens,
layers,
norm,
lm_head_t,
rope_cache,
tp: None,
is_first_stage,
is_last_stage,
}
}
/// Stage-0 token embedding: `[S]` token ids -> `[S, hidden]` hidden state.
pub fn embed(&self, token_ids: &[u32]) -> Tensor {
debug_assert!(self.is_first_stage);
embedding(&self.embed_tokens, token_ids)
}
/// Last-stage head: `[*, hidden]` -> logits `[*, vocab]`.
pub fn head(&self, x: &Tensor) -> Tensor {
debug_assert!(self.is_last_stage);
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
let x = rmsnorm(x, &self.norm, eps);
matmul_2d(&x, &self.lm_head_t)
}
pub fn pp_is_first(&self) -> bool { self.is_first_stage }
pub fn pp_is_last(&self) -> bool { self.is_last_stage }
/// PP prefill over THIS stage's layers. `x` is `[S, hidden]` (stage 0: from
/// `embed`; otherwise received from the previous stage). Writes K/V for this
/// stage's layers into `paged_cache` (indexed by local layer id) and returns
/// the `[S, hidden]` hidden state to hand to the next stage. Same kernels as
/// `forward_prefill_paged`, minus embedding and the final norm/lm_head.
pub fn forward_layers_prefill(
&self,
mut x: Tensor,
slot: usize,
paged_cache: &mut PagedKVCache,
) -> Tensor {
let new_tokens = x.shape()[0];
let pos_offset = paged_cache.seq_len(slot);
let num_heads = self.local_num_heads;
let num_kv_heads = self.local_num_kv_heads;
let head_dim = self.config.head_dim();
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
paged_cache.advance_seq_len(slot, new_tokens);
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps);
let q = matmul_2d(&normed, &layer.q_proj_wt);
let k = matmul_2d(&normed, &layer.k_proj_wt);
let v = matmul_2d(&normed, &layer.v_proj_wt);
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
let v = xserv_kernels::reshape_heads_gpu(&v, new_tokens, num_kv_heads, head_dim);
let q = head_rmsnorm(&q, &layer.q_norm, eps);
let k = head_rmsnorm(&k, &layer.k_norm, eps);
let q = xserv_kernels::transpose_for_rope_gpu(&q, new_tokens, num_heads, head_dim);
let k = xserv_kernels::transpose_for_rope_gpu(&k, new_tokens, num_kv_heads, head_dim);
rope_inplace(&q, &self.rope_cache, &positions);
rope_inplace(&k, &self.rope_cache, &positions);
let q = xserv_kernels::transpose_from_rope_gpu(&q, new_tokens, num_heads, head_dim);
let k = xserv_kernels::transpose_from_rope_gpu(&k, new_tokens, num_kv_heads, head_dim);
paged_cache.append_tokens(slot, layer_idx, &k, &v, new_tokens, pos_offset);
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
let attn_out = flash_attention(&q, &k_full, &v_full, true);
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
let up = matmul_2d(&normed, &layer.up_proj_wt);
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
x = add_any(&residual, &down);
}
x
}
/// PP decode over THIS stage's layers. `x` is `[B, hidden]`. Returns
/// `[B, hidden]`. Positions are read from `paged_cache` (all stages advance
/// in lockstep, so they agree). Same kernels as `forward_decode_paged`.
pub fn forward_layers_decode(
&self,
mut x: Tensor,
seq_slots: &[usize],
paged_cache: &mut PagedKVCache,
) -> Tensor {
let batch = seq_slots.len();
assert_eq!(x.shape()[0], batch);
let num_heads = self.local_num_heads;
let num_kv_heads = self.local_num_kv_heads;
let head_dim = self.config.head_dim();
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
let positions: Vec<usize> = seq_slots.iter().map(|&s| paged_cache.seq_len(s)).collect();
let kv_lens: Vec<i32> = positions.iter().map(|&p| (p + 1) as i32).collect();
for (b, &slot) in seq_slots.iter().enumerate() {
paged_cache.ensure_capacity(slot, positions[b] + 1);
}
paged_cache.sync_active_batch_with_lens(seq_slots, &kv_lens);
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
let max_blocks = paged_cache.max_blocks_per_seq();
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps);
let q_all = matmul_2d(&normed, &layer.q_proj_wt);
let k_all = matmul_2d(&normed, &layer.k_proj_wt);
let v_all = matmul_2d(&normed, &layer.v_proj_wt);
let mut q_rows: Vec<Tensor> = Vec::with_capacity(batch);
for b in 0..batch {
let q_row = row_view(&q_all, b);
let k_row = row_view(&k_all, b);
let v_row = row_view(&v_all, b);
let q = xserv_kernels::reshape_heads_gpu(&q_row, 1, num_heads, head_dim);
let k = xserv_kernels::reshape_heads_gpu(&k_row, 1, num_kv_heads, head_dim);
let v = xserv_kernels::reshape_heads_gpu(&v_row, 1, num_kv_heads, head_dim);
let q = head_rmsnorm(&q, &layer.q_norm, eps);
let k = head_rmsnorm(&k, &layer.k_norm, eps);
let q = xserv_kernels::transpose_for_rope_gpu(&q, 1, num_heads, head_dim);
let k = xserv_kernels::transpose_for_rope_gpu(&k, 1, num_kv_heads, head_dim);
let pos = [positions[b] as u32];
rope_inplace(&q, &self.rope_cache, &pos);
rope_inplace(&k, &self.rope_cache, &pos);
let q = xserv_kernels::transpose_from_rope_gpu(&q, 1, num_heads, head_dim);
let k = xserv_kernels::transpose_from_rope_gpu(&k, 1, num_kv_heads, head_dim);
paged_cache.append_tokens(seq_slots[b], layer_idx, &k, &v, 1, positions[b]);
let q_flat = xserv_kernels::merge_heads_gpu(&q, 1, num_heads, head_dim);
q_rows.push(q_flat);
}
let q_batched_2d = concat_rows(&q_rows);
let q_4d = q_batched_2d.reshape(&[batch, num_heads, 1, head_dim]);
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let attn_out = xserv_kernels::paged_decode_attention(
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
batch, num_heads, num_kv_heads, head_dim, max_blocks,
);
let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
let up = matmul_2d(&normed, &layer.up_proj_wt);
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
x = add_any(&residual, &down);
}
for &slot in seq_slots {
paged_cache.advance_seq_len(slot, 1);
}
x
}
/// In-place AllReduce(sum) of a partial `[*, hidden]` BF16 activation across
/// TP ranks (no-op when not tensor-parallel). Used after o_proj and down_proj.
#[inline]

View File

@@ -2,6 +2,7 @@ use half::bf16;
use rand::Rng;
use xserv_tensor::{DType, Device, Tensor};
#[derive(Clone)]
pub struct SamplingParams {
pub temperature: f32,
pub top_k: usize,