- Thread-local launch stream (xserv_cuda::stream): every kernel wrapper, cublasSetStream, and NCCL collective now launches on current_stream_raw() — the legacy null stream by default (behavior unchanged), or the capture stream installed via push_stream during graph capture. Capture is impossible on the legacy stream. - Allocator retain mode: blocks freed inside a retain window are quarantined (RetainedBlocks) instead of pooled, so an instantiated graph keeps exclusive ownership of every intermediate buffer it references across replays. - Capture mode GLOBAL -> THREAD_LOCAL: concurrent TP rank threads must not poison each other's captures with their own cudaMallocs. - embedding_device_ids / rope_inplace_device_pos: variants reading token ids / positions from persistent device buffers, replacing the per-call host upload that a captured region cannot contain. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
154 lines
7.2 KiB
Rust
154 lines
7.2 KiB
Rust
use std::ffi::c_void;
|
||
use xserv_tensor::{DType, Device, Tensor};
|
||
|
||
unsafe extern "C" {
|
||
fn launch_gelu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||
fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||
fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||
fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||
fn launch_scale_f32(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
|
||
fn launch_scale_bf16(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
|
||
fn launch_add_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||
fn launch_mul_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||
fn launch_mul_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
|
||
fn launch_gpt_oss_glu_bf16(gate_up: *const c_void, out: *mut c_void, n_elements: i32,
|
||
alpha: f32, limit: f32, stream: *mut c_void);
|
||
fn launch_bias_add_2d_bf16(x: *const c_void, bias: *const c_void, out: *mut c_void,
|
||
rows: i32, cols: i32, stream: *mut c_void);
|
||
}
|
||
|
||
fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
|
||
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
|
||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||
let n = x.numel();
|
||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||
let n = n as i32;
|
||
unsafe {
|
||
match x.dtype() {
|
||
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||
DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||
_ => panic!("unsupported dtype"),
|
||
}
|
||
}
|
||
out
|
||
}
|
||
|
||
fn dispatch_binary(a: &Tensor, b: &Tensor,
|
||
f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
|
||
bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
|
||
assert_eq!(a.shape(), b.shape());
|
||
assert!(a.is_contiguous() && b.is_contiguous());
|
||
assert!(matches!(a.device(), Device::Cuda(_)));
|
||
assert_eq!(a.dtype(), b.dtype());
|
||
let out = Tensor::empty(a.shape(), a.dtype(), a.device());
|
||
let n = a.numel();
|
||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||
let n = n as i32;
|
||
unsafe {
|
||
match a.dtype() {
|
||
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||
DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||
_ => panic!("unsupported dtype"),
|
||
}
|
||
}
|
||
out
|
||
}
|
||
|
||
pub fn gelu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16) }
|
||
pub fn silu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_silu_f32, launch_silu_bf16) }
|
||
|
||
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||
let n = x.numel();
|
||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||
let n = n as i32;
|
||
unsafe {
|
||
match x.dtype() {
|
||
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()),
|
||
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()),
|
||
_ => panic!("unsupported dtype for scale"),
|
||
}
|
||
}
|
||
out
|
||
}
|
||
|
||
pub fn add(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_add_f32, launch_add_bf16) }
|
||
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) }
|
||
|
||
/// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c] (BF16 only).
|
||
pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor {
|
||
assert_eq!(x.ndim(), 2);
|
||
assert_eq!(bias.ndim(), 1);
|
||
assert_eq!(x.dtype(), DType::BF16);
|
||
assert_eq!(bias.dtype(), DType::BF16);
|
||
assert!(x.is_contiguous() && bias.is_contiguous());
|
||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||
let rows = x.shape()[0];
|
||
let cols = x.shape()[1];
|
||
assert_eq!(bias.shape()[0], cols, "bias size {} != cols {cols}", bias.shape()[0]);
|
||
assert!(rows * cols <= i32::MAX as usize);
|
||
let out = Tensor::empty(&[rows, cols], DType::BF16, x.device());
|
||
unsafe {
|
||
launch_bias_add_2d_bf16(
|
||
x.data_ptr() as _, bias.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||
rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
|
||
);
|
||
}
|
||
out
|
||
}
|
||
|
||
/// Fused SiLU×Mul: out = silu(gate) * up (BF16 only)
|
||
/// Saves one HBM read + one HBM write compared to separate silu + mul.
|
||
pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
|
||
assert_eq!(gate.shape(), up.shape());
|
||
assert!(gate.is_contiguous() && up.is_contiguous());
|
||
assert!(matches!(gate.device(), Device::Cuda(_)));
|
||
assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16");
|
||
let out = Tensor::empty(gate.shape(), gate.dtype(), gate.device());
|
||
let n = gate.numel();
|
||
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
|
||
let n = n as i32;
|
||
unsafe {
|
||
launch_silu_mul_bf16(
|
||
gate.data_ptr() as *const c_void,
|
||
up.data_ptr() as *const c_void,
|
||
out.data_ptr() as *mut c_void,
|
||
n,
|
||
xserv_cuda::current_stream_raw(),
|
||
);
|
||
}
|
||
out
|
||
}
|
||
|
||
/// gpt-oss fused GLU activation (BF16 only).
|
||
/// Input: gate_up [rows, 2*D] with interleaved columns (gate=even, up=odd).
|
||
/// Output: [rows, D]
|
||
/// Computes: gate.clamp(max=limit) * sigmoid(gate * alpha) * (up.clamp(-limit,limit) + 1)
|
||
pub fn gpt_oss_glu(gate_up: &Tensor, alpha: f32, limit: f32) -> Tensor {
|
||
assert!(gate_up.is_contiguous());
|
||
assert!(matches!(gate_up.device(), Device::Cuda(_)));
|
||
assert_eq!(gate_up.dtype(), DType::BF16, "gpt_oss_glu requires BF16");
|
||
assert_eq!(gate_up.ndim(), 2);
|
||
let rows = gate_up.shape()[0];
|
||
let cols = gate_up.shape()[1];
|
||
assert_eq!(cols % 2, 0);
|
||
let d = cols / 2;
|
||
let out = Tensor::empty(&[rows, d], gate_up.dtype(), gate_up.device());
|
||
let n_elements = (rows * d) as i32;
|
||
unsafe {
|
||
launch_gpt_oss_glu_bf16(
|
||
gate_up.data_ptr() as *const c_void,
|
||
out.data_ptr() as *mut c_void,
|
||
n_elements,
|
||
alpha,
|
||
limit,
|
||
xserv_cuda::current_stream_raw(),
|
||
);
|
||
}
|
||
out
|
||
}
|