phase 5: naive multi-head attention
- Batched GEMM via cublasGemmStridedBatchedEx - Causal mask CUDA kernel (F32 + BF16) - Element-wise scale CUDA kernel (F32 + BF16) - attention() composing: batched_matmul + scale + causal_mask + softmax - Fixed to_device/contiguous infinite recursion (GPU contiguous via CPU round-trip) - 5 attention tests passing (max_err < 3e-7 F32) - Total: 61 tests passing across all crates Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -22,6 +22,7 @@ fn main() {
|
|||||||
.file("../../csrc/reduce/softmax.cu")
|
.file("../../csrc/reduce/softmax.cu")
|
||||||
.file("../../csrc/embedding/embedding.cu")
|
.file("../../csrc/embedding/embedding.cu")
|
||||||
.file("../../csrc/embedding/rope.cu")
|
.file("../../csrc/embedding/rope.cu")
|
||||||
|
.file("../../csrc/attention/causal_mask.cu")
|
||||||
.compile("xserv_kernels");
|
.compile("xserv_kernels");
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed=../../csrc/");
|
println!("cargo:rerun-if-changed=../../csrc/");
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ unsafe extern "C" {
|
|||||||
fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||||
fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||||
fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||||
|
fn launch_scale_f32(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
|
||||||
|
fn launch_scale_bf16(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn gelu(x: &Tensor) -> Tensor {
|
pub fn gelu(x: &Tensor) -> Tensor {
|
||||||
@@ -39,3 +41,19 @@ pub fn silu(x: &Tensor) -> Tensor {
|
|||||||
xserv_cuda::device::synchronize().unwrap();
|
xserv_cuda::device::synchronize().unwrap();
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
||||||
|
assert!(x.is_contiguous());
|
||||||
|
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||||
|
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||||
|
let n = x.numel() as i32;
|
||||||
|
unsafe {
|
||||||
|
match x.dtype() {
|
||||||
|
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
||||||
|
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
||||||
|
_ => panic!("unsupported dtype for scale"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xserv_cuda::device::synchronize().unwrap();
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|||||||
77
crates/xserv-kernels/src/attention.rs
Normal file
77
crates/xserv-kernels/src/attention.rs
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
use std::ffi::c_void;
|
||||||
|
use xserv_tensor::{DType, Tensor};
|
||||||
|
|
||||||
|
use crate::activation::scale;
|
||||||
|
use crate::gemm::batched_matmul;
|
||||||
|
use crate::softmax::softmax;
|
||||||
|
|
||||||
|
unsafe extern "C" {
|
||||||
|
fn launch_causal_mask_f32(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
|
||||||
|
offset: i32, stream: *mut c_void);
|
||||||
|
fn launch_causal_mask_bf16(scores: *mut c_void, batch: i32, rows: i32, cols: i32,
|
||||||
|
offset: i32, stream: *mut c_void);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
||||||
|
let ndim = scores.ndim();
|
||||||
|
let rows = scores.shape()[ndim - 2];
|
||||||
|
let cols = scores.shape()[ndim - 1];
|
||||||
|
let batch: usize = scores.shape()[..ndim - 2].iter().product();
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
match scores.dtype() {
|
||||||
|
DType::F32 => launch_causal_mask_f32(
|
||||||
|
scores.data_ptr() as *mut c_void,
|
||||||
|
batch as i32, rows as i32, cols as i32, offset as i32,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
),
|
||||||
|
DType::BF16 => launch_causal_mask_bf16(
|
||||||
|
scores.data_ptr() as *mut c_void,
|
||||||
|
batch as i32, rows as i32, cols as i32, offset as i32,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
),
|
||||||
|
_ => panic!("unsupported dtype for causal mask"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xserv_cuda::device::synchronize().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Multi-head attention (naive, materializes S×S score matrix).
|
||||||
|
///
|
||||||
|
/// q, k, v: [batch, num_heads, seq_len, head_dim] — contiguous, on GPU
|
||||||
|
/// Returns: [batch, num_heads, seq_len, head_dim]
|
||||||
|
pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor {
|
||||||
|
assert_eq!(q.ndim(), 4);
|
||||||
|
assert_eq!(k.ndim(), 4);
|
||||||
|
assert_eq!(v.ndim(), 4);
|
||||||
|
assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous());
|
||||||
|
|
||||||
|
let batch = q.shape()[0];
|
||||||
|
let num_heads = q.shape()[1];
|
||||||
|
let q_len = q.shape()[2];
|
||||||
|
let head_dim = q.shape()[3];
|
||||||
|
let kv_len = k.shape()[2];
|
||||||
|
|
||||||
|
assert_eq!(k.shape(), &[batch, num_heads, kv_len, head_dim]);
|
||||||
|
assert_eq!(v.shape(), &[batch, num_heads, kv_len, head_dim]);
|
||||||
|
|
||||||
|
// scores = Q @ K^T → [B, H, q_len, kv_len]
|
||||||
|
let k_t = k.transpose(2, 3).contiguous();
|
||||||
|
let scores = batched_matmul(q, &k_t);
|
||||||
|
|
||||||
|
// Scale by 1/sqrt(head_dim)
|
||||||
|
let scale_factor = 1.0 / (head_dim as f32).sqrt();
|
||||||
|
let scaled_scores = scale(&scores, scale_factor);
|
||||||
|
|
||||||
|
// Causal mask
|
||||||
|
if causal {
|
||||||
|
let offset = kv_len - q_len;
|
||||||
|
apply_causal_mask(&scaled_scores, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Softmax
|
||||||
|
let weights = softmax(&scaled_scores);
|
||||||
|
|
||||||
|
// output = weights @ V → [B, H, q_len, head_dim]
|
||||||
|
batched_matmul(&weights, v)
|
||||||
|
}
|
||||||
@@ -46,6 +46,19 @@ unsafe extern "C" {
|
|||||||
compute_type: i32,
|
compute_type: i32,
|
||||||
algo: i32,
|
algo: i32,
|
||||||
) -> i32;
|
) -> i32;
|
||||||
|
fn cublasGemmStridedBatchedEx(
|
||||||
|
handle: CublasHandle,
|
||||||
|
transa: i32, transb: i32,
|
||||||
|
m: i32, n: i32, k: i32,
|
||||||
|
alpha: *const c_void,
|
||||||
|
a: *const c_void, a_type: i32, lda: i32, stride_a: i64,
|
||||||
|
b: *const c_void, b_type: i32, ldb: i32, stride_b: i64,
|
||||||
|
beta: *const c_void,
|
||||||
|
c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64,
|
||||||
|
batch_count: i32,
|
||||||
|
compute_type: i32,
|
||||||
|
algo: i32,
|
||||||
|
) -> i32;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CublasContext {
|
pub struct CublasContext {
|
||||||
@@ -149,3 +162,68 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
|||||||
|
|
||||||
c
|
c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Batched matrix multiplication via cuBLAS: C[b] = A[b] @ B[b]
|
||||||
|
/// a: [..., M, K], b: [..., K, N] → [..., M, N]
|
||||||
|
/// Leading dimensions must match and tensors must be contiguous.
|
||||||
|
pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
|
||||||
|
assert!(a.ndim() >= 2 && b.ndim() >= 2);
|
||||||
|
assert_eq!(a.ndim(), b.ndim());
|
||||||
|
assert!(a.is_contiguous() && b.is_contiguous());
|
||||||
|
assert!(matches!(a.device(), Device::Cuda(_)));
|
||||||
|
assert_eq!(a.dtype(), b.dtype());
|
||||||
|
|
||||||
|
let ndim = a.ndim();
|
||||||
|
let m = a.shape()[ndim - 2];
|
||||||
|
let k = a.shape()[ndim - 1];
|
||||||
|
let n = b.shape()[ndim - 1];
|
||||||
|
assert_eq!(b.shape()[ndim - 2], k, "inner dimension mismatch");
|
||||||
|
|
||||||
|
// Compute batch count from leading dimensions
|
||||||
|
let batch: usize = a.shape()[..ndim - 2].iter().product();
|
||||||
|
assert_eq!(
|
||||||
|
b.shape()[..ndim - 2].iter().product::<usize>(),
|
||||||
|
batch,
|
||||||
|
"batch dimensions mismatch"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut out_shape: Vec<usize> = a.shape()[..ndim - 2].to_vec();
|
||||||
|
out_shape.push(m);
|
||||||
|
out_shape.push(n);
|
||||||
|
let c = Tensor::zeros(&out_shape, a.dtype(), a.device());
|
||||||
|
|
||||||
|
let dtype = a.dtype();
|
||||||
|
let (a_type, b_type, c_type) = match dtype {
|
||||||
|
DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F),
|
||||||
|
DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF),
|
||||||
|
_ => panic!("unsupported dtype for batched matmul"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let alpha = 1.0f32;
|
||||||
|
let beta = 0.0f32;
|
||||||
|
// cuBLAS strides are in elements (not bytes)
|
||||||
|
let stride_a = (m * k) as i64;
|
||||||
|
let stride_b = (k * n) as i64;
|
||||||
|
let stride_c = (m * n) as i64;
|
||||||
|
|
||||||
|
let ctx = CublasContext::new().unwrap();
|
||||||
|
unsafe {
|
||||||
|
cublasSetStream_v2(ctx.handle, std::ptr::null_mut());
|
||||||
|
// Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major)
|
||||||
|
error::check(cublasGemmStridedBatchedEx(
|
||||||
|
ctx.handle,
|
||||||
|
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||||
|
n as i32, m as i32, k as i32,
|
||||||
|
&alpha as *const f32 as *const c_void,
|
||||||
|
b.data_ptr() as _, b_type, n as i32, stride_b,
|
||||||
|
a.data_ptr() as _, a_type, k as i32, stride_a,
|
||||||
|
&beta as *const f32 as *const c_void,
|
||||||
|
c.data_ptr() as *mut c_void, c_type, n as i32, stride_c,
|
||||||
|
batch as i32,
|
||||||
|
CUBLAS_COMPUTE_32F,
|
||||||
|
-1,
|
||||||
|
)).expect("cuBLAS batched GEMM failed");
|
||||||
|
}
|
||||||
|
xserv_cuda::device::synchronize().unwrap();
|
||||||
|
c
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
pub mod activation;
|
pub mod activation;
|
||||||
|
pub mod attention;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod gemm;
|
pub mod gemm;
|
||||||
pub mod layernorm;
|
pub mod layernorm;
|
||||||
@@ -6,9 +7,10 @@ pub mod rmsnorm;
|
|||||||
pub mod rope;
|
pub mod rope;
|
||||||
pub mod softmax;
|
pub mod softmax;
|
||||||
|
|
||||||
pub use activation::{gelu, silu};
|
pub use activation::{gelu, scale, silu};
|
||||||
|
pub use attention::attention;
|
||||||
pub use embedding::embedding;
|
pub use embedding::embedding;
|
||||||
pub use gemm::{matmul, GemmBackend};
|
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
||||||
pub use layernorm::layernorm;
|
pub use layernorm::layernorm;
|
||||||
pub use rmsnorm::rmsnorm;
|
pub use rmsnorm::rmsnorm;
|
||||||
pub use rope::{rope_inplace, RopeCache};
|
pub use rope::{rope_inplace, RopeCache};
|
||||||
|
|||||||
187
crates/xserv-kernels/tests/attention_test.rs
Normal file
187
crates/xserv-kernels/tests/attention_test.rs
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
use xserv_kernels::*;
|
||||||
|
use xserv_tensor::{Device, Tensor};
|
||||||
|
|
||||||
|
fn init() { xserv_cuda::device::set_device(0).unwrap(); }
|
||||||
|
|
||||||
|
fn cpu_attention(q: &[f32], k: &[f32], v: &[f32],
|
||||||
|
batch: usize, heads: usize, q_len: usize, kv_len: usize, head_dim: usize,
|
||||||
|
causal: bool) -> Vec<f32> {
|
||||||
|
let mut out = vec![0.0f32; batch * heads * q_len * head_dim];
|
||||||
|
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||||
|
|
||||||
|
for b in 0..batch {
|
||||||
|
for h in 0..heads {
|
||||||
|
// scores = Q @ K^T, scaled
|
||||||
|
let mut scores = vec![0.0f32; q_len * kv_len];
|
||||||
|
for i in 0..q_len {
|
||||||
|
for j in 0..kv_len {
|
||||||
|
let mut s = 0.0f32;
|
||||||
|
for d in 0..head_dim {
|
||||||
|
let qi = q[((b * heads + h) * q_len + i) * head_dim + d];
|
||||||
|
let ki = k[((b * heads + h) * kv_len + j) * head_dim + d];
|
||||||
|
s += qi * ki;
|
||||||
|
}
|
||||||
|
scores[i * kv_len + j] = s * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// causal mask
|
||||||
|
if causal {
|
||||||
|
let offset = kv_len - q_len;
|
||||||
|
for i in 0..q_len {
|
||||||
|
for j in 0..kv_len {
|
||||||
|
if j > i + offset {
|
||||||
|
scores[i * kv_len + j] = f32::NEG_INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// softmax per row
|
||||||
|
for i in 0..q_len {
|
||||||
|
let row = &mut scores[i * kv_len..(i + 1) * kv_len];
|
||||||
|
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||||
|
let mut sum = 0.0f32;
|
||||||
|
for v in row.iter_mut() {
|
||||||
|
*v = (*v - max).exp();
|
||||||
|
sum += *v;
|
||||||
|
}
|
||||||
|
for v in row.iter_mut() {
|
||||||
|
*v /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// output = weights @ V
|
||||||
|
for i in 0..q_len {
|
||||||
|
for d in 0..head_dim {
|
||||||
|
let mut s = 0.0f32;
|
||||||
|
for j in 0..kv_len {
|
||||||
|
let w = scores[i * kv_len + j];
|
||||||
|
let vi = v[((b * heads + h) * kv_len + j) * head_dim + d];
|
||||||
|
s += w * vi;
|
||||||
|
}
|
||||||
|
out[((b * heads + h) * q_len + i) * head_dim + d] = s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_close(a: &[f32], b: &[f32], atol: f32, name: &str) {
|
||||||
|
assert_eq!(a.len(), b.len(), "{name}: length mismatch");
|
||||||
|
let mut max_err = 0.0f32;
|
||||||
|
for (i, (x, y)) in a.iter().zip(b).enumerate() {
|
||||||
|
let err = (x - y).abs();
|
||||||
|
if err > max_err { max_err = err; }
|
||||||
|
assert!(err <= atol, "{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}");
|
||||||
|
}
|
||||||
|
println!("{name}: max_err = {max_err:.6e}");
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_data(n: usize) -> Vec<f32> {
|
||||||
|
(0..n).map(|i| ((i % 17) as f32 - 8.0) * 0.05).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_batched_matmul() {
|
||||||
|
init();
|
||||||
|
let batch = 4;
|
||||||
|
let heads = 8;
|
||||||
|
let m = 32;
|
||||||
|
let k = 64;
|
||||||
|
let n = 32;
|
||||||
|
|
||||||
|
let a_data = make_data(batch * heads * m * k);
|
||||||
|
let b_data = make_data(batch * heads * k * n);
|
||||||
|
|
||||||
|
let a = Tensor::from_slice(&a_data, &[batch, heads, m, k]).to_device(Device::Cuda(0));
|
||||||
|
let b = Tensor::from_slice(&b_data, &[batch, heads, k, n]).to_device(Device::Cuda(0));
|
||||||
|
let c = batched_matmul(&a, &b).to_device(Device::Cpu);
|
||||||
|
|
||||||
|
assert_eq!(c.shape(), &[batch, heads, m, n]);
|
||||||
|
|
||||||
|
// Verify one batch element
|
||||||
|
let a_cpu = &a_data[0..m * k];
|
||||||
|
let b_cpu = &b_data[0..k * n];
|
||||||
|
let mut expected = vec![0.0f32; m * n];
|
||||||
|
for i in 0..m {
|
||||||
|
for j in 0..n {
|
||||||
|
let mut s = 0.0f32;
|
||||||
|
for kk in 0..k { s += a_cpu[i * k + kk] * b_cpu[kk * n + j]; }
|
||||||
|
expected[i * n + j] = s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let result = c.as_slice::<f32>();
|
||||||
|
check_close(&result[0..m * n], &expected, 1e-3, "batched_matmul[0]");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_attention_no_causal() {
|
||||||
|
init();
|
||||||
|
let b = 1; let h = 2; let s = 8; let d = 16;
|
||||||
|
let q_data = make_data(b * h * s * d);
|
||||||
|
let k_data = make_data(b * h * s * d);
|
||||||
|
let v_data = make_data(b * h * s * d);
|
||||||
|
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, false);
|
||||||
|
|
||||||
|
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||||
|
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||||
|
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||||
|
let out = attention(&q, &k, &v, false).to_device(Device::Cpu);
|
||||||
|
check_close(out.as_slice::<f32>(), &expected, 1e-4, "attention_no_causal");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_attention_causal() {
|
||||||
|
init();
|
||||||
|
let b = 1; let h = 2; let s = 16; let d = 32;
|
||||||
|
let q_data = make_data(b * h * s * d);
|
||||||
|
let k_data = make_data(b * h * s * d);
|
||||||
|
let v_data = make_data(b * h * s * d);
|
||||||
|
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true);
|
||||||
|
|
||||||
|
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||||
|
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||||
|
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||||
|
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
||||||
|
check_close(out.as_slice::<f32>(), &expected, 1e-3, "attention_causal");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_attention_causal_larger() {
|
||||||
|
init();
|
||||||
|
let b = 2; let h = 4; let s = 64; let d = 64;
|
||||||
|
let q_data = make_data(b * h * s * d);
|
||||||
|
let k_data = make_data(b * h * s * d);
|
||||||
|
let v_data = make_data(b * h * s * d);
|
||||||
|
let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true);
|
||||||
|
|
||||||
|
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||||
|
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||||
|
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||||
|
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
||||||
|
check_close(out.as_slice::<f32>(), &expected, 1e-2, "attention_causal_larger");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_attention_causal_first_row_sees_only_first_token() {
|
||||||
|
init();
|
||||||
|
let b = 1; let h = 1; let s = 4; let d = 8;
|
||||||
|
let q_data = make_data(b * h * s * d);
|
||||||
|
let k_data = make_data(b * h * s * d);
|
||||||
|
let v_data: Vec<f32> = (0..s * d).map(|i| {
|
||||||
|
if i < d { 1.0 } else { 0.0 } // only first V row is nonzero
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||||
|
let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||||
|
let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0));
|
||||||
|
let out = attention(&q, &k, &v, true).to_device(Device::Cpu);
|
||||||
|
|
||||||
|
// First row (position 0) with causal mask can only see position 0.
|
||||||
|
// So attention weight for position 0 is 1.0 for token 0 only.
|
||||||
|
// output[0] should be exactly V[0] = [1, 1, 1, ...1]
|
||||||
|
let result = out.as_slice::<f32>();
|
||||||
|
for i in 0..d {
|
||||||
|
assert!((result[i] - 1.0).abs() < 1e-5,
|
||||||
|
"first row should equal V[0], got {} at dim {}", result[i], i);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -137,8 +137,13 @@ impl Tensor {
|
|||||||
if self.is_contiguous() {
|
if self.is_contiguous() {
|
||||||
return self.clone();
|
return self.clone();
|
||||||
}
|
}
|
||||||
// Copy to contiguous layout on CPU
|
// For GPU tensors: round-trip through CPU (correct but slow).
|
||||||
assert_eq!(self.device(), Device::Cpu, "contiguous() on GPU not yet supported");
|
// TODO: write a GPU contiguous-copy kernel for performance.
|
||||||
|
if matches!(self.device(), Device::Cuda(_)) {
|
||||||
|
let cpu = self.to_device(Device::Cpu);
|
||||||
|
let contig = cpu.contiguous();
|
||||||
|
return contig.to_device(self.device());
|
||||||
|
}
|
||||||
let numel = self.numel();
|
let numel = self.numel();
|
||||||
let elem_size = self.dtype.size_bytes();
|
let elem_size = self.dtype.size_bytes();
|
||||||
let src_bytes = self.storage.as_cpu_bytes();
|
let src_bytes = self.storage.as_cpu_bytes();
|
||||||
@@ -173,17 +178,18 @@ impl Tensor {
|
|||||||
// --- Device transfer ---
|
// --- Device transfer ---
|
||||||
|
|
||||||
pub fn to_device(&self, device: Device) -> Self {
|
pub fn to_device(&self, device: Device) -> Self {
|
||||||
let t = if self.is_contiguous() { self.clone() } else { self.contiguous() };
|
if self.device() == device {
|
||||||
if t.device() == device {
|
return self.clone();
|
||||||
return t;
|
|
||||||
}
|
}
|
||||||
let new_storage = t.storage.to_device(device).expect("device transfer failed");
|
// Transfer the raw storage (preserving strides/offset).
|
||||||
|
// Non-contiguous layout is preserved — the user can call contiguous() after.
|
||||||
|
let new_storage = self.storage.to_device(device).expect("device transfer failed");
|
||||||
Self {
|
Self {
|
||||||
storage: new_storage,
|
storage: new_storage,
|
||||||
shape: t.shape,
|
shape: self.shape.clone(),
|
||||||
strides: t.strides,
|
strides: self.strides.clone(),
|
||||||
offset: 0,
|
offset: self.offset,
|
||||||
dtype: t.dtype,
|
dtype: self.dtype,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,16 @@ __global__ void silu_bf16(const __nv_bfloat16* x, __nv_bfloat16* out, int n) {
|
|||||||
if (idx < n) out[idx] = __float2bfloat16(silu_f(__bfloat162float(x[idx])));
|
if (idx < n) out[idx] = __float2bfloat16(silu_f(__bfloat162float(x[idx])));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void scale_f32_kernel(const float* x, float* out, float scale, int n) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx < n) out[idx] = x[idx] * scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void scale_bf16_kernel(const __nv_bfloat16* x, __nv_bfloat16* out, float scale, int n) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(x[idx]) * scale);
|
||||||
|
}
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
void launch_gelu_f32(const void* x, void* out, int n, void* stream) {
|
void launch_gelu_f32(const void* x, void* out, int n, void* stream) {
|
||||||
@@ -63,4 +73,18 @@ void launch_silu_bf16(const void* x, void* out, int n, void* stream) {
|
|||||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void launch_scale_f32(const void* x, void* out, float scale, int n, void* stream) {
|
||||||
|
int block = 256;
|
||||||
|
int grid = (n + block - 1) / block;
|
||||||
|
scale_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
|
(const float*)x, (float*)out, scale, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_scale_bf16(const void* x, void* out, float scale, int n, void* stream) {
|
||||||
|
int block = 256;
|
||||||
|
int grid = (n + block - 1) / block;
|
||||||
|
scale_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
|
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, scale, n);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
53
csrc/attention/causal_mask.cu
Normal file
53
csrc/attention/causal_mask.cu
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
#include <cuda_bf16.h>
|
||||||
|
|
||||||
|
// Apply causal mask: set scores[row][col] = -inf where col > row + offset.
|
||||||
|
// offset is used for KV cache: when query starts at position `offset`,
|
||||||
|
// we allow attending to positions [0, offset + row].
|
||||||
|
// scores: [batch, rows, cols] (flattened batch×heads)
|
||||||
|
|
||||||
|
__global__ void causal_mask_f32(
|
||||||
|
float* __restrict__ scores,
|
||||||
|
int rows, int cols, int offset
|
||||||
|
) {
|
||||||
|
int batch_idx = blockIdx.z;
|
||||||
|
int row = blockIdx.y;
|
||||||
|
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (col < cols && col > row + offset) {
|
||||||
|
scores[batch_idx * rows * cols + row * cols + col] = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void causal_mask_bf16(
|
||||||
|
__nv_bfloat16* __restrict__ scores,
|
||||||
|
int rows, int cols, int offset
|
||||||
|
) {
|
||||||
|
int batch_idx = blockIdx.z;
|
||||||
|
int row = blockIdx.y;
|
||||||
|
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (col < cols && col > row + offset) {
|
||||||
|
// BF16 doesn't have proper -inf literal, use a very large negative
|
||||||
|
scores[batch_idx * rows * cols + row * cols + col] = __float2bfloat16(-1e9f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
void launch_causal_mask_f32(void* scores, int batch, int rows, int cols,
|
||||||
|
int offset, void* stream) {
|
||||||
|
int block = 256;
|
||||||
|
dim3 grid((cols + block - 1) / block, rows, batch);
|
||||||
|
causal_mask_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
|
(float*)scores, rows, cols, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_causal_mask_bf16(void* scores, int batch, int rows, int cols,
|
||||||
|
int offset, void* stream) {
|
||||||
|
int block = 256;
|
||||||
|
dim3 grid((cols + block - 1) / block, rows, batch);
|
||||||
|
causal_mask_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
|
(__nv_bfloat16*)scores, rows, cols, offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
92
docs/05-attention.md
Normal file
92
docs/05-attention.md
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
# Phase 5: Naive Attention Kernel — Design Document
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
实现标准 Multi-Head Attention(不做 Flash/Paged 优化),用组合式方法(GEMM + Softmax)完成。这是理解 attention 计算流程的基础,也是后续 Flash Attention 的 baseline。
|
||||||
|
|
||||||
|
## 计算流程
|
||||||
|
|
||||||
|
```
|
||||||
|
Input: Q [B, H, S, D], K [B, H, S, D], V [B, H, S, D]
|
||||||
|
B=batch, H=num_heads, S=seq_len, D=head_dim
|
||||||
|
|
||||||
|
1. scores = Q @ K^T / sqrt(D) → [B, H, S, S]
|
||||||
|
2. scores += causal_mask → 上三角置为 -inf
|
||||||
|
3. weights = softmax(scores, dim=-1) → [B, H, S, S]
|
||||||
|
4. output = weights @ V → [B, H, S, D]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 设计选择
|
||||||
|
|
||||||
|
### 组合式实现(Phase 3 GEMM + Phase 4 Softmax)
|
||||||
|
|
||||||
|
不写新的 fused CUDA kernel,而是复用已有的 matmul 和 softmax:
|
||||||
|
- `scores = batched_matmul(Q, K^T)` — 需要支持 batched GEMM
|
||||||
|
- `masked_fill(scores, causal_mask, -inf)` — 新的逐元素 kernel
|
||||||
|
- `softmax(scores)` — 复用 Phase 4
|
||||||
|
- `output = batched_matmul(weights, V)` — 复用 batched GEMM
|
||||||
|
|
||||||
|
这意味着需要先扩展 matmul 支持 batched GEMM(cublasGemmStridedBatchedEx)。
|
||||||
|
|
||||||
|
### Causal Mask
|
||||||
|
|
||||||
|
不显式构造 mask 矩阵。写一个 kernel:
|
||||||
|
```
|
||||||
|
if (col > row + offset) score = -infinity
|
||||||
|
```
|
||||||
|
其中 offset 用于支持 KV cache 场景(decode 时 query 的 row 偏移)。
|
||||||
|
|
||||||
|
### Batched GEMM via cuBLAS
|
||||||
|
|
||||||
|
`cublasGemmStridedBatchedEx` 在一个 batch 维度上并行执行多个 GEMM:
|
||||||
|
```
|
||||||
|
C[b] = A[b] @ B[b] for b = 0..batch_count
|
||||||
|
stride_a = M * K, stride_b = K * N, stride_c = M * N
|
||||||
|
```
|
||||||
|
|
||||||
|
Attention 中 batch 维度 = B * H(batch_size × num_heads)。
|
||||||
|
|
||||||
|
## 文件布局
|
||||||
|
|
||||||
|
```
|
||||||
|
csrc/attention/
|
||||||
|
└── causal_mask.cu # causal mask fill kernel
|
||||||
|
|
||||||
|
crates/xserv-kernels/src/
|
||||||
|
├── gemm.rs # 扩展: batched_matmul
|
||||||
|
├── attention.rs # NEW: multi_head_attention()
|
||||||
|
└── causal_mask.rs # NEW: causal mask apply
|
||||||
|
```
|
||||||
|
|
||||||
|
## API 设计
|
||||||
|
|
||||||
|
```rust
|
||||||
|
/// Multi-head attention (naive, materializes S×S scores).
|
||||||
|
/// q, k, v: [batch, num_heads, seq_len, head_dim]
|
||||||
|
/// Returns: [batch, num_heads, seq_len, head_dim]
|
||||||
|
pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor;
|
||||||
|
|
||||||
|
/// Batched matmul: A[b] @ B[b] for all b.
|
||||||
|
/// a: [..., M, K], b: [..., K, N] → [..., M, N]
|
||||||
|
pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor;
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test Plan
|
||||||
|
|
||||||
|
- [x] batched_matmul: [4,8,32,64]×[4,8,64,32] → max_err 2.7e-7
|
||||||
|
- [x] attention (non-causal): B=1,H=2,S=8,D=16 → max_err 4.5e-8
|
||||||
|
- [x] attention (causal): B=1,H=2,S=16,D=32 → max_err 3.0e-8
|
||||||
|
- [x] attention (causal, larger): B=2,H=4,S=64,D=64 → max_err 6.0e-8
|
||||||
|
- [x] causal mask 语义: position 0 只能看到 token 0,output[0] == V[0] → exact
|
||||||
|
|
||||||
|
## Takeaways
|
||||||
|
|
||||||
|
1. **`to_device` 不应强制 contiguous**:最初 `to_device()` 会先调 `contiguous()`,而 GPU 的 `contiguous()` 又调 `to_device(Cpu)`,导致无限递归栈溢出。修复:`to_device()` 直接传输 raw storage,保留 strides/offset,用户需要时自己调 `contiguous()`。GPU `contiguous()` 现在走 GPU→CPU→CPU contiguous→CPU→GPU 路径——正确但低效,Phase 15 需要写 GPU contiguous kernel。
|
||||||
|
|
||||||
|
2. **Batched GEMM via `cublasGemmStridedBatchedEx`**:row-major trick 同 Phase 3,额外参数是 stride(元素数,不是字节)。stride_a = M×K, stride_b = K×N, stride_c = M×N。注意初始版本错误地乘了 `elem_size`,cuBLAS 的 stride 单位是元素。
|
||||||
|
|
||||||
|
3. **Attention 的组合式实现足够验证正确性**:没有写 fused kernel,而是复用 `batched_matmul` + `scale` + `causal_mask` + `softmax`。精度极好(max_err < 1e-7),因为每步都在 FP32 中完成。缺点是 S×S score 矩阵完全 materialize(O(S²) 显存),Flash Attention 会解决。
|
||||||
|
|
||||||
|
4. **Scale kernel 的必要性**:原本想在 CPU 上做 scale(round-trip),但那太慢了。加了 `scale_f32/bf16` 逐元素 CUDA kernel。未来可以把 scale 合进 GEMM 的 alpha 参数,省一次 kernel launch。
|
||||||
|
|
||||||
|
5. **Causal mask 的 offset 设计**:`col > row + offset` 中的 offset 为 KV cache 场景预留。Decode 时 Q 只有 1 行但 KV cache 有前 S 行,offset = kv_len - q_len 确保 decode query 能看到所有 cached tokens。
|
||||||
Reference in New Issue
Block a user