phase 14: Flash Attention 2 for SM120 (RTX 5090)

Implement Flash Attention 2 forward kernel targeting SM120 (CC 12.0).
FA4 requires TMEM (only on data-center Blackwell SM100), so FA2 is the
correct target for consumer Blackwell GPUs like the RTX 5090.

CUDA kernel (csrc/attention/flash_attention.cu):
- Online softmax with tiled Q/K/V — O(1) extra memory, no S×S matrix
- Tile sizes: BR=BC=64, head_dim up to 128 (runtime parameter)
- BF16 input, FP32 accumulation, BF16 output
- Native GQA: kv_head = q_head / (num_q_heads / num_kv_heads)
- Causal mask with tile-level skip optimization
- Shared memory: 32 KB (Q_tile 16KB + KV_tile 16KB, fits in 48KB default)
- Grid: (q_tiles, batch × num_q_heads), Block: 128 threads

Integration:
- flash_attention() Rust wrapper in xserv-kernels with shape/dtype validation
- Qwen3 forward_gpu_cache uses flash_attention directly (no repeat_kv_gpu)
- Eliminates repeat_kv memory allocation + copy per layer per step
- Naive attention() preserved for testing/comparison

Validated on dash5 (RTX 5090, CUDA 12.9):
- Correctness: 9/10 top-1 match vs HF (identical to pre-FA baseline)
- Throughput: 12.9 tok/s (up from 10.3, +25% improvement)
- Now at 35% of HF transformers baseline (up from 30%)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-22 18:27:39 +08:00
parent ee68d3565d
commit d67dda404e
5 changed files with 291 additions and 7 deletions

View File

@@ -24,6 +24,7 @@ fn main() {
.file("../../csrc/embedding/rope.cu")
.file("../../csrc/attention/causal_mask.cu")
.file("../../csrc/embedding/transpose.cu")
.file("../../csrc/attention/flash_attention.cu")
.compile("xserv_kernels");
println!("cargo:rerun-if-changed=../../csrc/");

View File

@@ -10,6 +10,12 @@ unsafe extern "C" {
offset: i32, stream: *mut c_void);
fn launch_causal_mask_bf16(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
offset: i32, stream: *mut c_void);
fn launch_flash_attention_bf16(
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
q_len: i32, kv_len: i32, head_dim: i32,
scale: f32, causal: i32, stream: *mut c_void,
);
}
fn apply_causal_mask(scores: &Tensor, offset: usize) {
@@ -74,3 +80,59 @@ pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
// output = weights @ V → [B, H, q_len, head_dim]
batched_matmul(&weights, v)
}
/// Flash Attention 2 — O(1) extra memory, supports GQA natively.
///
/// q: [batch, num_q_heads, q_len, head_dim] BF16, contiguous, GPU
/// k: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
/// v: [batch, num_kv_heads, kv_len, head_dim] BF16, contiguous, GPU
///
/// Returns: [batch, num_q_heads, q_len, head_dim] BF16
pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
assert_eq!(q.ndim(), 4);
assert_eq!(k.ndim(), 4);
assert_eq!(v.ndim(), 4);
assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous());
assert_eq!(q.dtype(), DType::BF16, "flash_attention requires BF16");
assert_eq!(k.dtype(), DType::BF16);
assert_eq!(v.dtype(), DType::BF16);
let batch = q.shape()[0];
let num_q_heads = q.shape()[1];
let q_len = q.shape()[2];
let head_dim = q.shape()[3];
let num_kv_heads = k.shape()[1];
let kv_len = k.shape()[2];
assert_eq!(k.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
assert!(num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads");
assert!(head_dim <= 128, "flash_attention supports head_dim up to 128");
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::zeros(
&[batch, num_q_heads, q_len, head_dim],
DType::BF16,
q.device(),
);
unsafe {
launch_flash_attention_bf16(
q.data_ptr() as *const c_void,
k.data_ptr() as *const c_void,
v.data_ptr() as *const c_void,
output.data_ptr() as *mut c_void,
batch as i32,
num_q_heads as i32,
num_kv_heads as i32,
q_len as i32,
kv_len as i32,
head_dim as i32,
scale,
if causal { 1 } else { 0 },
std::ptr::null_mut(),
);
}
output
}

View File

@@ -10,7 +10,7 @@ pub mod transpose;
pub use activation::{add, gelu, mul, scale, silu};
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
pub use attention::attention;
pub use attention::{attention, flash_attention};
pub use embedding::embedding;
pub use gemm::{batched_matmul, matmul, GemmBackend};
pub use layernorm::layernorm;

View File

@@ -191,12 +191,8 @@ impl Qwen3 {
cache.append(layer_idx, &k, &v, new_tokens, pos_offset);
let (k_full, v_full) = cache.get_kv_len(layer_idx, pos_offset + new_tokens);
// GPU repeat KV for GQA
let n_rep = num_heads / num_kv_heads;
let k_full = xserv_kernels::repeat_kv_gpu(&k_full, n_rep);
let v_full = xserv_kernels::repeat_kv_gpu(&v_full, n_rep);
let attn_out = attention(&q, &k_full, &v_full, true);
// Flash Attention with native GQA (no repeat_kv needed)
let attn_out = flash_attention(&q, &k_full, &v_full, true);
// GPU merge_heads: [1, H, S, D] → [S, H*D]
let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);