kernels/cuda: paged-attention kernel, dispatch, pinned host memory
CUDA layer for the paged-KV + swap work: - csrc: new paged_attention.cu plus updates across attention/gemm/norm/ activation/embedding/reduce kernels and common.cuh. - xserv-kernels: new dispatch module and kernel-binding updates. - xserv-cuda: cudaMallocHost/FreeHost bindings + PinnedBuffer (host swap pool backing) and offset-aware D2H/H2D copies used to move KV blocks between the GPU pool and pinned host memory. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -39,6 +39,7 @@ unsafe extern "C" {
|
||||
stream: CudaStream,
|
||||
) -> i32;
|
||||
pub fn cudaMemset(devptr: *mut u8, value: i32, count: usize) -> i32;
|
||||
pub fn cudaMemsetAsync(devptr: *mut u8, value: i32, count: usize, stream: CudaStream) -> i32;
|
||||
|
||||
// --- Stream ---
|
||||
pub fn cudaStreamCreate(stream: *mut CudaStream) -> i32;
|
||||
|
||||
@@ -116,6 +116,56 @@ impl GpuBuffer {
|
||||
})
|
||||
}
|
||||
|
||||
/// Async copy `count` bytes from `src` at `src_offset` to `self` at `dst_offset` on `stream`.
|
||||
pub fn copy_from_device_at_async(&mut self, src: &GpuBuffer, src_offset: usize, dst_offset: usize, count: usize, stream: &CudaStream) -> Result<()> {
|
||||
assert!(src_offset + count <= src.len);
|
||||
assert!(dst_offset + count <= self.len);
|
||||
error::check(unsafe {
|
||||
ffi::cudaMemcpyAsync(
|
||||
self.ptr.add(dst_offset),
|
||||
src.ptr.add(src_offset),
|
||||
count,
|
||||
ffi::CUDA_MEMCPY_D2D,
|
||||
stream.as_raw(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Copy `count` bytes from this GPU buffer at `src_offset` to a host slice (D2H).
|
||||
pub fn copy_to_host_at(&self, dst: &mut [u8], src_offset: usize, count: usize) -> Result<()> {
|
||||
assert!(src_offset + count <= self.len, "src range out of bounds");
|
||||
assert!(count <= dst.len(), "host dst too small");
|
||||
error::check(unsafe {
|
||||
ffi::cudaMemcpy(
|
||||
dst.as_mut_ptr(),
|
||||
self.ptr.add(src_offset),
|
||||
count,
|
||||
ffi::CUDA_MEMCPY_D2H,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Copy `count` bytes from a host slice to this GPU buffer at `dst_offset` (H2D).
|
||||
pub fn copy_from_host_at(&mut self, src: &[u8], dst_offset: usize, count: usize) -> Result<()> {
|
||||
assert!(dst_offset + count <= self.len, "dst range out of bounds");
|
||||
assert!(count <= src.len(), "host src too small");
|
||||
error::check(unsafe {
|
||||
ffi::cudaMemcpy(
|
||||
self.ptr.add(dst_offset),
|
||||
src.as_ptr(),
|
||||
count,
|
||||
ffi::CUDA_MEMCPY_H2D,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Async zero fill on stream.
|
||||
pub fn zero_async(&mut self, stream: &CudaStream) -> Result<()> {
|
||||
error::check(unsafe {
|
||||
ffi::cudaMemsetAsync(self.ptr, 0, self.len, stream.as_raw())
|
||||
})
|
||||
}
|
||||
|
||||
/// Consume the buffer without freeing GPU memory. Returns the raw pointer and length.
|
||||
/// Caller is responsible for eventually calling cudaFree.
|
||||
pub fn into_raw(self) -> (*mut u8, usize) {
|
||||
|
||||
@@ -26,6 +26,7 @@ fn main() {
|
||||
.file("../../csrc/attention/causal_mask.cu")
|
||||
.file("../../csrc/embedding/transpose.cu")
|
||||
.file("../../csrc/attention/flash_attention.cu")
|
||||
.file("../../csrc/attention/paged_attention.cu")
|
||||
.compile("xserv_kernels");
|
||||
|
||||
println!("cargo:rerun-if-changed=../../csrc/");
|
||||
|
||||
@@ -19,7 +19,9 @@ fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c
|
||||
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel() as i32;
|
||||
let n = x.numel();
|
||||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
@@ -38,7 +40,9 @@ fn dispatch_binary(a: &Tensor, b: &Tensor,
|
||||
assert!(matches!(a.device(), Device::Cuda(_)));
|
||||
assert_eq!(a.dtype(), b.dtype());
|
||||
let out = Tensor::empty(a.shape(), a.dtype(), a.device());
|
||||
let n = a.numel() as i32;
|
||||
let n = a.numel();
|
||||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match a.dtype() {
|
||||
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
@@ -55,7 +59,9 @@ pub fn silu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_silu_f32, launch_si
|
||||
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel() as i32;
|
||||
let n = x.numel();
|
||||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
||||
@@ -77,7 +83,9 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
|
||||
assert!(matches!(gate.device(), Device::Cuda(_)));
|
||||
assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16");
|
||||
let out = Tensor::empty(gate.shape(), gate.dtype(), gate.device());
|
||||
let n = gate.numel() as i32;
|
||||
let n = gate.numel();
|
||||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
launch_silu_mul_bf16(
|
||||
gate.data_ptr() as *const c_void,
|
||||
|
||||
@@ -22,6 +22,17 @@ unsafe extern "C" {
|
||||
kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, stream: *mut c_void,
|
||||
);
|
||||
fn launch_paged_decode_attention_bf16(
|
||||
q: *const c_void,
|
||||
k_cache: *const c_void,
|
||||
v_cache: *const c_void,
|
||||
o: *mut c_void,
|
||||
block_tables: *const i32,
|
||||
context_lens: *const i32,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
head_dim: i32, max_blocks_per_seq: i32,
|
||||
scale: f32, stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
||||
@@ -192,3 +203,58 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Paged decode attention.
|
||||
///
|
||||
/// q: [batch, num_q_heads, 1, head_dim] BF16, contiguous, GPU
|
||||
/// k_cache_ptr / v_cache_ptr: pointers to [num_blocks, num_kv_heads, BLOCK_SIZE, head_dim] BF16 pools
|
||||
/// block_tables_ptr: i32 [batch, max_blocks_per_seq] (rows already arranged for this batch)
|
||||
/// context_lens_ptr: i32 [batch]
|
||||
///
|
||||
/// Returns: [batch, num_q_heads, 1, head_dim] BF16
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn paged_decode_attention(
|
||||
q: &Tensor,
|
||||
k_cache_ptr: *const c_void,
|
||||
v_cache_ptr: *const c_void,
|
||||
block_tables_ptr: *const i32,
|
||||
context_lens_ptr: *const i32,
|
||||
batch: usize,
|
||||
num_q_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(q.shape()[2], 1, "paged_decode_attention requires q_len == 1");
|
||||
assert_eq!(q.dtype(), DType::BF16);
|
||||
assert!(num_q_heads % num_kv_heads == 0, "GQA: num_q_heads must be divisible by num_kv_heads");
|
||||
assert!(head_dim <= 128);
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(
|
||||
&[batch, num_q_heads, 1, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
);
|
||||
|
||||
unsafe {
|
||||
launch_paged_decode_attention_bf16(
|
||||
q.data_ptr() as *const c_void,
|
||||
k_cache_ptr,
|
||||
v_cache_ptr,
|
||||
output.data_ptr() as *mut c_void,
|
||||
block_tables_ptr,
|
||||
context_lens_ptr,
|
||||
batch as i32,
|
||||
num_q_heads as i32,
|
||||
num_kv_heads as i32,
|
||||
head_dim as i32,
|
||||
max_blocks_per_seq as i32,
|
||||
scale,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
118
crates/xserv-kernels/src/dispatch.rs
Normal file
118
crates/xserv-kernels/src/dispatch.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
//! Low-level kernel dispatchers for CUDA Graph capture.
|
||||
//! These functions write to pre-allocated output buffers and accept an explicit stream.
|
||||
|
||||
use std::ffi::c_void;
|
||||
|
||||
// Re-declare the extern functions we need (same as in the individual modules)
|
||||
unsafe extern "C" {
|
||||
fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
|
||||
normed_out: *mut c_void, sum_out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void);
|
||||
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||||
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
|
||||
fn launch_reshape_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_merge_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void);
|
||||
fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void);
|
||||
fn launch_gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, y_fp32_buf: *mut c_void,
|
||||
k: i32, n: i32, stream: *mut c_void);
|
||||
fn launch_decode_attention_bf16(
|
||||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// Raw rmsnorm dispatch: writes to pre-allocated `out`.
|
||||
pub unsafe fn rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void) {
|
||||
launch_rmsnorm_bf16(x, gamma, out, rows, hidden_size, eps, stream);
|
||||
}
|
||||
|
||||
/// Raw add_rmsnorm dispatch.
|
||||
pub unsafe fn add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void,
|
||||
normed_out: *mut c_void, sum_out: *mut c_void,
|
||||
rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void) {
|
||||
launch_add_rmsnorm_bf16(x, residual, gamma, normed_out, sum_out, rows, hidden_size, eps, stream);
|
||||
}
|
||||
|
||||
/// Raw silu_mul dispatch.
|
||||
pub unsafe fn silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void) {
|
||||
launch_silu_mul_bf16(gate, up, out, n, stream);
|
||||
}
|
||||
|
||||
/// Raw add dispatch.
|
||||
pub unsafe fn add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void) {
|
||||
launch_add_bf16(a, b, out, n, stream);
|
||||
}
|
||||
|
||||
/// Raw embedding dispatch.
|
||||
pub unsafe fn embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void) {
|
||||
launch_embedding_bf16(table, token_ids, out, num_tokens, hidden_size, vocab_size, stream);
|
||||
}
|
||||
|
||||
/// Raw reshape_heads dispatch.
|
||||
pub unsafe fn reshape_heads_bf16(inp: *const c_void, out: *mut c_void,
|
||||
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||
launch_reshape_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw merge_heads dispatch.
|
||||
pub unsafe fn merge_heads_bf16(inp: *const c_void, out: *mut c_void,
|
||||
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||
launch_merge_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw transpose HSD->SHD dispatch.
|
||||
pub unsafe fn transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void,
|
||||
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||
launch_transpose_hsd_to_shd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw transpose SHD->HSD dispatch.
|
||||
pub unsafe fn transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void,
|
||||
seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) {
|
||||
launch_transpose_shd_to_hsd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw RoPE dispatch (in-place).
|
||||
pub unsafe fn rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void,
|
||||
positions: *const c_void, num_tokens: i32, num_heads: i32,
|
||||
head_dim: i32, stream: *mut c_void) {
|
||||
launch_rope_bf16(x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream);
|
||||
}
|
||||
|
||||
/// Raw GEMV dispatch (BF16, M=1). Caller must provide fp32 accumulator buffer.
|
||||
pub unsafe fn gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void,
|
||||
y_fp32_buf: *mut c_void, k: i32, n: i32, stream: *mut c_void) {
|
||||
launch_gemv_bf16(x, w, y_bf16, y_fp32_buf, k, n, stream);
|
||||
}
|
||||
|
||||
/// Raw decode attention dispatch.
|
||||
pub unsafe fn decode_attention_bf16(q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
kv_len: i32, head_dim: i32,
|
||||
scale: f32, stream: *mut c_void) {
|
||||
launch_decode_attention_bf16(q, k, v, o, batch, num_q_heads, num_kv_heads, kv_len, head_dim, scale, 1, stream);
|
||||
}
|
||||
|
||||
// cuBLAS FFI
|
||||
pub type CublasHandle = *mut c_void;
|
||||
|
||||
unsafe extern "C" {
|
||||
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
|
||||
}
|
||||
|
||||
/// Set cuBLAS stream. Must be called before any cuBLAS operations during graph capture.
|
||||
pub unsafe fn set_cublas_stream(handle: CublasHandle, stream: *mut c_void) {
|
||||
cublasSetStream_v2(handle, stream);
|
||||
}
|
||||
@@ -4,9 +4,9 @@ use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
unsafe extern "C" {
|
||||
fn launch_embedding_f32(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, stream: *mut c_void);
|
||||
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
|
||||
fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void,
|
||||
num_tokens: i32, hidden_size: i32, stream: *mut c_void);
|
||||
num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
/// Embedding lookup: table[token_ids[i]] for each i.
|
||||
@@ -18,6 +18,9 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
||||
|
||||
let hidden_size = table.shape()[1];
|
||||
let num_tokens = token_ids.len();
|
||||
let vocab_size = table.shape()[0];
|
||||
assert!(num_tokens <= i32::MAX as usize, "too many tokens for i32 kernel param");
|
||||
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||
|
||||
// Upload token_ids to GPU
|
||||
let ids_bytes = unsafe {
|
||||
@@ -29,6 +32,10 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
||||
let mut ids_gpu = xserv_cuda::allocator::cached_alloc(ids_bytes.len()).expect("alloc token_ids");
|
||||
ids_gpu.copy_from_host(ids_bytes).unwrap();
|
||||
|
||||
for &tid in token_ids {
|
||||
assert!((tid as usize) < vocab_size, "token_id {tid} out of bounds (vocab_size={vocab_size})");
|
||||
}
|
||||
|
||||
let out = Tensor::empty(&[num_tokens, hidden_size], table.dtype(), table.device());
|
||||
|
||||
unsafe {
|
||||
@@ -36,12 +43,12 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
||||
DType::F32 => launch_embedding_f32(
|
||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden_size as i32, std::ptr::null_mut(),
|
||||
num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(),
|
||||
),
|
||||
DType::BF16 => launch_embedding_bf16(
|
||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden_size as i32, std::ptr::null_mut(),
|
||||
num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for embedding"),
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ unsafe extern "C" {
|
||||
}
|
||||
|
||||
// --- FFI: cuBLAS ---
|
||||
type CublasHandle = *mut c_void;
|
||||
pub type CublasHandle = *mut c_void;
|
||||
|
||||
#[allow(non_upper_case_globals)]
|
||||
const CUBLAS_OP_N: i32 = 0;
|
||||
@@ -100,6 +100,13 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the thread-local cuBLAS handle for use with dispatch module.
|
||||
pub fn cublas_handle() -> CublasHandle {
|
||||
CUBLAS_CTX.with(|cell| {
|
||||
cell.borrow().handle
|
||||
})
|
||||
}
|
||||
|
||||
/// Matrix multiplication: C = A @ B
|
||||
/// A: [M, K], B: [K, N], C: [M, N]
|
||||
/// All tensors must be contiguous and on the same GPU.
|
||||
|
||||
@@ -17,6 +17,8 @@ pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor
|
||||
assert_eq!(beta.shape(), &[hidden_size]);
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
pub mod activation;
|
||||
pub mod attention;
|
||||
pub mod dispatch;
|
||||
pub mod embedding;
|
||||
pub mod gemm;
|
||||
pub mod layernorm;
|
||||
@@ -10,7 +11,7 @@ pub mod transpose;
|
||||
|
||||
pub use activation::{add, gelu, mul, scale, silu, silu_mul};
|
||||
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
|
||||
pub use attention::{attention, decode_attention, flash_attention};
|
||||
pub use attention::{attention, decode_attention, flash_attention, paged_decode_attention};
|
||||
pub use embedding::embedding;
|
||||
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
||||
pub use layernorm::layernorm;
|
||||
|
||||
@@ -20,6 +20,8 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
||||
assert_eq!(x.dtype(), gamma.dtype());
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
@@ -54,6 +56,8 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (
|
||||
assert_eq!(gamma.shape(), &[hidden_size]);
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||
assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param");
|
||||
let normed_out = Tensor::empty(x.shape(), DType::BF16, x.device());
|
||||
let sum_out = Tensor::empty(x.shape(), DType::BF16, x.device());
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ pub fn softmax(x: &Tensor) -> Tensor {
|
||||
|
||||
let cols = *x.shape().last().unwrap();
|
||||
let rows = x.numel() / cols;
|
||||
assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param");
|
||||
assert!(cols <= i32::MAX as usize, "cols too large for i32 kernel param");
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
|
||||
@@ -74,10 +74,10 @@ fn cpu_rope(x: &mut [f32], positions: &[u32], num_heads: usize, head_dim: usize,
|
||||
let cos_val = angle.cos();
|
||||
let sin_val = angle.sin();
|
||||
let base = (t * num_heads + h) * head_dim;
|
||||
let x0 = x[base + 2 * i];
|
||||
let x1 = x[base + 2 * i + 1];
|
||||
x[base + 2 * i] = x0 * cos_val - x1 * sin_val;
|
||||
x[base + 2 * i + 1] = x0 * sin_val + x1 * cos_val;
|
||||
let x0 = x[base + i];
|
||||
let x1 = x[base + i + half_dim];
|
||||
x[base + i] = x0 * cos_val - x1 * sin_val;
|
||||
x[base + i + half_dim] = x1 * cos_val + x0 * sin_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user