diff --git a/crates/xserv-model/src/kv_cache.rs b/crates/xserv-model/src/kv_cache.rs index 3ed49b9..11de9f5 100644 --- a/crates/xserv-model/src/kv_cache.rs +++ b/crates/xserv-model/src/kv_cache.rs @@ -139,3 +139,11 @@ unsafe fn tensor_from_gpu_buffer(buf: GpuBuffer, shape: &[usize], dtype: DType, dtype, ) } + +/// Public version for use by other modules (e.g., batched decode concat). +/// +/// # Safety +/// `buf` must be a valid GPU allocation with at least `product(shape) * dtype.size_bytes()` bytes. +pub unsafe fn tensor_from_gpu_buffer_pub(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor { + tensor_from_gpu_buffer(buf, shape, dtype, device) +} diff --git a/crates/xserv-model/src/qwen3.rs b/crates/xserv-model/src/qwen3.rs index ef9405e..e1e32c9 100644 --- a/crates/xserv-model/src/qwen3.rs +++ b/crates/xserv-model/src/qwen3.rs @@ -148,6 +148,113 @@ impl Qwen3 { matmul_2d(&x, &self.lm_head_t) } + /// Batched decode: process one token per sequence simultaneously. + /// All compute-heavy ops (projections, FFN) operate on [B, hidden] tensors. + /// Per-sequence ops (RoPE, KV cache, attention) are handled individually. + /// + /// tokens: one token per sequence (len = batch_size) + /// positions: position offset for each sequence (len = batch_size) + /// caches: one mutable KV cache per sequence (len = batch_size) + /// + /// Returns logits: [batch_size, vocab_size] + pub fn forward_decode_batch( + &self, + tokens: &[u32], + positions: &[usize], + caches: &mut [&mut GpuKVCache], + ) -> Tensor { + let batch = tokens.len(); + assert_eq!(positions.len(), batch); + assert_eq!(caches.len(), batch); + assert!(batch > 0); + + let num_heads = self.config.num_heads(); + let num_kv_heads = self.config.num_kv_heads(); + let head_dim = self.config.head_dim(); + let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32; + + // Batched embedding: [B, hidden] + let mut x = embedding(&self.embed_tokens, tokens); + + for (layer_idx, layer) in self.layers.iter().enumerate() { + let residual = x.clone(); + let normed = rmsnorm(&x, &layer.input_norm, eps); // [B, hidden] + + // Batched projections: [B, hidden] × [hidden, X] = [B, X] + let q_all = matmul_2d(&normed, &layer.q_proj_wt); // [B, num_heads*head_dim] + let k_all = matmul_2d(&normed, &layer.k_proj_wt); // [B, num_kv_heads*head_dim] + let v_all = matmul_2d(&normed, &layer.v_proj_wt); // [B, num_kv_heads*head_dim] + + // Per-sequence: reshape, qk-norm, RoPE, KV cache, attention, merge + let mut attn_outputs: Vec = Vec::with_capacity(batch); + for b in 0..batch { + // Extract row b: [1, X] — view into contiguous [B, X] + let q_row = row_view(&q_all, b); // [1, num_heads*head_dim] + let k_row = row_view(&k_all, b); // [1, num_kv_heads*head_dim] + let v_row = row_view(&v_all, b); // [1, num_kv_heads*head_dim] + + // GPU reshape: [1, H*D] → [1, H, 1, D] + let q = xserv_kernels::reshape_heads_gpu(&q_row, 1, num_heads, head_dim); + let k = xserv_kernels::reshape_heads_gpu(&k_row, 1, num_kv_heads, head_dim); + let v = xserv_kernels::reshape_heads_gpu(&v_row, 1, num_kv_heads, head_dim); + + // QK norm + let q = head_rmsnorm(&q, &layer.q_norm, eps); + let k = head_rmsnorm(&k, &layer.k_norm, eps); + + // GPU transpose for RoPE: [1, H, 1, D] → [1, H, D] + let q = xserv_kernels::transpose_for_rope_gpu(&q, 1, num_heads, head_dim); + let k = xserv_kernels::transpose_for_rope_gpu(&k, 1, num_kv_heads, head_dim); + + // RoPE with per-sequence position + let pos = [positions[b] as u32]; + rope_inplace(&q, &self.rope_cache, &pos); + rope_inplace(&k, &self.rope_cache, &pos); + + // Transpose back: [1, H, D] → [1, H, 1, D] + let q = xserv_kernels::transpose_from_rope_gpu(&q, 1, num_heads, head_dim); + let k = xserv_kernels::transpose_from_rope_gpu(&k, 1, num_kv_heads, head_dim); + + // KV cache: append and get full cache + let pos_b = positions[b]; + caches[b].append(layer_idx, &k, &v, 1, pos_b); + let (k_full, v_full) = caches[b].get_kv_len(layer_idx, pos_b + 1); + + // Decode attention (uses native GQA, no repeat_kv needed) + let attn_out = flash_attention(&q, &k_full, &v_full, true); + + // Merge heads: [1, H, 1, D] → [1, hidden] + let merged = xserv_kernels::merge_heads_gpu(&attn_out, 1, num_heads, head_dim); + attn_outputs.push(merged); + } + + // Concat attention outputs: [B, hidden] + let attn_merged = concat_rows(&attn_outputs); + + // Batched O projection: [B, hidden] × [hidden, hidden] = [B, hidden] + let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt); + + // Fused add + rmsnorm + let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); + let residual = x_new.clone(); + + // Batched FFN: all projections on [B, hidden] + let gate = matmul_2d(&normed, &layer.gate_proj_wt); + let up = matmul_2d(&normed, &layer.up_proj_wt); + let hidden_states = xserv_kernels::silu_mul(&gate, &up); + let down = matmul_2d(&hidden_states, &layer.down_proj_wt); + x = add_any(&residual, &down); + } + + // Advance KV cache seq_len for each sequence + for b in 0..batch { + caches[b].advance_seq_len(1); + } + + let x = rmsnorm(&x, &self.norm, eps); + matmul_2d(&x, &self.lm_head_t) // [B, vocab_size] + } + /// Forward with GPU-resident KV cache and GPU transpose/reshape kernels. pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor { let new_tokens = token_ids.len(); @@ -317,6 +424,53 @@ fn repeat_kv(x: &Tensor, n_rep: usize) -> Tensor { Tensor::from_slice(&out, &[1, new_heads, seq_len, head_dim]).to_device(x.device()) } +/// Extract row `b` from a contiguous 2D tensor [B, cols] as a [1, cols] view. +/// Zero-copy: shares storage with the original tensor. +fn row_view(t: &Tensor, row: usize) -> Tensor { + assert_eq!(t.ndim(), 2); + assert!(t.is_contiguous()); + let cols = t.shape()[1]; + assert!(row < t.shape()[0]); + let new_offset = t.offset() + row * cols; + Tensor::from_storage( + t.storage().clone(), + smallvec::SmallVec::from_slice(&[1, cols]), + xserv_tensor::shape::contiguous_strides(&[1, cols]), + new_offset, + t.dtype(), + ) +} + +/// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy. +fn concat_rows(rows: &[Tensor]) -> Tensor { + assert!(!rows.is_empty()); + let batch = rows.len(); + let cols = rows[0].shape()[1]; + let dtype = rows[0].dtype(); + let device = rows[0].device(); + let elem_size = dtype.size_bytes(); + let row_bytes = cols * elem_size; + + // Allocate output [B, cols] and copy each row into it + let total_bytes = batch * row_bytes; + let mut out_buf = xserv_cuda::GpuBuffer::alloc(total_bytes).expect("alloc concat_rows"); + + for (b, row) in rows.iter().enumerate() { + assert_eq!(row.shape(), &[1, cols]); + assert!(row.is_contiguous()); + let src_buf = row.storage().gpu_buffer(); + let src_offset = row.offset() * elem_size; + let dst_offset = b * row_bytes; + out_buf.copy_from_device_at(src_buf, src_offset, dst_offset, row_bytes).unwrap(); + } + + // Wrap in a Tensor + let device_id = match device { Device::Cuda(id) => id, _ => panic!("expected CUDA device") }; + unsafe { + crate::kv_cache::tensor_from_gpu_buffer_pub(out_buf, &[batch, cols], dtype, device_id) + } +} + fn add_any(a: &Tensor, b: &Tensor) -> Tensor { xserv_kernels::add(a, b) } diff --git a/crates/xserv-server/src/engine.rs b/crates/xserv-server/src/engine.rs index 2211738..69712d9 100644 --- a/crates/xserv-server/src/engine.rs +++ b/crates/xserv-server/src/engine.rs @@ -104,28 +104,78 @@ impl Engine { } } - // Step 4b: Process decode (one token per sequence) - // Currently per-sequence (each has different KV cache length). - // TODO(Phase 14): With Flash Attention, batch all decode tokens into - // one forward pass — batch the compute-heavy ops (projections, FFN) - // and use FlashDecoding for per-seq variable-length attention. - let decode_count = running.iter() - .filter(|s| s.prefilled && !newly_prefilled.contains(&s.id)) - .count(); - if decode_count > 0 { + // Step 4b: Batched decode — batch all decode-ready sequences into one forward pass. + // Projections and FFN run as [B, hidden] matmuls; attention remains per-seq. + let decode_indices: Vec = running.iter().enumerate() + .filter(|(_, s)| s.prefilled && !newly_prefilled.contains(&s.id)) + .map(|(i, _)| i) + .collect(); + + if !decode_indices.is_empty() { static LOG_ONCE: Once = Once::new(); LOG_ONCE.call_once(|| { - eprintln!("[scheduler] decode batching active (per-seq until Flash Attention)"); + eprintln!("[scheduler] batched decode active"); }); - eprintln!("[scheduler] decode batch_size={}", decode_count); - } - for seq in running.iter_mut() { - if seq.prefilled && !newly_prefilled.contains(&seq.id) { - let last = *seq.generated_tokens.last().unwrap(); - let logits = self.model.forward_gpu_cache(&[last], &mut seq.kv_cache); - let next = sample(&logits, &seq.sampling); - seq.generated_tokens.push(next); - self.emit_token(seq, next); + eprintln!("[scheduler] decode batch_size={}", decode_indices.len()); + + if decode_indices.len() == 1 { + // Single sequence: use per-seq path (no batching overhead) + let i = decode_indices[0]; + let last = *running[i].generated_tokens.last().unwrap(); + let logits = self.model.forward_gpu_cache(&[last], &mut running[i].kv_cache); + let next = sample(&logits, &running[i].sampling); + running[i].generated_tokens.push(next); + self.emit_token(&running[i], next); + } else { + // Batched decode: extract tokens and positions + let tokens: Vec = decode_indices.iter() + .map(|&i| *running[i].generated_tokens.last().unwrap()) + .collect(); + let positions: Vec = decode_indices.iter() + .map(|&i| running[i].kv_cache.seq_len()) + .collect(); + + // Take caches out of sequences temporarily to satisfy borrow checker. + // The dummy caches (max_seq_len=1) are never used and dropped immediately + // after the real caches are restored. Minor alloc overhead (~microseconds). + let mut caches: Vec = decode_indices.iter() + .map(|&i| { + std::mem::replace( + &mut running[i].kv_cache, + GpuKVCache::new(&self.config, 1, DType::BF16, 0), + ) + }) + .collect(); + let mut cache_refs: Vec<&mut GpuKVCache> = caches.iter_mut().collect(); + + let logits = self.model.forward_decode_batch(&tokens, &positions, &mut cache_refs); + + // Put caches back: pop from end while iterating in reverse + drop(cache_refs); + for &i in decode_indices.iter().rev() { + running[i].kv_cache = caches.pop().unwrap(); + } + + // Sample per-sequence from batched logits [B, vocab_size] + let vocab_size = logits.shape()[1]; + let logits_cpu = logits.to_device(xserv_tensor::Device::Cpu); + let data = logits_cpu.as_slice::(); + for (j, &i) in decode_indices.iter().enumerate() { + let row_start = j * vocab_size; + let row_logits = &data[row_start..row_start + vocab_size]; + let next = if running[i].sampling.temperature == 0.0 { + // Greedy: argmax + row_logits.iter().enumerate() + .max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap()) + .map(|(idx, _)| idx as u32).unwrap() + } else { + // Use the row as a single-row tensor for full sampling + let row_tensor = xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]); + sample(&row_tensor, &running[i].sampling) + }; + running[i].generated_tokens.push(next); + self.emit_token(&running[i], next); + } } }