Add Mixture-of-Experts support for the gpt-oss-20b model (20.9B params, 32 experts × top-4 routing). Key additions: - ModelConfig: MoE fields (num_local_experts, layer_types, sliding_window, attention_bias, explicit head_dim, rope_scaling, swiglu_limit) - YaRN RoPE: RopeCache::new_yarn() with correct frequency interpolation and attention_scaling = 0.1*ln(factor)+1 - Custom GLU kernel: gpt_oss_glu_bf16 (clamped sigmoid gate activation) - Paged attention with sinks + sliding window kernel variant - GptOss model struct with expert-parallel TP (split 32 experts across ranks) - bench-gpt-oss binary for TP inference benchmarking Verified on dash5 with 2x RTX 5090: 63.6 tok/s decode, ~160ms TTFT. Model generates topically-coherent output (needs chat template for quality). Known issues: - Custom GEMV kernel produces NaN with small N (workaround: pad to M=2) - Prefill doesn't use attention sinks (uses standard flash attention) - Output quality requires chat template formatting Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
407 lines
13 KiB
Rust
407 lines
13 KiB
Rust
use std::ffi::c_void;
|
||
use xserv_tensor::{DType, Tensor};
|
||
|
||
use crate::activation::scale;
|
||
use crate::gemm::batched_matmul;
|
||
use crate::softmax::softmax;
|
||
|
||
unsafe extern "C" {
|
||
fn launch_causal_mask_f32(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
|
||
offset: i32, stream: *mut c_void);
|
||
fn launch_causal_mask_bf16(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
|
||
offset: i32, stream: *mut c_void);
|
||
fn launch_flash_attention_bf16(
|
||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||
q_len: i32, kv_len: i32, head_dim: i32,
|
||
scale: f32, causal: i32, stream: *mut c_void,
|
||
);
|
||
fn launch_decode_attention_bf16(
|
||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||
kv_len: i32, head_dim: i32,
|
||
scale: f32, causal: i32, stream: *mut c_void,
|
||
);
|
||
fn launch_paged_decode_attention_bf16(
|
||
q: *const c_void,
|
||
k_cache: *const c_void,
|
||
v_cache: *const c_void,
|
||
o: *mut c_void,
|
||
block_tables: *const i32,
|
||
context_lens: *const i32,
|
||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||
head_dim: i32, max_blocks_per_seq: i32,
|
||
scale: f32, stream: *mut c_void,
|
||
);
|
||
fn launch_paged_decode_attention_sinks_bf16(
|
||
q: *const c_void,
|
||
k_cache: *const c_void,
|
||
v_cache: *const c_void,
|
||
o: *mut c_void,
|
||
block_tables: *const i32,
|
||
context_lens: *const i32,
|
||
sinks: *const c_void,
|
||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||
head_dim: i32, max_blocks_per_seq: i32,
|
||
scale: f32, window_size: i32, stream: *mut c_void,
|
||
);
|
||
fn launch_reshape_and_cache_bf16(
|
||
k_src: *const c_void, v_src: *const c_void,
|
||
k_pool: *mut c_void, v_pool: *mut c_void,
|
||
block_ids: *const c_void,
|
||
num_tokens: i32, num_heads: i32,
|
||
head_dim: i32, start_pos: i32, block_size: i32,
|
||
stream: *mut c_void,
|
||
);
|
||
fn launch_reshape_and_cache_batched_bf16(
|
||
k_src: *const c_void, v_src: *const c_void,
|
||
k_pool: *mut c_void, v_pool: *mut c_void,
|
||
block_tables: *const c_void, kv_lens: *const c_void,
|
||
batch: i32, num_heads: i32,
|
||
head_dim: i32, block_size: i32, max_blocks_per_seq: i32,
|
||
stream: *mut c_void,
|
||
);
|
||
}
|
||
|
||
/// Scatter `[num_kv_heads, num_tokens, head_dim]` BF16 K/V into a paged
|
||
/// pool for a single sequence whose block table lives at `block_ids_gpu`
|
||
/// (int32, on device).
|
||
///
|
||
/// `k_pool_ptr`/`v_pool_ptr` point to one layer's pool, of logical shape
|
||
/// `[num_blocks_total, num_kv_heads, block_size, head_dim]`.
|
||
///
|
||
/// All pointers must be on the same GPU as the launching context.
|
||
///
|
||
/// # Safety
|
||
/// Pointers must be valid GPU pointers with the documented layouts.
|
||
/// `block_ids_gpu` must contain at least `(start_pos + num_tokens + block_size - 1) / block_size`
|
||
/// valid physical block ids.
|
||
pub unsafe fn reshape_and_cache_bf16(
|
||
k_src: *const c_void, v_src: *const c_void,
|
||
k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void,
|
||
block_ids_gpu: *const i32,
|
||
num_tokens: usize, num_heads: usize,
|
||
head_dim: usize, start_pos: usize, block_size: usize,
|
||
stream: *mut c_void,
|
||
) {
|
||
unsafe {
|
||
launch_reshape_and_cache_bf16(
|
||
k_src, v_src,
|
||
k_pool_ptr, v_pool_ptr,
|
||
block_ids_gpu as *const c_void,
|
||
num_tokens as i32, num_heads as i32,
|
||
head_dim as i32, start_pos as i32, block_size as i32,
|
||
stream,
|
||
);
|
||
}
|
||
}
|
||
|
||
/// Batched scatter for the multi-sequence decode step. Reads
|
||
/// `block_tables` (`[batch, max_blocks_per_seq]` int32 — same buffer the
|
||
/// paged-attention kernel reads) and `kv_lens` (`[batch]` int32, current
|
||
/// seq_len + 1 — i.e., the index of the just-written token + 1) so the
|
||
/// caller doesn't need a separate per-step upload of block ids.
|
||
///
|
||
/// # Safety
|
||
/// All pointers must be on the same GPU. `block_tables` and `kv_lens` must
|
||
/// already be synced to the device for the active batch.
|
||
pub unsafe fn reshape_and_cache_batched_bf16(
|
||
k_src: *const c_void, v_src: *const c_void,
|
||
k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void,
|
||
block_tables_gpu: *const i32, kv_lens_gpu: *const i32,
|
||
batch: usize, num_heads: usize,
|
||
head_dim: usize, block_size: usize, max_blocks_per_seq: usize,
|
||
stream: *mut c_void,
|
||
) {
|
||
unsafe {
|
||
launch_reshape_and_cache_batched_bf16(
|
||
k_src, v_src,
|
||
k_pool_ptr, v_pool_ptr,
|
||
block_tables_gpu as *const c_void,
|
||
kv_lens_gpu as *const c_void,
|
||
batch as i32, num_heads as i32,
|
||
head_dim as i32, block_size as i32, max_blocks_per_seq as i32,
|
||
stream,
|
||
);
|
||
}
|
||
}
|
||
|
||
fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
||
let ndim = scores.ndim();
|
||
let rows = scores.shape()[ndim - 2];
|
||
let cols = scores.shape()[ndim - 1];
|
||
let batch: usize = scores.shape()[..ndim - 2].iter().product();
|
||
|
||
unsafe {
|
||
match scores.dtype() {
|
||
DType::F32 => launch_causal_mask_f32(
|
||
scores.data_ptr() as *mut c_void,
|
||
batch as i32, rows as i32, cols as i32, offset as i32,
|
||
std::ptr::null_mut(),
|
||
),
|
||
DType::BF16 => launch_causal_mask_bf16(
|
||
scores.data_ptr() as *mut c_void,
|
||
batch as i32, rows as i32, cols as i32, offset as i32,
|
||
std::ptr::null_mut(),
|
||
),
|
||
_ => panic!("unsupported dtype for causal mask"),
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Multi-head attention (naive, materializes S×S score matrix).
|
||
///
|
||
/// q, k, v: [batch, num_heads, seq_len, head_dim] — contiguous, on GPU
|
||
/// Returns: [batch, num_heads, seq_len, head_dim]
|
||
pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
|
||
assert_eq!(q.ndim(), 4);
|
||
assert_eq!(k.ndim(), 4);
|
||
assert_eq!(v.ndim(), 4);
|
||
assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous());
|
||
|
||
let batch = q.shape()[0];
|
||
let num_heads = q.shape()[1];
|
||
let q_len = q.shape()[2];
|
||
let head_dim = q.shape()[3];
|
||
let kv_len = k.shape()[2];
|
||
|
||
assert_eq!(k.shape(), &[batch, num_heads, kv_len, head_dim]);
|
||
assert_eq!(v.shape(), &[batch, num_heads, kv_len, head_dim]);
|
||
|
||
// scores = Q @ K^T → [B, H, q_len, kv_len]
|
||
let k_t = k.transpose(2, 3).contiguous();
|
||
let scores = batched_matmul(q, &k_t);
|
||
|
||
// Scale by 1/sqrt(head_dim)
|
||
let scale_factor = 1.0 / (head_dim as f32).sqrt();
|
||
let scaled_scores = scale(&scores, scale_factor);
|
||
|
||
// Causal mask
|
||
if causal {
|
||
let offset = kv_len - q_len;
|
||
apply_causal_mask(&scaled_scores, offset);
|
||
}
|
||
|
||
// Softmax
|
||
let weights = softmax(&scaled_scores);
|
||
|
||
// output = weights @ V → [B, H, q_len, head_dim]
|
||
batched_matmul(&weights, v)
|
||
}
|
||
|
||
/// Decode Attention — optimized for single-token decode (q_len=1).
|
||
///
|
||
/// q: [batch, num_q_heads, 1, head_dim] BF16, contiguous, GPU
|
||
/// k: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
|
||
/// v: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
|
||
///
|
||
/// Returns: [batch, num_q_heads, 1, head_dim] BF16
|
||
pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
|
||
assert_eq!(q.ndim(), 4);
|
||
assert_eq!(q.shape()[2], 1, "decode_attention requires q_len == 1");
|
||
|
||
let batch = q.shape()[0];
|
||
let num_q_heads = q.shape()[1];
|
||
let head_dim = q.shape()[3];
|
||
let num_kv_heads = k.shape()[1];
|
||
let kv_len = k.shape()[2];
|
||
|
||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||
let output = Tensor::empty(
|
||
&[batch, num_q_heads, 1, head_dim],
|
||
DType::BF16,
|
||
q.device(),
|
||
);
|
||
|
||
unsafe {
|
||
launch_decode_attention_bf16(
|
||
q.data_ptr() as *const c_void,
|
||
k.data_ptr() as *const c_void,
|
||
v.data_ptr() as *const c_void,
|
||
output.data_ptr() as *mut c_void,
|
||
batch as i32,
|
||
num_q_heads as i32,
|
||
num_kv_heads as i32,
|
||
kv_len as i32,
|
||
head_dim as i32,
|
||
scale,
|
||
1, // causal (always 1 for decode)
|
||
std::ptr::null_mut(),
|
||
);
|
||
}
|
||
|
||
output
|
||
}
|
||
|
||
/// Flash Attention 2 — O(1) extra memory, supports GQA natively.
|
||
/// Auto-dispatches to decode_attention when q_len == 1.
|
||
///
|
||
/// q: [batch, num_q_heads, q_len, head_dim] BF16, contiguous, GPU
|
||
/// k: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
|
||
/// v: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
|
||
///
|
||
/// Returns: [batch, num_q_heads, q_len, head_dim] BF16
|
||
pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
|
||
assert_eq!(q.ndim(), 4);
|
||
assert_eq!(k.ndim(), 4);
|
||
assert_eq!(v.ndim(), 4);
|
||
assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous());
|
||
assert_eq!(q.dtype(), DType::BF16, "flash_attention requires BF16");
|
||
assert_eq!(k.dtype(), DType::BF16);
|
||
assert_eq!(v.dtype(), DType::BF16);
|
||
|
||
let batch = q.shape()[0];
|
||
let num_q_heads = q.shape()[1];
|
||
let q_len = q.shape()[2];
|
||
let head_dim = q.shape()[3];
|
||
let num_kv_heads = k.shape()[1];
|
||
let kv_len = k.shape()[2];
|
||
|
||
assert_eq!(k.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||
assert!(num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads");
|
||
assert!(head_dim <= 128, "flash_attention supports head_dim up to 128");
|
||
|
||
// Dispatch to specialized decode kernel for single-token generation
|
||
if q_len == 1 {
|
||
return decode_attention(q, k, v);
|
||
}
|
||
|
||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||
let output = Tensor::empty(
|
||
&[batch, num_q_heads, q_len, head_dim],
|
||
DType::BF16,
|
||
q.device(),
|
||
);
|
||
|
||
unsafe {
|
||
launch_flash_attention_bf16(
|
||
q.data_ptr() as *const c_void,
|
||
k.data_ptr() as *const c_void,
|
||
v.data_ptr() as *const c_void,
|
||
output.data_ptr() as *mut c_void,
|
||
batch as i32,
|
||
num_q_heads as i32,
|
||
num_kv_heads as i32,
|
||
q_len as i32,
|
||
kv_len as i32,
|
||
head_dim as i32,
|
||
scale,
|
||
if causal { 1 } else { 0 },
|
||
std::ptr::null_mut(),
|
||
);
|
||
}
|
||
|
||
output
|
||
}
|
||
|
||
/// Paged decode attention.
|
||
///
|
||
/// q: [batch, num_q_heads, 1, head_dim] BF16, contiguous, GPU
|
||
/// k_cache_ptr / v_cache_ptr: pointers to [num_blocks, num_kv_heads, BLOCK_SIZE, head_dim] BF16 pools
|
||
/// block_tables_ptr: i32 [batch, max_blocks_per_seq] (rows already arranged for this batch)
|
||
/// context_lens_ptr: i32 [batch]
|
||
///
|
||
/// Returns: [batch, num_q_heads, 1, head_dim] BF16
|
||
#[allow(clippy::too_many_arguments)]
|
||
pub fn paged_decode_attention(
|
||
q: &Tensor,
|
||
k_cache_ptr: *const c_void,
|
||
v_cache_ptr: *const c_void,
|
||
block_tables_ptr: *const i32,
|
||
context_lens_ptr: *const i32,
|
||
batch: usize,
|
||
num_q_heads: usize,
|
||
num_kv_heads: usize,
|
||
head_dim: usize,
|
||
max_blocks_per_seq: usize,
|
||
) -> Tensor {
|
||
assert_eq!(q.ndim(), 4);
|
||
assert_eq!(q.shape()[2], 1, "paged_decode_attention requires q_len == 1");
|
||
assert_eq!(q.dtype(), DType::BF16);
|
||
assert!(num_q_heads % num_kv_heads == 0, "GQA: num_q_heads must be divisible by num_kv_heads");
|
||
assert!(head_dim <= 128);
|
||
|
||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||
let output = Tensor::empty(
|
||
&[batch, num_q_heads, 1, head_dim],
|
||
DType::BF16,
|
||
q.device(),
|
||
);
|
||
|
||
unsafe {
|
||
launch_paged_decode_attention_bf16(
|
||
q.data_ptr() as *const c_void,
|
||
k_cache_ptr,
|
||
v_cache_ptr,
|
||
output.data_ptr() as *mut c_void,
|
||
block_tables_ptr,
|
||
context_lens_ptr,
|
||
batch as i32,
|
||
num_q_heads as i32,
|
||
num_kv_heads as i32,
|
||
head_dim as i32,
|
||
max_blocks_per_seq as i32,
|
||
scale,
|
||
std::ptr::null_mut(),
|
||
);
|
||
}
|
||
|
||
output
|
||
}
|
||
|
||
/// Paged decode attention with attention sinks and optional sliding window.
|
||
///
|
||
/// sinks_ptr: pointer to [num_q_heads] BF16 on GPU (or null for no sinks)
|
||
/// window_size: 0 = full attention, >0 = sliding window
|
||
#[allow(clippy::too_many_arguments)]
|
||
pub fn paged_decode_attention_sinks(
|
||
q: &Tensor,
|
||
k_cache_ptr: *const c_void,
|
||
v_cache_ptr: *const c_void,
|
||
block_tables_ptr: *const i32,
|
||
context_lens_ptr: *const i32,
|
||
sinks_ptr: *const c_void,
|
||
batch: usize,
|
||
num_q_heads: usize,
|
||
num_kv_heads: usize,
|
||
head_dim: usize,
|
||
max_blocks_per_seq: usize,
|
||
window_size: usize,
|
||
) -> Tensor {
|
||
assert_eq!(q.ndim(), 4);
|
||
assert_eq!(q.shape()[2], 1);
|
||
assert_eq!(q.dtype(), DType::BF16);
|
||
assert!(num_q_heads % num_kv_heads == 0);
|
||
assert!(head_dim <= 128);
|
||
|
||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||
let output = Tensor::empty(
|
||
&[batch, num_q_heads, 1, head_dim],
|
||
DType::BF16,
|
||
q.device(),
|
||
);
|
||
|
||
unsafe {
|
||
launch_paged_decode_attention_sinks_bf16(
|
||
q.data_ptr() as *const c_void,
|
||
k_cache_ptr,
|
||
v_cache_ptr,
|
||
output.data_ptr() as *mut c_void,
|
||
block_tables_ptr,
|
||
context_lens_ptr,
|
||
sinks_ptr,
|
||
batch as i32,
|
||
num_q_heads as i32,
|
||
num_kv_heads as i32,
|
||
head_dim as i32,
|
||
max_blocks_per_seq as i32,
|
||
scale,
|
||
window_size as i32,
|
||
std::ptr::null_mut(),
|
||
);
|
||
}
|
||
|
||
output
|
||
}
|