Add the bf16 compute primitives for T12 mixed precision: - DType::BF16 (half::bf16 as TensorDType), 2 bytes. - cublasGemmEx / cublasGemmStridedBatchedEx FFI + CUDA_R_16BF / CUBLAS_COMPUTE_32F constants (values per xserv gemm.rs). - cublas::gemm_ex / gemm_ex_strided_batched: same row-major⟺col-major transpose algebra as sgemm, bf16 in/out, fp32 accumulation. - csrc/ops/cast.cu: f32<->bf16 cast + bf16 elementwise (add/mul/scale/ silu(+dx)/add_bias/sum_rows), each load->fp32->compute->store bf16. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
142 lines
5.9 KiB
Plaintext
142 lines
5.9 KiB
Plaintext
// bf16 mixed-precision kernels (Phase T12, KI-2).
|
|
//
|
|
// Two groups:
|
|
// 1. f32 <-> bf16 cast — the bridge between fp32 master weights / fp32
|
|
// reductions and the bf16 compute/activation stream.
|
|
// 2. bf16 elementwise ops (add / mul / silu / scale + their backwards) — the
|
|
// residual-stream ops that flow bf16 activations. Each loads bf16 -> float,
|
|
// computes in fp32, stores bf16 (so the math accumulates in fp32 while the
|
|
// stored activation is half-size). Matmuls go through cuBLAS GemmEx
|
|
// (cublas.rs); norm / softmax / rope / cross-entropy stay fp32 (the Rust
|
|
// wrappers upcast around the existing fp32 kernels).
|
|
//
|
|
// bf16 is __nv_bfloat16; __float2bfloat16 / __bfloat162float round-trip via fp32.
|
|
|
|
#include <cuda_bf16.h>
|
|
|
|
extern "C" {
|
|
|
|
// --- f32 <-> bf16 cast ---
|
|
|
|
__global__ void cast_f32_to_bf16_k(const float* in, __nv_bfloat16* out, int n) {
|
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (i < n) out[i] = __float2bfloat16(in[i]);
|
|
}
|
|
void launch_cast_f32_to_bf16(const float* in, void* out, int n, void* s) {
|
|
int blk = 256, grid = (n + blk - 1) / blk;
|
|
cast_f32_to_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
|
in, (__nv_bfloat16*)out, n);
|
|
}
|
|
|
|
__global__ void cast_bf16_to_f32_k(const __nv_bfloat16* in, float* out, int n) {
|
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (i < n) out[i] = __bfloat162float(in[i]);
|
|
}
|
|
void launch_cast_bf16_to_f32(const void* in, float* out, int n, void* s) {
|
|
int blk = 256, grid = (n + blk - 1) / blk;
|
|
cast_bf16_to_f32_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
|
(const __nv_bfloat16*)in, out, n);
|
|
}
|
|
|
|
// --- bf16 elementwise (load->fp32->compute->store bf16) ---
|
|
|
|
__global__ void add_bf16_k(const __nv_bfloat16* a, const __nv_bfloat16* b,
|
|
__nv_bfloat16* out, int n) {
|
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (i < n)
|
|
out[i] = __float2bfloat16(__bfloat162float(a[i]) + __bfloat162float(b[i]));
|
|
}
|
|
void launch_add_bf16(const void* a, const void* b, void* out, int n, void* s) {
|
|
int blk = 256, grid = (n + blk - 1) / blk;
|
|
add_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
|
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
|
}
|
|
|
|
__global__ void mul_bf16_k(const __nv_bfloat16* a, const __nv_bfloat16* b,
|
|
__nv_bfloat16* out, int n) {
|
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (i < n)
|
|
out[i] = __float2bfloat16(__bfloat162float(a[i]) * __bfloat162float(b[i]));
|
|
}
|
|
void launch_mul_bf16(const void* a, const void* b, void* out, int n, void* s) {
|
|
int blk = 256, grid = (n + blk - 1) / blk;
|
|
mul_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
|
(const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n);
|
|
}
|
|
|
|
__global__ void scale_bf16_k(const __nv_bfloat16* in, __nv_bfloat16* out,
|
|
float alpha, int n) {
|
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (i < n) out[i] = __float2bfloat16(__bfloat162float(in[i]) * alpha);
|
|
}
|
|
void launch_scale_bf16(const void* in, void* out, float alpha, int n, void* s) {
|
|
int blk = 256, grid = (n + blk - 1) / blk;
|
|
scale_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
|
(const __nv_bfloat16*)in, (__nv_bfloat16*)out, alpha, n);
|
|
}
|
|
|
|
// SiLU: y = x*sigmoid(x). Backward: dx = dy * (sig + x*sig*(1-sig)).
|
|
__global__ void silu_bf16_k(const __nv_bfloat16* x, __nv_bfloat16* y, int n) {
|
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (i < n) {
|
|
float v = __bfloat162float(x[i]);
|
|
float sig = 1.0f / (1.0f + expf(-v));
|
|
y[i] = __float2bfloat16(v * sig);
|
|
}
|
|
}
|
|
void launch_silu_bf16(const void* x, void* y, int n, void* s) {
|
|
int blk = 256, grid = (n + blk - 1) / blk;
|
|
silu_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
|
(const __nv_bfloat16*)x, (__nv_bfloat16*)y, n);
|
|
}
|
|
|
|
__global__ void silu_dx_bf16_k(const __nv_bfloat16* x, const __nv_bfloat16* dy,
|
|
__nv_bfloat16* dx, int n) {
|
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (i < n) {
|
|
float v = __bfloat162float(x[i]);
|
|
float sig = 1.0f / (1.0f + expf(-v));
|
|
float g = sig + v * sig * (1.0f - sig);
|
|
dx[i] = __float2bfloat16(__bfloat162float(dy[i]) * g);
|
|
}
|
|
}
|
|
void launch_silu_dx_bf16(const void* x, const void* dy, void* dx, int n, void* s) {
|
|
int blk = 256, grid = (n + blk - 1) / blk;
|
|
silu_dx_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
|
(const __nv_bfloat16*)x, (const __nv_bfloat16*)dy, (__nv_bfloat16*)dx, n);
|
|
}
|
|
|
|
// Broadcast bias add: out[r,c] = x[r,c] + bias[c]. x:[rows,cols], bias:[cols].
|
|
__global__ void add_bias_bf16_k(const __nv_bfloat16* x, const __nv_bfloat16* bias,
|
|
__nv_bfloat16* out, int rows, int cols) {
|
|
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (i < rows * cols)
|
|
out[i] = __float2bfloat16(__bfloat162float(x[i]) +
|
|
__bfloat162float(bias[i % cols]));
|
|
}
|
|
void launch_add_bias_bf16(const void* x, const void* bias, void* out, int rows,
|
|
int cols, void* s) {
|
|
int blk = 256, grid = (rows * cols + blk - 1) / blk;
|
|
add_bias_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
|
(const __nv_bfloat16*)x, (const __nv_bfloat16*)bias, (__nv_bfloat16*)out,
|
|
rows, cols);
|
|
}
|
|
|
|
// Column-sum over rows: dbias[c] = sum_r dout[r,c] (bias backward), fp32 accum.
|
|
__global__ void sum_rows_bf16_k(const __nv_bfloat16* dout, __nv_bfloat16* dbias,
|
|
int rows, int cols) {
|
|
int c = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (c < cols) {
|
|
float acc = 0.0f;
|
|
for (int r = 0; r < rows; ++r) acc += __bfloat162float(dout[r * cols + c]);
|
|
dbias[c] = __float2bfloat16(acc);
|
|
}
|
|
}
|
|
void launch_sum_rows_bf16(const void* dout, void* dbias, int rows, int cols, void* s) {
|
|
int blk = 256, grid = (cols + blk - 1) / blk;
|
|
sum_rows_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
|
(const __nv_bfloat16*)dout, (__nv_bfloat16*)dbias, rows, cols);
|
|
}
|
|
|
|
} // extern "C"
|