phase 15: batched decode forward — 35 tok/s (97% of HF transformers)

Implement batched decode that processes multiple sequences' tokens in one
forward pass. The key insight: cuBLAS M=4 GEMM is dramatically faster
than 4× M=1 GEMV due to better TensorCore utilization and amortized
kernel launch overhead.

New method Qwen3::forward_decode_batch(&tokens, &positions, &mut caches):
- Batched embedding, norm, projections, FFN: [B, hidden] × [hidden, X]
  → one cuBLAS call per weight matrix instead of B calls
- Per-sequence attention: RoPE, KV cache, decode_attention remain per-seq
  (each has different position and KV length)
- Row extraction (row_view) and concatenation (concat_rows) for
  batched↔per-seq transitions

Engine Step 4b:
- batch_size >= 2: extracts caches via std::mem::replace, calls
  forward_decode_batch, restores caches, samples per-sequence
- batch_size == 1: falls back to per-seq forward_gpu_cache (no overhead)

Ablation results (dash5, RTX 5090, Qwen3-8B BF16):

| Scenario | Throughput | vs HF |
|----------|-----------|-------|
| Serial (batch=1) | 13.2 tok/s | 37% |
| Concurrent (batch=4) | 35.1 tok/s | 97% |
| HF transformers | 36.0 tok/s | 100% |

The 2.66x throughput improvement (13.2 → 35.1) for concurrent requests
comes from cuBLAS going from 1008 M=1 GEMVs to 252 M=4 GEMMs per step,
which cuBLAS handles ~4x more efficiently on TensorCores.

Milestone ④ target (50% of vLLM/HF throughput) achieved with 97%.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-22 20:07:43 +08:00
parent 9783fcf410
commit 876d3f5d6a
3 changed files with 231 additions and 19 deletions

View File

@@ -139,3 +139,11 @@ unsafe fn tensor_from_gpu_buffer(buf: GpuBuffer, shape: &[usize], dtype: DType,
dtype, dtype,
) )
} }
/// Public version for use by other modules (e.g., batched decode concat).
///
/// # Safety
/// `buf` must be a valid GPU allocation with at least `product(shape) * dtype.size_bytes()` bytes.
pub unsafe fn tensor_from_gpu_buffer_pub(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor {
tensor_from_gpu_buffer(buf, shape, dtype, device)
}

View File

@@ -148,6 +148,113 @@ impl Qwen3 {
matmul_2d(&x, &self.lm_head_t) matmul_2d(&x, &self.lm_head_t)
} }
/// Batched decode: process one token per sequence simultaneously.
/// All compute-heavy ops (projections, FFN) operate on [B, hidden] tensors.
/// Per-sequence ops (RoPE, KV cache, attention) are handled individually.
///
/// tokens: one token per sequence (len = batch_size)
/// positions: position offset for each sequence (len = batch_size)
/// caches: one mutable KV cache per sequence (len = batch_size)
///
/// Returns logits: [batch_size, vocab_size]
pub fn forward_decode_batch(
&self,
tokens: &[u32],
positions: &[usize],
caches: &mut [&mut GpuKVCache],
) -> Tensor {
let batch = tokens.len();
assert_eq!(positions.len(), batch);
assert_eq!(caches.len(), batch);
assert!(batch > 0);
let num_heads = self.config.num_heads();
let num_kv_heads = self.config.num_kv_heads();
let head_dim = self.config.head_dim();
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
// Batched embedding: [B, hidden]
let mut x = embedding(&self.embed_tokens, tokens);
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps); // [B, hidden]
// Batched projections: [B, hidden] × [hidden, X] = [B, X]
let q_all = matmul_2d(&normed, &layer.q_proj_wt); // [B, num_heads*head_dim]
let k_all = matmul_2d(&normed, &layer.k_proj_wt); // [B, num_kv_heads*head_dim]
let v_all = matmul_2d(&normed, &layer.v_proj_wt); // [B, num_kv_heads*head_dim]
// Per-sequence: reshape, qk-norm, RoPE, KV cache, attention, merge
let mut attn_outputs: Vec<Tensor> = Vec::with_capacity(batch);
for b in 0..batch {
// Extract row b: [1, X] — view into contiguous [B, X]
let q_row = row_view(&q_all, b); // [1, num_heads*head_dim]
let k_row = row_view(&k_all, b); // [1, num_kv_heads*head_dim]
let v_row = row_view(&v_all, b); // [1, num_kv_heads*head_dim]
// GPU reshape: [1, H*D] → [1, H, 1, D]
let q = xserv_kernels::reshape_heads_gpu(&q_row, 1, num_heads, head_dim);
let k = xserv_kernels::reshape_heads_gpu(&k_row, 1, num_kv_heads, head_dim);
let v = xserv_kernels::reshape_heads_gpu(&v_row, 1, num_kv_heads, head_dim);
// QK norm
let q = head_rmsnorm(&q, &layer.q_norm, eps);
let k = head_rmsnorm(&k, &layer.k_norm, eps);
// GPU transpose for RoPE: [1, H, 1, D] → [1, H, D]
let q = xserv_kernels::transpose_for_rope_gpu(&q, 1, num_heads, head_dim);
let k = xserv_kernels::transpose_for_rope_gpu(&k, 1, num_kv_heads, head_dim);
// RoPE with per-sequence position
let pos = [positions[b] as u32];
rope_inplace(&q, &self.rope_cache, &pos);
rope_inplace(&k, &self.rope_cache, &pos);
// Transpose back: [1, H, D] → [1, H, 1, D]
let q = xserv_kernels::transpose_from_rope_gpu(&q, 1, num_heads, head_dim);
let k = xserv_kernels::transpose_from_rope_gpu(&k, 1, num_kv_heads, head_dim);
// KV cache: append and get full cache
let pos_b = positions[b];
caches[b].append(layer_idx, &k, &v, 1, pos_b);
let (k_full, v_full) = caches[b].get_kv_len(layer_idx, pos_b + 1);
// Decode attention (uses native GQA, no repeat_kv needed)
let attn_out = flash_attention(&q, &k_full, &v_full, true);
// Merge heads: [1, H, 1, D] → [1, hidden]
let merged = xserv_kernels::merge_heads_gpu(&attn_out, 1, num_heads, head_dim);
attn_outputs.push(merged);
}
// Concat attention outputs: [B, hidden]
let attn_merged = concat_rows(&attn_outputs);
// Batched O projection: [B, hidden] × [hidden, hidden] = [B, hidden]
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
// Fused add + rmsnorm
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
// Batched FFN: all projections on [B, hidden]
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
let up = matmul_2d(&normed, &layer.up_proj_wt);
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
x = add_any(&residual, &down);
}
// Advance KV cache seq_len for each sequence
for b in 0..batch {
caches[b].advance_seq_len(1);
}
let x = rmsnorm(&x, &self.norm, eps);
matmul_2d(&x, &self.lm_head_t) // [B, vocab_size]
}
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels. /// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor { pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
let new_tokens = token_ids.len(); let new_tokens = token_ids.len();
@@ -317,6 +424,53 @@ fn repeat_kv(x: &Tensor, n_rep: usize) -> Tensor {
Tensor::from_slice(&out, &[1, new_heads, seq_len, head_dim]).to_device(x.device()) Tensor::from_slice(&out, &[1, new_heads, seq_len, head_dim]).to_device(x.device())
} }
/// Extract row `b` from a contiguous 2D tensor [B, cols] as a [1, cols] view.
/// Zero-copy: shares storage with the original tensor.
fn row_view(t: &Tensor, row: usize) -> Tensor {
assert_eq!(t.ndim(), 2);
assert!(t.is_contiguous());
let cols = t.shape()[1];
assert!(row < t.shape()[0]);
let new_offset = t.offset() + row * cols;
Tensor::from_storage(
t.storage().clone(),
smallvec::SmallVec::from_slice(&[1, cols]),
xserv_tensor::shape::contiguous_strides(&[1, cols]),
new_offset,
t.dtype(),
)
}
/// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy.
fn concat_rows(rows: &[Tensor]) -> Tensor {
assert!(!rows.is_empty());
let batch = rows.len();
let cols = rows[0].shape()[1];
let dtype = rows[0].dtype();
let device = rows[0].device();
let elem_size = dtype.size_bytes();
let row_bytes = cols * elem_size;
// Allocate output [B, cols] and copy each row into it
let total_bytes = batch * row_bytes;
let mut out_buf = xserv_cuda::GpuBuffer::alloc(total_bytes).expect("alloc concat_rows");
for (b, row) in rows.iter().enumerate() {
assert_eq!(row.shape(), &[1, cols]);
assert!(row.is_contiguous());
let src_buf = row.storage().gpu_buffer();
let src_offset = row.offset() * elem_size;
let dst_offset = b * row_bytes;
out_buf.copy_from_device_at(src_buf, src_offset, dst_offset, row_bytes).unwrap();
}
// Wrap in a Tensor
let device_id = match device { Device::Cuda(id) => id, _ => panic!("expected CUDA device") };
unsafe {
crate::kv_cache::tensor_from_gpu_buffer_pub(out_buf, &[batch, cols], dtype, device_id)
}
}
fn add_any(a: &Tensor, b: &Tensor) -> Tensor { fn add_any(a: &Tensor, b: &Tensor) -> Tensor {
xserv_kernels::add(a, b) xserv_kernels::add(a, b)
} }

View File

@@ -104,28 +104,78 @@ impl Engine {
} }
} }
// Step 4b: Process decode (one token per sequence) // Step 4b: Batched decode — batch all decode-ready sequences into one forward pass.
// Currently per-sequence (each has different KV cache length). // Projections and FFN run as [B, hidden] matmuls; attention remains per-seq.
// TODO(Phase 14): With Flash Attention, batch all decode tokens into let decode_indices: Vec<usize> = running.iter().enumerate()
// one forward pass — batch the compute-heavy ops (projections, FFN) .filter(|(_, s)| s.prefilled && !newly_prefilled.contains(&s.id))
// and use FlashDecoding for per-seq variable-length attention. .map(|(i, _)| i)
let decode_count = running.iter() .collect();
.filter(|s| s.prefilled && !newly_prefilled.contains(&s.id))
.count(); if !decode_indices.is_empty() {
if decode_count > 0 {
static LOG_ONCE: Once = Once::new(); static LOG_ONCE: Once = Once::new();
LOG_ONCE.call_once(|| { LOG_ONCE.call_once(|| {
eprintln!("[scheduler] decode batching active (per-seq until Flash Attention)"); eprintln!("[scheduler] batched decode active");
}); });
eprintln!("[scheduler] decode batch_size={}", decode_count); eprintln!("[scheduler] decode batch_size={}", decode_indices.len());
}
for seq in running.iter_mut() { if decode_indices.len() == 1 {
if seq.prefilled && !newly_prefilled.contains(&seq.id) { // Single sequence: use per-seq path (no batching overhead)
let last = *seq.generated_tokens.last().unwrap(); let i = decode_indices[0];
let logits = self.model.forward_gpu_cache(&[last], &mut seq.kv_cache); let last = *running[i].generated_tokens.last().unwrap();
let next = sample(&logits, &seq.sampling); let logits = self.model.forward_gpu_cache(&[last], &mut running[i].kv_cache);
seq.generated_tokens.push(next); let next = sample(&logits, &running[i].sampling);
self.emit_token(seq, next); running[i].generated_tokens.push(next);
self.emit_token(&running[i], next);
} else {
// Batched decode: extract tokens and positions
let tokens: Vec<u32> = decode_indices.iter()
.map(|&i| *running[i].generated_tokens.last().unwrap())
.collect();
let positions: Vec<usize> = decode_indices.iter()
.map(|&i| running[i].kv_cache.seq_len())
.collect();
// Take caches out of sequences temporarily to satisfy borrow checker.
// The dummy caches (max_seq_len=1) are never used and dropped immediately
// after the real caches are restored. Minor alloc overhead (~microseconds).
let mut caches: Vec<GpuKVCache> = decode_indices.iter()
.map(|&i| {
std::mem::replace(
&mut running[i].kv_cache,
GpuKVCache::new(&self.config, 1, DType::BF16, 0),
)
})
.collect();
let mut cache_refs: Vec<&mut GpuKVCache> = caches.iter_mut().collect();
let logits = self.model.forward_decode_batch(&tokens, &positions, &mut cache_refs);
// Put caches back: pop from end while iterating in reverse
drop(cache_refs);
for &i in decode_indices.iter().rev() {
running[i].kv_cache = caches.pop().unwrap();
}
// Sample per-sequence from batched logits [B, vocab_size]
let vocab_size = logits.shape()[1];
let logits_cpu = logits.to_device(xserv_tensor::Device::Cpu);
let data = logits_cpu.as_slice::<half::bf16>();
for (j, &i) in decode_indices.iter().enumerate() {
let row_start = j * vocab_size;
let row_logits = &data[row_start..row_start + vocab_size];
let next = if running[i].sampling.temperature == 0.0 {
// Greedy: argmax
row_logits.iter().enumerate()
.max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap())
.map(|(idx, _)| idx as u32).unwrap()
} else {
// Use the row as a single-row tensor for full sampling
let row_tensor = xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]);
sample(&row_tensor, &running[i].sampling)
};
running[i].generated_tokens.push(next);
self.emit_token(&running[i], next);
}
} }
} }