Files
xserv/crates/xserv-kernels/src/attention.rs
Gahow Wang 9ad91a4a92 phase19: MoE support — gpt-oss-20b end-to-end inference with TP=2
Add Mixture-of-Experts support for the gpt-oss-20b model (20.9B params,
32 experts × top-4 routing). Key additions:

- ModelConfig: MoE fields (num_local_experts, layer_types, sliding_window,
  attention_bias, explicit head_dim, rope_scaling, swiglu_limit)
- YaRN RoPE: RopeCache::new_yarn() with correct frequency interpolation
  and attention_scaling = 0.1*ln(factor)+1
- Custom GLU kernel: gpt_oss_glu_bf16 (clamped sigmoid gate activation)
- Paged attention with sinks + sliding window kernel variant
- GptOss model struct with expert-parallel TP (split 32 experts across ranks)
- bench-gpt-oss binary for TP inference benchmarking

Verified on dash5 with 2x RTX 5090: 63.6 tok/s decode, ~160ms TTFT.
Model generates topically-coherent output (needs chat template for quality).

Known issues:
- Custom GEMV kernel produces NaN with small N (workaround: pad to M=2)
- Prefill doesn't use attention sinks (uses standard flash attention)
- Output quality requires chat template formatting

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-30 15:18:01 +08:00

407 lines
13 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::ffi::c_void;
use xserv_tensor::{DType, Tensor};
use crate::activation::scale;
use crate::gemm::batched_matmul;
use crate::softmax::softmax;
unsafe extern "C" {
fn launch_causal_mask_f32(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
offset: i32, stream: *mut c_void);
fn launch_causal_mask_bf16(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
offset: i32, stream: *mut c_void);
fn launch_flash_attention_bf16(
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
q_len: i32, kv_len: i32, head_dim: i32,
scale: f32, causal: i32, stream: *mut c_void,
);
fn launch_decode_attention_bf16(
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
kv_len: i32, head_dim: i32,
scale: f32, causal: i32, stream: *mut c_void,
);
fn launch_paged_decode_attention_bf16(
q: *const c_void,
k_cache: *const c_void,
v_cache: *const c_void,
o: *mut c_void,
block_tables: *const i32,
context_lens: *const i32,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
head_dim: i32, max_blocks_per_seq: i32,
scale: f32, stream: *mut c_void,
);
fn launch_paged_decode_attention_sinks_bf16(
q: *const c_void,
k_cache: *const c_void,
v_cache: *const c_void,
o: *mut c_void,
block_tables: *const i32,
context_lens: *const i32,
sinks: *const c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
head_dim: i32, max_blocks_per_seq: i32,
scale: f32, window_size: i32, stream: *mut c_void,
);
fn launch_reshape_and_cache_bf16(
k_src: *const c_void, v_src: *const c_void,
k_pool: *mut c_void, v_pool: *mut c_void,
block_ids: *const c_void,
num_tokens: i32, num_heads: i32,
head_dim: i32, start_pos: i32, block_size: i32,
stream: *mut c_void,
);
fn launch_reshape_and_cache_batched_bf16(
k_src: *const c_void, v_src: *const c_void,
k_pool: *mut c_void, v_pool: *mut c_void,
block_tables: *const c_void, kv_lens: *const c_void,
batch: i32, num_heads: i32,
head_dim: i32, block_size: i32, max_blocks_per_seq: i32,
stream: *mut c_void,
);
}
/// Scatter `[num_kv_heads, num_tokens, head_dim]` BF16 K/V into a paged
/// pool for a single sequence whose block table lives at `block_ids_gpu`
/// (int32, on device).
///
/// `k_pool_ptr`/`v_pool_ptr` point to one layer's pool, of logical shape
/// `[num_blocks_total, num_kv_heads, block_size, head_dim]`.
///
/// All pointers must be on the same GPU as the launching context.
///
/// # Safety
/// Pointers must be valid GPU pointers with the documented layouts.
/// `block_ids_gpu` must contain at least `(start_pos + num_tokens + block_size - 1) / block_size`
/// valid physical block ids.
pub unsafe fn reshape_and_cache_bf16(
k_src: *const c_void, v_src: *const c_void,
k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void,
block_ids_gpu: *const i32,
num_tokens: usize, num_heads: usize,
head_dim: usize, start_pos: usize, block_size: usize,
stream: *mut c_void,
) {
unsafe {
launch_reshape_and_cache_bf16(
k_src, v_src,
k_pool_ptr, v_pool_ptr,
block_ids_gpu as *const c_void,
num_tokens as i32, num_heads as i32,
head_dim as i32, start_pos as i32, block_size as i32,
stream,
);
}
}
/// Batched scatter for the multi-sequence decode step. Reads
/// `block_tables` (`[batch, max_blocks_per_seq]` int32 — same buffer the
/// paged-attention kernel reads) and `kv_lens` (`[batch]` int32, current
/// seq_len + 1 — i.e., the index of the just-written token + 1) so the
/// caller doesn't need a separate per-step upload of block ids.
///
/// # Safety
/// All pointers must be on the same GPU. `block_tables` and `kv_lens` must
/// already be synced to the device for the active batch.
pub unsafe fn reshape_and_cache_batched_bf16(
k_src: *const c_void, v_src: *const c_void,
k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void,
block_tables_gpu: *const i32, kv_lens_gpu: *const i32,
batch: usize, num_heads: usize,
head_dim: usize, block_size: usize, max_blocks_per_seq: usize,
stream: *mut c_void,
) {
unsafe {
launch_reshape_and_cache_batched_bf16(
k_src, v_src,
k_pool_ptr, v_pool_ptr,
block_tables_gpu as *const c_void,
kv_lens_gpu as *const c_void,
batch as i32, num_heads as i32,
head_dim as i32, block_size as i32, max_blocks_per_seq as i32,
stream,
);
}
}
fn apply_causal_mask(scores: &Tensor, offset: usize) {
let ndim = scores.ndim();
let rows = scores.shape()[ndim - 2];
let cols = scores.shape()[ndim - 1];
let batch: usize = scores.shape()[..ndim - 2].iter().product();
unsafe {
match scores.dtype() {
DType::F32 => launch_causal_mask_f32(
scores.data_ptr() as *mut c_void,
batch as i32, rows as i32, cols as i32, offset as i32,
std::ptr::null_mut(),
),
DType::BF16 => launch_causal_mask_bf16(
scores.data_ptr() as *mut c_void,
batch as i32, rows as i32, cols as i32, offset as i32,
std::ptr::null_mut(),
),
_ => panic!("unsupported dtype for causal mask"),
}
}
}
/// Multi-head attention (naive, materializes S×S score matrix).
///
/// q, k, v: [batch, num_heads, seq_len, head_dim] — contiguous, on GPU
/// Returns: [batch, num_heads, seq_len, head_dim]
pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
assert_eq!(q.ndim(), 4);
assert_eq!(k.ndim(), 4);
assert_eq!(v.ndim(), 4);
assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous());
let batch = q.shape()[0];
let num_heads = q.shape()[1];
let q_len = q.shape()[2];
let head_dim = q.shape()[3];
let kv_len = k.shape()[2];
assert_eq!(k.shape(), &[batch, num_heads, kv_len, head_dim]);
assert_eq!(v.shape(), &[batch, num_heads, kv_len, head_dim]);
// scores = Q @ K^T → [B, H, q_len, kv_len]
let k_t = k.transpose(2, 3).contiguous();
let scores = batched_matmul(q, &k_t);
// Scale by 1/sqrt(head_dim)
let scale_factor = 1.0 / (head_dim as f32).sqrt();
let scaled_scores = scale(&scores, scale_factor);
// Causal mask
if causal {
let offset = kv_len - q_len;
apply_causal_mask(&scaled_scores, offset);
}
// Softmax
let weights = softmax(&scaled_scores);
// output = weights @ V → [B, H, q_len, head_dim]
batched_matmul(&weights, v)
}
/// Decode Attention — optimized for single-token decode (q_len=1).
///
/// q: [batch, num_q_heads, 1, head_dim] BF16, contiguous, GPU
/// k: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
/// v: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
///
/// Returns: [batch, num_q_heads, 1, head_dim] BF16
pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
assert_eq!(q.ndim(), 4);
assert_eq!(q.shape()[2], 1, "decode_attention requires q_len == 1");
let batch = q.shape()[0];
let num_q_heads = q.shape()[1];
let head_dim = q.shape()[3];
let num_kv_heads = k.shape()[1];
let kv_len = k.shape()[2];
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::empty(
&[batch, num_q_heads, 1, head_dim],
DType::BF16,
q.device(),
);
unsafe {
launch_decode_attention_bf16(
q.data_ptr() as *const c_void,
k.data_ptr() as *const c_void,
v.data_ptr() as *const c_void,
output.data_ptr() as *mut c_void,
batch as i32,
num_q_heads as i32,
num_kv_heads as i32,
kv_len as i32,
head_dim as i32,
scale,
1, // causal (always 1 for decode)
std::ptr::null_mut(),
);
}
output
}
/// Flash Attention 2 — O(1) extra memory, supports GQA natively.
/// Auto-dispatches to decode_attention when q_len == 1.
///
/// q: [batch, num_q_heads, q_len, head_dim] BF16, contiguous, GPU
/// k: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
/// v: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
///
/// Returns: [batch, num_q_heads, q_len, head_dim] BF16
pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
assert_eq!(q.ndim(), 4);
assert_eq!(k.ndim(), 4);
assert_eq!(v.ndim(), 4);
assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous());
assert_eq!(q.dtype(), DType::BF16, "flash_attention requires BF16");
assert_eq!(k.dtype(), DType::BF16);
assert_eq!(v.dtype(), DType::BF16);
let batch = q.shape()[0];
let num_q_heads = q.shape()[1];
let q_len = q.shape()[2];
let head_dim = q.shape()[3];
let num_kv_heads = k.shape()[1];
let kv_len = k.shape()[2];
assert_eq!(k.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
assert!(num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads");
assert!(head_dim <= 128, "flash_attention supports head_dim up to 128");
// Dispatch to specialized decode kernel for single-token generation
if q_len == 1 {
return decode_attention(q, k, v);
}
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::empty(
&[batch, num_q_heads, q_len, head_dim],
DType::BF16,
q.device(),
);
unsafe {
launch_flash_attention_bf16(
q.data_ptr() as *const c_void,
k.data_ptr() as *const c_void,
v.data_ptr() as *const c_void,
output.data_ptr() as *mut c_void,
batch as i32,
num_q_heads as i32,
num_kv_heads as i32,
q_len as i32,
kv_len as i32,
head_dim as i32,
scale,
if causal { 1 } else { 0 },
std::ptr::null_mut(),
);
}
output
}
/// Paged decode attention.
///
/// q: [batch, num_q_heads, 1, head_dim] BF16, contiguous, GPU
/// k_cache_ptr / v_cache_ptr: pointers to [num_blocks, num_kv_heads, BLOCK_SIZE, head_dim] BF16 pools
/// block_tables_ptr: i32 [batch, max_blocks_per_seq] (rows already arranged for this batch)
/// context_lens_ptr: i32 [batch]
///
/// Returns: [batch, num_q_heads, 1, head_dim] BF16
#[allow(clippy::too_many_arguments)]
pub fn paged_decode_attention(
q: &Tensor,
k_cache_ptr: *const c_void,
v_cache_ptr: *const c_void,
block_tables_ptr: *const i32,
context_lens_ptr: *const i32,
batch: usize,
num_q_heads: usize,
num_kv_heads: usize,
head_dim: usize,
max_blocks_per_seq: usize,
) -> Tensor {
assert_eq!(q.ndim(), 4);
assert_eq!(q.shape()[2], 1, "paged_decode_attention requires q_len == 1");
assert_eq!(q.dtype(), DType::BF16);
assert!(num_q_heads % num_kv_heads == 0, "GQA: num_q_heads must be divisible by num_kv_heads");
assert!(head_dim <= 128);
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::empty(
&[batch, num_q_heads, 1, head_dim],
DType::BF16,
q.device(),
);
unsafe {
launch_paged_decode_attention_bf16(
q.data_ptr() as *const c_void,
k_cache_ptr,
v_cache_ptr,
output.data_ptr() as *mut c_void,
block_tables_ptr,
context_lens_ptr,
batch as i32,
num_q_heads as i32,
num_kv_heads as i32,
head_dim as i32,
max_blocks_per_seq as i32,
scale,
std::ptr::null_mut(),
);
}
output
}
/// Paged decode attention with attention sinks and optional sliding window.
///
/// sinks_ptr: pointer to [num_q_heads] BF16 on GPU (or null for no sinks)
/// window_size: 0 = full attention, >0 = sliding window
#[allow(clippy::too_many_arguments)]
pub fn paged_decode_attention_sinks(
q: &Tensor,
k_cache_ptr: *const c_void,
v_cache_ptr: *const c_void,
block_tables_ptr: *const i32,
context_lens_ptr: *const i32,
sinks_ptr: *const c_void,
batch: usize,
num_q_heads: usize,
num_kv_heads: usize,
head_dim: usize,
max_blocks_per_seq: usize,
window_size: usize,
) -> Tensor {
assert_eq!(q.ndim(), 4);
assert_eq!(q.shape()[2], 1);
assert_eq!(q.dtype(), DType::BF16);
assert!(num_q_heads % num_kv_heads == 0);
assert!(head_dim <= 128);
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::empty(
&[batch, num_q_heads, 1, head_dim],
DType::BF16,
q.device(),
);
unsafe {
launch_paged_decode_attention_sinks_bf16(
q.data_ptr() as *const c_void,
k_cache_ptr,
v_cache_ptr,
output.data_ptr() as *mut c_void,
block_tables_ptr,
context_lens_ptr,
sinks_ptr,
batch as i32,
num_q_heads as i32,
num_kv_heads as i32,
head_dim as i32,
max_blocks_per_seq as i32,
scale,
window_size as i32,
std::ptr::null_mut(),
);
}
output
}