phase 15: decode attention kernel + fused silu_mul + fused add_rmsnorm
Three performance optimizations targeting decode throughput: 1. Decode Attention Kernel (csrc/attention/flash_attention.cu): - Specialized kernel for Q_len=1 (decode step) - 256 threads parallelize across KV sequence dimension - Online softmax with block-level warp-shuffle reduction - Replaces FA2 kernel which wasted 63/64 threads for decode - flash_attention() auto-dispatches when q_len==1 2. Fused SiLU×Mul (csrc/activation/activations.cu): - Single kernel: out = silu(gate) * up - Saves 1 HBM read + 1 HBM write per FFN layer (N elements) - Eliminates intermediate tensor allocation 3. Fused Add+RMSNorm (csrc/normalization/rmsnorm.cu): - Single kernel: (normed, sum) = (rmsnorm(x+residual), x+residual) - Saves 1 full HBM round-trip per attention block - Eliminates separate add + rmsnorm kernel pair Performance analysis: - At current short sequences (max 79 tokens), these optimizations provide marginal benefit because the bottleneck is cuBLAS GEMV overhead: 252 weight matrix reads × ~32MB each = 15.5 GB per decode step. Theoretical minimum at 1.79 TB/s = 8.7ms, actual ~78ms (9x gap). - The fused kernels and decode attention will show larger gains at longer sequences where attention and element-wise ops dominate. - Next optimization target: CUDA Graphs to eliminate kernel launch overhead, or custom GEMV kernels to replace cuBLAS for M=1. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -12,6 +12,7 @@ unsafe extern "C" {
|
||||
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_mul_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_mul_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
|
||||
@@ -67,3 +68,24 @@ pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
||||
|
||||
pub fn add(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_add_f32, launch_add_bf16) }
|
||||
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) }
|
||||
|
||||
/// Fused SiLU×Mul: out = silu(gate) * up (BF16 only)
|
||||
/// Saves one HBM read + one HBM write compared to separate silu + mul.
|
||||
pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
|
||||
assert_eq!(gate.shape(), up.shape());
|
||||
assert!(gate.is_contiguous() && up.is_contiguous());
|
||||
assert!(matches!(gate.device(), Device::Cuda(_)));
|
||||
assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16");
|
||||
let out = Tensor::zeros(gate.shape(), gate.dtype(), gate.device());
|
||||
let n = gate.numel() as i32;
|
||||
unsafe {
|
||||
launch_silu_mul_bf16(
|
||||
gate.data_ptr() as *const c_void,
|
||||
up.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
@@ -16,6 +16,12 @@ unsafe extern "C" {
|
||||
q_len: i32, kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, stream: *mut c_void,
|
||||
);
|
||||
fn launch_decode_attention_bf16(
|
||||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
||||
@@ -81,7 +87,52 @@ pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
|
||||
batched_matmul(&weights, v)
|
||||
}
|
||||
|
||||
/// Decode Attention — optimized for single-token decode (q_len=1).
|
||||
///
|
||||
/// q: [batch, num_q_heads, 1, head_dim] BF16, contiguous, GPU
|
||||
/// k: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
|
||||
/// v: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
|
||||
///
|
||||
/// Returns: [batch, num_q_heads, 1, head_dim] BF16
|
||||
pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(q.shape()[2], 1, "decode_attention requires q_len == 1");
|
||||
|
||||
let batch = q.shape()[0];
|
||||
let num_q_heads = q.shape()[1];
|
||||
let head_dim = q.shape()[3];
|
||||
let num_kv_heads = k.shape()[1];
|
||||
let kv_len = k.shape()[2];
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::zeros(
|
||||
&[batch, num_q_heads, 1, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
);
|
||||
|
||||
unsafe {
|
||||
launch_decode_attention_bf16(
|
||||
q.data_ptr() as *const c_void,
|
||||
k.data_ptr() as *const c_void,
|
||||
v.data_ptr() as *const c_void,
|
||||
output.data_ptr() as *mut c_void,
|
||||
batch as i32,
|
||||
num_q_heads as i32,
|
||||
num_kv_heads as i32,
|
||||
kv_len as i32,
|
||||
head_dim as i32,
|
||||
scale,
|
||||
1, // causal (always 1 for decode)
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Flash Attention 2 — O(1) extra memory, supports GQA natively.
|
||||
/// Auto-dispatches to decode_attention when q_len == 1.
|
||||
///
|
||||
/// q: [batch, num_q_heads, q_len, head_dim] BF16, contiguous, GPU
|
||||
/// k: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
|
||||
@@ -109,6 +160,11 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
|
||||
assert!(num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads");
|
||||
assert!(head_dim <= 128, "flash_attention supports head_dim up to 128");
|
||||
|
||||
// Dispatch to specialized decode kernel for single-token generation
|
||||
if q_len == 1 {
|
||||
return decode_attention(q, k, v);
|
||||
}
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::zeros(
|
||||
&[batch, num_q_heads, q_len, head_dim],
|
||||
|
||||
@@ -8,13 +8,13 @@ pub mod rope;
|
||||
pub mod softmax;
|
||||
pub mod transpose;
|
||||
|
||||
pub use activation::{add, gelu, mul, scale, silu};
|
||||
pub use activation::{add, gelu, mul, scale, silu, silu_mul};
|
||||
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
|
||||
pub use attention::{attention, flash_attention};
|
||||
pub use attention::{attention, decode_attention, flash_attention};
|
||||
pub use embedding::embedding;
|
||||
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
||||
pub use layernorm::layernorm;
|
||||
pub use rmsnorm::rmsnorm;
|
||||
pub use rmsnorm::{add_rmsnorm, rmsnorm};
|
||||
pub use rope::{rope_inplace, RopeCache};
|
||||
pub use softmax::softmax;
|
||||
|
||||
|
||||
@@ -6,6 +6,9 @@ unsafe extern "C" {
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
|
||||
normed_out: *mut c_void, sum_out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
||||
@@ -34,3 +37,39 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Fused Add + RMSNorm: computes sum = x + residual, then normed = rmsnorm(sum, gamma, eps).
|
||||
/// Returns (normed, sum). BF16 only.
|
||||
/// Saves one kernel launch and one full HBM round-trip per layer.
|
||||
pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (Tensor, Tensor) {
|
||||
assert!(x.ndim() >= 1);
|
||||
assert_eq!(x.shape(), residual.shape());
|
||||
assert!(x.is_contiguous() && residual.is_contiguous() && gamma.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
assert_eq!(x.dtype(), DType::BF16, "add_rmsnorm requires BF16");
|
||||
assert_eq!(residual.dtype(), DType::BF16);
|
||||
assert_eq!(gamma.dtype(), DType::BF16);
|
||||
|
||||
let hidden_size = *x.shape().last().unwrap();
|
||||
assert_eq!(gamma.shape(), &[hidden_size]);
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
let normed_out = Tensor::zeros(x.shape(), DType::BF16, x.device());
|
||||
let sum_out = Tensor::zeros(x.shape(), DType::BF16, x.device());
|
||||
|
||||
unsafe {
|
||||
launch_add_rmsnorm_bf16(
|
||||
x.data_ptr() as *const c_void,
|
||||
residual.data_ptr() as *const c_void,
|
||||
gamma.data_ptr() as *const c_void,
|
||||
normed_out.data_ptr() as *mut c_void,
|
||||
sum_out.data_ptr() as *mut c_void,
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
|
||||
(normed_out, sum_out)
|
||||
}
|
||||
|
||||
@@ -196,14 +196,15 @@ impl Qwen3 {
|
||||
// GPU merge_heads: [1, H, S, D] → [S, H*D]
|
||||
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
x = add_any(&residual, &attn_proj);
|
||||
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.post_norm, eps);
|
||||
// Fused add + rmsnorm: (normed, x) where x = residual + attn_proj
|
||||
let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
// Fused SiLU×Mul
|
||||
let gate = matmul_2d(&normed, &layer.gate_proj_wt);
|
||||
let up = matmul_2d(&normed, &layer.up_proj_wt);
|
||||
let gate_activated = silu(&gate);
|
||||
let hidden_states = mul_any(&gate_activated, &up);
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
x = add_any(&residual, &down);
|
||||
}
|
||||
|
||||
@@ -45,6 +45,18 @@ __global__ void scale_bf16_kernel(const __nv_bfloat16* x, __nv_bfloat16* out, fl
|
||||
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(x[idx]) * scale);
|
||||
}
|
||||
|
||||
// Fused SiLU×Mul: out = silu(gate) * up
|
||||
__global__ void silu_mul_bf16_kernel(const __nv_bfloat16* gate, const __nv_bfloat16* up,
|
||||
__nv_bfloat16* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
float g = __bfloat162float(gate[idx]);
|
||||
float u = __bfloat162float(up[idx]);
|
||||
float silu_g = g / (1.0f + expf(-g));
|
||||
out[idx] = __float2bfloat16(silu_g * u);
|
||||
}
|
||||
}
|
||||
|
||||
// Element-wise add: out = a + b
|
||||
__global__ void add_f32_kernel(const float* a, const float* b, float* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
@@ -132,4 +144,11 @@ void launch_mul_bf16(const void* a, const void* b, void* out, int n, void* strea
|
||||
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
||||
}
|
||||
|
||||
void launch_silu_mul_bf16(const void* gate, const void* up, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
silu_mul_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)gate, (const __nv_bfloat16*)up, (__nv_bfloat16*)out, n);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -196,6 +196,177 @@ __global__ void flash_attention_bf16_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Decode Attention kernel: optimized for Q_len=1 (single-token decode).
|
||||
// Parallelizes across KV sequence dimension instead of Q rows.
|
||||
//
|
||||
// Grid: (batch * num_q_heads, 1) — one block per Q head
|
||||
// Block: 256 threads — each thread handles ceil(kv_len / 256) KV positions
|
||||
// Uses online softmax reduction across threads.
|
||||
// ============================================================
|
||||
|
||||
#define DECODE_THREADS 256
|
||||
#define HEAD_DIM_MAX 128
|
||||
|
||||
__global__ void decode_attention_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ Q,
|
||||
const __nv_bfloat16* __restrict__ K,
|
||||
const __nv_bfloat16* __restrict__ V,
|
||||
__nv_bfloat16* __restrict__ O,
|
||||
int num_q_heads, int num_kv_heads,
|
||||
int kv_len, int head_dim,
|
||||
float scale
|
||||
) {
|
||||
int bh = blockIdx.x;
|
||||
int batch_idx = bh / num_q_heads;
|
||||
int q_head = bh % num_q_heads;
|
||||
|
||||
// GQA mapping
|
||||
int heads_per_group = num_q_heads / num_kv_heads;
|
||||
int kv_head = q_head / heads_per_group;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
// Pointers to this batch/head's data
|
||||
// Q: [batch, num_q_heads, 1, head_dim]
|
||||
const __nv_bfloat16* Q_ptr = Q + ((long long)batch_idx * num_q_heads + q_head) * head_dim;
|
||||
// K/V: [batch, num_kv_heads, kv_len, head_dim]
|
||||
const __nv_bfloat16* K_base = K + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
||||
const __nv_bfloat16* V_base = V + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
||||
__nv_bfloat16* O_ptr = O + ((long long)batch_idx * num_q_heads + q_head) * head_dim;
|
||||
|
||||
// Load Q vector into registers (head_dim <= 128)
|
||||
float q_reg[HEAD_DIM_MAX];
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
q_reg[d] = __bfloat162float(Q_ptr[d]);
|
||||
}
|
||||
|
||||
// Each thread processes a chunk of KV positions
|
||||
// Thread tid handles positions: tid, tid+DECODE_THREADS, tid+2*DECODE_THREADS, ...
|
||||
float local_max = -INFINITY;
|
||||
float local_sum = 0.0f;
|
||||
float local_O[HEAD_DIM_MAX];
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
local_O[d] = 0.0f;
|
||||
}
|
||||
|
||||
for (int pos = tid; pos < kv_len; pos += DECODE_THREADS) {
|
||||
// Compute dot(Q, K[pos]) * scale
|
||||
const __nv_bfloat16* K_pos = K_base + pos * head_dim;
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
dot += q_reg[d] * __bfloat162float(K_pos[d]);
|
||||
}
|
||||
float s = dot * scale;
|
||||
|
||||
// Online softmax update
|
||||
float new_max = fmaxf(local_max, s);
|
||||
float correction = expf(local_max - new_max);
|
||||
float p = expf(s - new_max);
|
||||
|
||||
// Rescale running sum and O
|
||||
local_sum = local_sum * correction + p;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
local_O[d] = local_O[d] * correction;
|
||||
}
|
||||
|
||||
// Accumulate V[pos] weighted by p
|
||||
const __nv_bfloat16* V_pos = V_base + pos * head_dim;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
local_O[d] += p * __bfloat162float(V_pos[d]);
|
||||
}
|
||||
|
||||
local_max = new_max;
|
||||
}
|
||||
|
||||
// --- Block-level online softmax reduction ---
|
||||
// We need to combine (local_max, local_sum, local_O) across all threads.
|
||||
// Strategy: reduce max, then each thread rescales, then reduce sum and O.
|
||||
|
||||
// Shared memory for reduction
|
||||
__shared__ float smem_max[32]; // one per warp
|
||||
__shared__ float smem_sum[32];
|
||||
__shared__ float smem_O[HEAD_DIM_MAX]; // final output accumulator
|
||||
|
||||
// Step 1: Block-wide max reduction
|
||||
int lane = tid & 31;
|
||||
int warp_id = tid >> 5;
|
||||
int num_warps = DECODE_THREADS >> 5; // 8 warps
|
||||
|
||||
float warp_max = local_max;
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
|
||||
if (lane == 0) smem_max[warp_id] = warp_max;
|
||||
__syncthreads();
|
||||
|
||||
float global_max;
|
||||
if (tid == 0) {
|
||||
global_max = smem_max[0];
|
||||
for (int i = 1; i < num_warps; i++)
|
||||
global_max = fmaxf(global_max, smem_max[i]);
|
||||
smem_max[0] = global_max;
|
||||
}
|
||||
__syncthreads();
|
||||
global_max = smem_max[0];
|
||||
|
||||
// Step 2: Each thread rescales its local_sum and local_O with global_max
|
||||
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
|
||||
local_sum *= rescale;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
local_O[d] *= rescale;
|
||||
}
|
||||
|
||||
// Step 3: Reduce sum across block
|
||||
float warp_sum = local_sum;
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
|
||||
if (lane == 0) smem_sum[warp_id] = warp_sum;
|
||||
__syncthreads();
|
||||
|
||||
float global_sum;
|
||||
if (tid == 0) {
|
||||
global_sum = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++)
|
||||
global_sum += smem_sum[i];
|
||||
smem_sum[0] = global_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
global_sum = smem_sum[0];
|
||||
|
||||
// Step 4: Reduce O across block (dimension by dimension using shared mem)
|
||||
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
||||
|
||||
// Process head_dim in chunks: each iteration reduces one dimension
|
||||
// Use shared memory accumulator: each warp contributes via warp reduction + atomic
|
||||
// Actually simpler: iterate over dimensions, warp reduce each, then lane0 atomicAdd to smem_O
|
||||
|
||||
// Initialize smem_O
|
||||
for (int d = tid; d < head_dim; d += DECODE_THREADS) {
|
||||
smem_O[d] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Each thread adds its local_O contributions via warp reduction + atomicAdd
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
float val = local_O[d];
|
||||
// Warp-level reduction
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
if (lane == 0) {
|
||||
atomicAdd(&smem_O[d], val);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Thread 0..head_dim-1 write final output
|
||||
for (int d = tid; d < head_dim; d += DECODE_THREADS) {
|
||||
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_flash_attention_bf16(
|
||||
@@ -222,4 +393,24 @@ void launch_flash_attention_bf16(
|
||||
);
|
||||
}
|
||||
|
||||
void launch_decode_attention_bf16(
|
||||
const void* Q, const void* K, const void* V, void* O,
|
||||
int batch, int num_q_heads, int num_kv_heads,
|
||||
int kv_len, int head_dim,
|
||||
float scale, int causal, void* stream
|
||||
) {
|
||||
int grid = batch * num_q_heads;
|
||||
int block = DECODE_THREADS;
|
||||
|
||||
decode_attention_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)Q,
|
||||
(const __nv_bfloat16*)K,
|
||||
(const __nv_bfloat16*)V,
|
||||
(__nv_bfloat16*)O,
|
||||
num_q_heads, num_kv_heads,
|
||||
kv_len, head_dim,
|
||||
scale
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -63,6 +63,46 @@ __global__ void rmsnorm_bf16(
|
||||
}
|
||||
}
|
||||
|
||||
// Fused Add + RMSNorm: sum_out = x + residual, normed_out = rmsnorm(sum_out, gamma, eps)
|
||||
// Each block handles one row of [hidden_size].
|
||||
__global__ void add_rmsnorm_bf16(
|
||||
const __nv_bfloat16* __restrict__ x,
|
||||
const __nv_bfloat16* __restrict__ residual,
|
||||
const __nv_bfloat16* __restrict__ gamma,
|
||||
__nv_bfloat16* __restrict__ normed_out,
|
||||
__nv_bfloat16* __restrict__ sum_out,
|
||||
int hidden_size, float eps
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const __nv_bfloat16* x_row = x + row * hidden_size;
|
||||
const __nv_bfloat16* res_row = residual + row * hidden_size;
|
||||
__nv_bfloat16* sum_row = sum_out + row * hidden_size;
|
||||
__nv_bfloat16* norm_row = normed_out + row * hidden_size;
|
||||
|
||||
// Pass 1: compute sum = x + residual, and accumulate sum_sq
|
||||
float sum_sq = 0.0f;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float s = __bfloat162float(x_row[i]) + __bfloat162float(res_row[i]);
|
||||
sum_row[i] = __float2bfloat16(s);
|
||||
sum_sq += s * s;
|
||||
}
|
||||
sum_sq = block_reduce_sum(sum_sq);
|
||||
|
||||
__shared__ float s_rms_inv;
|
||||
if (threadIdx.x == 0) {
|
||||
s_rms_inv = rsqrtf(sum_sq / hidden_size + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Pass 2: normed_out = sum * rms_inv * gamma
|
||||
float rms_inv = s_rms_inv;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float s = __bfloat162float(sum_row[i]);
|
||||
float g = __bfloat162float(gamma[i]);
|
||||
norm_row[i] = __float2bfloat16(s * rms_inv * g);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_rmsnorm_f32(const void* x, const void* gamma, void* out,
|
||||
@@ -80,4 +120,15 @@ void launch_rmsnorm_bf16(const void* x, const void* gamma, void* out,
|
||||
(__nv_bfloat16*)out, hidden_size, eps);
|
||||
}
|
||||
|
||||
void launch_add_rmsnorm_bf16(const void* x, const void* residual, const void* gamma,
|
||||
void* normed_out, void* sum_out,
|
||||
int rows, int hidden_size, float eps, void* stream) {
|
||||
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
||||
add_rmsnorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const __nv_bfloat16*)residual,
|
||||
(const __nv_bfloat16*)gamma,
|
||||
(__nv_bfloat16*)normed_out, (__nv_bfloat16*)sum_out,
|
||||
hidden_size, eps);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user