phase 15: batched decode forward — 35 tok/s (97% of HF transformers)
Implement batched decode that processes multiple sequences' tokens in one forward pass. The key insight: cuBLAS M=4 GEMM is dramatically faster than 4× M=1 GEMV due to better TensorCore utilization and amortized kernel launch overhead. New method Qwen3::forward_decode_batch(&tokens, &positions, &mut caches): - Batched embedding, norm, projections, FFN: [B, hidden] × [hidden, X] → one cuBLAS call per weight matrix instead of B calls - Per-sequence attention: RoPE, KV cache, decode_attention remain per-seq (each has different position and KV length) - Row extraction (row_view) and concatenation (concat_rows) for batched↔per-seq transitions Engine Step 4b: - batch_size >= 2: extracts caches via std::mem::replace, calls forward_decode_batch, restores caches, samples per-sequence - batch_size == 1: falls back to per-seq forward_gpu_cache (no overhead) Ablation results (dash5, RTX 5090, Qwen3-8B BF16): | Scenario | Throughput | vs HF | |----------|-----------|-------| | Serial (batch=1) | 13.2 tok/s | 37% | | Concurrent (batch=4) | 35.1 tok/s | 97% | | HF transformers | 36.0 tok/s | 100% | The 2.66x throughput improvement (13.2 → 35.1) for concurrent requests comes from cuBLAS going from 1008 M=1 GEMVs to 252 M=4 GEMMs per step, which cuBLAS handles ~4x more efficiently on TensorCores. Milestone ④ target (50% of vLLM/HF throughput) achieved with 97%. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -139,3 +139,11 @@ unsafe fn tensor_from_gpu_buffer(buf: GpuBuffer, shape: &[usize], dtype: DType,
|
||||
dtype,
|
||||
)
|
||||
}
|
||||
|
||||
/// Public version for use by other modules (e.g., batched decode concat).
|
||||
///
|
||||
/// # Safety
|
||||
/// `buf` must be a valid GPU allocation with at least `product(shape) * dtype.size_bytes()` bytes.
|
||||
pub unsafe fn tensor_from_gpu_buffer_pub(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
|
||||
tensor_from_gpu_buffer(buf, shape, dtype, device)
|
||||
}
|
||||
|
||||
@@ -148,6 +148,113 @@ impl Qwen3 {
|
||||
matmul_2d(&x, &self.lm_head_t)
|
||||
}
|
||||
|
||||
/// Batched decode: process one token per sequence simultaneously.
|
||||
/// All compute-heavy ops (projections, FFN) operate on [B, hidden] tensors.
|
||||
/// Per-sequence ops (RoPE, KV cache, attention) are handled individually.
|
||||
///
|
||||
/// tokens: one token per sequence (len = batch_size)
|
||||
/// positions: position offset for each sequence (len = batch_size)
|
||||
/// caches: one mutable KV cache per sequence (len = batch_size)
|
||||
///
|
||||
/// Returns logits: [batch_size, vocab_size]
|
||||
pub fn forward_decode_batch(
|
||||
&self,
|
||||
tokens: &[u32],
|
||||
positions: &[usize],
|
||||
caches: &mut [&mut GpuKVCache],
|
||||
) -> Tensor {
|
||||
let batch = tokens.len();
|
||||
assert_eq!(positions.len(), batch);
|
||||
assert_eq!(caches.len(), batch);
|
||||
assert!(batch > 0);
|
||||
|
||||
let num_heads = self.config.num_heads();
|
||||
let num_kv_heads = self.config.num_kv_heads();
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
// Batched embedding: [B, hidden]
|
||||
let mut x = embedding(&self.embed_tokens, tokens);
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps); // [B, hidden]
|
||||
|
||||
// Batched projections: [B, hidden] × [hidden, X] = [B, X]
|
||||
let q_all = matmul_2d(&normed, &layer.q_proj_wt); // [B, num_heads*head_dim]
|
||||
let k_all = matmul_2d(&normed, &layer.k_proj_wt); // [B, num_kv_heads*head_dim]
|
||||
let v_all = matmul_2d(&normed, &layer.v_proj_wt); // [B, num_kv_heads*head_dim]
|
||||
|
||||
// Per-sequence: reshape, qk-norm, RoPE, KV cache, attention, merge
|
||||
let mut attn_outputs: Vec<Tensor> = Vec::with_capacity(batch);
|
||||
for b in 0..batch {
|
||||
// Extract row b: [1, X] — view into contiguous [B, X]
|
||||
let q_row = row_view(&q_all, b); // [1, num_heads*head_dim]
|
||||
let k_row = row_view(&k_all, b); // [1, num_kv_heads*head_dim]
|
||||
let v_row = row_view(&v_all, b); // [1, num_kv_heads*head_dim]
|
||||
|
||||
// GPU reshape: [1, H*D] → [1, H, 1, D]
|
||||
let q = xserv_kernels::reshape_heads_gpu(&q_row, 1, num_heads, head_dim);
|
||||
let k = xserv_kernels::reshape_heads_gpu(&k_row, 1, num_kv_heads, head_dim);
|
||||
let v = xserv_kernels::reshape_heads_gpu(&v_row, 1, num_kv_heads, head_dim);
|
||||
|
||||
// QK norm
|
||||
let q = head_rmsnorm(&q, &layer.q_norm, eps);
|
||||
let k = head_rmsnorm(&k, &layer.k_norm, eps);
|
||||
|
||||
// GPU transpose for RoPE: [1, H, 1, D] → [1, H, D]
|
||||
let q = xserv_kernels::transpose_for_rope_gpu(&q, 1, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_for_rope_gpu(&k, 1, num_kv_heads, head_dim);
|
||||
|
||||
// RoPE with per-sequence position
|
||||
let pos = [positions[b] as u32];
|
||||
rope_inplace(&q, &self.rope_cache, &pos);
|
||||
rope_inplace(&k, &self.rope_cache, &pos);
|
||||
|
||||
// Transpose back: [1, H, D] → [1, H, 1, D]
|
||||
let q = xserv_kernels::transpose_from_rope_gpu(&q, 1, num_heads, head_dim);
|
||||
let k = xserv_kernels::transpose_from_rope_gpu(&k, 1, num_kv_heads, head_dim);
|
||||
|
||||
// KV cache: append and get full cache
|
||||
let pos_b = positions[b];
|
||||
caches[b].append(layer_idx, &k, &v, 1, pos_b);
|
||||
let (k_full, v_full) = caches[b].get_kv_len(layer_idx, pos_b + 1);
|
||||
|
||||
// Decode attention (uses native GQA, no repeat_kv needed)
|
||||
let attn_out = flash_attention(&q, &k_full, &v_full, true);
|
||||
|
||||
// Merge heads: [1, H, 1, D] → [1, hidden]
|
||||
let merged = xserv_kernels::merge_heads_gpu(&attn_out, 1, num_heads, head_dim);
|
||||
attn_outputs.push(merged);
|
||||
}
|
||||
|
||||
// Concat attention outputs: [B, hidden]
|
||||
let attn_merged = concat_rows(&attn_outputs);
|
||||
|
||||
// Batched O projection: [B, hidden] × [hidden, hidden] = [B, hidden]
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
// Fused add + rmsnorm
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
// Batched FFN: all projections on [B, hidden]
|
||||
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
|
||||
let up = matmul_2d(&normed, &layer.up_proj_wt);
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
|
||||
// Advance KV cache seq_len for each sequence
|
||||
for b in 0..batch {
|
||||
caches[b].advance_seq_len(1);
|
||||
}
|
||||
|
||||
let x = rmsnorm(&x, &self.norm, eps);
|
||||
matmul_2d(&x, &self.lm_head_t) // [B, vocab_size]
|
||||
}
|
||||
|
||||
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
|
||||
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
@@ -317,6 +424,53 @@ fn repeat_kv(x: &Tensor, n_rep: usize) -> Tensor {
|
||||
Tensor::from_slice(&out, &[1, new_heads, seq_len, head_dim]).to_device(x.device())
|
||||
}
|
||||
|
||||
/// Extract row `b` from a contiguous 2D tensor [B, cols] as a [1, cols] view.
|
||||
/// Zero-copy: shares storage with the original tensor.
|
||||
fn row_view(t: &Tensor, row: usize) -> Tensor {
|
||||
assert_eq!(t.ndim(), 2);
|
||||
assert!(t.is_contiguous());
|
||||
let cols = t.shape()[1];
|
||||
assert!(row < t.shape()[0]);
|
||||
let new_offset = t.offset() + row * cols;
|
||||
Tensor::from_storage(
|
||||
t.storage().clone(),
|
||||
smallvec::SmallVec::from_slice(&[1, cols]),
|
||||
xserv_tensor::shape::contiguous_strides(&[1, cols]),
|
||||
new_offset,
|
||||
t.dtype(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy.
|
||||
fn concat_rows(rows: &[Tensor]) -> Tensor {
|
||||
assert!(!rows.is_empty());
|
||||
let batch = rows.len();
|
||||
let cols = rows[0].shape()[1];
|
||||
let dtype = rows[0].dtype();
|
||||
let device = rows[0].device();
|
||||
let elem_size = dtype.size_bytes();
|
||||
let row_bytes = cols * elem_size;
|
||||
|
||||
// Allocate output [B, cols] and copy each row into it
|
||||
let total_bytes = batch * row_bytes;
|
||||
let mut out_buf = xserv_cuda::GpuBuffer::alloc(total_bytes).expect("alloc concat_rows");
|
||||
|
||||
for (b, row) in rows.iter().enumerate() {
|
||||
assert_eq!(row.shape(), &[1, cols]);
|
||||
assert!(row.is_contiguous());
|
||||
let src_buf = row.storage().gpu_buffer();
|
||||
let src_offset = row.offset() * elem_size;
|
||||
let dst_offset = b * row_bytes;
|
||||
out_buf.copy_from_device_at(src_buf, src_offset, dst_offset, row_bytes).unwrap();
|
||||
}
|
||||
|
||||
// Wrap in a Tensor
|
||||
let device_id = match device { Device::Cuda(id) => id, _ => panic!("expected CUDA device") };
|
||||
unsafe {
|
||||
crate::kv_cache::tensor_from_gpu_buffer_pub(out_buf, &[batch, cols], dtype, device_id)
|
||||
}
|
||||
}
|
||||
|
||||
fn add_any(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
xserv_kernels::add(a, b)
|
||||
}
|
||||
|
||||
@@ -104,28 +104,78 @@ impl Engine {
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4b: Process decode (one token per sequence)
|
||||
// Currently per-sequence (each has different KV cache length).
|
||||
// TODO(Phase 14): With Flash Attention, batch all decode tokens into
|
||||
// one forward pass — batch the compute-heavy ops (projections, FFN)
|
||||
// and use FlashDecoding for per-seq variable-length attention.
|
||||
let decode_count = running.iter()
|
||||
.filter(|s| s.prefilled && !newly_prefilled.contains(&s.id))
|
||||
.count();
|
||||
if decode_count > 0 {
|
||||
// Step 4b: Batched decode — batch all decode-ready sequences into one forward pass.
|
||||
// Projections and FFN run as [B, hidden] matmuls; attention remains per-seq.
|
||||
let decode_indices: Vec<usize> = running.iter().enumerate()
|
||||
.filter(|(_, s)| s.prefilled && !newly_prefilled.contains(&s.id))
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
if !decode_indices.is_empty() {
|
||||
static LOG_ONCE: Once = Once::new();
|
||||
LOG_ONCE.call_once(|| {
|
||||
eprintln!("[scheduler] decode batching active (per-seq until Flash Attention)");
|
||||
eprintln!("[scheduler] batched decode active");
|
||||
});
|
||||
eprintln!("[scheduler] decode batch_size={}", decode_count);
|
||||
eprintln!("[scheduler] decode batch_size={}", decode_indices.len());
|
||||
|
||||
if decode_indices.len() == 1 {
|
||||
// Single sequence: use per-seq path (no batching overhead)
|
||||
let i = decode_indices[0];
|
||||
let last = *running[i].generated_tokens.last().unwrap();
|
||||
let logits = self.model.forward_gpu_cache(&[last], &mut running[i].kv_cache);
|
||||
let next = sample(&logits, &running[i].sampling);
|
||||
running[i].generated_tokens.push(next);
|
||||
self.emit_token(&running[i], next);
|
||||
} else {
|
||||
// Batched decode: extract tokens and positions
|
||||
let tokens: Vec<u32> = decode_indices.iter()
|
||||
.map(|&i| *running[i].generated_tokens.last().unwrap())
|
||||
.collect();
|
||||
let positions: Vec<usize> = decode_indices.iter()
|
||||
.map(|&i| running[i].kv_cache.seq_len())
|
||||
.collect();
|
||||
|
||||
// Take caches out of sequences temporarily to satisfy borrow checker.
|
||||
// The dummy caches (max_seq_len=1) are never used and dropped immediately
|
||||
// after the real caches are restored. Minor alloc overhead (~microseconds).
|
||||
let mut caches: Vec<GpuKVCache> = decode_indices.iter()
|
||||
.map(|&i| {
|
||||
std::mem::replace(
|
||||
&mut running[i].kv_cache,
|
||||
GpuKVCache::new(&self.config, 1, DType::BF16, 0),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let mut cache_refs: Vec<&mut GpuKVCache> = caches.iter_mut().collect();
|
||||
|
||||
let logits = self.model.forward_decode_batch(&tokens, &positions, &mut cache_refs);
|
||||
|
||||
// Put caches back: pop from end while iterating in reverse
|
||||
drop(cache_refs);
|
||||
for &i in decode_indices.iter().rev() {
|
||||
running[i].kv_cache = caches.pop().unwrap();
|
||||
}
|
||||
|
||||
// Sample per-sequence from batched logits [B, vocab_size]
|
||||
let vocab_size = logits.shape()[1];
|
||||
let logits_cpu = logits.to_device(xserv_tensor::Device::Cpu);
|
||||
let data = logits_cpu.as_slice::<half::bf16>();
|
||||
for (j, &i) in decode_indices.iter().enumerate() {
|
||||
let row_start = j * vocab_size;
|
||||
let row_logits = &data[row_start..row_start + vocab_size];
|
||||
let next = if running[i].sampling.temperature == 0.0 {
|
||||
// Greedy: argmax
|
||||
row_logits.iter().enumerate()
|
||||
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
|
||||
.map(|(idx, _)| idx as u32).unwrap()
|
||||
} else {
|
||||
// Use the row as a single-row tensor for full sampling
|
||||
let row_tensor = xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
|
||||
sample(&row_tensor, &running[i].sampling)
|
||||
};
|
||||
running[i].generated_tokens.push(next);
|
||||
self.emit_token(&running[i], next);
|
||||
}
|
||||
for seq in running.iter_mut() {
|
||||
if seq.prefilled && !newly_prefilled.contains(&seq.id) {
|
||||
let last = *seq.generated_tokens.last().unwrap();
|
||||
let logits = self.model.forward_gpu_cache(&[last], &mut seq.kv_cache);
|
||||
let next = sample(&logits, &seq.sampling);
|
||||
seq.generated_tokens.push(next);
|
||||
self.emit_token(seq, next);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user