Files
xserv/crates/xserv-kernels/src/dispatch.rs

317 lines
7.0 KiB
Rust

//! Low-level kernel dispatchers for CUDA Graph capture.
//! These functions write to pre-allocated output buffers and accept an explicit stream.
use std::ffi::c_void;
// Re-declare the extern functions we need (same as in the individual modules)
unsafe extern "C" {
fn launch_rmsnorm_bf16(
x: *const c_void,
gamma: *const c_void,
out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
fn launch_add_rmsnorm_bf16(
x: *const c_void,
residual: *const c_void,
gamma: *const c_void,
normed_out: *mut c_void,
sum_out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
);
fn launch_silu_mul_bf16(
gate: *const c_void,
up: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_add_bf16(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
);
fn launch_embedding_bf16(
table: *const c_void,
token_ids: *const c_void,
out: *mut c_void,
num_tokens: i32,
hidden_size: i32,
vocab_size: i32,
stream: *mut c_void,
);
fn launch_reshape_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_merge_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_transpose_hsd_to_shd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_transpose_shd_to_hsd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_rope_bf16(
x: *mut c_void,
cos_cache: *const c_void,
sin_cache: *const c_void,
positions: *const c_void,
num_tokens: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
);
fn launch_gemv_bf16(
x: *const c_void,
w: *const c_void,
y_bf16: *mut c_void,
y_fp32_buf: *mut c_void,
k: i32,
n: i32,
stream: *mut c_void,
);
fn launch_decode_attention_bf16(
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
kv_len: i32,
head_dim: i32,
scale: f32,
causal: i32,
stream: *mut c_void,
);
}
/// Raw rmsnorm dispatch: writes to pre-allocated `out`.
pub unsafe fn rmsnorm_bf16(
x: *const c_void,
gamma: *const c_void,
out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
) {
launch_rmsnorm_bf16(x, gamma, out, rows, hidden_size, eps, stream);
}
/// Raw add_rmsnorm dispatch.
pub unsafe fn add_rmsnorm_bf16(
x: *const c_void,
residual: *const c_void,
gamma: *const c_void,
normed_out: *mut c_void,
sum_out: *mut c_void,
rows: i32,
hidden_size: i32,
eps: f32,
stream: *mut c_void,
) {
launch_add_rmsnorm_bf16(
x,
residual,
gamma,
normed_out,
sum_out,
rows,
hidden_size,
eps,
stream,
);
}
/// Raw silu_mul dispatch.
pub unsafe fn silu_mul_bf16(
gate: *const c_void,
up: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
) {
launch_silu_mul_bf16(gate, up, out, n, stream);
}
/// Raw add dispatch.
pub unsafe fn add_bf16(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
stream: *mut c_void,
) {
launch_add_bf16(a, b, out, n, stream);
}
/// Raw embedding dispatch.
pub unsafe fn embedding_bf16(
table: *const c_void,
token_ids: *const c_void,
out: *mut c_void,
num_tokens: i32,
hidden_size: i32,
vocab_size: i32,
stream: *mut c_void,
) {
launch_embedding_bf16(
table,
token_ids,
out,
num_tokens,
hidden_size,
vocab_size,
stream,
);
}
/// Raw reshape_heads dispatch.
pub unsafe fn reshape_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_reshape_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
}
/// Raw merge_heads dispatch.
pub unsafe fn merge_heads_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_merge_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
}
/// Raw transpose HSD->SHD dispatch.
pub unsafe fn transpose_hsd_to_shd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_transpose_hsd_to_shd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
}
/// Raw transpose SHD->HSD dispatch.
pub unsafe fn transpose_shd_to_hsd_bf16(
inp: *const c_void,
out: *mut c_void,
seq_len: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_transpose_shd_to_hsd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
}
/// Raw RoPE dispatch (in-place).
pub unsafe fn rope_bf16(
x: *mut c_void,
cos_cache: *const c_void,
sin_cache: *const c_void,
positions: *const c_void,
num_tokens: i32,
num_heads: i32,
head_dim: i32,
stream: *mut c_void,
) {
launch_rope_bf16(
x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream,
);
}
/// Raw GEMV dispatch (BF16, M=1). Caller must provide fp32 accumulator buffer.
pub unsafe fn gemv_bf16(
x: *const c_void,
w: *const c_void,
y_bf16: *mut c_void,
y_fp32_buf: *mut c_void,
k: i32,
n: i32,
stream: *mut c_void,
) {
launch_gemv_bf16(x, w, y_bf16, y_fp32_buf, k, n, stream);
}
/// Raw decode attention dispatch.
pub unsafe fn decode_attention_bf16(
q: *const c_void,
k: *const c_void,
v: *const c_void,
o: *mut c_void,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
kv_len: i32,
head_dim: i32,
scale: f32,
stream: *mut c_void,
) {
launch_decode_attention_bf16(
q,
k,
v,
o,
batch,
num_q_heads,
num_kv_heads,
kv_len,
head_dim,
scale,
1,
stream,
);
}
// cuBLAS FFI
pub type CublasHandle = *mut c_void;
unsafe extern "C" {
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
}
/// Set cuBLAS stream. Must be called before any cuBLAS operations during graph capture.
pub unsafe fn set_cublas_stream(handle: CublasHandle, stream: *mut c_void) {
cublasSetStream_v2(handle, stream);
}