diff --git a/crates/xserv-kernels/build.rs b/crates/xserv-kernels/build.rs index aa99297..23298c3 100644 --- a/crates/xserv-kernels/build.rs +++ b/crates/xserv-kernels/build.rs @@ -22,6 +22,7 @@ fn main() { .file("../../csrc/reduce/softmax.cu") .file("../../csrc/embedding/embedding.cu") .file("../../csrc/embedding/rope.cu") + .file("../../csrc/attention/causal_mask.cu") .compile("xserv_kernels"); println!("cargo:rerun-if-changed=../../csrc/"); diff --git a/crates/xserv-kernels/src/activation.rs b/crates/xserv-kernels/src/activation.rs index ad08f41..09d8e2b 100644 --- a/crates/xserv-kernels/src/activation.rs +++ b/crates/xserv-kernels/src/activation.rs @@ -6,6 +6,8 @@ unsafe extern "C" { fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); + fn launch_scale_f32(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void); + fn launch_scale_bf16(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void); } pub fn gelu(x: &Tensor) -> Tensor { @@ -39,3 +41,19 @@ pub fn silu(x: &Tensor) -> Tensor { xserv_cuda::device::synchronize().unwrap(); out } + +pub fn scale(x: &Tensor, scale_val: f32) -> Tensor { + assert!(x.is_contiguous()); + assert!(matches!(x.device(), Device::Cuda(_))); + let out = Tensor::zeros(x.shape(), x.dtype(), x.device()); + let n = x.numel() as i32; + unsafe { + match x.dtype() { + DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()), + DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()), + _ => panic!("unsupported dtype for scale"), + } + } + xserv_cuda::device::synchronize().unwrap(); + out +} diff --git a/crates/xserv-kernels/src/attention.rs b/crates/xserv-kernels/src/attention.rs new file mode 100644 index 0000000..3d1fb48 --- /dev/null +++ b/crates/xserv-kernels/src/attention.rs @@ -0,0 +1,77 @@ +use std::ffi::c_void; +use xserv_tensor::{DType, Tensor}; + +use crate::activation::scale; +use crate::gemm::batched_matmul; +use crate::softmax::softmax; + +unsafe extern "C" { + fn launch_causal_mask_f32(scores: *mut c_void, batch: i32, rows: i32, cols: i32, + offset: i32, stream: *mut c_void); + fn launch_causal_mask_bf16(scores: *mut c_void, batch: i32, rows: i32, cols: i32, + offset: i32, stream: *mut c_void); +} + +fn apply_causal_mask(scores: &Tensor, offset: usize) { + let ndim = scores.ndim(); + let rows = scores.shape()[ndim - 2]; + let cols = scores.shape()[ndim - 1]; + let batch: usize = scores.shape()[..ndim - 2].iter().product(); + + unsafe { + match scores.dtype() { + DType::F32 => launch_causal_mask_f32( + scores.data_ptr() as *mut c_void, + batch as i32, rows as i32, cols as i32, offset as i32, + std::ptr::null_mut(), + ), + DType::BF16 => launch_causal_mask_bf16( + scores.data_ptr() as *mut c_void, + batch as i32, rows as i32, cols as i32, offset as i32, + std::ptr::null_mut(), + ), + _ => panic!("unsupported dtype for causal mask"), + } + } + xserv_cuda::device::synchronize().unwrap(); +} + +/// Multi-head attention (naive, materializes S×S score matrix). +/// +/// q, k, v: [batch, num_heads, seq_len, head_dim] — contiguous, on GPU +/// Returns: [batch, num_heads, seq_len, head_dim] +pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor { + assert_eq!(q.ndim(), 4); + assert_eq!(k.ndim(), 4); + assert_eq!(v.ndim(), 4); + assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous()); + + let batch = q.shape()[0]; + let num_heads = q.shape()[1]; + let q_len = q.shape()[2]; + let head_dim = q.shape()[3]; + let kv_len = k.shape()[2]; + + assert_eq!(k.shape(), &[batch, num_heads, kv_len, head_dim]); + assert_eq!(v.shape(), &[batch, num_heads, kv_len, head_dim]); + + // scores = Q @ K^T → [B, H, q_len, kv_len] + let k_t = k.transpose(2, 3).contiguous(); + let scores = batched_matmul(q, &k_t); + + // Scale by 1/sqrt(head_dim) + let scale_factor = 1.0 / (head_dim as f32).sqrt(); + let scaled_scores = scale(&scores, scale_factor); + + // Causal mask + if causal { + let offset = kv_len - q_len; + apply_causal_mask(&scaled_scores, offset); + } + + // Softmax + let weights = softmax(&scaled_scores); + + // output = weights @ V → [B, H, q_len, head_dim] + batched_matmul(&weights, v) +} diff --git a/crates/xserv-kernels/src/gemm.rs b/crates/xserv-kernels/src/gemm.rs index 533b2bc..7bc5a94 100644 --- a/crates/xserv-kernels/src/gemm.rs +++ b/crates/xserv-kernels/src/gemm.rs @@ -46,6 +46,19 @@ unsafe extern "C" { compute_type: i32, algo: i32, ) -> i32; + fn cublasGemmStridedBatchedEx( + handle: CublasHandle, + transa: i32, transb: i32, + m: i32, n: i32, k: i32, + alpha: *const c_void, + a: *const c_void, a_type: i32, lda: i32, stride_a: i64, + b: *const c_void, b_type: i32, ldb: i32, stride_b: i64, + beta: *const c_void, + c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64, + batch_count: i32, + compute_type: i32, + algo: i32, + ) -> i32; } pub struct CublasContext { @@ -149,3 +162,68 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor { c } + +/// Batched matrix multiplication via cuBLAS: C[b] = A[b] @ B[b] +/// a: [..., M, K], b: [..., K, N] → [..., M, N] +/// Leading dimensions must match and tensors must be contiguous. +pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor { + assert!(a.ndim() >= 2 && b.ndim() >= 2); + assert_eq!(a.ndim(), b.ndim()); + assert!(a.is_contiguous() && b.is_contiguous()); + assert!(matches!(a.device(), Device::Cuda(_))); + assert_eq!(a.dtype(), b.dtype()); + + let ndim = a.ndim(); + let m = a.shape()[ndim - 2]; + let k = a.shape()[ndim - 1]; + let n = b.shape()[ndim - 1]; + assert_eq!(b.shape()[ndim - 2], k, "inner dimension mismatch"); + + // Compute batch count from leading dimensions + let batch: usize = a.shape()[..ndim - 2].iter().product(); + assert_eq!( + b.shape()[..ndim - 2].iter().product::(), + batch, + "batch dimensions mismatch" + ); + + let mut out_shape: Vec = a.shape()[..ndim - 2].to_vec(); + out_shape.push(m); + out_shape.push(n); + let c = Tensor::zeros(&out_shape, a.dtype(), a.device()); + + let dtype = a.dtype(); + let (a_type, b_type, c_type) = match dtype { + DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F), + DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF), + _ => panic!("unsupported dtype for batched matmul"), + }; + + let alpha = 1.0f32; + let beta = 0.0f32; + // cuBLAS strides are in elements (not bytes) + let stride_a = (m * k) as i64; + let stride_b = (k * n) as i64; + let stride_c = (m * n) as i64; + + let ctx = CublasContext::new().unwrap(); + unsafe { + cublasSetStream_v2(ctx.handle, std::ptr::null_mut()); + // Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major) + error::check(cublasGemmStridedBatchedEx( + ctx.handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n as i32, m as i32, k as i32, + &alpha as *const f32 as *const c_void, + b.data_ptr() as _, b_type, n as i32, stride_b, + a.data_ptr() as _, a_type, k as i32, stride_a, + &beta as *const f32 as *const c_void, + c.data_ptr() as *mut c_void, c_type, n as i32, stride_c, + batch as i32, + CUBLAS_COMPUTE_32F, + -1, + )).expect("cuBLAS batched GEMM failed"); + } + xserv_cuda::device::synchronize().unwrap(); + c +} diff --git a/crates/xserv-kernels/src/lib.rs b/crates/xserv-kernels/src/lib.rs index 44a9db0..20cab23 100644 --- a/crates/xserv-kernels/src/lib.rs +++ b/crates/xserv-kernels/src/lib.rs @@ -1,4 +1,5 @@ pub mod activation; +pub mod attention; pub mod embedding; pub mod gemm; pub mod layernorm; @@ -6,9 +7,10 @@ pub mod rmsnorm; pub mod rope; pub mod softmax; -pub use activation::{gelu, silu}; +pub use activation::{gelu, scale, silu}; +pub use attention::attention; pub use embedding::embedding; -pub use gemm::{matmul, GemmBackend}; +pub use gemm::{batched_matmul, matmul, GemmBackend}; pub use layernorm::layernorm; pub use rmsnorm::rmsnorm; pub use rope::{rope_inplace, RopeCache}; diff --git a/crates/xserv-kernels/tests/attention_test.rs b/crates/xserv-kernels/tests/attention_test.rs new file mode 100644 index 0000000..744b58e --- /dev/null +++ b/crates/xserv-kernels/tests/attention_test.rs @@ -0,0 +1,187 @@ +use xserv_kernels::*; +use xserv_tensor::{Device, Tensor}; + +fn init() { xserv_cuda::device::set_device(0).unwrap(); } + +fn cpu_attention(q: &[f32], k: &[f32], v: &[f32], + batch: usize, heads: usize, q_len: usize, kv_len: usize, head_dim: usize, + causal: bool) -> Vec { + let mut out = vec![0.0f32; batch * heads * q_len * head_dim]; + let scale = 1.0 / (head_dim as f32).sqrt(); + + for b in 0..batch { + for h in 0..heads { + // scores = Q @ K^T, scaled + let mut scores = vec![0.0f32; q_len * kv_len]; + for i in 0..q_len { + for j in 0..kv_len { + let mut s = 0.0f32; + for d in 0..head_dim { + let qi = q[((b * heads + h) * q_len + i) * head_dim + d]; + let ki = k[((b * heads + h) * kv_len + j) * head_dim + d]; + s += qi * ki; + } + scores[i * kv_len + j] = s * scale; + } + } + // causal mask + if causal { + let offset = kv_len - q_len; + for i in 0..q_len { + for j in 0..kv_len { + if j > i + offset { + scores[i * kv_len + j] = f32::NEG_INFINITY; + } + } + } + } + // softmax per row + for i in 0..q_len { + let row = &mut scores[i * kv_len..(i + 1) * kv_len]; + let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max); + let mut sum = 0.0f32; + for v in row.iter_mut() { + *v = (*v - max).exp(); + sum += *v; + } + for v in row.iter_mut() { + *v /= sum; + } + } + // output = weights @ V + for i in 0..q_len { + for d in 0..head_dim { + let mut s = 0.0f32; + for j in 0..kv_len { + let w = scores[i * kv_len + j]; + let vi = v[((b * heads + h) * kv_len + j) * head_dim + d]; + s += w * vi; + } + out[((b * heads + h) * q_len + i) * head_dim + d] = s; + } + } + } + } + out +} + +fn check_close(a: &[f32], b: &[f32], atol: f32, name: &str) { + assert_eq!(a.len(), b.len(), "{name}: length mismatch"); + let mut max_err = 0.0f32; + for (i, (x, y)) in a.iter().zip(b).enumerate() { + let err = (x - y).abs(); + if err > max_err { max_err = err; } + assert!(err <= atol, "{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}"); + } + println!("{name}: max_err = {max_err:.6e}"); +} + +fn make_data(n: usize) -> Vec { + (0..n).map(|i| ((i % 17) as f32 - 8.0) * 0.05).collect() +} + +#[test] +fn test_batched_matmul() { + init(); + let batch = 4; + let heads = 8; + let m = 32; + let k = 64; + let n = 32; + + let a_data = make_data(batch * heads * m * k); + let b_data = make_data(batch * heads * k * n); + + let a = Tensor::from_slice(&a_data, &[batch, heads, m, k]).to_device(Device::Cuda(0)); + let b = Tensor::from_slice(&b_data, &[batch, heads, k, n]).to_device(Device::Cuda(0)); + let c = batched_matmul(&a, &b).to_device(Device::Cpu); + + assert_eq!(c.shape(), &[batch, heads, m, n]); + + // Verify one batch element + let a_cpu = &a_data[0..m * k]; + let b_cpu = &b_data[0..k * n]; + let mut expected = vec![0.0f32; m * n]; + for i in 0..m { + for j in 0..n { + let mut s = 0.0f32; + for kk in 0..k { s += a_cpu[i * k + kk] * b_cpu[kk * n + j]; } + expected[i * n + j] = s; + } + } + let result = c.as_slice::(); + check_close(&result[0..m * n], &expected, 1e-3, "batched_matmul[0]"); +} + +#[test] +fn test_attention_no_causal() { + init(); + let b = 1; let h = 2; let s = 8; let d = 16; + let q_data = make_data(b * h * s * d); + let k_data = make_data(b * h * s * d); + let v_data = make_data(b * h * s * d); + let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, false); + + let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0)); + let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0)); + let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0)); + let out = attention(&q, &k, &v, false).to_device(Device::Cpu); + check_close(out.as_slice::(), &expected, 1e-4, "attention_no_causal"); +} + +#[test] +fn test_attention_causal() { + init(); + let b = 1; let h = 2; let s = 16; let d = 32; + let q_data = make_data(b * h * s * d); + let k_data = make_data(b * h * s * d); + let v_data = make_data(b * h * s * d); + let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true); + + let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0)); + let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0)); + let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0)); + let out = attention(&q, &k, &v, true).to_device(Device::Cpu); + check_close(out.as_slice::(), &expected, 1e-3, "attention_causal"); +} + +#[test] +fn test_attention_causal_larger() { + init(); + let b = 2; let h = 4; let s = 64; let d = 64; + let q_data = make_data(b * h * s * d); + let k_data = make_data(b * h * s * d); + let v_data = make_data(b * h * s * d); + let expected = cpu_attention(&q_data, &k_data, &v_data, b, h, s, s, d, true); + + let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0)); + let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0)); + let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0)); + let out = attention(&q, &k, &v, true).to_device(Device::Cpu); + check_close(out.as_slice::(), &expected, 1e-2, "attention_causal_larger"); +} + +#[test] +fn test_attention_causal_first_row_sees_only_first_token() { + init(); + let b = 1; let h = 1; let s = 4; let d = 8; + let q_data = make_data(b * h * s * d); + let k_data = make_data(b * h * s * d); + let v_data: Vec = (0..s * d).map(|i| { + if i < d { 1.0 } else { 0.0 } // only first V row is nonzero + }).collect(); + + let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0)); + let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0)); + let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0)); + let out = attention(&q, &k, &v, true).to_device(Device::Cpu); + + // First row (position 0) with causal mask can only see position 0. + // So attention weight for position 0 is 1.0 for token 0 only. + // output[0] should be exactly V[0] = [1, 1, 1, ...1] + let result = out.as_slice::(); + for i in 0..d { + assert!((result[i] - 1.0).abs() < 1e-5, + "first row should equal V[0], got {} at dim {}", result[i], i); + } +} diff --git a/crates/xserv-tensor/src/tensor.rs b/crates/xserv-tensor/src/tensor.rs index e94b9f4..b8dad63 100644 --- a/crates/xserv-tensor/src/tensor.rs +++ b/crates/xserv-tensor/src/tensor.rs @@ -137,8 +137,13 @@ impl Tensor { if self.is_contiguous() { return self.clone(); } - // Copy to contiguous layout on CPU - assert_eq!(self.device(), Device::Cpu, "contiguous() on GPU not yet supported"); + // For GPU tensors: round-trip through CPU (correct but slow). + // TODO: write a GPU contiguous-copy kernel for performance. + if matches!(self.device(), Device::Cuda(_)) { + let cpu = self.to_device(Device::Cpu); + let contig = cpu.contiguous(); + return contig.to_device(self.device()); + } let numel = self.numel(); let elem_size = self.dtype.size_bytes(); let src_bytes = self.storage.as_cpu_bytes(); @@ -173,17 +178,18 @@ impl Tensor { // --- Device transfer --- pub fn to_device(&self, device: Device) -> Self { - let t = if self.is_contiguous() { self.clone() } else { self.contiguous() }; - if t.device() == device { - return t; + if self.device() == device { + return self.clone(); } - let new_storage = t.storage.to_device(device).expect("device transfer failed"); + // Transfer the raw storage (preserving strides/offset). + // Non-contiguous layout is preserved — the user can call contiguous() after. + let new_storage = self.storage.to_device(device).expect("device transfer failed"); Self { storage: new_storage, - shape: t.shape, - strides: t.strides, - offset: 0, - dtype: t.dtype, + shape: self.shape.clone(), + strides: self.strides.clone(), + offset: self.offset, + dtype: self.dtype, } } diff --git a/csrc/activation/activations.cu b/csrc/activation/activations.cu index 8231a4a..3593dc1 100644 --- a/csrc/activation/activations.cu +++ b/csrc/activation/activations.cu @@ -35,6 +35,16 @@ __global__ void silu_bf16(const __nv_bfloat16* x, __nv_bfloat16* out, int n) { if (idx < n) out[idx] = __float2bfloat16(silu_f(__bfloat162float(x[idx]))); } +__global__ void scale_f32_kernel(const float* x, float* out, float scale, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) out[idx] = x[idx] * scale; +} + +__global__ void scale_bf16_kernel(const __nv_bfloat16* x, __nv_bfloat16* out, float scale, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(x[idx]) * scale); +} + extern "C" { void launch_gelu_f32(const void* x, void* out, int n, void* stream) { @@ -63,4 +73,18 @@ void launch_silu_bf16(const void* x, void* out, int n, void* stream) { (const __nv_bfloat16*)x, (__nv_bfloat16*)out, n); } +void launch_scale_f32(const void* x, void* out, float scale, int n, void* stream) { + int block = 256; + int grid = (n + block - 1) / block; + scale_f32_kernel<<>>( + (const float*)x, (float*)out, scale, n); +} + +void launch_scale_bf16(const void* x, void* out, float scale, int n, void* stream) { + int block = 256; + int grid = (n + block - 1) / block; + scale_bf16_kernel<<>>( + (const __nv_bfloat16*)x, (__nv_bfloat16*)out, scale, n); +} + } diff --git a/csrc/attention/causal_mask.cu b/csrc/attention/causal_mask.cu new file mode 100644 index 0000000..e6a60a3 --- /dev/null +++ b/csrc/attention/causal_mask.cu @@ -0,0 +1,53 @@ +#include + +// Apply causal mask: set scores[row][col] = -inf where col > row + offset. +// offset is used for KV cache: when query starts at position `offset`, +// we allow attending to positions [0, offset + row]. +// scores: [batch, rows, cols] (flattened batch×heads) + +__global__ void causal_mask_f32( + float* __restrict__ scores, + int rows, int cols, int offset +) { + int batch_idx = blockIdx.z; + int row = blockIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + if (col < cols && col > row + offset) { + scores[batch_idx * rows * cols + row * cols + col] = -INFINITY; + } +} + +__global__ void causal_mask_bf16( + __nv_bfloat16* __restrict__ scores, + int rows, int cols, int offset +) { + int batch_idx = blockIdx.z; + int row = blockIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + if (col < cols && col > row + offset) { + // BF16 doesn't have proper -inf literal, use a very large negative + scores[batch_idx * rows * cols + row * cols + col] = __float2bfloat16(-1e9f); + } +} + +extern "C" { + +void launch_causal_mask_f32(void* scores, int batch, int rows, int cols, + int offset, void* stream) { + int block = 256; + dim3 grid((cols + block - 1) / block, rows, batch); + causal_mask_f32<<>>( + (float*)scores, rows, cols, offset); +} + +void launch_causal_mask_bf16(void* scores, int batch, int rows, int cols, + int offset, void* stream) { + int block = 256; + dim3 grid((cols + block - 1) / block, rows, batch); + causal_mask_bf16<<>>( + (__nv_bfloat16*)scores, rows, cols, offset); +} + +} diff --git a/docs/05-attention.md b/docs/05-attention.md new file mode 100644 index 0000000..69266d1 --- /dev/null +++ b/docs/05-attention.md @@ -0,0 +1,92 @@ +# Phase 5: Naive Attention Kernel — Design Document + +## Goal + +实现标准 Multi-Head Attention(不做 Flash/Paged 优化),用组合式方法(GEMM + Softmax)完成。这是理解 attention 计算流程的基础,也是后续 Flash Attention 的 baseline。 + +## 计算流程 + +``` +Input: Q [B, H, S, D], K [B, H, S, D], V [B, H, S, D] + B=batch, H=num_heads, S=seq_len, D=head_dim + +1. scores = Q @ K^T / sqrt(D) → [B, H, S, S] +2. scores += causal_mask → 上三角置为 -inf +3. weights = softmax(scores, dim=-1) → [B, H, S, S] +4. output = weights @ V → [B, H, S, D] +``` + +## 设计选择 + +### 组合式实现(Phase 3 GEMM + Phase 4 Softmax) + +不写新的 fused CUDA kernel,而是复用已有的 matmul 和 softmax: +- `scores = batched_matmul(Q, K^T)` — 需要支持 batched GEMM +- `masked_fill(scores, causal_mask, -inf)` — 新的逐元素 kernel +- `softmax(scores)` — 复用 Phase 4 +- `output = batched_matmul(weights, V)` — 复用 batched GEMM + +这意味着需要先扩展 matmul 支持 batched GEMM(cublasGemmStridedBatchedEx)。 + +### Causal Mask + +不显式构造 mask 矩阵。写一个 kernel: +``` +if (col > row + offset) score = -infinity +``` +其中 offset 用于支持 KV cache 场景(decode 时 query 的 row 偏移)。 + +### Batched GEMM via cuBLAS + +`cublasGemmStridedBatchedEx` 在一个 batch 维度上并行执行多个 GEMM: +``` +C[b] = A[b] @ B[b] for b = 0..batch_count +stride_a = M * K, stride_b = K * N, stride_c = M * N +``` + +Attention 中 batch 维度 = B * H(batch_size × num_heads)。 + +## 文件布局 + +``` +csrc/attention/ +└── causal_mask.cu # causal mask fill kernel + +crates/xserv-kernels/src/ +├── gemm.rs # 扩展: batched_matmul +├── attention.rs # NEW: multi_head_attention() +└── causal_mask.rs # NEW: causal mask apply +``` + +## API 设计 + +```rust +/// Multi-head attention (naive, materializes S×S scores). +/// q, k, v: [batch, num_heads, seq_len, head_dim] +/// Returns: [batch, num_heads, seq_len, head_dim] +pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor; + +/// Batched matmul: A[b] @ B[b] for all b. +/// a: [..., M, K], b: [..., K, N] → [..., M, N] +pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor; +``` + +## Test Plan + +- [x] batched_matmul: [4,8,32,64]×[4,8,64,32] → max_err 2.7e-7 +- [x] attention (non-causal): B=1,H=2,S=8,D=16 → max_err 4.5e-8 +- [x] attention (causal): B=1,H=2,S=16,D=32 → max_err 3.0e-8 +- [x] attention (causal, larger): B=2,H=4,S=64,D=64 → max_err 6.0e-8 +- [x] causal mask 语义: position 0 只能看到 token 0,output[0] == V[0] → exact + +## Takeaways + +1. **`to_device` 不应强制 contiguous**:最初 `to_device()` 会先调 `contiguous()`,而 GPU 的 `contiguous()` 又调 `to_device(Cpu)`,导致无限递归栈溢出。修复:`to_device()` 直接传输 raw storage,保留 strides/offset,用户需要时自己调 `contiguous()`。GPU `contiguous()` 现在走 GPU→CPU→CPU contiguous→CPU→GPU 路径——正确但低效,Phase 15 需要写 GPU contiguous kernel。 + +2. **Batched GEMM via `cublasGemmStridedBatchedEx`**:row-major trick 同 Phase 3,额外参数是 stride(元素数,不是字节)。stride_a = M×K, stride_b = K×N, stride_c = M×N。注意初始版本错误地乘了 `elem_size`,cuBLAS 的 stride 单位是元素。 + +3. **Attention 的组合式实现足够验证正确性**:没有写 fused kernel,而是复用 `batched_matmul` + `scale` + `causal_mask` + `softmax`。精度极好(max_err < 1e-7),因为每步都在 FP32 中完成。缺点是 S×S score 矩阵完全 materialize(O(S²) 显存),Flash Attention 会解决。 + +4. **Scale kernel 的必要性**:原本想在 CPU 上做 scale(round-trip),但那太慢了。加了 `scale_f32/bf16` 逐元素 CUDA kernel。未来可以把 scale 合进 GEMM 的 alpha 参数,省一次 kernel launch。 + +5. **Causal mask 的 offset 设计**:`col > row + offset` 中的 offset 为 KV cache 场景预留。Decode 时 Q 只有 1 行但 KV cache 有前 S 行,offset = kv_len - q_len 确保 decode query 能看到所有 cached tokens。