Gahow Wang
d05115ddf3
cuda: bf16 cuBLAS GemmEx (16BF in/out, fp32 accum) + cast kernels
Add the bf16 compute primitives for T12 mixed precision:
- DType::BF16 (half::bf16 as TensorDType), 2 bytes.
- cublasGemmEx / cublasGemmStridedBatchedEx FFI + CUDA_R_16BF /
CUBLAS_COMPUTE_32F constants (values per xserv gemm.rs).
- cublas::gemm_ex / gemm_ex_strided_batched: same row-major⟺col-major
transpose algebra as sgemm, bf16 in/out, fp32 accumulation.
- csrc/ops/cast.cu: f32<->bf16 cast + bf16 elementwise (add/mul/scale/
silu(+dx)/add_bias/sum_rows), each load->fp32->compute->store bf16.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>