cuda: bf16 cuBLAS GemmEx (16BF in/out, fp32 accum) + cast kernels
Add the bf16 compute primitives for T12 mixed precision: - DType::BF16 (half::bf16 as TensorDType), 2 bytes. - cublasGemmEx / cublasGemmStridedBatchedEx FFI + CUDA_R_16BF / CUBLAS_COMPUTE_32F constants (values per xserv gemm.rs). - cublas::gemm_ex / gemm_ex_strided_batched: same row-major⟺col-major transpose algebra as sgemm, bf16 in/out, fp32 accumulation. - csrc/ops/cast.cu: f32<->bf16 cast + bf16 elementwise (add/mul/scale/ silu(+dx)/add_bias/sum_rows), each load->fp32->compute->store bf16. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -36,6 +36,7 @@ fn main() {
|
||||
.file("../../csrc/ops/model.cu")
|
||||
.file("../../csrc/ops/optim.cu")
|
||||
.file("../../csrc/ops/attention.cu")
|
||||
.file("../../csrc/ops/cast.cu")
|
||||
.compile("xtrain_cuda_kernels");
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
|
||||
use crate::ffi::{self, CublasHandle};
|
||||
use std::cell::RefCell;
|
||||
use std::ffi::c_void;
|
||||
|
||||
thread_local! {
|
||||
static HANDLE: RefCell<Option<CublasHandle>> = const { RefCell::new(None) };
|
||||
@@ -159,3 +160,131 @@ pub fn sgemm_strided_batched(
|
||||
assert_eq!(status, 0, "cublasSgemmStridedBatched failed: {status}");
|
||||
});
|
||||
}
|
||||
|
||||
/// bf16 row-major GEMM `C[m,n] = opA(A)·opB(B)` via `cublasGemmEx`: bf16 in/out,
|
||||
/// **fp32 accumulation** (`CUBLAS_COMPUTE_32F`) — the standard AMP matmul (Phase
|
||||
/// T12). `a`/`b`/`c` are device pointers to row-major **bf16** matrices; the
|
||||
/// row-major⟺col-major transpose algebra is identical to [`sgemm`] (we compute
|
||||
/// the col-major `Cᵀ`). `alpha`/`beta` are fp32 host scalars (compute is fp32).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn gemm_ex(
|
||||
trans_a: bool,
|
||||
trans_b: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
beta: f32,
|
||||
c: *mut c_void,
|
||||
) {
|
||||
let lda = if trans_a { m } else { k };
|
||||
let ldb = if trans_b { k } else { n };
|
||||
let ldc = n;
|
||||
let op_a = if trans_a {
|
||||
ffi::CUBLAS_OP_T
|
||||
} else {
|
||||
ffi::CUBLAS_OP_N
|
||||
};
|
||||
let op_b = if trans_b {
|
||||
ffi::CUBLAS_OP_T
|
||||
} else {
|
||||
ffi::CUBLAS_OP_N
|
||||
};
|
||||
let bf16 = ffi::CUDA_R_16BF;
|
||||
|
||||
with_handle(|handle| {
|
||||
let status = unsafe {
|
||||
ffi::cublasGemmEx(
|
||||
handle,
|
||||
op_b,
|
||||
op_a,
|
||||
n as i32,
|
||||
m as i32,
|
||||
k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b,
|
||||
bf16,
|
||||
ldb as i32,
|
||||
a,
|
||||
bf16,
|
||||
lda as i32,
|
||||
&beta as *const f32 as *const c_void,
|
||||
c,
|
||||
bf16,
|
||||
ldc as i32,
|
||||
ffi::CUBLAS_COMPUTE_32F,
|
||||
ffi::CUBLAS_GEMM_DEFAULT,
|
||||
)
|
||||
};
|
||||
assert_eq!(status, 0, "cublasGemmEx failed: {status}");
|
||||
});
|
||||
}
|
||||
|
||||
/// Strided-batched bf16 GEMM (Phase T12) — the [`gemm_ex`] analogue of
|
||||
/// [`sgemm_strided_batched`] for the batched attention GEMMs. bf16 in/out, fp32
|
||||
/// accumulation; strides are in ELEMENTS.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn gemm_ex_strided_batched(
|
||||
trans_a: bool,
|
||||
trans_b: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: *const c_void,
|
||||
stride_a: usize,
|
||||
b: *const c_void,
|
||||
stride_b: usize,
|
||||
beta: f32,
|
||||
c: *mut c_void,
|
||||
stride_c: usize,
|
||||
batch: usize,
|
||||
) {
|
||||
let lda = if trans_a { m } else { k };
|
||||
let ldb = if trans_b { k } else { n };
|
||||
let ldc = n;
|
||||
let op_a = if trans_a {
|
||||
ffi::CUBLAS_OP_T
|
||||
} else {
|
||||
ffi::CUBLAS_OP_N
|
||||
};
|
||||
let op_b = if trans_b {
|
||||
ffi::CUBLAS_OP_T
|
||||
} else {
|
||||
ffi::CUBLAS_OP_N
|
||||
};
|
||||
let bf16 = ffi::CUDA_R_16BF;
|
||||
|
||||
with_handle(|handle| {
|
||||
let status = unsafe {
|
||||
ffi::cublasGemmStridedBatchedEx(
|
||||
handle,
|
||||
op_b,
|
||||
op_a,
|
||||
n as i32,
|
||||
m as i32,
|
||||
k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b,
|
||||
bf16,
|
||||
ldb as i32,
|
||||
stride_b as i64,
|
||||
a,
|
||||
bf16,
|
||||
lda as i32,
|
||||
stride_a as i64,
|
||||
&beta as *const f32 as *const c_void,
|
||||
c,
|
||||
bf16,
|
||||
ldc as i32,
|
||||
stride_c as i64,
|
||||
batch as i32,
|
||||
ffi::CUBLAS_COMPUTE_32F,
|
||||
ffi::CUBLAS_GEMM_DEFAULT,
|
||||
)
|
||||
};
|
||||
assert_eq!(status, 0, "cublasGemmStridedBatchedEx failed: {status}");
|
||||
});
|
||||
}
|
||||
|
||||
@@ -324,3 +324,126 @@ unsafe extern "C" {
|
||||
pub const CUBLAS_OP_N: i32 = 0;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub const CUBLAS_OP_T: i32 = 1;
|
||||
|
||||
// --- bf16 mixed precision (Phase T12) ---
|
||||
//
|
||||
// cudaDataType / cublasComputeType enum values (same as xserv's gemm.rs). The
|
||||
// bf16 GEMM uses bf16 in/out with fp32 accumulation (CUBLAS_COMPUTE_32F).
|
||||
#[cfg(not(no_cuda))]
|
||||
pub const CUDA_R_32F: i32 = 0;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub const CUDA_R_16BF: i32 = 14;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub const CUBLAS_COMPUTE_32F: i32 = 68;
|
||||
/// CUBLAS_GEMM_DEFAULT — let cuBLAS pick the algorithm.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub const CUBLAS_GEMM_DEFAULT: i32 = -1;
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
unsafe extern "C" {
|
||||
// General GEMM with explicit in/out + compute types (bf16 path). `alpha`/
|
||||
// `beta` are fp32 host scalars (compute type is fp32). Pointers are void* so
|
||||
// the same FFI serves bf16 / fp32.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn cublasGemmEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32,
|
||||
transb: i32,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: *const std::ffi::c_void,
|
||||
a: *const std::ffi::c_void,
|
||||
a_type: i32,
|
||||
lda: i32,
|
||||
b: *const std::ffi::c_void,
|
||||
b_type: i32,
|
||||
ldb: i32,
|
||||
beta: *const std::ffi::c_void,
|
||||
c: *mut std::ffi::c_void,
|
||||
c_type: i32,
|
||||
ldc: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
) -> i32;
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn cublasGemmStridedBatchedEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32,
|
||||
transb: i32,
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
alpha: *const std::ffi::c_void,
|
||||
a: *const std::ffi::c_void,
|
||||
a_type: i32,
|
||||
lda: i32,
|
||||
stride_a: i64,
|
||||
b: *const std::ffi::c_void,
|
||||
b_type: i32,
|
||||
ldb: i32,
|
||||
stride_b: i64,
|
||||
beta: *const std::ffi::c_void,
|
||||
c: *mut std::ffi::c_void,
|
||||
c_type: i32,
|
||||
ldc: i32,
|
||||
stride_c: i64,
|
||||
batch_count: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
) -> i32;
|
||||
}
|
||||
|
||||
// bf16 cast + elementwise kernels (csrc/ops/cast.cu). Pointers are void* (bf16
|
||||
// buffers); f32 sides are typed. The activation stream flows bf16; the math
|
||||
// accumulates in fp32 inside each kernel.
|
||||
#[cfg(not(no_cuda))]
|
||||
unsafe extern "C" {
|
||||
pub fn launch_cast_f32_to_bf16(input: *const f32, out: *mut c_void, n: i32, s: CudaStream);
|
||||
pub fn launch_cast_bf16_to_f32(input: *const c_void, out: *mut f32, n: i32, s: CudaStream);
|
||||
|
||||
pub fn launch_add_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_mul_bf16(
|
||||
a: *const c_void,
|
||||
b: *const c_void,
|
||||
out: *mut c_void,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_scale_bf16(
|
||||
input: *const c_void,
|
||||
out: *mut c_void,
|
||||
alpha: f32,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_silu_bf16(x: *const c_void, y: *mut c_void, n: i32, s: CudaStream);
|
||||
pub fn launch_silu_dx_bf16(
|
||||
x: *const c_void,
|
||||
dy: *const c_void,
|
||||
dx: *mut c_void,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_add_bias_bf16(
|
||||
x: *const c_void,
|
||||
bias: *const c_void,
|
||||
out: *mut c_void,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_sum_rows_bf16(
|
||||
dout: *const c_void,
|
||||
dbias: *mut c_void,
|
||||
rows: i32,
|
||||
cols: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
//! Tensor data types.
|
||||
//!
|
||||
//! T2 only needs `F32`, but the enum + `TensorDType` trait are structured so
|
||||
//! half-precision types (F16/BF16) can be added later (T7 mixed precision)
|
||||
//! without touching call sites.
|
||||
//! T2 only needs `F32`; `BF16` was added in T12 for mixed-precision training
|
||||
//! (bf16 linears / activations, fp32 master weights — see
|
||||
//! `docs/11-bf16-mixed-precision.md`). The enum + `TensorDType` trait keep call
|
||||
//! sites dtype-polymorphic.
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
F32,
|
||||
/// bfloat16: 1 sign / 8 exponent / 7 mantissa. Same exponent range as f32
|
||||
/// (so no loss scaling needed), ~2-3 decimal digits. The T12 AMP compute type.
|
||||
BF16,
|
||||
/// 32-bit signed integers. Used for cross-entropy targets (token ids).
|
||||
I32,
|
||||
}
|
||||
@@ -15,6 +19,7 @@ impl DType {
|
||||
pub fn size_bytes(self) -> usize {
|
||||
match self {
|
||||
DType::F32 => 4,
|
||||
DType::BF16 => 2,
|
||||
DType::I32 => 4,
|
||||
}
|
||||
}
|
||||
@@ -22,6 +27,7 @@ impl DType {
|
||||
pub fn name(self) -> &'static str {
|
||||
match self {
|
||||
DType::F32 => "f32",
|
||||
DType::BF16 => "bf16",
|
||||
DType::I32 => "i32",
|
||||
}
|
||||
}
|
||||
@@ -50,6 +56,16 @@ impl TensorDType for f32 {
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorDType for half::bf16 {
|
||||
const DTYPE: DType = DType::BF16;
|
||||
fn to_f64(self) -> f64 {
|
||||
self.to_f64()
|
||||
}
|
||||
fn from_f64(v: f64) -> Self {
|
||||
half::bf16::from_f64(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorDType for i32 {
|
||||
const DTYPE: DType = DType::I32;
|
||||
fn to_f64(self) -> f64 {
|
||||
|
||||
141
csrc/ops/cast.cu
Normal file
141
csrc/ops/cast.cu
Normal file
@@ -0,0 +1,141 @@
|
||||
// bf16 mixed-precision kernels (Phase T12, KI-2).
|
||||
//
|
||||
// Two groups:
|
||||
// 1. f32 <-> bf16 cast — the bridge between fp32 master weights / fp32
|
||||
// reductions and the bf16 compute/activation stream.
|
||||
// 2. bf16 elementwise ops (add / mul / silu / scale + their backwards) — the
|
||||
// residual-stream ops that flow bf16 activations. Each loads bf16 -> float,
|
||||
// computes in fp32, stores bf16 (so the math accumulates in fp32 while the
|
||||
// stored activation is half-size). Matmuls go through cuBLAS GemmEx
|
||||
// (cublas.rs); norm / softmax / rope / cross-entropy stay fp32 (the Rust
|
||||
// wrappers upcast around the existing fp32 kernels).
|
||||
//
|
||||
// bf16 is __nv_bfloat16; __float2bfloat16 / __bfloat162float round-trip via fp32.
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
extern "C" {
|
||||
|
||||
// --- f32 <-> bf16 cast ---
|
||||
|
||||
__global__ void cast_f32_to_bf16_k(const float* in, __nv_bfloat16* out, int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) out[i] = __float2bfloat16(in[i]);
|
||||
}
|
||||
void launch_cast_f32_to_bf16(const float* in, void* out, int n, void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
cast_f32_to_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
||||
in, (__nv_bfloat16*)out, n);
|
||||
}
|
||||
|
||||
__global__ void cast_bf16_to_f32_k(const __nv_bfloat16* in, float* out, int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) out[i] = __bfloat162float(in[i]);
|
||||
}
|
||||
void launch_cast_bf16_to_f32(const void* in, float* out, int n, void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
cast_bf16_to_f32_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
||||
(const __nv_bfloat16*)in, out, n);
|
||||
}
|
||||
|
||||
// --- bf16 elementwise (load->fp32->compute->store bf16) ---
|
||||
|
||||
__global__ void add_bf16_k(const __nv_bfloat16* a, const __nv_bfloat16* b,
|
||||
__nv_bfloat16* out, int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n)
|
||||
out[i] = __float2bfloat16(__bfloat162float(a[i]) + __bfloat162float(b[i]));
|
||||
}
|
||||
void launch_add_bf16(const void* a, const void* b, void* out, int n, void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
add_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
||||
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
||||
}
|
||||
|
||||
__global__ void mul_bf16_k(const __nv_bfloat16* a, const __nv_bfloat16* b,
|
||||
__nv_bfloat16* out, int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n)
|
||||
out[i] = __float2bfloat16(__bfloat162float(a[i]) * __bfloat162float(b[i]));
|
||||
}
|
||||
void launch_mul_bf16(const void* a, const void* b, void* out, int n, void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
mul_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
||||
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
||||
}
|
||||
|
||||
__global__ void scale_bf16_k(const __nv_bfloat16* in, __nv_bfloat16* out,
|
||||
float alpha, int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) out[i] = __float2bfloat16(__bfloat162float(in[i]) * alpha);
|
||||
}
|
||||
void launch_scale_bf16(const void* in, void* out, float alpha, int n, void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
scale_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
||||
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, alpha, n);
|
||||
}
|
||||
|
||||
// SiLU: y = x*sigmoid(x). Backward: dx = dy * (sig + x*sig*(1-sig)).
|
||||
__global__ void silu_bf16_k(const __nv_bfloat16* x, __nv_bfloat16* y, int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) {
|
||||
float v = __bfloat162float(x[i]);
|
||||
float sig = 1.0f / (1.0f + expf(-v));
|
||||
y[i] = __float2bfloat16(v * sig);
|
||||
}
|
||||
}
|
||||
void launch_silu_bf16(const void* x, void* y, int n, void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
silu_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)y, n);
|
||||
}
|
||||
|
||||
__global__ void silu_dx_bf16_k(const __nv_bfloat16* x, const __nv_bfloat16* dy,
|
||||
__nv_bfloat16* dx, int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) {
|
||||
float v = __bfloat162float(x[i]);
|
||||
float sig = 1.0f / (1.0f + expf(-v));
|
||||
float g = sig + v * sig * (1.0f - sig);
|
||||
dx[i] = __float2bfloat16(__bfloat162float(dy[i]) * g);
|
||||
}
|
||||
}
|
||||
void launch_silu_dx_bf16(const void* x, const void* dy, void* dx, int n, void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
silu_dx_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
||||
(const __nv_bfloat16*)x, (const __nv_bfloat16*)dy, (__nv_bfloat16*)dx, n);
|
||||
}
|
||||
|
||||
// Broadcast bias add: out[r,c] = x[r,c] + bias[c]. x:[rows,cols], bias:[cols].
|
||||
__global__ void add_bias_bf16_k(const __nv_bfloat16* x, const __nv_bfloat16* bias,
|
||||
__nv_bfloat16* out, int rows, int cols) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < rows * cols)
|
||||
out[i] = __float2bfloat16(__bfloat162float(x[i]) +
|
||||
__bfloat162float(bias[i % cols]));
|
||||
}
|
||||
void launch_add_bias_bf16(const void* x, const void* bias, void* out, int rows,
|
||||
int cols, void* s) {
|
||||
int blk = 256, grid = (rows * cols + blk - 1) / blk;
|
||||
add_bias_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
||||
(const __nv_bfloat16*)x, (const __nv_bfloat16*)bias, (__nv_bfloat16*)out,
|
||||
rows, cols);
|
||||
}
|
||||
|
||||
// Column-sum over rows: dbias[c] = sum_r dout[r,c] (bias backward), fp32 accum.
|
||||
__global__ void sum_rows_bf16_k(const __nv_bfloat16* dout, __nv_bfloat16* dbias,
|
||||
int rows, int cols) {
|
||||
int c = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (c < cols) {
|
||||
float acc = 0.0f;
|
||||
for (int r = 0; r < rows; ++r) acc += __bfloat162float(dout[r * cols + c]);
|
||||
dbias[c] = __float2bfloat16(acc);
|
||||
}
|
||||
}
|
||||
void launch_sum_rows_bf16(const void* dout, void* dbias, int rows, int cols, void* s) {
|
||||
int blk = 256, grid = (cols + blk - 1) / blk;
|
||||
sum_rows_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
||||
(const __nv_bfloat16*)dout, (__nv_bfloat16*)dbias, rows, cols);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
Reference in New Issue
Block a user