317 lines
7.0 KiB
Rust
317 lines
7.0 KiB
Rust
//! Low-level kernel dispatchers for CUDA Graph capture.
|
|
//! These functions write to pre-allocated output buffers and accept an explicit stream.
|
|
|
|
use std::ffi::c_void;
|
|
|
|
// Re-declare the extern functions we need (same as in the individual modules)
|
|
unsafe extern "C" {
|
|
fn launch_rmsnorm_bf16(
|
|
x: *const c_void,
|
|
gamma: *const c_void,
|
|
out: *mut c_void,
|
|
rows: i32,
|
|
hidden_size: i32,
|
|
eps: f32,
|
|
stream: *mut c_void,
|
|
);
|
|
fn launch_add_rmsnorm_bf16(
|
|
x: *const c_void,
|
|
residual: *const c_void,
|
|
gamma: *const c_void,
|
|
normed_out: *mut c_void,
|
|
sum_out: *mut c_void,
|
|
rows: i32,
|
|
hidden_size: i32,
|
|
eps: f32,
|
|
stream: *mut c_void,
|
|
);
|
|
fn launch_silu_mul_bf16(
|
|
gate: *const c_void,
|
|
up: *const c_void,
|
|
out: *mut c_void,
|
|
n: i32,
|
|
stream: *mut c_void,
|
|
);
|
|
fn launch_add_bf16(
|
|
a: *const c_void,
|
|
b: *const c_void,
|
|
out: *mut c_void,
|
|
n: i32,
|
|
stream: *mut c_void,
|
|
);
|
|
fn launch_embedding_bf16(
|
|
table: *const c_void,
|
|
token_ids: *const c_void,
|
|
out: *mut c_void,
|
|
num_tokens: i32,
|
|
hidden_size: i32,
|
|
vocab_size: i32,
|
|
stream: *mut c_void,
|
|
);
|
|
fn launch_reshape_heads_bf16(
|
|
inp: *const c_void,
|
|
out: *mut c_void,
|
|
seq_len: i32,
|
|
num_heads: i32,
|
|
head_dim: i32,
|
|
stream: *mut c_void,
|
|
);
|
|
fn launch_merge_heads_bf16(
|
|
inp: *const c_void,
|
|
out: *mut c_void,
|
|
seq_len: i32,
|
|
num_heads: i32,
|
|
head_dim: i32,
|
|
stream: *mut c_void,
|
|
);
|
|
fn launch_transpose_hsd_to_shd_bf16(
|
|
inp: *const c_void,
|
|
out: *mut c_void,
|
|
seq_len: i32,
|
|
num_heads: i32,
|
|
head_dim: i32,
|
|
stream: *mut c_void,
|
|
);
|
|
fn launch_transpose_shd_to_hsd_bf16(
|
|
inp: *const c_void,
|
|
out: *mut c_void,
|
|
seq_len: i32,
|
|
num_heads: i32,
|
|
head_dim: i32,
|
|
stream: *mut c_void,
|
|
);
|
|
fn launch_rope_bf16(
|
|
x: *mut c_void,
|
|
cos_cache: *const c_void,
|
|
sin_cache: *const c_void,
|
|
positions: *const c_void,
|
|
num_tokens: i32,
|
|
num_heads: i32,
|
|
head_dim: i32,
|
|
stream: *mut c_void,
|
|
);
|
|
fn launch_gemv_bf16(
|
|
x: *const c_void,
|
|
w: *const c_void,
|
|
y_bf16: *mut c_void,
|
|
y_fp32_buf: *mut c_void,
|
|
k: i32,
|
|
n: i32,
|
|
stream: *mut c_void,
|
|
);
|
|
fn launch_decode_attention_bf16(
|
|
q: *const c_void,
|
|
k: *const c_void,
|
|
v: *const c_void,
|
|
o: *mut c_void,
|
|
batch: i32,
|
|
num_q_heads: i32,
|
|
num_kv_heads: i32,
|
|
kv_len: i32,
|
|
head_dim: i32,
|
|
scale: f32,
|
|
causal: i32,
|
|
stream: *mut c_void,
|
|
);
|
|
}
|
|
|
|
/// Raw rmsnorm dispatch: writes to pre-allocated `out`.
|
|
pub unsafe fn rmsnorm_bf16(
|
|
x: *const c_void,
|
|
gamma: *const c_void,
|
|
out: *mut c_void,
|
|
rows: i32,
|
|
hidden_size: i32,
|
|
eps: f32,
|
|
stream: *mut c_void,
|
|
) {
|
|
launch_rmsnorm_bf16(x, gamma, out, rows, hidden_size, eps, stream);
|
|
}
|
|
|
|
/// Raw add_rmsnorm dispatch.
|
|
pub unsafe fn add_rmsnorm_bf16(
|
|
x: *const c_void,
|
|
residual: *const c_void,
|
|
gamma: *const c_void,
|
|
normed_out: *mut c_void,
|
|
sum_out: *mut c_void,
|
|
rows: i32,
|
|
hidden_size: i32,
|
|
eps: f32,
|
|
stream: *mut c_void,
|
|
) {
|
|
launch_add_rmsnorm_bf16(
|
|
x,
|
|
residual,
|
|
gamma,
|
|
normed_out,
|
|
sum_out,
|
|
rows,
|
|
hidden_size,
|
|
eps,
|
|
stream,
|
|
);
|
|
}
|
|
|
|
/// Raw silu_mul dispatch.
|
|
pub unsafe fn silu_mul_bf16(
|
|
gate: *const c_void,
|
|
up: *const c_void,
|
|
out: *mut c_void,
|
|
n: i32,
|
|
stream: *mut c_void,
|
|
) {
|
|
launch_silu_mul_bf16(gate, up, out, n, stream);
|
|
}
|
|
|
|
/// Raw add dispatch.
|
|
pub unsafe fn add_bf16(
|
|
a: *const c_void,
|
|
b: *const c_void,
|
|
out: *mut c_void,
|
|
n: i32,
|
|
stream: *mut c_void,
|
|
) {
|
|
launch_add_bf16(a, b, out, n, stream);
|
|
}
|
|
|
|
/// Raw embedding dispatch.
|
|
pub unsafe fn embedding_bf16(
|
|
table: *const c_void,
|
|
token_ids: *const c_void,
|
|
out: *mut c_void,
|
|
num_tokens: i32,
|
|
hidden_size: i32,
|
|
vocab_size: i32,
|
|
stream: *mut c_void,
|
|
) {
|
|
launch_embedding_bf16(
|
|
table,
|
|
token_ids,
|
|
out,
|
|
num_tokens,
|
|
hidden_size,
|
|
vocab_size,
|
|
stream,
|
|
);
|
|
}
|
|
|
|
/// Raw reshape_heads dispatch.
|
|
pub unsafe fn reshape_heads_bf16(
|
|
inp: *const c_void,
|
|
out: *mut c_void,
|
|
seq_len: i32,
|
|
num_heads: i32,
|
|
head_dim: i32,
|
|
stream: *mut c_void,
|
|
) {
|
|
launch_reshape_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
|
}
|
|
|
|
/// Raw merge_heads dispatch.
|
|
pub unsafe fn merge_heads_bf16(
|
|
inp: *const c_void,
|
|
out: *mut c_void,
|
|
seq_len: i32,
|
|
num_heads: i32,
|
|
head_dim: i32,
|
|
stream: *mut c_void,
|
|
) {
|
|
launch_merge_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
|
}
|
|
|
|
/// Raw transpose HSD->SHD dispatch.
|
|
pub unsafe fn transpose_hsd_to_shd_bf16(
|
|
inp: *const c_void,
|
|
out: *mut c_void,
|
|
seq_len: i32,
|
|
num_heads: i32,
|
|
head_dim: i32,
|
|
stream: *mut c_void,
|
|
) {
|
|
launch_transpose_hsd_to_shd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
|
}
|
|
|
|
/// Raw transpose SHD->HSD dispatch.
|
|
pub unsafe fn transpose_shd_to_hsd_bf16(
|
|
inp: *const c_void,
|
|
out: *mut c_void,
|
|
seq_len: i32,
|
|
num_heads: i32,
|
|
head_dim: i32,
|
|
stream: *mut c_void,
|
|
) {
|
|
launch_transpose_shd_to_hsd_bf16(inp, out, seq_len, num_heads, head_dim, stream);
|
|
}
|
|
|
|
/// Raw RoPE dispatch (in-place).
|
|
pub unsafe fn rope_bf16(
|
|
x: *mut c_void,
|
|
cos_cache: *const c_void,
|
|
sin_cache: *const c_void,
|
|
positions: *const c_void,
|
|
num_tokens: i32,
|
|
num_heads: i32,
|
|
head_dim: i32,
|
|
stream: *mut c_void,
|
|
) {
|
|
launch_rope_bf16(
|
|
x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream,
|
|
);
|
|
}
|
|
|
|
/// Raw GEMV dispatch (BF16, M=1). Caller must provide fp32 accumulator buffer.
|
|
pub unsafe fn gemv_bf16(
|
|
x: *const c_void,
|
|
w: *const c_void,
|
|
y_bf16: *mut c_void,
|
|
y_fp32_buf: *mut c_void,
|
|
k: i32,
|
|
n: i32,
|
|
stream: *mut c_void,
|
|
) {
|
|
launch_gemv_bf16(x, w, y_bf16, y_fp32_buf, k, n, stream);
|
|
}
|
|
|
|
/// Raw decode attention dispatch.
|
|
pub unsafe fn decode_attention_bf16(
|
|
q: *const c_void,
|
|
k: *const c_void,
|
|
v: *const c_void,
|
|
o: *mut c_void,
|
|
batch: i32,
|
|
num_q_heads: i32,
|
|
num_kv_heads: i32,
|
|
kv_len: i32,
|
|
head_dim: i32,
|
|
scale: f32,
|
|
stream: *mut c_void,
|
|
) {
|
|
launch_decode_attention_bf16(
|
|
q,
|
|
k,
|
|
v,
|
|
o,
|
|
batch,
|
|
num_q_heads,
|
|
num_kv_heads,
|
|
kv_len,
|
|
head_dim,
|
|
scale,
|
|
1,
|
|
stream,
|
|
);
|
|
}
|
|
|
|
// cuBLAS FFI
|
|
pub type CublasHandle = *mut c_void;
|
|
|
|
unsafe extern "C" {
|
|
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
|
|
}
|
|
|
|
/// Set cuBLAS stream. Must be called before any cuBLAS operations during graph capture.
|
|
pub unsafe fn set_cublas_stream(handle: CublasHandle, stream: *mut c_void) {
|
|
cublasSetStream_v2(handle, stream);
|
|
}
|