cuda: bf16 cuBLAS GemmEx (16BF in/out, fp32 accum) + cast kernels

Add the bf16 compute primitives for T12 mixed precision:
- DType::BF16 (half::bf16 as TensorDType), 2 bytes.
- cublasGemmEx / cublasGemmStridedBatchedEx FFI + CUDA_R_16BF /
  CUBLAS_COMPUTE_32F constants (values per xserv gemm.rs).
- cublas::gemm_ex / gemm_ex_strided_batched: same row-major⟺col-major
  transpose algebra as sgemm, bf16 in/out, fp32 accumulation.
- csrc/ops/cast.cu: f32<->bf16 cast + bf16 elementwise (add/mul/scale/
  silu(+dx)/add_bias/sum_rows), each load->fp32->compute->store bf16.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-16 14:14:39 +08:00
parent 511ceebbb3
commit d05115ddf3
5 changed files with 413 additions and 3 deletions

View File

@@ -36,6 +36,7 @@ fn main() {
.file("../../csrc/ops/model.cu")
.file("../../csrc/ops/optim.cu")
.file("../../csrc/ops/attention.cu")
.file("../../csrc/ops/cast.cu")
.compile("xtrain_cuda_kernels");
}

View File

@@ -19,6 +19,7 @@
use crate::ffi::{self, CublasHandle};
use std::cell::RefCell;
use std::ffi::c_void;
thread_local! {
static HANDLE: RefCell<Option<CublasHandle>> = const { RefCell::new(None) };
@@ -159,3 +160,131 @@ pub fn sgemm_strided_batched(
assert_eq!(status, 0, "cublasSgemmStridedBatched failed: {status}");
});
}
/// bf16 row-major GEMM `C[m,n] = opA(A)·opB(B)` via `cublasGemmEx`: bf16 in/out,
/// **fp32 accumulation** (`CUBLAS_COMPUTE_32F`) — the standard AMP matmul (Phase
/// T12). `a`/`b`/`c` are device pointers to row-major **bf16** matrices; the
/// row-major⟺col-major transpose algebra is identical to [`sgemm`] (we compute
/// the col-major `Cᵀ`). `alpha`/`beta` are fp32 host scalars (compute is fp32).
#[allow(clippy::too_many_arguments)]
pub fn gemm_ex(
trans_a: bool,
trans_b: bool,
m: usize,
n: usize,
k: usize,
alpha: f32,
a: *const c_void,
b: *const c_void,
beta: f32,
c: *mut c_void,
) {
let lda = if trans_a { m } else { k };
let ldb = if trans_b { k } else { n };
let ldc = n;
let op_a = if trans_a {
ffi::CUBLAS_OP_T
} else {
ffi::CUBLAS_OP_N
};
let op_b = if trans_b {
ffi::CUBLAS_OP_T
} else {
ffi::CUBLAS_OP_N
};
let bf16 = ffi::CUDA_R_16BF;
with_handle(|handle| {
let status = unsafe {
ffi::cublasGemmEx(
handle,
op_b,
op_a,
n as i32,
m as i32,
k as i32,
&alpha as *const f32 as *const c_void,
b,
bf16,
ldb as i32,
a,
bf16,
lda as i32,
&beta as *const f32 as *const c_void,
c,
bf16,
ldc as i32,
ffi::CUBLAS_COMPUTE_32F,
ffi::CUBLAS_GEMM_DEFAULT,
)
};
assert_eq!(status, 0, "cublasGemmEx failed: {status}");
});
}
/// Strided-batched bf16 GEMM (Phase T12) — the [`gemm_ex`] analogue of
/// [`sgemm_strided_batched`] for the batched attention GEMMs. bf16 in/out, fp32
/// accumulation; strides are in ELEMENTS.
#[allow(clippy::too_many_arguments)]
pub fn gemm_ex_strided_batched(
trans_a: bool,
trans_b: bool,
m: usize,
n: usize,
k: usize,
alpha: f32,
a: *const c_void,
stride_a: usize,
b: *const c_void,
stride_b: usize,
beta: f32,
c: *mut c_void,
stride_c: usize,
batch: usize,
) {
let lda = if trans_a { m } else { k };
let ldb = if trans_b { k } else { n };
let ldc = n;
let op_a = if trans_a {
ffi::CUBLAS_OP_T
} else {
ffi::CUBLAS_OP_N
};
let op_b = if trans_b {
ffi::CUBLAS_OP_T
} else {
ffi::CUBLAS_OP_N
};
let bf16 = ffi::CUDA_R_16BF;
with_handle(|handle| {
let status = unsafe {
ffi::cublasGemmStridedBatchedEx(
handle,
op_b,
op_a,
n as i32,
m as i32,
k as i32,
&alpha as *const f32 as *const c_void,
b,
bf16,
ldb as i32,
stride_b as i64,
a,
bf16,
lda as i32,
stride_a as i64,
&beta as *const f32 as *const c_void,
c,
bf16,
ldc as i32,
stride_c as i64,
batch as i32,
ffi::CUBLAS_COMPUTE_32F,
ffi::CUBLAS_GEMM_DEFAULT,
)
};
assert_eq!(status, 0, "cublasGemmStridedBatchedEx failed: {status}");
});
}

View File

@@ -324,3 +324,126 @@ unsafe extern "C" {
pub const CUBLAS_OP_N: i32 = 0;
#[cfg(not(no_cuda))]
pub const CUBLAS_OP_T: i32 = 1;
// --- bf16 mixed precision (Phase T12) ---
//
// cudaDataType / cublasComputeType enum values (same as xserv's gemm.rs). The
// bf16 GEMM uses bf16 in/out with fp32 accumulation (CUBLAS_COMPUTE_32F).
#[cfg(not(no_cuda))]
pub const CUDA_R_32F: i32 = 0;
#[cfg(not(no_cuda))]
pub const CUDA_R_16BF: i32 = 14;
#[cfg(not(no_cuda))]
pub const CUBLAS_COMPUTE_32F: i32 = 68;
/// CUBLAS_GEMM_DEFAULT — let cuBLAS pick the algorithm.
#[cfg(not(no_cuda))]
pub const CUBLAS_GEMM_DEFAULT: i32 = -1;
#[cfg(not(no_cuda))]
unsafe extern "C" {
// General GEMM with explicit in/out + compute types (bf16 path). `alpha`/
// `beta` are fp32 host scalars (compute type is fp32). Pointers are void* so
// the same FFI serves bf16 / fp32.
#[allow(clippy::too_many_arguments)]
pub fn cublasGemmEx(
handle: CublasHandle,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const std::ffi::c_void,
a: *const std::ffi::c_void,
a_type: i32,
lda: i32,
b: *const std::ffi::c_void,
b_type: i32,
ldb: i32,
beta: *const std::ffi::c_void,
c: *mut std::ffi::c_void,
c_type: i32,
ldc: i32,
compute_type: i32,
algo: i32,
) -> i32;
#[allow(clippy::too_many_arguments)]
pub fn cublasGemmStridedBatchedEx(
handle: CublasHandle,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const std::ffi::c_void,
a: *const std::ffi::c_void,
a_type: i32,
lda: i32,
stride_a: i64,
b: *const std::ffi::c_void,
b_type: i32,
ldb: i32,
stride_b: i64,
beta: *const std::ffi::c_void,
c: *mut std::ffi::c_void,
c_type: i32,
ldc: i32,
stride_c: i64,
batch_count: i32,
compute_type: i32,
algo: i32,
) -> i32;
}
// bf16 cast + elementwise kernels (csrc/ops/cast.cu). Pointers are void* (bf16
// buffers); f32 sides are typed. The activation stream flows bf16; the math
// accumulates in fp32 inside each kernel.
#[cfg(not(no_cuda))]
unsafe extern "C" {
pub fn launch_cast_f32_to_bf16(input: *const f32, out: *mut c_void, n: i32, s: CudaStream);
pub fn launch_cast_bf16_to_f32(input: *const c_void, out: *mut f32, n: i32, s: CudaStream);
pub fn launch_add_bf16(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
s: CudaStream,
);
pub fn launch_mul_bf16(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
s: CudaStream,
);
pub fn launch_scale_bf16(
input: *const c_void,
out: *mut c_void,
alpha: f32,
n: i32,
s: CudaStream,
);
pub fn launch_silu_bf16(x: *const c_void, y: *mut c_void, n: i32, s: CudaStream);
pub fn launch_silu_dx_bf16(
x: *const c_void,
dy: *const c_void,
dx: *mut c_void,
n: i32,
s: CudaStream,
);
pub fn launch_add_bias_bf16(
x: *const c_void,
bias: *const c_void,
out: *mut c_void,
rows: i32,
cols: i32,
s: CudaStream,
);
pub fn launch_sum_rows_bf16(
dout: *const c_void,
dbias: *mut c_void,
rows: i32,
cols: i32,
s: CudaStream,
);
}

View File

@@ -1,12 +1,16 @@
//! Tensor data types.
//!
//! T2 only needs `F32`, but the enum + `TensorDType` trait are structured so
//! half-precision types (F16/BF16) can be added later (T7 mixed precision)
//! without touching call sites.
//! T2 only needs `F32`; `BF16` was added in T12 for mixed-precision training
//! (bf16 linears / activations, fp32 master weights — see
//! `docs/11-bf16-mixed-precision.md`). The enum + `TensorDType` trait keep call
//! sites dtype-polymorphic.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DType {
F32,
/// bfloat16: 1 sign / 8 exponent / 7 mantissa. Same exponent range as f32
/// (so no loss scaling needed), ~2-3 decimal digits. The T12 AMP compute type.
BF16,
/// 32-bit signed integers. Used for cross-entropy targets (token ids).
I32,
}
@@ -15,6 +19,7 @@ impl DType {
pub fn size_bytes(self) -> usize {
match self {
DType::F32 => 4,
DType::BF16 => 2,
DType::I32 => 4,
}
}
@@ -22,6 +27,7 @@ impl DType {
pub fn name(self) -> &'static str {
match self {
DType::F32 => "f32",
DType::BF16 => "bf16",
DType::I32 => "i32",
}
}
@@ -50,6 +56,16 @@ impl TensorDType for f32 {
}
}
impl TensorDType for half::bf16 {
const DTYPE: DType = DType::BF16;
fn to_f64(self) -> f64 {
self.to_f64()
}
fn from_f64(v: f64) -> Self {
half::bf16::from_f64(v)
}
}
impl TensorDType for i32 {
const DTYPE: DType = DType::I32;
fn to_f64(self) -> f64 {

141
csrc/ops/cast.cu Normal file
View File

@@ -0,0 +1,141 @@
// bf16 mixed-precision kernels (Phase T12, KI-2).
//
// Two groups:
// 1. f32 <-> bf16 cast — the bridge between fp32 master weights / fp32
// reductions and the bf16 compute/activation stream.
// 2. bf16 elementwise ops (add / mul / silu / scale + their backwards) — the
// residual-stream ops that flow bf16 activations. Each loads bf16 -> float,
// computes in fp32, stores bf16 (so the math accumulates in fp32 while the
// stored activation is half-size). Matmuls go through cuBLAS GemmEx
// (cublas.rs); norm / softmax / rope / cross-entropy stay fp32 (the Rust
// wrappers upcast around the existing fp32 kernels).
//
// bf16 is __nv_bfloat16; __float2bfloat16 / __bfloat162float round-trip via fp32.
#include <cuda_bf16.h>
extern "C" {
// --- f32 <-> bf16 cast ---
__global__ void cast_f32_to_bf16_k(const float* in, __nv_bfloat16* out, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) out[i] = __float2bfloat16(in[i]);
}
void launch_cast_f32_to_bf16(const float* in, void* out, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
cast_f32_to_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
in, (__nv_bfloat16*)out, n);
}
__global__ void cast_bf16_to_f32_k(const __nv_bfloat16* in, float* out, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) out[i] = __bfloat162float(in[i]);
}
void launch_cast_bf16_to_f32(const void* in, float* out, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
cast_bf16_to_f32_k<<<grid, blk, 0, (cudaStream_t)s>>>(
(const __nv_bfloat16*)in, out, n);
}
// --- bf16 elementwise (load->fp32->compute->store bf16) ---
__global__ void add_bf16_k(const __nv_bfloat16* a, const __nv_bfloat16* b,
__nv_bfloat16* out, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n)
out[i] = __float2bfloat16(__bfloat162float(a[i]) + __bfloat162float(b[i]));
}
void launch_add_bf16(const void* a, const void* b, void* out, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
add_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
}
__global__ void mul_bf16_k(const __nv_bfloat16* a, const __nv_bfloat16* b,
__nv_bfloat16* out, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n)
out[i] = __float2bfloat16(__bfloat162float(a[i]) * __bfloat162float(b[i]));
}
void launch_mul_bf16(const void* a, const void* b, void* out, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
mul_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
}
__global__ void scale_bf16_k(const __nv_bfloat16* in, __nv_bfloat16* out,
float alpha, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) out[i] = __float2bfloat16(__bfloat162float(in[i]) * alpha);
}
void launch_scale_bf16(const void* in, void* out, float alpha, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
scale_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, alpha, n);
}
// SiLU: y = x*sigmoid(x). Backward: dx = dy * (sig + x*sig*(1-sig)).
__global__ void silu_bf16_k(const __nv_bfloat16* x, __nv_bfloat16* y, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
float v = __bfloat162float(x[i]);
float sig = 1.0f / (1.0f + expf(-v));
y[i] = __float2bfloat16(v * sig);
}
}
void launch_silu_bf16(const void* x, void* y, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
silu_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
(const __nv_bfloat16*)x, (__nv_bfloat16*)y, n);
}
__global__ void silu_dx_bf16_k(const __nv_bfloat16* x, const __nv_bfloat16* dy,
__nv_bfloat16* dx, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
float v = __bfloat162float(x[i]);
float sig = 1.0f / (1.0f + expf(-v));
float g = sig + v * sig * (1.0f - sig);
dx[i] = __float2bfloat16(__bfloat162float(dy[i]) * g);
}
}
void launch_silu_dx_bf16(const void* x, const void* dy, void* dx, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
silu_dx_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
(const __nv_bfloat16*)x, (const __nv_bfloat16*)dy, (__nv_bfloat16*)dx, n);
}
// Broadcast bias add: out[r,c] = x[r,c] + bias[c]. x:[rows,cols], bias:[cols].
__global__ void add_bias_bf16_k(const __nv_bfloat16* x, const __nv_bfloat16* bias,
__nv_bfloat16* out, int rows, int cols) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < rows * cols)
out[i] = __float2bfloat16(__bfloat162float(x[i]) +
__bfloat162float(bias[i % cols]));
}
void launch_add_bias_bf16(const void* x, const void* bias, void* out, int rows,
int cols, void* s) {
int blk = 256, grid = (rows * cols + blk - 1) / blk;
add_bias_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
(const __nv_bfloat16*)x, (const __nv_bfloat16*)bias, (__nv_bfloat16*)out,
rows, cols);
}
// Column-sum over rows: dbias[c] = sum_r dout[r,c] (bias backward), fp32 accum.
__global__ void sum_rows_bf16_k(const __nv_bfloat16* dout, __nv_bfloat16* dbias,
int rows, int cols) {
int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < cols) {
float acc = 0.0f;
for (int r = 0; r < rows; ++r) acc += __bfloat162float(dout[r * cols + c]);
dbias[c] = __float2bfloat16(acc);
}
}
void launch_sum_rows_bf16(const void* dout, void* dbias, int rows, int cols, void* s) {
int blk = 256, grid = (cols + blk - 1) / blk;
sum_rows_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
(const __nv_bfloat16*)dout, (__nv_bfloat16*)dbias, rows, cols);
}
} // extern "C"