Files
xserv/crates/xserv-kernels/src/activation.rs
Gahow Wang 4088f49b7d cuda: infrastructure for whole-step CUDA graph capture
- Thread-local launch stream (xserv_cuda::stream): every kernel
  wrapper, cublasSetStream, and NCCL collective now launches on
  current_stream_raw() — the legacy null stream by default (behavior
  unchanged), or the capture stream installed via push_stream during
  graph capture. Capture is impossible on the legacy stream.
- Allocator retain mode: blocks freed inside a retain window are
  quarantined (RetainedBlocks) instead of pooled, so an instantiated
  graph keeps exclusive ownership of every intermediate buffer it
  references across replays.
- Capture mode GLOBAL -> THREAD_LOCAL: concurrent TP rank threads
  must not poison each other's captures with their own cudaMallocs.
- embedding_device_ids / rope_inplace_device_pos: variants reading
  token ids / positions from persistent device buffers, replacing the
  per-call host upload that a captured region cannot contain.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 20:12:37 +08:00

154 lines
7.2 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
unsafe extern "C" {
fn launch_gelu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_scale_f32(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
fn launch_scale_bf16(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void);
fn launch_add_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_mul_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_mul_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void);
fn launch_gpt_oss_glu_bf16(gate_up: *const c_void, out: *mut c_void, n_elements: i32,
alpha: f32, limit: f32, stream: *mut c_void);
fn launch_bias_add_2d_bf16(x: *const c_void, bias: *const c_void, out: *mut c_void,
rows: i32, cols: i32, stream: *mut c_void);
}
fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
let n = x.numel();
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
let n = n as i32;
unsafe {
match x.dtype() {
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
_ => panic!("unsupported dtype"),
}
}
out
}
fn dispatch_binary(a: &Tensor, b: &Tensor,
f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void),
bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
assert_eq!(a.shape(), b.shape());
assert!(a.is_contiguous() && b.is_contiguous());
assert!(matches!(a.device(), Device::Cuda(_)));
assert_eq!(a.dtype(), b.dtype());
let out = Tensor::empty(a.shape(), a.dtype(), a.device());
let n = a.numel();
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
let n = n as i32;
unsafe {
match a.dtype() {
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
_ => panic!("unsupported dtype"),
}
}
out
}
pub fn gelu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16) }
pub fn silu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_silu_f32, launch_silu_bf16) }
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
let n = x.numel();
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
let n = n as i32;
unsafe {
match x.dtype() {
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()),
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()),
_ => panic!("unsupported dtype for scale"),
}
}
out
}
pub fn add(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_add_f32, launch_add_bf16) }
pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) }
/// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c] (BF16 only).
pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor {
assert_eq!(x.ndim(), 2);
assert_eq!(bias.ndim(), 1);
assert_eq!(x.dtype(), DType::BF16);
assert_eq!(bias.dtype(), DType::BF16);
assert!(x.is_contiguous() && bias.is_contiguous());
assert!(matches!(x.device(), Device::Cuda(_)));
let rows = x.shape()[0];
let cols = x.shape()[1];
assert_eq!(bias.shape()[0], cols, "bias size {} != cols {cols}", bias.shape()[0]);
assert!(rows * cols <= i32::MAX as usize);
let out = Tensor::empty(&[rows, cols], DType::BF16, x.device());
unsafe {
launch_bias_add_2d_bf16(
x.data_ptr() as _, bias.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
);
}
out
}
/// Fused SiLU×Mul: out = silu(gate) * up (BF16 only)
/// Saves one HBM read + one HBM write compared to separate silu + mul.
pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
assert_eq!(gate.shape(), up.shape());
assert!(gate.is_contiguous() && up.is_contiguous());
assert!(matches!(gate.device(), Device::Cuda(_)));
assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16");
let out = Tensor::empty(gate.shape(), gate.dtype(), gate.device());
let n = gate.numel();
assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)");
let n = n as i32;
unsafe {
launch_silu_mul_bf16(
gate.data_ptr() as *const c_void,
up.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
n,
xserv_cuda::current_stream_raw(),
);
}
out
}
/// gpt-oss fused GLU activation (BF16 only).
/// Input: gate_up [rows, 2*D] with interleaved columns (gate=even, up=odd).
/// Output: [rows, D]
/// Computes: gate.clamp(max=limit) * sigmoid(gate * alpha) * (up.clamp(-limit,limit) + 1)
pub fn gpt_oss_glu(gate_up: &Tensor, alpha: f32, limit: f32) -> Tensor {
assert!(gate_up.is_contiguous());
assert!(matches!(gate_up.device(), Device::Cuda(_)));
assert_eq!(gate_up.dtype(), DType::BF16, "gpt_oss_glu requires BF16");
assert_eq!(gate_up.ndim(), 2);
let rows = gate_up.shape()[0];
let cols = gate_up.shape()[1];
assert_eq!(cols % 2, 0);
let d = cols / 2;
let out = Tensor::empty(&[rows, d], gate_up.dtype(), gate_up.device());
let n_elements = (rows * d) as i32;
unsafe {
launch_gpt_oss_glu_bf16(
gate_up.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
n_elements,
alpha,
limit,
xserv_cuda::current_stream_raw(),
);
}
out
}