phase 15: decode attention kernel + fused silu_mul + fused add_rmsnorm

Three performance optimizations targeting decode throughput:

1. Decode Attention Kernel (csrc/attention/flash_attention.cu):
   - Specialized kernel for Q_len=1 (decode step)
   - 256 threads parallelize across KV sequence dimension
   - Online softmax with block-level warp-shuffle reduction
   - Replaces FA2 kernel which wasted 63/64 threads for decode
   - flash_attention() auto-dispatches when q_len==1

2. Fused SiLU×Mul (csrc/activation/activations.cu):
   - Single kernel: out = silu(gate) * up
   - Saves 1 HBM read + 1 HBM write per FFN layer (N elements)
   - Eliminates intermediate tensor allocation

3. Fused Add+RMSNorm (csrc/normalization/rmsnorm.cu):
   - Single kernel: (normed, sum) = (rmsnorm(x+residual), x+residual)
   - Saves 1 full HBM round-trip per attention block
   - Eliminates separate add + rmsnorm kernel pair

Performance analysis:
- At current short sequences (max 79 tokens), these optimizations provide
  marginal benefit because the bottleneck is cuBLAS GEMV overhead:
  252 weight matrix reads × ~32MB each = 15.5 GB per decode step.
  Theoretical minimum at 1.79 TB/s = 8.7ms, actual ~78ms (9x gap).
- The fused kernels and decode attention will show larger gains at
  longer sequences where attention and element-wise ops dominate.
- Next optimization target: CUDA Graphs to eliminate kernel launch
  overhead, or custom GEMV kernels to replace cuBLAS for M=1.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-22 19:40:56 +08:00
parent 6cc1c9332d
commit 9783fcf410
8 changed files with 387 additions and 8 deletions

View File

@@ -12,6 +12,7 @@ unsafe extern "C" {
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_mul_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_mul_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
}
fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
@@ -67,3 +68,24 @@ pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
pub fn add(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_add_f32, launch_add_bf16) }
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) }
/// Fused SiLU×Mul: out = silu(gate) * up (BF16 only)
/// Saves one HBM read + one HBM write compared to separate silu + mul.
pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
assert_eq!(gate.shape(), up.shape());
assert!(gate.is_contiguous() && up.is_contiguous());
assert!(matches!(gate.device(), Device::Cuda(_)));
assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16");
let out = Tensor::zeros(gate.shape(), gate.dtype(), gate.device());
let n = gate.numel() as i32;
unsafe {
launch_silu_mul_bf16(
gate.data_ptr() as *const c_void,
up.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
n,
std::ptr::null_mut(),
);
}
out
}

View File

@@ -16,6 +16,12 @@ unsafe extern "C" {
q_len: i32, kv_len: i32, head_dim: i32,
scale: f32, causal: i32, stream: *mut c_void,
);
fn launch_decode_attention_bf16(
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
kv_len: i32, head_dim: i32,
scale: f32, causal: i32, stream: *mut c_void,
);
}
fn apply_causal_mask(scores: &Tensor, offset: usize) {
@@ -81,7 +87,52 @@ pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
batched_matmul(&weights, v)
}
/// Decode Attention — optimized for single-token decode (q_len=1).
///
/// q: [batch, num_q_heads, 1, head_dim] BF16, contiguous, GPU
/// k: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
/// v: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
///
/// Returns: [batch, num_q_heads, 1, head_dim] BF16
pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
assert_eq!(q.ndim(), 4);
assert_eq!(q.shape()[2], 1, "decode_attention requires q_len == 1");
let batch = q.shape()[0];
let num_q_heads = q.shape()[1];
let head_dim = q.shape()[3];
let num_kv_heads = k.shape()[1];
let kv_len = k.shape()[2];
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::zeros(
&[batch, num_q_heads, 1, head_dim],
DType::BF16,
q.device(),
);
unsafe {
launch_decode_attention_bf16(
q.data_ptr() as *const c_void,
k.data_ptr() as *const c_void,
v.data_ptr() as *const c_void,
output.data_ptr() as *mut c_void,
batch as i32,
num_q_heads as i32,
num_kv_heads as i32,
kv_len as i32,
head_dim as i32,
scale,
1, // causal (always 1 for decode)
std::ptr::null_mut(),
);
}
output
}
/// Flash Attention 2 — O(1) extra memory, supports GQA natively.
/// Auto-dispatches to decode_attention when q_len == 1.
///
/// q: [batch, num_q_heads, q_len, head_dim] BF16, contiguous, GPU
/// k: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
@@ -109,6 +160,11 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
assert!(num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads");
assert!(head_dim <= 128, "flash_attention supports head_dim up to 128");
// Dispatch to specialized decode kernel for single-token generation
if q_len == 1 {
return decode_attention(q, k, v);
}
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::zeros(
&[batch, num_q_heads, q_len, head_dim],

View File

@@ -8,13 +8,13 @@ pub mod rope;
pub mod softmax;
pub mod transpose;
pub use activation::{add, gelu, mul, scale, silu};
pub use activation::{add, gelu, mul, scale, silu, silu_mul};
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
pub use attention::{attention, flash_attention};
pub use attention::{attention, decode_attention, flash_attention};
pub use embedding::embedding;
pub use gemm::{batched_matmul, matmul, GemmBackend};
pub use layernorm::layernorm;
pub use rmsnorm::rmsnorm;
pub use rmsnorm::{add_rmsnorm, rmsnorm};
pub use rope::{rope_inplace, RopeCache};
pub use softmax::softmax;

View File

@@ -6,6 +6,9 @@ unsafe extern "C" {
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
fn launch_add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
normed_out: *mut c_void, sum_out: *mut c_void,
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
}
pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
@@ -34,3 +37,39 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
}
out
}
/// Fused Add + RMSNorm: computes sum = x + residual, then normed = rmsnorm(sum, gamma, eps).
/// Returns (normed, sum). BF16 only.
/// Saves one kernel launch and one full HBM round-trip per layer.
pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (Tensor, Tensor) {
assert!(x.ndim() >= 1);
assert_eq!(x.shape(), residual.shape());
assert!(x.is_contiguous() && residual.is_contiguous() && gamma.is_contiguous());
assert!(matches!(x.device(), Device::Cuda(_)));
assert_eq!(x.dtype(), DType::BF16, "add_rmsnorm requires BF16");
assert_eq!(residual.dtype(), DType::BF16);
assert_eq!(gamma.dtype(), DType::BF16);
let hidden_size = *x.shape().last().unwrap();
assert_eq!(gamma.shape(), &[hidden_size]);
let rows = x.numel() / hidden_size;
let normed_out = Tensor::zeros(x.shape(), DType::BF16, x.device());
let sum_out = Tensor::zeros(x.shape(), DType::BF16, x.device());
unsafe {
launch_add_rmsnorm_bf16(
x.data_ptr() as *const c_void,
residual.data_ptr() as *const c_void,
gamma.data_ptr() as *const c_void,
normed_out.data_ptr() as *mut c_void,
sum_out.data_ptr() as *mut c_void,
rows as i32,
hidden_size as i32,
eps,
std::ptr::null_mut(),
);
}
(normed_out, sum_out)
}

View File

@@ -196,14 +196,15 @@ impl Qwen3 {
// GPU merge_heads: [1, H, S, D] → [S, H*D]
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
x = add_any(&residual, &attn_proj);
let residual = x.clone();
let normed = rmsnorm(&x, &layer.post_norm, eps);
// Fused add + rmsnorm: (normed, x) where x = residual + attn_proj
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
// Fused SiLU×Mul
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
let up = matmul_2d(&normed, &layer.up_proj_wt);
let gate_activated = silu(&gate);
let hidden_states = mul_any(&gate_activated, &up);
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
x = add_any(&residual, &down);
}

View File

@@ -45,6 +45,18 @@ __global__ void scale_bf16_kernel(const __nv_bfloat16* x, __nv_bfloat16* out, fl
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(x[idx]) * scale);
}
// Fused SiLU×Mul: out = silu(gate) * up
__global__ void silu_mul_bf16_kernel(const __nv_bfloat16* gate, const __nv_bfloat16* up,
__nv_bfloat16* out, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float g = __bfloat162float(gate[idx]);
float u = __bfloat162float(up[idx]);
float silu_g = g / (1.0f + expf(-g));
out[idx] = __float2bfloat16(silu_g * u);
}
}
// Element-wise add: out = a + b
__global__ void add_f32_kernel(const float* a, const float* b, float* out, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -132,4 +144,11 @@ void launch_mul_bf16(const void* a, const void* b, void* out, int n, void* strea
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
}
void launch_silu_mul_bf16(const void* gate, const void* up, void* out, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;
silu_mul_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)gate, (const __nv_bfloat16*)up, (__nv_bfloat16*)out, n);
}
}

View File

@@ -196,6 +196,177 @@ __global__ void flash_attention_bf16_kernel(
}
}
// ============================================================
// Decode Attention kernel: optimized for Q_len=1 (single-token decode).
// Parallelizes across KV sequence dimension instead of Q rows.
//
// Grid: (batch * num_q_heads, 1) — one block per Q head
// Block: 256 threads — each thread handles ceil(kv_len / 256) KV positions
// Uses online softmax reduction across threads.
// ============================================================
#define DECODE_THREADS 256
#define HEAD_DIM_MAX 128
__global__ void decode_attention_bf16_kernel(
const __nv_bfloat16* __restrict__ Q,
const __nv_bfloat16* __restrict__ K,
const __nv_bfloat16* __restrict__ V,
__nv_bfloat16* __restrict__ O,
int num_q_heads, int num_kv_heads,
int kv_len, int head_dim,
float scale
) {
int bh = blockIdx.x;
int batch_idx = bh / num_q_heads;
int q_head = bh % num_q_heads;
// GQA mapping
int heads_per_group = num_q_heads / num_kv_heads;
int kv_head = q_head / heads_per_group;
int tid = threadIdx.x;
// Pointers to this batch/head's data
// Q: [batch, num_q_heads, 1, head_dim]
const __nv_bfloat16* Q_ptr = Q + ((long long)batch_idx * num_q_heads + q_head) * head_dim;
// K/V: [batch, num_kv_heads, kv_len, head_dim]
const __nv_bfloat16* K_base = K + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
const __nv_bfloat16* V_base = V + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
__nv_bfloat16* O_ptr = O + ((long long)batch_idx * num_q_heads + q_head) * head_dim;
// Load Q vector into registers (head_dim <= 128)
float q_reg[HEAD_DIM_MAX];
for (int d = 0; d < head_dim; d++) {
q_reg[d] = __bfloat162float(Q_ptr[d]);
}
// Each thread processes a chunk of KV positions
// Thread tid handles positions: tid, tid+DECODE_THREADS, tid+2*DECODE_THREADS, ...
float local_max = -INFINITY;
float local_sum = 0.0f;
float local_O[HEAD_DIM_MAX];
for (int d = 0; d < head_dim; d++) {
local_O[d] = 0.0f;
}
for (int pos = tid; pos < kv_len; pos += DECODE_THREADS) {
// Compute dot(Q, K[pos]) * scale
const __nv_bfloat16* K_pos = K_base + pos * head_dim;
float dot = 0.0f;
for (int d = 0; d < head_dim; d++) {
dot += q_reg[d] * __bfloat162float(K_pos[d]);
}
float s = dot * scale;
// Online softmax update
float new_max = fmaxf(local_max, s);
float correction = expf(local_max - new_max);
float p = expf(s - new_max);
// Rescale running sum and O
local_sum = local_sum * correction + p;
for (int d = 0; d < head_dim; d++) {
local_O[d] = local_O[d] * correction;
}
// Accumulate V[pos] weighted by p
const __nv_bfloat16* V_pos = V_base + pos * head_dim;
for (int d = 0; d < head_dim; d++) {
local_O[d] += p * __bfloat162float(V_pos[d]);
}
local_max = new_max;
}
// --- Block-level online softmax reduction ---
// We need to combine (local_max, local_sum, local_O) across all threads.
// Strategy: reduce max, then each thread rescales, then reduce sum and O.
// Shared memory for reduction
__shared__ float smem_max[32]; // one per warp
__shared__ float smem_sum[32];
__shared__ float smem_O[HEAD_DIM_MAX]; // final output accumulator
// Step 1: Block-wide max reduction
int lane = tid & 31;
int warp_id = tid >> 5;
int num_warps = DECODE_THREADS >> 5; // 8 warps
float warp_max = local_max;
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
if (lane == 0) smem_max[warp_id] = warp_max;
__syncthreads();
float global_max;
if (tid == 0) {
global_max = smem_max[0];
for (int i = 1; i < num_warps; i++)
global_max = fmaxf(global_max, smem_max[i]);
smem_max[0] = global_max;
}
__syncthreads();
global_max = smem_max[0];
// Step 2: Each thread rescales its local_sum and local_O with global_max
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
local_sum *= rescale;
for (int d = 0; d < head_dim; d++) {
local_O[d] *= rescale;
}
// Step 3: Reduce sum across block
float warp_sum = local_sum;
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
if (lane == 0) smem_sum[warp_id] = warp_sum;
__syncthreads();
float global_sum;
if (tid == 0) {
global_sum = 0.0f;
for (int i = 0; i < num_warps; i++)
global_sum += smem_sum[i];
smem_sum[0] = global_sum;
}
__syncthreads();
global_sum = smem_sum[0];
// Step 4: Reduce O across block (dimension by dimension using shared mem)
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
// Process head_dim in chunks: each iteration reduces one dimension
// Use shared memory accumulator: each warp contributes via warp reduction + atomic
// Actually simpler: iterate over dimensions, warp reduce each, then lane0 atomicAdd to smem_O
// Initialize smem_O
for (int d = tid; d < head_dim; d += DECODE_THREADS) {
smem_O[d] = 0.0f;
}
__syncthreads();
// Each thread adds its local_O contributions via warp reduction + atomicAdd
for (int d = 0; d < head_dim; d++) {
float val = local_O[d];
// Warp-level reduction
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
val += __shfl_down_sync(0xffffffff, val, offset);
if (lane == 0) {
atomicAdd(&smem_O[d], val);
}
}
__syncthreads();
// Thread 0..head_dim-1 write final output
for (int d = tid; d < head_dim; d += DECODE_THREADS) {
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
}
}
extern "C" {
void launch_flash_attention_bf16(
@@ -222,4 +393,24 @@ void launch_flash_attention_bf16(
);
}
void launch_decode_attention_bf16(
const void* Q, const void* K, const void* V, void* O,
int batch, int num_q_heads, int num_kv_heads,
int kv_len, int head_dim,
float scale, int causal, void* stream
) {
int grid = batch * num_q_heads;
int block = DECODE_THREADS;
decode_attention_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)Q,
(const __nv_bfloat16*)K,
(const __nv_bfloat16*)V,
(__nv_bfloat16*)O,
num_q_heads, num_kv_heads,
kv_len, head_dim,
scale
);
}
}

View File

@@ -63,6 +63,46 @@ __global__ void rmsnorm_bf16(
}
}
// Fused Add + RMSNorm: sum_out = x + residual, normed_out = rmsnorm(sum_out, gamma, eps)
// Each block handles one row of [hidden_size].
__global__ void add_rmsnorm_bf16(
const __nv_bfloat16* __restrict__ x,
const __nv_bfloat16* __restrict__ residual,
const __nv_bfloat16* __restrict__ gamma,
__nv_bfloat16* __restrict__ normed_out,
__nv_bfloat16* __restrict__ sum_out,
int hidden_size, float eps
) {
int row = blockIdx.x;
const __nv_bfloat16* x_row = x + row * hidden_size;
const __nv_bfloat16* res_row = residual + row * hidden_size;
__nv_bfloat16* sum_row = sum_out + row * hidden_size;
__nv_bfloat16* norm_row = normed_out + row * hidden_size;
// Pass 1: compute sum = x + residual, and accumulate sum_sq
float sum_sq = 0.0f;
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float s = __bfloat162float(x_row[i]) + __bfloat162float(res_row[i]);
sum_row[i] = __float2bfloat16(s);
sum_sq += s * s;
}
sum_sq = block_reduce_sum(sum_sq);
__shared__ float s_rms_inv;
if (threadIdx.x == 0) {
s_rms_inv = rsqrtf(sum_sq / hidden_size + eps);
}
__syncthreads();
// Pass 2: normed_out = sum * rms_inv * gamma
float rms_inv = s_rms_inv;
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float s = __bfloat162float(sum_row[i]);
float g = __bfloat162float(gamma[i]);
norm_row[i] = __float2bfloat16(s * rms_inv * g);
}
}
extern "C" {
void launch_rmsnorm_f32(const void* x, const void* gamma, void* out,
@@ -80,4 +120,15 @@ void launch_rmsnorm_bf16(const void* x, const void* gamma, void* out,
(__nv_bfloat16*)out, hidden_size, eps);
}
void launch_add_rmsnorm_bf16(const void* x, const void* residual, const void* gamma,
void* normed_out, void* sum_out,
int rows, int hidden_size, float eps, void* stream) {
int block = (hidden_size < 1024) ? hidden_size : 1024;
add_rmsnorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const __nv_bfloat16*)residual,
(const __nv_bfloat16*)gamma,
(__nv_bfloat16*)normed_out, (__nv_bfloat16*)sum_out,
hidden_size, eps);
}
}