model: pipeline-parallel Qwen3 (from_weights_pp + stage forward)
Layer-wise split: each stage loads only its contiguous layer range [s*L, (s+1)*L); stage 0 keeps embed_tokens, the last stage keeps norm/lm_head (others get a 1x1 placeholder). Heads are NOT split (PP is orthogonal to TP). Adds embed/head and forward_layers_prefill/ forward_layers_decode that take and return the [tokens, hidden] hidden state; per-stage PagedKVCache is indexed by local layer id. sampling: derive Clone on SamplingParams (carried in the PP command enum). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -20,6 +20,11 @@ pub struct Qwen3 {
|
||||
tp: Option<std::sync::Arc<xserv_distributed::TpContext>>,
|
||||
local_num_heads: usize, // = num_heads / world
|
||||
local_num_kv_heads: usize, // = num_kv_heads / world
|
||||
// Pipeline parallelism (Phase 18): this stage holds a contiguous slice of
|
||||
// layers. `is_first_stage` owns `embed_tokens`; `is_last_stage` owns
|
||||
// `norm`/`lm_head_t`. Both true for single-GPU / TP (the whole model).
|
||||
is_first_stage: bool,
|
||||
is_last_stage: bool,
|
||||
}
|
||||
|
||||
struct Qwen3Block {
|
||||
@@ -113,9 +118,267 @@ impl Qwen3 {
|
||||
lm_head_t,
|
||||
rope_cache,
|
||||
tp,
|
||||
is_first_stage: true,
|
||||
is_last_stage: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Pipeline-parallel load (Phase 18). This stage holds the contiguous layer
|
||||
/// range `[stage*L, (stage+1)*L)` with `L = num_layers / num_stages`; only
|
||||
/// stage 0 keeps `embed_tokens` and only the last stage keeps `norm`/`lm_head`
|
||||
/// (others get a 1x1 placeholder, guarded by the stage flags and never used).
|
||||
/// Heads are NOT split (PP is orthogonal to TP), so each stage runs full
|
||||
/// attention/MLP over its layers and hands off the `[tokens, hidden]` hidden
|
||||
/// state to the next stage (the engine does the NCCL send/recv).
|
||||
pub fn from_weights_pp(
|
||||
config: ModelConfig,
|
||||
mut w: HashMap<String, Tensor>,
|
||||
stage: usize,
|
||||
num_stages: usize,
|
||||
device: u32,
|
||||
) -> Self {
|
||||
crate::init_kernels();
|
||||
let dev = Device::Cuda(device);
|
||||
assert!(num_stages >= 1);
|
||||
let num_layers = config.num_layers();
|
||||
assert!(num_layers % num_stages == 0, "num_layers {num_layers} not divisible by pp {num_stages}");
|
||||
let per_stage = num_layers / num_stages;
|
||||
let lo = stage * per_stage;
|
||||
let hi = lo + per_stage;
|
||||
let is_first_stage = stage == 0;
|
||||
let is_last_stage = stage == num_stages - 1;
|
||||
|
||||
let take = |w: &mut HashMap<String, Tensor>, name: &str| -> Tensor {
|
||||
w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}"))
|
||||
};
|
||||
let repl = |t: Tensor| -> Tensor { t.to_device(dev) };
|
||||
// Pre-transpose like the TP path's `col`/`row` do for world==1 (no shard).
|
||||
let wt = |t: Tensor| -> Tensor { t.to_device(dev).transpose(0, 1).contiguous() };
|
||||
let placeholder = || Tensor::from_slice(&[bf16::ZERO], &[1, 1]).to_device(dev);
|
||||
|
||||
let embed_tokens = if is_first_stage { repl(take(&mut w, "model.embed_tokens.weight")) } else { placeholder() };
|
||||
let norm = if is_last_stage { repl(take(&mut w, "model.norm.weight")) } else { placeholder() };
|
||||
let lm_head_t = if is_last_stage { wt(take(&mut w, "lm_head.weight")) } else { placeholder() };
|
||||
|
||||
let rope_cache = RopeCache::new(
|
||||
config.max_seq_len(),
|
||||
config.head_dim(),
|
||||
config.rope_theta.unwrap_or(1_000_000.0) as f32,
|
||||
);
|
||||
|
||||
let mut layers = Vec::with_capacity(per_stage);
|
||||
eprintln!(
|
||||
"[pp] stage {stage}/{num_stages}: layers [{lo}, {hi}) {}{}",
|
||||
if is_first_stage { "+embed " } else { "" },
|
||||
if is_last_stage { "+norm+lm_head" } else { "" }
|
||||
);
|
||||
for i in lo..hi {
|
||||
let p = format!("model.layers.{i}");
|
||||
layers.push(Qwen3Block {
|
||||
input_norm: repl(take(&mut w, &format!("{p}.input_layernorm.weight"))),
|
||||
q_proj_wt: wt(take(&mut w, &format!("{p}.self_attn.q_proj.weight"))),
|
||||
k_proj_wt: wt(take(&mut w, &format!("{p}.self_attn.k_proj.weight"))),
|
||||
v_proj_wt: wt(take(&mut w, &format!("{p}.self_attn.v_proj.weight"))),
|
||||
o_proj_wt: wt(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))),
|
||||
q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))),
|
||||
k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))),
|
||||
post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))),
|
||||
gate_proj_wt: wt(take(&mut w, &format!("{p}.mlp.gate_proj.weight"))),
|
||||
up_proj_wt: wt(take(&mut w, &format!("{p}.mlp.up_proj.weight"))),
|
||||
down_proj_wt: wt(take(&mut w, &format!("{p}.mlp.down_proj.weight"))),
|
||||
});
|
||||
}
|
||||
|
||||
Self {
|
||||
local_num_heads: config.num_heads(),
|
||||
local_num_kv_heads: config.num_kv_heads(),
|
||||
config,
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head_t,
|
||||
rope_cache,
|
||||
tp: None,
|
||||
is_first_stage,
|
||||
is_last_stage,
|
||||
}
|
||||
}
|
||||
|
||||
/// Stage-0 token embedding: `[S]` token ids -> `[S, hidden]` hidden state.
|
||||
pub fn embed(&self, token_ids: &[u32]) -> Tensor {
|
||||
debug_assert!(self.is_first_stage);
|
||||
embedding(&self.embed_tokens, token_ids)
|
||||
}
|
||||
|
||||
/// Last-stage head: `[*, hidden]` -> logits `[*, vocab]`.
|
||||
pub fn head(&self, x: &Tensor) -> Tensor {
|
||||
debug_assert!(self.is_last_stage);
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
let x = rmsnorm(x, &self.norm, eps);
|
||||
matmul_2d(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
pub fn pp_is_first(&self) -> bool { self.is_first_stage }
|
||||
pub fn pp_is_last(&self) -> bool { self.is_last_stage }
|
||||
|
||||
/// PP prefill over THIS stage's layers. `x` is `[S, hidden]` (stage 0: from
|
||||
/// `embed`; otherwise received from the previous stage). Writes K/V for this
|
||||
/// stage's layers into `paged_cache` (indexed by local layer id) and returns
|
||||
/// the `[S, hidden]` hidden state to hand to the next stage. Same kernels as
|
||||
/// `forward_prefill_paged`, minus embedding and the final norm/lm_head.
|
||||
pub fn forward_layers_prefill(
|
||||
&self,
|
||||
mut x: Tensor,
|
||||
slot: usize,
|
||||
paged_cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
let new_tokens = x.shape()[0];
|
||||
let pos_offset = paged_cache.seq_len(slot);
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
|
||||
paged_cache.advance_seq_len(slot, new_tokens);
|
||||
|
||||
let positions: Vec<u32> = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
let q = matmul_2d(&normed, &layer.q_proj_wt);
|
||||
let k = matmul_2d(&normed, &layer.k_proj_wt);
|
||||
let v = matmul_2d(&normed, &layer.v_proj_wt);
|
||||
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
let v = xserv_kernels::reshape_heads_gpu(&v, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
let q = head_rmsnorm(&q, &layer.q_norm, eps);
|
||||
let k = head_rmsnorm(&k, &layer.k_norm, eps);
|
||||
|
||||
let q = xserv_kernels::transpose_for_rope_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_for_rope_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
rope_inplace(&q, &self.rope_cache, &positions);
|
||||
rope_inplace(&k, &self.rope_cache, &positions);
|
||||
let q = xserv_kernels::transpose_from_rope_gpu(&q, new_tokens, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_from_rope_gpu(&k, new_tokens, num_kv_heads, head_dim);
|
||||
|
||||
paged_cache.append_tokens(slot, layer_idx, &k, &v, new_tokens, pos_offset);
|
||||
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
|
||||
let attn_out = flash_attention(&q, &k_full, &v_full, true);
|
||||
|
||||
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
|
||||
let up = matmul_2d(&normed, &layer.up_proj_wt);
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
/// PP decode over THIS stage's layers. `x` is `[B, hidden]`. Returns
|
||||
/// `[B, hidden]`. Positions are read from `paged_cache` (all stages advance
|
||||
/// in lockstep, so they agree). Same kernels as `forward_decode_paged`.
|
||||
pub fn forward_layers_decode(
|
||||
&self,
|
||||
mut x: Tensor,
|
||||
seq_slots: &[usize],
|
||||
paged_cache: &mut PagedKVCache,
|
||||
) -> Tensor {
|
||||
let batch = seq_slots.len();
|
||||
assert_eq!(x.shape()[0], batch);
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
let positions: Vec<usize> = seq_slots.iter().map(|&s| paged_cache.seq_len(s)).collect();
|
||||
let kv_lens: Vec<i32> = positions.iter().map(|&p| (p + 1) as i32).collect();
|
||||
for (b, &slot) in seq_slots.iter().enumerate() {
|
||||
paged_cache.ensure_capacity(slot, positions[b] + 1);
|
||||
}
|
||||
paged_cache.sync_active_batch_with_lens(seq_slots, &kv_lens);
|
||||
|
||||
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
|
||||
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
|
||||
let max_blocks = paged_cache.max_blocks_per_seq();
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
let q_all = matmul_2d(&normed, &layer.q_proj_wt);
|
||||
let k_all = matmul_2d(&normed, &layer.k_proj_wt);
|
||||
let v_all = matmul_2d(&normed, &layer.v_proj_wt);
|
||||
|
||||
let mut q_rows: Vec<Tensor> = Vec::with_capacity(batch);
|
||||
for b in 0..batch {
|
||||
let q_row = row_view(&q_all, b);
|
||||
let k_row = row_view(&k_all, b);
|
||||
let v_row = row_view(&v_all, b);
|
||||
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q_row, 1, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k_row, 1, num_kv_heads, head_dim);
|
||||
let v = xserv_kernels::reshape_heads_gpu(&v_row, 1, num_kv_heads, head_dim);
|
||||
|
||||
let q = head_rmsnorm(&q, &layer.q_norm, eps);
|
||||
let k = head_rmsnorm(&k, &layer.k_norm, eps);
|
||||
|
||||
let q = xserv_kernels::transpose_for_rope_gpu(&q, 1, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_for_rope_gpu(&k, 1, num_kv_heads, head_dim);
|
||||
|
||||
let pos = [positions[b] as u32];
|
||||
rope_inplace(&q, &self.rope_cache, &pos);
|
||||
rope_inplace(&k, &self.rope_cache, &pos);
|
||||
|
||||
let q = xserv_kernels::transpose_from_rope_gpu(&q, 1, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_from_rope_gpu(&k, 1, num_kv_heads, head_dim);
|
||||
|
||||
paged_cache.append_tokens(seq_slots[b], layer_idx, &k, &v, 1, positions[b]);
|
||||
|
||||
let q_flat = xserv_kernels::merge_heads_gpu(&q, 1, num_heads, head_dim);
|
||||
q_rows.push(q_flat);
|
||||
}
|
||||
|
||||
let q_batched_2d = concat_rows(&q_rows);
|
||||
let q_4d = q_batched_2d.reshape(&[batch, num_heads, 1, head_dim]);
|
||||
|
||||
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
|
||||
let attn_out = xserv_kernels::paged_decode_attention(
|
||||
&q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr,
|
||||
batch, num_heads, num_kv_heads, head_dim, max_blocks,
|
||||
);
|
||||
|
||||
let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
|
||||
let up = matmul_2d(&normed, &layer.up_proj_wt);
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
|
||||
for &slot in seq_slots {
|
||||
paged_cache.advance_seq_len(slot, 1);
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
/// In-place AllReduce(sum) of a partial `[*, hidden]` BF16 activation across
|
||||
/// TP ranks (no-op when not tensor-parallel). Used after o_proj and down_proj.
|
||||
#[inline]
|
||||
|
||||
@@ -2,6 +2,7 @@ use half::bf16;
|
||||
use rand::Rng;
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SamplingParams {
|
||||
pub temperature: f32,
|
||||
pub top_k: usize,
|
||||
|
||||
Reference in New Issue
Block a user