kernels/cuda: paged-attention kernel, dispatch, pinned host memory
CUDA layer for the paged-KV + swap work: - csrc: new paged_attention.cu plus updates across attention/gemm/norm/ activation/embedding/reduce kernels and common.cuh. - xserv-kernels: new dispatch module and kernel-binding updates. - xserv-cuda: cudaMallocHost/FreeHost bindings + PinnedBuffer (host swap pool backing) and offset-aware D2H/H2D copies used to move KV blocks between the GPU pool and pinned host memory. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -39,6 +39,7 @@ unsafe extern "C" {
|
|||||||
stream: CudaStream,
|
stream: CudaStream,
|
||||||
) -> i32;
|
) -> i32;
|
||||||
pub fn cudaMemset(devptr: *mut u8, value: i32, count: usize) -> i32;
|
pub fn cudaMemset(devptr: *mut u8, value: i32, count: usize) -> i32;
|
||||||
|
pub fn cudaMemsetAsync(devptr: *mut u8, value: i32, count: usize, stream: CudaStream) -> i32;
|
||||||
|
|
||||||
// --- Stream ---
|
// --- Stream ---
|
||||||
pub fn cudaStreamCreate(stream: *mut CudaStream) -> i32;
|
pub fn cudaStreamCreate(stream: *mut CudaStream) -> i32;
|
||||||
|
|||||||
@@ -116,6 +116,56 @@ impl GpuBuffer {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Async copy `count` bytes from `src` at `src_offset` to `self` at `dst_offset` on `stream`.
|
||||||
|
pub fn copy_from_device_at_async(&mut self, src: &GpuBuffer, src_offset: usize, dst_offset: usize, count: usize, stream: &CudaStream) -> Result<()> {
|
||||||
|
assert!(src_offset + count <= src.len);
|
||||||
|
assert!(dst_offset + count <= self.len);
|
||||||
|
error::check(unsafe {
|
||||||
|
ffi::cudaMemcpyAsync(
|
||||||
|
self.ptr.add(dst_offset),
|
||||||
|
src.ptr.add(src_offset),
|
||||||
|
count,
|
||||||
|
ffi::CUDA_MEMCPY_D2D,
|
||||||
|
stream.as_raw(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copy `count` bytes from this GPU buffer at `src_offset` to a host slice (D2H).
|
||||||
|
pub fn copy_to_host_at(&self, dst: &mut [u8], src_offset: usize, count: usize) -> Result<()> {
|
||||||
|
assert!(src_offset + count <= self.len, "src range out of bounds");
|
||||||
|
assert!(count <= dst.len(), "host dst too small");
|
||||||
|
error::check(unsafe {
|
||||||
|
ffi::cudaMemcpy(
|
||||||
|
dst.as_mut_ptr(),
|
||||||
|
self.ptr.add(src_offset),
|
||||||
|
count,
|
||||||
|
ffi::CUDA_MEMCPY_D2H,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Copy `count` bytes from a host slice to this GPU buffer at `dst_offset` (H2D).
|
||||||
|
pub fn copy_from_host_at(&mut self, src: &[u8], dst_offset: usize, count: usize) -> Result<()> {
|
||||||
|
assert!(dst_offset + count <= self.len, "dst range out of bounds");
|
||||||
|
assert!(count <= src.len(), "host src too small");
|
||||||
|
error::check(unsafe {
|
||||||
|
ffi::cudaMemcpy(
|
||||||
|
self.ptr.add(dst_offset),
|
||||||
|
src.as_ptr(),
|
||||||
|
count,
|
||||||
|
ffi::CUDA_MEMCPY_H2D,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Async zero fill on stream.
|
||||||
|
pub fn zero_async(&mut self, stream: &CudaStream) -> Result<()> {
|
||||||
|
error::check(unsafe {
|
||||||
|
ffi::cudaMemsetAsync(self.ptr, 0, self.len, stream.as_raw())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Consume the buffer without freeing GPU memory. Returns the raw pointer and length.
|
/// Consume the buffer without freeing GPU memory. Returns the raw pointer and length.
|
||||||
/// Caller is responsible for eventually calling cudaFree.
|
/// Caller is responsible for eventually calling cudaFree.
|
||||||
pub fn into_raw(self) -> (*mut u8, usize) {
|
pub fn into_raw(self) -> (*mut u8, usize) {
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ fn main() {
|
|||||||
.file("../../csrc/attention/causal_mask.cu")
|
.file("../../csrc/attention/causal_mask.cu")
|
||||||
.file("../../csrc/embedding/transpose.cu")
|
.file("../../csrc/embedding/transpose.cu")
|
||||||
.file("../../csrc/attention/flash_attention.cu")
|
.file("../../csrc/attention/flash_attention.cu")
|
||||||
|
.file("../../csrc/attention/paged_attention.cu")
|
||||||
.compile("xserv_kernels");
|
.compile("xserv_kernels");
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed=../../csrc/");
|
println!("cargo:rerun-if-changed=../../csrc/");
|
||||||
|
|||||||
@@ -19,7 +19,9 @@ fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c
|
|||||||
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
|
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
|
||||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||||
let n = x.numel() as i32;
|
let n = x.numel();
|
||||||
|
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||||
|
let n = n as i32;
|
||||||
unsafe {
|
unsafe {
|
||||||
match x.dtype() {
|
match x.dtype() {
|
||||||
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||||
@@ -38,7 +40,9 @@ fn dispatch_binary(a: &Tensor, b: &Tensor,
|
|||||||
assert!(matches!(a.device(), Device::Cuda(_)));
|
assert!(matches!(a.device(), Device::Cuda(_)));
|
||||||
assert_eq!(a.dtype(), b.dtype());
|
assert_eq!(a.dtype(), b.dtype());
|
||||||
let out = Tensor::empty(a.shape(), a.dtype(), a.device());
|
let out = Tensor::empty(a.shape(), a.dtype(), a.device());
|
||||||
let n = a.numel() as i32;
|
let n = a.numel();
|
||||||
|
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||||
|
let n = n as i32;
|
||||||
unsafe {
|
unsafe {
|
||||||
match a.dtype() {
|
match a.dtype() {
|
||||||
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||||
@@ -55,7 +59,9 @@ pub fn silu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_silu_f32, launch_si
|
|||||||
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
||||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||||
let n = x.numel() as i32;
|
let n = x.numel();
|
||||||
|
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||||
|
let n = n as i32;
|
||||||
unsafe {
|
unsafe {
|
||||||
match x.dtype() {
|
match x.dtype() {
|
||||||
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
||||||
@@ -77,7 +83,9 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
|
|||||||
assert!(matches!(gate.device(), Device::Cuda(_)));
|
assert!(matches!(gate.device(), Device::Cuda(_)));
|
||||||
assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16");
|
assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16");
|
||||||
let out = Tensor::empty(gate.shape(), gate.dtype(), gate.device());
|
let out = Tensor::empty(gate.shape(), gate.dtype(), gate.device());
|
||||||
let n = gate.numel() as i32;
|
let n = gate.numel();
|
||||||
|
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||||
|
let n = n as i32;
|
||||||
unsafe {
|
unsafe {
|
||||||
launch_silu_mul_bf16(
|
launch_silu_mul_bf16(
|
||||||
gate.data_ptr() as *const c_void,
|
gate.data_ptr() as *const c_void,
|
||||||
|
|||||||
@@ -22,6 +22,17 @@ unsafe extern "C" {
|
|||||||
kv_len: i32, head_dim: i32,
|
kv_len: i32, head_dim: i32,
|
||||||
scale: f32, causal: i32, stream: *mut c_void,
|
scale: f32, causal: i32, stream: *mut c_void,
|
||||||
);
|
);
|
||||||
|
fn launch_paged_decode_attention_bf16(
|
||||||
|
q: *const c_void,
|
||||||
|
k_cache: *const c_void,
|
||||||
|
v_cache: *const c_void,
|
||||||
|
o: *mut c_void,
|
||||||
|
block_tables: *const i32,
|
||||||
|
context_lens: *const i32,
|
||||||
|
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||||
|
head_dim: i32, max_blocks_per_seq: i32,
|
||||||
|
scale: f32, stream: *mut c_void,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
||||||
@@ -192,3 +203,58 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
|
|||||||
|
|
||||||
output
|
output
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Paged decode attention.
|
||||||
|
///
|
||||||
|
/// q: [batch, num_q_heads, 1, head_dim] BF16, contiguous, GPU
|
||||||
|
/// k_cache_ptr / v_cache_ptr: pointers to [num_blocks, num_kv_heads, BLOCK_SIZE, head_dim] BF16 pools
|
||||||
|
/// block_tables_ptr: i32 [batch, max_blocks_per_seq] (rows already arranged for this batch)
|
||||||
|
/// context_lens_ptr: i32 [batch]
|
||||||
|
///
|
||||||
|
/// Returns: [batch, num_q_heads, 1, head_dim] BF16
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn paged_decode_attention(
|
||||||
|
q: &Tensor,
|
||||||
|
k_cache_ptr: *const c_void,
|
||||||
|
v_cache_ptr: *const c_void,
|
||||||
|
block_tables_ptr: *const i32,
|
||||||
|
context_lens_ptr: *const i32,
|
||||||
|
batch: usize,
|
||||||
|
num_q_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
max_blocks_per_seq: usize,
|
||||||
|
) -> Tensor {
|
||||||
|
assert_eq!(q.ndim(), 4);
|
||||||
|
assert_eq!(q.shape()[2], 1, "paged_decode_attention requires q_len == 1");
|
||||||
|
assert_eq!(q.dtype(), DType::BF16);
|
||||||
|
assert!(num_q_heads % num_kv_heads == 0, "GQA: num_q_heads must be divisible by num_kv_heads");
|
||||||
|
assert!(head_dim <= 128);
|
||||||
|
|
||||||
|
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||||
|
let output = Tensor::empty(
|
||||||
|
&[batch, num_q_heads, 1, head_dim],
|
||||||
|
DType::BF16,
|
||||||
|
q.device(),
|
||||||
|
);
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
launch_paged_decode_attention_bf16(
|
||||||
|
q.data_ptr() as *const c_void,
|
||||||
|
k_cache_ptr,
|
||||||
|
v_cache_ptr,
|
||||||
|
output.data_ptr() as *mut c_void,
|
||||||
|
block_tables_ptr,
|
||||||
|
context_lens_ptr,
|
||||||
|
batch as i32,
|
||||||
|
num_q_heads as i32,
|
||||||
|
num_kv_heads as i32,
|
||||||
|
head_dim as i32,
|
||||||
|
max_blocks_per_seq as i32,
|
||||||
|
scale,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|||||||
118
crates/xserv-kernels/src/dispatch.rs
Normal file
118
crates/xserv-kernels/src/dispatch.rs
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
//! Low-level kernel dispatchers for CUDA Graph capture.
|
||||||
|
//! These functions write to pre-allocated output buffers and accept an explicit stream.
|
||||||
|
|
||||||
|
use std::ffi::c_void;
|
||||||
|
|
||||||
|
// Re-declare the extern functions we need (same as in the individual modules)
|
||||||
|
unsafe extern "C" {
|
||||||
|
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||||
|
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||||
|
fn launch_add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
|
||||||
|
normed_out: *mut c_void, sum_out: *mut c_void,
|
||||||
|
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||||
|
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||||
|
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||||
|
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||||
|
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
|
||||||
|
fn launch_reshape_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||||
|
fn launch_merge_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||||
|
fn launch_transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||||
|
fn launch_transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||||
|
fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||||
|
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||||
|
head_dim: i32, stream: *mut c_void);
|
||||||
|
fn launch_gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, y_fp32_buf: *mut c_void,
|
||||||
|
k: i32, n: i32, stream: *mut c_void);
|
||||||
|
fn launch_decode_attention_bf16(
|
||||||
|
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||||
|
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||||
|
kv_len: i32, head_dim: i32,
|
||||||
|
scale: f32, causal: i32, stream: *mut c_void,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw rmsnorm dispatch: writes to pre-allocated `out`.
|
||||||
|
pub unsafe fn rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||||
|
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void) {
|
||||||
|
launch_rmsnorm_bf16(x, gamma, out, rows, hidden_size, eps, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw add_rmsnorm dispatch.
|
||||||
|
pub unsafe fn add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
|
||||||
|
normed_out: *mut c_void, sum_out: *mut c_void,
|
||||||
|
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void) {
|
||||||
|
launch_add_rmsnorm_bf16(x, residual, gamma, normed_out, sum_out, rows, hidden_size, eps, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw silu_mul dispatch.
|
||||||
|
pub unsafe fn silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void) {
|
||||||
|
launch_silu_mul_bf16(gate, up, out, n, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw add dispatch.
|
||||||
|
pub unsafe fn add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void) {
|
||||||
|
launch_add_bf16(a, b, out, n, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw embedding dispatch.
|
||||||
|
pub unsafe fn embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||||
|
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void) {
|
||||||
|
launch_embedding_bf16(table, token_ids, out, num_tokens, hidden_size, vocab_size, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw reshape_heads dispatch.
|
||||||
|
pub unsafe fn reshape_heads_bf16(inp: *const c_void, out: *mut c_void,
|
||||||
|
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||||
|
launch_reshape_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw merge_heads dispatch.
|
||||||
|
pub unsafe fn merge_heads_bf16(inp: *const c_void, out: *mut c_void,
|
||||||
|
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||||
|
launch_merge_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw transpose HSD->SHD dispatch.
|
||||||
|
pub unsafe fn transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void,
|
||||||
|
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||||
|
launch_transpose_hsd_to_shd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw transpose SHD->HSD dispatch.
|
||||||
|
pub unsafe fn transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void,
|
||||||
|
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||||
|
launch_transpose_shd_to_hsd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw RoPE dispatch (in-place).
|
||||||
|
pub unsafe fn rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||||
|
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||||
|
head_dim: i32, stream: *mut c_void) {
|
||||||
|
launch_rope_bf16(x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw GEMV dispatch (BF16, M=1). Caller must provide fp32 accumulator buffer.
|
||||||
|
pub unsafe fn gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void,
|
||||||
|
y_fp32_buf: *mut c_void, k: i32, n: i32, stream: *mut c_void) {
|
||||||
|
launch_gemv_bf16(x, w, y_bf16, y_fp32_buf, k, n, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Raw decode attention dispatch.
|
||||||
|
pub unsafe fn decode_attention_bf16(q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||||
|
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||||
|
kv_len: i32, head_dim: i32,
|
||||||
|
scale: f32, stream: *mut c_void) {
|
||||||
|
launch_decode_attention_bf16(q, k, v, o, batch, num_q_heads, num_kv_heads, kv_len, head_dim, scale, 1, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
// cuBLAS FFI
|
||||||
|
pub type CublasHandle = *mut c_void;
|
||||||
|
|
||||||
|
unsafe extern "C" {
|
||||||
|
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set cuBLAS stream. Must be called before any cuBLAS operations during graph capture.
|
||||||
|
pub unsafe fn set_cublas_stream(handle: CublasHandle, stream: *mut c_void) {
|
||||||
|
cublasSetStream_v2(handle, stream);
|
||||||
|
}
|
||||||
@@ -4,9 +4,9 @@ use xserv_tensor::{DType, Device, Tensor};
|
|||||||
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
fn launch_embedding_f32(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
fn launch_embedding_f32(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||||
num_tokens: i32, hidden_size: i32, stream: *mut c_void);
|
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
|
||||||
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||||
num_tokens: i32, hidden_size: i32, stream: *mut c_void);
|
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Embedding lookup: table[token_ids[i]] for each i.
|
/// Embedding lookup: table[token_ids[i]] for each i.
|
||||||
@@ -18,6 +18,9 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
|||||||
|
|
||||||
let hidden_size = table.shape()[1];
|
let hidden_size = table.shape()[1];
|
||||||
let num_tokens = token_ids.len();
|
let num_tokens = token_ids.len();
|
||||||
|
let vocab_size = table.shape()[0];
|
||||||
|
assert!(num_tokens <= i32::MAX as usize, "too many tokens for i32 kernel param");
|
||||||
|
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||||
|
|
||||||
// Upload token_ids to GPU
|
// Upload token_ids to GPU
|
||||||
let ids_bytes = unsafe {
|
let ids_bytes = unsafe {
|
||||||
@@ -29,6 +32,10 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
|||||||
let mut ids_gpu = xserv_cuda::allocator::cached_alloc(ids_bytes.len()).expect("alloc token_ids");
|
let mut ids_gpu = xserv_cuda::allocator::cached_alloc(ids_bytes.len()).expect("alloc token_ids");
|
||||||
ids_gpu.copy_from_host(ids_bytes).unwrap();
|
ids_gpu.copy_from_host(ids_bytes).unwrap();
|
||||||
|
|
||||||
|
for &tid in token_ids {
|
||||||
|
assert!((tid as usize) < vocab_size, "token_id {tid} out of bounds (vocab_size={vocab_size})");
|
||||||
|
}
|
||||||
|
|
||||||
let out = Tensor::empty(&[num_tokens, hidden_size], table.dtype(), table.device());
|
let out = Tensor::empty(&[num_tokens, hidden_size], table.dtype(), table.device());
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
@@ -36,12 +43,12 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
|||||||
DType::F32 => launch_embedding_f32(
|
DType::F32 => launch_embedding_f32(
|
||||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
||||||
out.data_ptr() as *mut c_void,
|
out.data_ptr() as *mut c_void,
|
||||||
num_tokens as i32, hidden_size as i32, std::ptr::null_mut(),
|
num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(),
|
||||||
),
|
),
|
||||||
DType::BF16 => launch_embedding_bf16(
|
DType::BF16 => launch_embedding_bf16(
|
||||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
||||||
out.data_ptr() as *mut c_void,
|
out.data_ptr() as *mut c_void,
|
||||||
num_tokens as i32, hidden_size as i32, std::ptr::null_mut(),
|
num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(),
|
||||||
),
|
),
|
||||||
_ => panic!("unsupported dtype for embedding"),
|
_ => panic!("unsupported dtype for embedding"),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ unsafe extern "C" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// --- FFI: cuBLAS ---
|
// --- FFI: cuBLAS ---
|
||||||
type CublasHandle = *mut c_void;
|
pub type CublasHandle = *mut c_void;
|
||||||
|
|
||||||
#[allow(non_upper_case_globals)]
|
#[allow(non_upper_case_globals)]
|
||||||
const CUBLAS_OP_N: i32 = 0;
|
const CUBLAS_OP_N: i32 = 0;
|
||||||
@@ -100,6 +100,13 @@ where
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the thread-local cuBLAS handle for use with dispatch module.
|
||||||
|
pub fn cublas_handle() -> CublasHandle {
|
||||||
|
CUBLAS_CTX.with(|cell| {
|
||||||
|
cell.borrow().handle
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Matrix multiplication: C = A @ B
|
/// Matrix multiplication: C = A @ B
|
||||||
/// A: [M, K], B: [K, N], C: [M, N]
|
/// A: [M, K], B: [K, N], C: [M, N]
|
||||||
/// All tensors must be contiguous and on the same GPU.
|
/// All tensors must be contiguous and on the same GPU.
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor
|
|||||||
assert_eq!(beta.shape(), &[hidden_size]);
|
assert_eq!(beta.shape(), &[hidden_size]);
|
||||||
|
|
||||||
let rows = x.numel() / hidden_size;
|
let rows = x.numel() / hidden_size;
|
||||||
|
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||||
|
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
pub mod activation;
|
pub mod activation;
|
||||||
pub mod attention;
|
pub mod attention;
|
||||||
|
pub mod dispatch;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod gemm;
|
pub mod gemm;
|
||||||
pub mod layernorm;
|
pub mod layernorm;
|
||||||
@@ -10,7 +11,7 @@ pub mod transpose;
|
|||||||
|
|
||||||
pub use activation::{add, gelu, mul, scale, silu, silu_mul};
|
pub use activation::{add, gelu, mul, scale, silu, silu_mul};
|
||||||
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
|
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
|
||||||
pub use attention::{attention, decode_attention, flash_attention};
|
pub use attention::{attention, decode_attention, flash_attention, paged_decode_attention};
|
||||||
pub use embedding::embedding;
|
pub use embedding::embedding;
|
||||||
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
||||||
pub use layernorm::layernorm;
|
pub use layernorm::layernorm;
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
|||||||
assert_eq!(x.dtype(), gamma.dtype());
|
assert_eq!(x.dtype(), gamma.dtype());
|
||||||
|
|
||||||
let rows = x.numel() / hidden_size;
|
let rows = x.numel() / hidden_size;
|
||||||
|
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||||
|
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
@@ -54,6 +56,8 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (
|
|||||||
assert_eq!(gamma.shape(), &[hidden_size]);
|
assert_eq!(gamma.shape(), &[hidden_size]);
|
||||||
|
|
||||||
let rows = x.numel() / hidden_size;
|
let rows = x.numel() / hidden_size;
|
||||||
|
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||||
|
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||||
let normed_out = Tensor::empty(x.shape(), DType::BF16, x.device());
|
let normed_out = Tensor::empty(x.shape(), DType::BF16, x.device());
|
||||||
let sum_out = Tensor::empty(x.shape(), DType::BF16, x.device());
|
let sum_out = Tensor::empty(x.shape(), DType::BF16, x.device());
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ pub fn softmax(x: &Tensor) -> Tensor {
|
|||||||
|
|
||||||
let cols = *x.shape().last().unwrap();
|
let cols = *x.shape().last().unwrap();
|
||||||
let rows = x.numel() / cols;
|
let rows = x.numel() / cols;
|
||||||
|
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||||
|
assert!(cols <= i32::MAX as usize, "cols too large for i32 kernel param");
|
||||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
|
|||||||
@@ -74,10 +74,10 @@ fn cpu_rope(x: &mut [f32], positions: &[u32], num_heads: usize, head_dim: usize,
|
|||||||
let cos_val = angle.cos();
|
let cos_val = angle.cos();
|
||||||
let sin_val = angle.sin();
|
let sin_val = angle.sin();
|
||||||
let base = (t * num_heads + h) * head_dim;
|
let base = (t * num_heads + h) * head_dim;
|
||||||
let x0 = x[base + 2 * i];
|
let x0 = x[base + i];
|
||||||
let x1 = x[base + 2 * i + 1];
|
let x1 = x[base + i + half_dim];
|
||||||
x[base + 2 * i] = x0 * cos_val - x1 * sin_val;
|
x[base + i] = x0 * cos_val - x1 * sin_val;
|
||||||
x[base + 2 * i + 1] = x0 * sin_val + x1 * cos_val;
|
x[base + i + half_dim] = x1 * cos_val + x0 * sin_val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
#include "../common.cuh"
|
||||||
|
|
||||||
// GELU (tanh approximation):
|
// GELU (tanh approximation):
|
||||||
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||||
@@ -83,6 +84,7 @@ void launch_gelu_f32(const void* x, void* out, int n, void* stream) {
|
|||||||
int block = 256;
|
int block = 256;
|
||||||
int grid = (n + block - 1) / block;
|
int grid = (n + block - 1) / block;
|
||||||
gelu_f32<<<grid, block, 0, (cudaStream_t)stream>>>((const float*)x, (float*)out, n);
|
gelu_f32<<<grid, block, 0, (cudaStream_t)stream>>>((const float*)x, (float*)out, n);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_gelu_bf16(const void* x, void* out, int n, void* stream) {
|
void launch_gelu_bf16(const void* x, void* out, int n, void* stream) {
|
||||||
@@ -90,12 +92,14 @@ void launch_gelu_bf16(const void* x, void* out, int n, void* stream) {
|
|||||||
int grid = (n + block - 1) / block;
|
int grid = (n + block - 1) / block;
|
||||||
gelu_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
gelu_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_silu_f32(const void* x, void* out, int n, void* stream) {
|
void launch_silu_f32(const void* x, void* out, int n, void* stream) {
|
||||||
int block = 256;
|
int block = 256;
|
||||||
int grid = (n + block - 1) / block;
|
int grid = (n + block - 1) / block;
|
||||||
silu_f32<<<grid, block, 0, (cudaStream_t)stream>>>((const float*)x, (float*)out, n);
|
silu_f32<<<grid, block, 0, (cudaStream_t)stream>>>((const float*)x, (float*)out, n);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_silu_bf16(const void* x, void* out, int n, void* stream) {
|
void launch_silu_bf16(const void* x, void* out, int n, void* stream) {
|
||||||
@@ -103,6 +107,7 @@ void launch_silu_bf16(const void* x, void* out, int n, void* stream) {
|
|||||||
int grid = (n + block - 1) / block;
|
int grid = (n + block - 1) / block;
|
||||||
silu_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
silu_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_scale_f32(const void* x, void* out, float scale, int n, void* stream) {
|
void launch_scale_f32(const void* x, void* out, float scale, int n, void* stream) {
|
||||||
@@ -110,6 +115,7 @@ void launch_scale_f32(const void* x, void* out, float scale, int n, void* stream
|
|||||||
int grid = (n + block - 1) / block;
|
int grid = (n + block - 1) / block;
|
||||||
scale_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
scale_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const float*)x, (float*)out, scale, n);
|
(const float*)x, (float*)out, scale, n);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_scale_bf16(const void* x, void* out, float scale, int n, void* stream) {
|
void launch_scale_bf16(const void* x, void* out, float scale, int n, void* stream) {
|
||||||
@@ -117,6 +123,7 @@ void launch_scale_bf16(const void* x, void* out, float scale, int n, void* strea
|
|||||||
int grid = (n + block - 1) / block;
|
int grid = (n + block - 1) / block;
|
||||||
scale_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
scale_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, scale, n);
|
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, scale, n);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_add_f32(const void* a, const void* b, void* out, int n, void* stream) {
|
void launch_add_f32(const void* a, const void* b, void* out, int n, void* stream) {
|
||||||
@@ -124,24 +131,28 @@ void launch_add_f32(const void* a, const void* b, void* out, int n, void* stream
|
|||||||
int grid = (n + block - 1) / block;
|
int grid = (n + block - 1) / block;
|
||||||
add_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
add_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const float*)a, (const float*)b, (float*)out, n);
|
(const float*)a, (const float*)b, (float*)out, n);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
void launch_add_bf16(const void* a, const void* b, void* out, int n, void* stream) {
|
void launch_add_bf16(const void* a, const void* b, void* out, int n, void* stream) {
|
||||||
int block = 256;
|
int block = 256;
|
||||||
int grid = (n + block - 1) / block;
|
int grid = (n + block - 1) / block;
|
||||||
add_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
add_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
void launch_mul_f32(const void* a, const void* b, void* out, int n, void* stream) {
|
void launch_mul_f32(const void* a, const void* b, void* out, int n, void* stream) {
|
||||||
int block = 256;
|
int block = 256;
|
||||||
int grid = (n + block - 1) / block;
|
int grid = (n + block - 1) / block;
|
||||||
mul_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
mul_f32_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const float*)a, (const float*)b, (float*)out, n);
|
(const float*)a, (const float*)b, (float*)out, n);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
void launch_mul_bf16(const void* a, const void* b, void* out, int n, void* stream) {
|
void launch_mul_bf16(const void* a, const void* b, void* out, int n, void* stream) {
|
||||||
int block = 256;
|
int block = 256;
|
||||||
int grid = (n + block - 1) / block;
|
int grid = (n + block - 1) / block;
|
||||||
mul_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
mul_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_silu_mul_bf16(const void* gate, const void* up, void* out, int n, void* stream) {
|
void launch_silu_mul_bf16(const void* gate, const void* up, void* out, int n, void* stream) {
|
||||||
@@ -149,6 +160,7 @@ void launch_silu_mul_bf16(const void* gate, const void* up, void* out, int n, vo
|
|||||||
int grid = (n + block - 1) / block;
|
int grid = (n + block - 1) / block;
|
||||||
silu_mul_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
silu_mul_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)gate, (const __nv_bfloat16*)up, (__nv_bfloat16*)out, n);
|
(const __nv_bfloat16*)gate, (const __nv_bfloat16*)up, (__nv_bfloat16*)out, n);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
|
#include "../common.cuh"
|
||||||
|
|
||||||
// Apply causal mask: set scores[row][col] = -inf where col > row + offset.
|
// Apply causal mask: set scores[row][col] = -inf where col > row + offset.
|
||||||
// offset is used for KV cache: when query starts at position `offset`,
|
// offset is used for KV cache: when query starts at position `offset`,
|
||||||
@@ -39,6 +40,7 @@ void launch_causal_mask_f32(void* scores, int batch, int rows, int cols,
|
|||||||
dim3 grid((cols + block - 1) / block, rows, batch);
|
dim3 grid((cols + block - 1) / block, rows, batch);
|
||||||
causal_mask_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
causal_mask_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(float*)scores, rows, cols, offset);
|
(float*)scores, rows, cols, offset);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_causal_mask_bf16(void* scores, int batch, int rows, int cols,
|
void launch_causal_mask_bf16(void* scores, int batch, int rows, int cols,
|
||||||
@@ -47,6 +49,7 @@ void launch_causal_mask_bf16(void* scores, int batch, int rows, int cols,
|
|||||||
dim3 grid((cols + block - 1) / block, rows, batch);
|
dim3 grid((cols + block - 1) / block, rows, batch);
|
||||||
causal_mask_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
causal_mask_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(__nv_bfloat16*)scores, rows, cols, offset);
|
(__nv_bfloat16*)scores, rows, cols, offset);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <float.h>
|
#include <float.h>
|
||||||
|
#include "../common.cuh"
|
||||||
|
|
||||||
// Flash Attention 2 forward kernel for BF16 with FP32 accumulation.
|
// Flash Attention 2 forward kernel for BF16 with FP32 accumulation.
|
||||||
//
|
//
|
||||||
@@ -391,6 +392,7 @@ void launch_flash_attention_bf16(
|
|||||||
q_len, kv_len, head_dim,
|
q_len, kv_len, head_dim,
|
||||||
scale, causal
|
scale, causal
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_decode_attention_bf16(
|
void launch_decode_attention_bf16(
|
||||||
@@ -411,6 +413,7 @@ void launch_decode_attention_bf16(
|
|||||||
kv_len, head_dim,
|
kv_len, head_dim,
|
||||||
scale
|
scale
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
215
csrc/attention/paged_attention.cu
Normal file
215
csrc/attention/paged_attention.cu
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <float.h>
|
||||||
|
#include "../common.cuh"
|
||||||
|
|
||||||
|
// Paged decode attention kernel for BF16 with FP32 accumulation.
|
||||||
|
//
|
||||||
|
// Reads K/V from a paged pool indexed by a per-sequence block table.
|
||||||
|
// One CUDA block per (sequence, q_head). Each block streams over the
|
||||||
|
// sequence's KV positions and accumulates attention output via online
|
||||||
|
// softmax.
|
||||||
|
//
|
||||||
|
// Layouts:
|
||||||
|
// Q [batch, num_q_heads, 1, head_dim] BF16
|
||||||
|
// K_cache [num_blocks, num_kv_heads, BLOCK_SIZE, head_dim] BF16
|
||||||
|
// V_cache same
|
||||||
|
// block_tables [max_seqs, max_blocks_per_seq] int32
|
||||||
|
// — the i-th sequence in this launch reads row
|
||||||
|
// block_tables[seq_slot[i] * stride + ...].
|
||||||
|
// For simplicity the launch passes a packed row table
|
||||||
|
// [batch, max_blocks_per_seq] (already gathered for the
|
||||||
|
// active batch) so we just index by blockIdx.x_seq.
|
||||||
|
// context_lens [batch] int32 — number of valid tokens per sequence.
|
||||||
|
//
|
||||||
|
// One CUDA block: 256 threads, head_dim <= 128.
|
||||||
|
|
||||||
|
#define PAGED_BLOCK_SIZE 16
|
||||||
|
#define PAGED_THREADS 256
|
||||||
|
#define PAGED_HEAD_DIM_MAX 128
|
||||||
|
|
||||||
|
__global__ void paged_decode_attention_bf16_kernel(
|
||||||
|
const __nv_bfloat16* __restrict__ Q,
|
||||||
|
const __nv_bfloat16* __restrict__ K_cache,
|
||||||
|
const __nv_bfloat16* __restrict__ V_cache,
|
||||||
|
__nv_bfloat16* __restrict__ O,
|
||||||
|
const int* __restrict__ block_tables, // [batch, max_blocks_per_seq]
|
||||||
|
const int* __restrict__ context_lens, // [batch]
|
||||||
|
int num_q_heads, int num_kv_heads,
|
||||||
|
int head_dim, int max_blocks_per_seq,
|
||||||
|
float scale
|
||||||
|
) {
|
||||||
|
int seq_idx = blockIdx.y; // batch dim
|
||||||
|
int q_head = blockIdx.x; // 0 .. num_q_heads-1
|
||||||
|
int tid = threadIdx.x;
|
||||||
|
|
||||||
|
int kv_len = context_lens[seq_idx];
|
||||||
|
if (kv_len <= 0) {
|
||||||
|
// Nothing to attend over; zero output for safety.
|
||||||
|
if (tid < head_dim) {
|
||||||
|
O[((long long)seq_idx * num_q_heads + q_head) * head_dim + tid] =
|
||||||
|
__float2bfloat16(0.0f);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// GQA mapping
|
||||||
|
int heads_per_group = num_q_heads / num_kv_heads;
|
||||||
|
int kv_head = q_head / heads_per_group;
|
||||||
|
|
||||||
|
// Pointers
|
||||||
|
const __nv_bfloat16* Q_ptr = Q +
|
||||||
|
((long long)seq_idx * num_q_heads + q_head) * head_dim;
|
||||||
|
__nv_bfloat16* O_ptr = O +
|
||||||
|
((long long)seq_idx * num_q_heads + q_head) * head_dim;
|
||||||
|
const int* bt = block_tables + (long long)seq_idx * max_blocks_per_seq;
|
||||||
|
|
||||||
|
// Load Q vector into registers.
|
||||||
|
float q_reg[PAGED_HEAD_DIM_MAX];
|
||||||
|
for (int d = 0; d < head_dim; d++) {
|
||||||
|
q_reg[d] = __bfloat162float(Q_ptr[d]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-thread online softmax state.
|
||||||
|
float local_max = -INFINITY;
|
||||||
|
float local_sum = 0.0f;
|
||||||
|
float local_O[PAGED_HEAD_DIM_MAX];
|
||||||
|
for (int d = 0; d < head_dim; d++) local_O[d] = 0.0f;
|
||||||
|
|
||||||
|
int kv_stride_block = num_kv_heads * PAGED_BLOCK_SIZE * head_dim;
|
||||||
|
int kv_stride_head = PAGED_BLOCK_SIZE * head_dim;
|
||||||
|
|
||||||
|
// Each thread handles positions tid, tid+PAGED_THREADS, ...
|
||||||
|
for (int pos = tid; pos < kv_len; pos += PAGED_THREADS) {
|
||||||
|
int logical_blk = pos / PAGED_BLOCK_SIZE;
|
||||||
|
int slot_in_blk = pos % PAGED_BLOCK_SIZE;
|
||||||
|
int phys_blk = bt[logical_blk];
|
||||||
|
|
||||||
|
const __nv_bfloat16* K_pos = K_cache
|
||||||
|
+ (long long)phys_blk * kv_stride_block
|
||||||
|
+ kv_head * kv_stride_head
|
||||||
|
+ slot_in_blk * head_dim;
|
||||||
|
const __nv_bfloat16* V_pos = V_cache
|
||||||
|
+ (long long)phys_blk * kv_stride_block
|
||||||
|
+ kv_head * kv_stride_head
|
||||||
|
+ slot_in_blk * head_dim;
|
||||||
|
|
||||||
|
// dot(Q, K[pos]) * scale
|
||||||
|
float dot = 0.0f;
|
||||||
|
for (int d = 0; d < head_dim; d++) {
|
||||||
|
dot += q_reg[d] * __bfloat162float(K_pos[d]);
|
||||||
|
}
|
||||||
|
float s = dot * scale;
|
||||||
|
|
||||||
|
float new_max = fmaxf(local_max, s);
|
||||||
|
float correction = expf(local_max - new_max);
|
||||||
|
float p = expf(s - new_max);
|
||||||
|
|
||||||
|
local_sum = local_sum * correction + p;
|
||||||
|
for (int d = 0; d < head_dim; d++) local_O[d] *= correction;
|
||||||
|
|
||||||
|
// Accumulate weighted V.
|
||||||
|
for (int d = 0; d < head_dim; d++) {
|
||||||
|
local_O[d] += p * __bfloat162float(V_pos[d]);
|
||||||
|
}
|
||||||
|
|
||||||
|
local_max = new_max;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Block-level online softmax reduction ----
|
||||||
|
__shared__ float smem_max[32];
|
||||||
|
__shared__ float smem_sum[32];
|
||||||
|
__shared__ float smem_O[PAGED_HEAD_DIM_MAX];
|
||||||
|
|
||||||
|
int lane = tid & 31;
|
||||||
|
int warp_id = tid >> 5;
|
||||||
|
int num_warps = PAGED_THREADS >> 5;
|
||||||
|
|
||||||
|
// Step 1: block-wide max
|
||||||
|
float warp_max = local_max;
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 16; offset > 0; offset >>= 1)
|
||||||
|
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
|
||||||
|
if (lane == 0) smem_max[warp_id] = warp_max;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float global_max;
|
||||||
|
if (tid == 0) {
|
||||||
|
global_max = smem_max[0];
|
||||||
|
for (int i = 1; i < num_warps; i++)
|
||||||
|
global_max = fmaxf(global_max, smem_max[i]);
|
||||||
|
smem_max[0] = global_max;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
global_max = smem_max[0];
|
||||||
|
|
||||||
|
// Step 2: rescale local state to global_max
|
||||||
|
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
|
||||||
|
local_sum *= rescale;
|
||||||
|
for (int d = 0; d < head_dim; d++) local_O[d] *= rescale;
|
||||||
|
|
||||||
|
// Step 3: reduce sum
|
||||||
|
float warp_sum = local_sum;
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 16; offset > 0; offset >>= 1)
|
||||||
|
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
|
||||||
|
if (lane == 0) smem_sum[warp_id] = warp_sum;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float global_sum;
|
||||||
|
if (tid == 0) {
|
||||||
|
global_sum = 0.0f;
|
||||||
|
for (int i = 0; i < num_warps; i++) global_sum += smem_sum[i];
|
||||||
|
smem_sum[0] = global_sum;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
global_sum = smem_sum[0];
|
||||||
|
|
||||||
|
// Step 4: reduce O across block, dim by dim
|
||||||
|
for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int d = 0; d < head_dim; d++) {
|
||||||
|
float val = local_O[d];
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 16; offset > 0; offset >>= 1)
|
||||||
|
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||||
|
if (lane == 0) atomicAdd(&smem_O[d], val);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
||||||
|
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
|
||||||
|
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
void launch_paged_decode_attention_bf16(
|
||||||
|
const void* Q,
|
||||||
|
const void* K_cache,
|
||||||
|
const void* V_cache,
|
||||||
|
void* O,
|
||||||
|
const int* block_tables,
|
||||||
|
const int* context_lens,
|
||||||
|
int batch, int num_q_heads, int num_kv_heads,
|
||||||
|
int head_dim, int max_blocks_per_seq,
|
||||||
|
float scale, void* stream
|
||||||
|
) {
|
||||||
|
dim3 grid(num_q_heads, batch);
|
||||||
|
int block = PAGED_THREADS;
|
||||||
|
|
||||||
|
paged_decode_attention_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
|
(const __nv_bfloat16*)Q,
|
||||||
|
(const __nv_bfloat16*)K_cache,
|
||||||
|
(const __nv_bfloat16*)V_cache,
|
||||||
|
(__nv_bfloat16*)O,
|
||||||
|
block_tables, context_lens,
|
||||||
|
num_q_heads, num_kv_heads,
|
||||||
|
head_dim, max_blocks_per_seq,
|
||||||
|
scale
|
||||||
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -48,3 +48,17 @@ __device__ __forceinline__ float block_reduce_max(float val) {
|
|||||||
if (warp_id == 0) val = warp_reduce_max(val);
|
if (warp_id == 0) val = warp_reduce_max(val);
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Launch error checking (debug builds only) ---
|
||||||
|
#ifdef NDEBUG
|
||||||
|
#define CUDA_CHECK_LAST_ERROR() ((void)0)
|
||||||
|
#else
|
||||||
|
#include <cstdio>
|
||||||
|
#define CUDA_CHECK_LAST_ERROR() do { \
|
||||||
|
cudaError_t err = cudaGetLastError(); \
|
||||||
|
if (err != cudaSuccess) { \
|
||||||
|
fprintf(stderr, "CUDA kernel launch error at %s:%d: %s\n", \
|
||||||
|
__FILE__, __LINE__, cudaGetErrorString(err)); \
|
||||||
|
} \
|
||||||
|
} while(0)
|
||||||
|
#endif
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
|
#include "../common.cuh"
|
||||||
|
|
||||||
// Embedding lookup: out[seq_idx] = table[token_ids[seq_idx]]
|
// Embedding lookup: out[seq_idx] = table[token_ids[seq_idx]]
|
||||||
// Grid: num_tokens, Block: handles hidden_size elements per token.
|
// Grid: num_tokens, Block: handles hidden_size elements per token.
|
||||||
@@ -7,10 +8,12 @@ __global__ void embedding_f32(
|
|||||||
const float* __restrict__ table, // [vocab_size, hidden_size]
|
const float* __restrict__ table, // [vocab_size, hidden_size]
|
||||||
const int* __restrict__ token_ids, // [num_tokens]
|
const int* __restrict__ token_ids, // [num_tokens]
|
||||||
float* __restrict__ out, // [num_tokens, hidden_size]
|
float* __restrict__ out, // [num_tokens, hidden_size]
|
||||||
int hidden_size
|
int hidden_size,
|
||||||
|
int vocab_size
|
||||||
) {
|
) {
|
||||||
int token_idx = blockIdx.x;
|
int token_idx = blockIdx.x;
|
||||||
int tid = token_ids[token_idx];
|
int tid = token_ids[token_idx];
|
||||||
|
if (tid < 0 || tid >= vocab_size) return;
|
||||||
const float* row = table + tid * hidden_size;
|
const float* row = table + tid * hidden_size;
|
||||||
float* dst = out + token_idx * hidden_size;
|
float* dst = out + token_idx * hidden_size;
|
||||||
|
|
||||||
@@ -23,10 +26,12 @@ __global__ void embedding_bf16(
|
|||||||
const __nv_bfloat16* __restrict__ table,
|
const __nv_bfloat16* __restrict__ table,
|
||||||
const int* __restrict__ token_ids,
|
const int* __restrict__ token_ids,
|
||||||
__nv_bfloat16* __restrict__ out,
|
__nv_bfloat16* __restrict__ out,
|
||||||
int hidden_size
|
int hidden_size,
|
||||||
|
int vocab_size
|
||||||
) {
|
) {
|
||||||
int token_idx = blockIdx.x;
|
int token_idx = blockIdx.x;
|
||||||
int tid = token_ids[token_idx];
|
int tid = token_ids[token_idx];
|
||||||
|
if (tid < 0 || tid >= vocab_size) return;
|
||||||
const __nv_bfloat16* row = table + tid * hidden_size;
|
const __nv_bfloat16* row = table + tid * hidden_size;
|
||||||
__nv_bfloat16* dst = out + token_idx * hidden_size;
|
__nv_bfloat16* dst = out + token_idx * hidden_size;
|
||||||
|
|
||||||
@@ -38,18 +43,20 @@ __global__ void embedding_bf16(
|
|||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
void launch_embedding_f32(const void* table, const void* token_ids, void* out,
|
void launch_embedding_f32(const void* table, const void* token_ids, void* out,
|
||||||
int num_tokens, int hidden_size, void* stream) {
|
int num_tokens, int hidden_size, int vocab_size, void* stream) {
|
||||||
int block = (hidden_size < 256) ? hidden_size : 256;
|
int block = (hidden_size < 256) ? hidden_size : 256;
|
||||||
embedding_f32<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
|
embedding_f32<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const float*)table, (const int*)token_ids, (float*)out, hidden_size);
|
(const float*)table, (const int*)token_ids, (float*)out, hidden_size, vocab_size);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_embedding_bf16(const void* table, const void* token_ids, void* out,
|
void launch_embedding_bf16(const void* table, const void* token_ids, void* out,
|
||||||
int num_tokens, int hidden_size, void* stream) {
|
int num_tokens, int hidden_size, int vocab_size, void* stream) {
|
||||||
int block = (hidden_size < 256) ? hidden_size : 256;
|
int block = (hidden_size < 256) ? hidden_size : 256;
|
||||||
embedding_bf16<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
|
embedding_bf16<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)table, (const int*)token_ids,
|
(const __nv_bfloat16*)table, (const int*)token_ids,
|
||||||
(__nv_bfloat16*)out, hidden_size);
|
(__nv_bfloat16*)out, hidden_size, vocab_size);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
#include "../common.cuh"
|
||||||
|
|
||||||
// RoPE: Rotary Position Embedding
|
// RoPE: Rotary Position Embedding, using the Qwen/Llama rotate_half layout.
|
||||||
// For each pair (x[2i], x[2i+1]) at position `pos`:
|
// For each dimension i in the first half at position `pos`:
|
||||||
// y[2i] = x[2i] * cos - x[2i+1] * sin
|
// y[i] = x[i] * cos - x[i + half_dim] * sin
|
||||||
// y[2i+1] = x[2i] * sin + x[2i+1] * cos
|
// y[i + half_dim] = x[i + half_dim] * cos + x[i] * sin
|
||||||
// where cos/sin come from precomputed cos_cache/sin_cache.
|
// where cos/sin come from precomputed cos_cache/sin_cache.
|
||||||
//
|
//
|
||||||
// cos_cache[pos][i] = cos(pos * freq[i])
|
// cos_cache[pos][i] = cos(pos * freq[i])
|
||||||
@@ -35,11 +36,11 @@ __global__ void rope_f32(
|
|||||||
float sin_val = sin_cache[pos * half_dim + pair_idx];
|
float sin_val = sin_cache[pos * half_dim + pair_idx];
|
||||||
|
|
||||||
int base = (token_idx * num_heads + head_idx) * head_dim;
|
int base = (token_idx * num_heads + head_idx) * head_dim;
|
||||||
float x0 = x[base + 2 * pair_idx];
|
float x0 = x[base + pair_idx];
|
||||||
float x1 = x[base + 2 * pair_idx + 1];
|
float x1 = x[base + pair_idx + half_dim];
|
||||||
|
|
||||||
x[base + 2 * pair_idx] = x0 * cos_val - x1 * sin_val;
|
x[base + pair_idx] = x0 * cos_val - x1 * sin_val;
|
||||||
x[base + 2 * pair_idx + 1] = x0 * sin_val + x1 * cos_val;
|
x[base + pair_idx + half_dim] = x1 * cos_val + x0 * sin_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void rope_bf16(
|
__global__ void rope_bf16(
|
||||||
@@ -61,11 +62,11 @@ __global__ void rope_bf16(
|
|||||||
float sin_val = sin_cache[pos * half_dim + pair_idx];
|
float sin_val = sin_cache[pos * half_dim + pair_idx];
|
||||||
|
|
||||||
int base = (token_idx * num_heads + head_idx) * head_dim;
|
int base = (token_idx * num_heads + head_idx) * head_dim;
|
||||||
float x0 = __bfloat162float(x[base + 2 * pair_idx]);
|
float x0 = __bfloat162float(x[base + pair_idx]);
|
||||||
float x1 = __bfloat162float(x[base + 2 * pair_idx + 1]);
|
float x1 = __bfloat162float(x[base + pair_idx + half_dim]);
|
||||||
|
|
||||||
x[base + 2 * pair_idx] = __float2bfloat16(x0 * cos_val - x1 * sin_val);
|
x[base + pair_idx] = __float2bfloat16(x0 * cos_val - x1 * sin_val);
|
||||||
x[base + 2 * pair_idx + 1] = __float2bfloat16(x0 * sin_val + x1 * cos_val);
|
x[base + pair_idx + half_dim] = __float2bfloat16(x1 * cos_val + x0 * sin_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Precompute cos/sin cache on GPU
|
// Precompute cos/sin cache on GPU
|
||||||
@@ -94,6 +95,7 @@ void launch_rope_f32(void* x, const void* cos_cache, const void* sin_cache,
|
|||||||
rope_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
rope_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(float*)x, (const float*)cos_cache, (const float*)sin_cache,
|
(float*)x, (const float*)cos_cache, (const float*)sin_cache,
|
||||||
(const int*)positions, num_heads, head_dim);
|
(const int*)positions, num_heads, head_dim);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_rope_bf16(void* x, const void* cos_cache, const void* sin_cache,
|
void launch_rope_bf16(void* x, const void* cos_cache, const void* sin_cache,
|
||||||
@@ -104,6 +106,7 @@ void launch_rope_bf16(void* x, const void* cos_cache, const void* sin_cache,
|
|||||||
rope_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
rope_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(__nv_bfloat16*)x, (const float*)cos_cache, (const float*)sin_cache,
|
(__nv_bfloat16*)x, (const float*)cos_cache, (const float*)sin_cache,
|
||||||
(const int*)positions, num_heads, head_dim);
|
(const int*)positions, num_heads, head_dim);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_compute_rope_cache(void* cos_cache, void* sin_cache,
|
void launch_compute_rope_cache(void* cos_cache, void* sin_cache,
|
||||||
@@ -111,6 +114,7 @@ void launch_compute_rope_cache(void* cos_cache, void* sin_cache,
|
|||||||
void* stream) {
|
void* stream) {
|
||||||
compute_rope_cache<<<max_seq_len, half_dim, 0, (cudaStream_t)stream>>>(
|
compute_rope_cache<<<max_seq_len, half_dim, 0, (cudaStream_t)stream>>>(
|
||||||
(float*)cos_cache, (float*)sin_cache, max_seq_len, half_dim, theta);
|
(float*)cos_cache, (float*)sin_cache, max_seq_len, half_dim, theta);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
|
#include "../common.cuh"
|
||||||
|
|
||||||
// Transpose between [S, H, D] and [H, S, D] layouts (used for RoPE and attention).
|
// Transpose between [S, H, D] and [H, S, D] layouts (used for RoPE and attention).
|
||||||
// Also handles [S, H*D] → [H, S, D] (reshape_heads) and reverse (merge_heads).
|
// Also handles [S, H*D] → [H, S, D] (reshape_heads) and reverse (merge_heads).
|
||||||
@@ -169,6 +170,7 @@ void launch_reshape_heads_bf16(const void* in, void* out,
|
|||||||
int grid = (total + block - 1) / block;
|
int grid = (total + block - 1) / block;
|
||||||
reshape_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
reshape_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_merge_heads_bf16(const void* in, void* out,
|
void launch_merge_heads_bf16(const void* in, void* out,
|
||||||
@@ -178,6 +180,7 @@ void launch_merge_heads_bf16(const void* in, void* out,
|
|||||||
int grid = (total + block - 1) / block;
|
int grid = (total + block - 1) / block;
|
||||||
merge_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
merge_heads_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_transpose_hsd_to_shd_bf16(const void* in, void* out,
|
void launch_transpose_hsd_to_shd_bf16(const void* in, void* out,
|
||||||
@@ -187,6 +190,7 @@ void launch_transpose_hsd_to_shd_bf16(const void* in, void* out,
|
|||||||
int grid = (total + block - 1) / block;
|
int grid = (total + block - 1) / block;
|
||||||
transpose_hsd_to_shd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
transpose_hsd_to_shd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_transpose_shd_to_hsd_bf16(const void* in, void* out,
|
void launch_transpose_shd_to_hsd_bf16(const void* in, void* out,
|
||||||
@@ -196,6 +200,7 @@ void launch_transpose_shd_to_hsd_bf16(const void* in, void* out,
|
|||||||
int grid = (total + block - 1) / block;
|
int grid = (total + block - 1) / block;
|
||||||
transpose_shd_to_hsd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
transpose_shd_to_hsd_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, seq_len, num_heads, head_dim);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_repeat_kv_bf16(const void* in, void* out,
|
void launch_repeat_kv_bf16(const void* in, void* out,
|
||||||
@@ -205,6 +210,7 @@ void launch_repeat_kv_bf16(const void* in, void* out,
|
|||||||
int grid = (total + block - 1) / block;
|
int grid = (total + block - 1) / block;
|
||||||
repeat_kv_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
repeat_kv_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, kv_heads, n_rep, seq_len, head_dim);
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, kv_heads, n_rep, seq_len, head_dim);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_strided_copy_bf16(const void* in, void* out, int numel, int ndim,
|
void launch_strided_copy_bf16(const void* in, void* out, int numel, int ndim,
|
||||||
@@ -217,6 +223,7 @@ void launch_strided_copy_bf16(const void* in, void* out, int numel, int ndim,
|
|||||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, numel, ndim,
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, numel, ndim,
|
||||||
shape0, shape1, shape2, shape3,
|
shape0, shape1, shape2, shape3,
|
||||||
in_stride0, in_stride1, in_stride2, in_stride3, in_offset);
|
in_stride0, in_stride1, in_stride2, in_stride3, in_offset);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_strided_copy_f32(const void* in, void* out, int numel, int ndim,
|
void launch_strided_copy_f32(const void* in, void* out, int numel, int ndim,
|
||||||
@@ -229,6 +236,7 @@ void launch_strided_copy_f32(const void* in, void* out, int numel, int ndim,
|
|||||||
(const float*)in, (float*)out, numel, ndim,
|
(const float*)in, (float*)out, numel, ndim,
|
||||||
shape0, shape1, shape2, shape3,
|
shape0, shape1, shape2, shape3,
|
||||||
in_stride0, in_stride1, in_stride2, in_stride3, in_offset);
|
in_stride0, in_stride1, in_stride2, in_stride3, in_offset);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
#include "../common.cuh"
|
||||||
|
|
||||||
// Custom GEMV kernel for M=1 decode step (BF16):
|
// Custom GEMV kernel for M=1 decode step (BF16):
|
||||||
// y[n] = sum_k x[k] * W[k * N + n]
|
// y[n] = sum_k x[k] * W[k * N + n]
|
||||||
@@ -88,6 +89,7 @@ void launch_gemv_bf16(
|
|||||||
(float*)y_fp32_buf,
|
(float*)y_fp32_buf,
|
||||||
K, N
|
K, N
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
|
|
||||||
// Convert FP32 -> BF16
|
// Convert FP32 -> BF16
|
||||||
int conv_block = 256;
|
int conv_block = 256;
|
||||||
@@ -97,6 +99,7 @@ void launch_gemv_bf16(
|
|||||||
(__nv_bfloat16*)y_bf16,
|
(__nv_bfloat16*)y_bf16,
|
||||||
N
|
N
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
|
#include "../common.cuh"
|
||||||
|
|
||||||
// Naive GEMM: each thread computes one element of C.
|
// Naive GEMM: each thread computes one element of C.
|
||||||
// C[i][j] = sum_k A[i][k] * B[k][j]
|
// C[i][j] = sum_k A[i][k] * B[k][j]
|
||||||
@@ -46,6 +47,7 @@ void launch_gemm_naive_bf16(
|
|||||||
gemm_naive_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
gemm_naive_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
|
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_gemm_naive_f32(
|
void launch_gemm_naive_f32(
|
||||||
@@ -57,6 +59,7 @@ void launch_gemm_naive_f32(
|
|||||||
gemm_naive_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
gemm_naive_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const float*)A, (const float*)B, (float*)C, M, N, K
|
(const float*)A, (const float*)B, (float*)C, M, N, K
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
|
#include "../common.cuh"
|
||||||
|
|
||||||
// Tiled GEMM using shared memory.
|
// Tiled GEMM using shared memory.
|
||||||
// Each thread block loads TILE_SIZE x TILE_SIZE tiles of A and B
|
// Each thread block loads TILE_SIZE x TILE_SIZE tiles of A and B
|
||||||
@@ -100,6 +101,7 @@ void launch_gemm_tiled_f32(
|
|||||||
gemm_tiled_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
gemm_tiled_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const float*)A, (const float*)B, (float*)C, M, N, K
|
(const float*)A, (const float*)B, (float*)C, M, N, K
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_gemm_tiled_bf16(
|
void launch_gemm_tiled_bf16(
|
||||||
@@ -111,6 +113,7 @@ void launch_gemm_tiled_bf16(
|
|||||||
gemm_tiled_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
gemm_tiled_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
|
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
|
||||||
);
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ void launch_layernorm_f32(const void* x, const void* gamma, const void* beta,
|
|||||||
layernorm_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
layernorm_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const float*)x, (const float*)gamma, (const float*)beta,
|
(const float*)x, (const float*)gamma, (const float*)beta,
|
||||||
(float*)out, hidden_size, eps);
|
(float*)out, hidden_size, eps);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_layernorm_bf16(const void* x, const void* gamma, const void* beta,
|
void launch_layernorm_bf16(const void* x, const void* gamma, const void* beta,
|
||||||
@@ -114,6 +115,7 @@ void launch_layernorm_bf16(const void* x, const void* gamma, const void* beta,
|
|||||||
layernorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
layernorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)x, (const __nv_bfloat16*)gamma, (const __nv_bfloat16*)beta,
|
(const __nv_bfloat16*)x, (const __nv_bfloat16*)gamma, (const __nv_bfloat16*)beta,
|
||||||
(__nv_bfloat16*)out, hidden_size, eps);
|
(__nv_bfloat16*)out, hidden_size, eps);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ void launch_rmsnorm_f32(const void* x, const void* gamma, void* out,
|
|||||||
if (block < 32) block = 32;
|
if (block < 32) block = 32;
|
||||||
rmsnorm_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
rmsnorm_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const float*)x, (const float*)gamma, (float*)out, hidden_size, eps);
|
(const float*)x, (const float*)gamma, (float*)out, hidden_size, eps);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_rmsnorm_bf16(const void* x, const void* gamma, void* out,
|
void launch_rmsnorm_bf16(const void* x, const void* gamma, void* out,
|
||||||
@@ -120,6 +121,7 @@ void launch_rmsnorm_bf16(const void* x, const void* gamma, void* out,
|
|||||||
rmsnorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
rmsnorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)x, (const __nv_bfloat16*)gamma,
|
(const __nv_bfloat16*)x, (const __nv_bfloat16*)gamma,
|
||||||
(__nv_bfloat16*)out, hidden_size, eps);
|
(__nv_bfloat16*)out, hidden_size, eps);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_add_rmsnorm_bf16(const void* x, const void* residual, const void* gamma,
|
void launch_add_rmsnorm_bf16(const void* x, const void* residual, const void* gamma,
|
||||||
@@ -132,6 +134,7 @@ void launch_add_rmsnorm_bf16(const void* x, const void* residual, const void* ga
|
|||||||
(const __nv_bfloat16*)gamma,
|
(const __nv_bfloat16*)gamma,
|
||||||
(__nv_bfloat16*)normed_out, (__nv_bfloat16*)sum_out,
|
(__nv_bfloat16*)normed_out, (__nv_bfloat16*)sum_out,
|
||||||
hidden_size, eps);
|
hidden_size, eps);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ void launch_softmax_f32(const void* x, void* out, int rows, int cols, void* stre
|
|||||||
if (block < 32) block = 32;
|
if (block < 32) block = 32;
|
||||||
softmax_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
softmax_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const float*)x, (float*)out, cols);
|
(const float*)x, (float*)out, cols);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_softmax_bf16(const void* x, void* out, int rows, int cols, void* stream) {
|
void launch_softmax_bf16(const void* x, void* out, int rows, int cols, void* stream) {
|
||||||
@@ -101,6 +102,7 @@ void launch_softmax_bf16(const void* x, void* out, int rows, int cols, void* str
|
|||||||
if (block < 32) block = 32;
|
if (block < 32) block = 32;
|
||||||
softmax_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
softmax_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, cols);
|
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, cols);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user