phase 14: Flash Attention 2 for SM120 (RTX 5090)
Implement Flash Attention 2 forward kernel targeting SM120 (CC 12.0). FA4 requires TMEM (only on data-center Blackwell SM100), so FA2 is the correct target for consumer Blackwell GPUs like the RTX 5090. CUDA kernel (csrc/attention/flash_attention.cu): - Online softmax with tiled Q/K/V — O(1) extra memory, no S×S matrix - Tile sizes: BR=BC=64, head_dim up to 128 (runtime parameter) - BF16 input, FP32 accumulation, BF16 output - Native GQA: kv_head = q_head / (num_q_heads / num_kv_heads) - Causal mask with tile-level skip optimization - Shared memory: 32 KB (Q_tile 16KB + KV_tile 16KB, fits in 48KB default) - Grid: (q_tiles, batch × num_q_heads), Block: 128 threads Integration: - flash_attention() Rust wrapper in xserv-kernels with shape/dtype validation - Qwen3 forward_gpu_cache uses flash_attention directly (no repeat_kv_gpu) - Eliminates repeat_kv memory allocation + copy per layer per step - Naive attention() preserved for testing/comparison Validated on dash5 (RTX 5090, CUDA 12.9): - Correctness: 9/10 top-1 match vs HF (identical to pre-FA baseline) - Throughput: 12.9 tok/s (up from 10.3, +25% improvement) - Now at 35% of HF transformers baseline (up from 30%) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -24,6 +24,7 @@ fn main() {
|
||||
.file("../../csrc/embedding/rope.cu")
|
||||
.file("../../csrc/attention/causal_mask.cu")
|
||||
.file("../../csrc/embedding/transpose.cu")
|
||||
.file("../../csrc/attention/flash_attention.cu")
|
||||
.compile("xserv_kernels");
|
||||
|
||||
println!("cargo:rerun-if-changed=../../csrc/");
|
||||
|
||||
@@ -10,6 +10,12 @@ unsafe extern "C" {
|
||||
offset: i32, stream: *mut c_void);
|
||||
fn launch_causal_mask_bf16(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
|
||||
offset: i32, stream: *mut c_void);
|
||||
fn launch_flash_attention_bf16(
|
||||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
q_len: i32, kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
||||
@@ -74,3 +80,59 @@ pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
|
||||
// output = weights @ V → [B, H, q_len, head_dim]
|
||||
batched_matmul(&weights, v)
|
||||
}
|
||||
|
||||
/// Flash Attention 2 — O(1) extra memory, supports GQA natively.
|
||||
///
|
||||
/// q: [batch, num_q_heads, q_len, head_dim] BF16, contiguous, GPU
|
||||
/// k: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
|
||||
/// v: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
|
||||
///
|
||||
/// Returns: [batch, num_q_heads, q_len, head_dim] BF16
|
||||
pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(k.ndim(), 4);
|
||||
assert_eq!(v.ndim(), 4);
|
||||
assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous());
|
||||
assert_eq!(q.dtype(), DType::BF16, "flash_attention requires BF16");
|
||||
assert_eq!(k.dtype(), DType::BF16);
|
||||
assert_eq!(v.dtype(), DType::BF16);
|
||||
|
||||
let batch = q.shape()[0];
|
||||
let num_q_heads = q.shape()[1];
|
||||
let q_len = q.shape()[2];
|
||||
let head_dim = q.shape()[3];
|
||||
let num_kv_heads = k.shape()[1];
|
||||
let kv_len = k.shape()[2];
|
||||
|
||||
assert_eq!(k.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||||
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||||
assert!(num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads");
|
||||
assert!(head_dim <= 128, "flash_attention supports head_dim up to 128");
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::zeros(
|
||||
&[batch, num_q_heads, q_len, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
);
|
||||
|
||||
unsafe {
|
||||
launch_flash_attention_bf16(
|
||||
q.data_ptr() as *const c_void,
|
||||
k.data_ptr() as *const c_void,
|
||||
v.data_ptr() as *const c_void,
|
||||
output.data_ptr() as *mut c_void,
|
||||
batch as i32,
|
||||
num_q_heads as i32,
|
||||
num_kv_heads as i32,
|
||||
q_len as i32,
|
||||
kv_len as i32,
|
||||
head_dim as i32,
|
||||
scale,
|
||||
if causal { 1 } else { 0 },
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ pub mod transpose;
|
||||
|
||||
pub use activation::{add, gelu, mul, scale, silu};
|
||||
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
|
||||
pub use attention::attention;
|
||||
pub use attention::{attention, flash_attention};
|
||||
pub use embedding::embedding;
|
||||
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
||||
pub use layernorm::layernorm;
|
||||
|
||||
@@ -191,12 +191,8 @@ impl Qwen3 {
|
||||
cache.append(layer_idx, &k, &v, new_tokens, pos_offset);
|
||||
let (k_full, v_full) = cache.get_kv_len(layer_idx, pos_offset + new_tokens);
|
||||
|
||||
// GPU repeat KV for GQA
|
||||
let n_rep = num_heads / num_kv_heads;
|
||||
let k_full = xserv_kernels::repeat_kv_gpu(&k_full, n_rep);
|
||||
let v_full = xserv_kernels::repeat_kv_gpu(&v_full, n_rep);
|
||||
|
||||
let attn_out = attention(&q, &k_full, &v_full, true);
|
||||
// Flash Attention with native GQA (no repeat_kv needed)
|
||||
let attn_out = flash_attention(&q, &k_full, &v_full, true);
|
||||
// GPU merge_heads: [1, H, S, D] → [S, H*D]
|
||||
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
225
csrc/attention/flash_attention.cu
Normal file
225
csrc/attention/flash_attention.cu
Normal file
@@ -0,0 +1,225 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <float.h>
|
||||
|
||||
// Flash Attention 2 forward kernel for BF16 with FP32 accumulation.
|
||||
//
|
||||
// Algorithm: outer loop over Q tiles (BR rows), inner loop over K/V tiles (BC rows).
|
||||
// Uses online softmax — no O(S^2) memory.
|
||||
//
|
||||
// Layout: Q [batch, num_q_heads, q_len, head_dim]
|
||||
// K [batch, num_kv_heads, kv_len, head_dim]
|
||||
// V [batch, num_kv_heads, kv_len, head_dim]
|
||||
// O [batch, num_q_heads, q_len, head_dim]
|
||||
//
|
||||
// Shared memory (BF16):
|
||||
// smem_q[BR][head_dim] — 64 * 128 * 2 = 16 KB (loaded once per Q tile)
|
||||
// smem_kv[BC][head_dim] — 64 * 128 * 2 = 16 KB (alternates K and V)
|
||||
// Total: 32 KB (fits in default 48 KB shared memory)
|
||||
|
||||
#define BR 64
|
||||
#define BC 64
|
||||
#define THREADS_PER_BLOCK 128
|
||||
|
||||
__global__ void flash_attention_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ Q,
|
||||
const __nv_bfloat16* __restrict__ K,
|
||||
const __nv_bfloat16* __restrict__ V,
|
||||
__nv_bfloat16* __restrict__ O,
|
||||
int num_q_heads, int num_kv_heads,
|
||||
int q_len, int kv_len, int head_dim,
|
||||
float scale, int causal
|
||||
) {
|
||||
// Grid: (ceil(q_len / BR), batch * num_q_heads)
|
||||
int q_tile_idx = blockIdx.x;
|
||||
int bh = blockIdx.y;
|
||||
int batch_idx = bh / num_q_heads;
|
||||
int q_head = bh % num_q_heads;
|
||||
|
||||
// GQA: map Q head to KV head
|
||||
int heads_per_group = num_q_heads / num_kv_heads;
|
||||
int kv_head = q_head / heads_per_group;
|
||||
|
||||
int q_tile_start = q_tile_idx * BR;
|
||||
if (q_tile_start >= q_len) return;
|
||||
int q_tile_rows = min(BR, q_len - q_tile_start);
|
||||
|
||||
// Pointers to this batch/head's data
|
||||
const __nv_bfloat16* Q_head = Q + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim;
|
||||
const __nv_bfloat16* K_head = K + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
||||
const __nv_bfloat16* V_head = V + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
||||
__nv_bfloat16* O_head = O + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
// Dynamic shared memory
|
||||
extern __shared__ __nv_bfloat16 smem[];
|
||||
__nv_bfloat16* smem_q = smem; // BR * head_dim elements
|
||||
__nv_bfloat16* smem_kv = smem + BR * head_dim; // BC * head_dim elements
|
||||
|
||||
// ---- Load Q tile into shared memory (cooperative) ----
|
||||
int q_elems = q_tile_rows * head_dim;
|
||||
for (int i = tid; i < q_elems; i += THREADS_PER_BLOCK) {
|
||||
int row = i / head_dim;
|
||||
int col = i % head_dim;
|
||||
smem_q[row * head_dim + col] = Q_head[(q_tile_start + row) * head_dim + col];
|
||||
}
|
||||
// Zero-pad if q_tile_rows < BR
|
||||
for (int i = q_elems + tid; i < BR * head_dim; i += THREADS_PER_BLOCK) {
|
||||
smem_q[i] = __float2bfloat16(0.0f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Thread t (0 <= t < q_tile_rows) owns Q row t
|
||||
bool owns_row = (tid < q_tile_rows);
|
||||
|
||||
// Per-thread FP32 accumulators (head_dim up to 128)
|
||||
float O_acc[128];
|
||||
float m_val = -INFINITY;
|
||||
float l_val = 0.0f;
|
||||
if (owns_row) {
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
O_acc[d] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// kv_offset handles cached KV longer than Q (decode step)
|
||||
int kv_offset = kv_len - q_len;
|
||||
int num_kv_tiles = (kv_len + BC - 1) / BC;
|
||||
|
||||
// ---- Inner loop over K/V tiles ----
|
||||
for (int j = 0; j < num_kv_tiles; j++) {
|
||||
int kv_tile_start = j * BC;
|
||||
int kv_tile_cols = min(BC, kv_len - kv_tile_start);
|
||||
|
||||
// Causal: skip entire tile if all K positions are in the future
|
||||
if (causal) {
|
||||
int max_allowed_kv = (q_tile_start + q_tile_rows - 1) + kv_offset;
|
||||
if (kv_tile_start > max_allowed_kv) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Load K tile into smem_kv ----
|
||||
int kv_elems = kv_tile_cols * head_dim;
|
||||
for (int i = tid; i < kv_elems; i += THREADS_PER_BLOCK) {
|
||||
int row = i / head_dim;
|
||||
int col = i % head_dim;
|
||||
smem_kv[row * head_dim + col] = K_head[(kv_tile_start + row) * head_dim + col];
|
||||
}
|
||||
for (int i = kv_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) {
|
||||
smem_kv[i] = __float2bfloat16(0.0f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ---- Compute S = Q @ K^T * scale, causal mask, online softmax ----
|
||||
float P[BC];
|
||||
|
||||
if (owns_row) {
|
||||
float row_max = -INFINITY;
|
||||
for (int c = 0; c < kv_tile_cols; c++) {
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
dot += __bfloat162float(smem_q[tid * head_dim + d])
|
||||
* __bfloat162float(smem_kv[c * head_dim + d]);
|
||||
}
|
||||
float s = dot * scale;
|
||||
|
||||
if (causal) {
|
||||
int q_pos = q_tile_start + tid;
|
||||
int kv_pos = kv_tile_start + c;
|
||||
if (kv_pos > q_pos + kv_offset) {
|
||||
s = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
P[c] = s; // store score temporarily in P
|
||||
row_max = fmaxf(row_max, s);
|
||||
}
|
||||
|
||||
// Online softmax: m_new, P = exp(S - m_new), l_new
|
||||
float m_new = fmaxf(m_val, row_max);
|
||||
|
||||
float psum = 0.0f;
|
||||
for (int c = 0; c < kv_tile_cols; c++) {
|
||||
P[c] = expf(P[c] - m_new);
|
||||
psum += P[c];
|
||||
}
|
||||
|
||||
// Rescale previous accumulator
|
||||
float correction = expf(m_val - m_new);
|
||||
l_val = correction * l_val + psum;
|
||||
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
O_acc[d] *= correction;
|
||||
}
|
||||
|
||||
m_val = m_new;
|
||||
}
|
||||
|
||||
// Sync before overwriting smem_kv with V tile
|
||||
__syncthreads();
|
||||
|
||||
// ---- Load V tile (reuse smem_kv) ----
|
||||
int v_elems = kv_tile_cols * head_dim;
|
||||
for (int i = tid; i < v_elems; i += THREADS_PER_BLOCK) {
|
||||
int row = i / head_dim;
|
||||
int col = i % head_dim;
|
||||
smem_kv[row * head_dim + col] = V_head[(kv_tile_start + row) * head_dim + col];
|
||||
}
|
||||
for (int i = v_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) {
|
||||
smem_kv[i] = __float2bfloat16(0.0f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ---- Accumulate O += P @ V_tile ----
|
||||
if (owns_row) {
|
||||
for (int c = 0; c < kv_tile_cols; c++) {
|
||||
float p = P[c];
|
||||
if (p != 0.0f) {
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
O_acc[d] += p * __bfloat162float(smem_kv[c * head_dim + d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// ---- Final normalize and write output (convert FP32 → BF16) ----
|
||||
if (owns_row) {
|
||||
float inv_l = (l_val > 0.0f) ? (1.0f / l_val) : 0.0f;
|
||||
int global_row = q_tile_start + tid;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
O_head[global_row * head_dim + d] = __float2bfloat16(O_acc[d] * inv_l);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_flash_attention_bf16(
|
||||
const void* Q, const void* K, const void* V, void* O,
|
||||
int batch, int num_q_heads, int num_kv_heads,
|
||||
int q_len, int kv_len, int head_dim,
|
||||
float scale, int causal, void* stream
|
||||
) {
|
||||
int q_tiles = (q_len + BR - 1) / BR;
|
||||
dim3 grid(q_tiles, batch * num_q_heads);
|
||||
int block = THREADS_PER_BLOCK;
|
||||
|
||||
// Shared memory: smem_q[BR * head_dim] + smem_kv[BC * head_dim], all BF16
|
||||
int smem_bytes = (BR + BC) * head_dim * (int)sizeof(__nv_bfloat16);
|
||||
|
||||
flash_attention_bf16_kernel<<<grid, block, smem_bytes, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)Q,
|
||||
(const __nv_bfloat16*)K,
|
||||
(const __nv_bfloat16*)V,
|
||||
(__nv_bfloat16*)O,
|
||||
num_q_heads, num_kv_heads,
|
||||
q_len, kv_len, head_dim,
|
||||
scale, causal
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user