autodiff: bf16 mixed-precision path (fp32 master via cast op)

Tensor ops dispatch on dtype: fp32 branch unchanged (bit-identical),
bf16 branch routes matmul/attention through GemmEx and elementwise
through the bf16 kernels. Norm/softmax/RoPE/cross-entropy upcast to
fp32 around the existing fp32 kernels (standard AMP: reductions/loss
fp32, matmuls bf16). Transposes route bf16 through fp32 (pure layout).

New autodiff `cast` op is the AMP bridge: forward downcasts a fp32
master leaf to bf16 for the matmul; backward upcasts the bf16 grad
back to fp32. So the fp32 leaf accumulates an fp32 grad and AdamW /
clip / DDP all-reduce stay fp32 and completely unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-16 14:14:48 +08:00
parent d05115ddf3
commit b0086b5214
2 changed files with 490 additions and 159 deletions

View File

@@ -13,7 +13,27 @@
#![cfg(not(no_cuda))]
use crate::tape::Var;
use xtrain_tensor::Tensor;
use xtrain_tensor::{DType, Tensor};
/// dtype cast as an autograd node (Phase T12 — the AMP bridge between fp32 master
/// weights / fp32 reductions and the bf16 compute stream). Forward casts `x` to
/// `target`; **backward casts the upstream grad back to `x`'s dtype**. So a fp32
/// master-weight leaf fed through `cast(w, BF16)` into a bf16 matmul accumulates
/// an **fp32** grad — AdamW / clip / DDP all-reduce stay fp32, untouched.
pub fn cast(x: &Var, target: DType) -> Var {
let src = x.value().dtype();
if src == target {
return x.clone();
}
let out = x.value().to_dtype(target);
Var::from_op(
out,
vec![x.clone()],
Box::new(move |d, parents| {
Var::push_grad(&parents[0], d.to_dtype(src));
}),
)
}
/// `C = A @ B` (2D). Backward: `dA = dC @ Bᵀ`, `dB = Aᵀ @ dC`.
pub fn matmul(a: &Var, b: &Var) -> Var {

View File

@@ -107,6 +107,45 @@ impl Tensor {
}
}
// --- dtype cast (Phase T12, bf16 mixed precision) ---
/// Cast between F32 and BF16 (the AMP bridge: fp32 master ↔ bf16 compute).
/// Same dtype returns a cheap clone. Requires a contiguous CUDA tensor.
/// I32 is not castable here (only used for token-id targets).
#[cfg(not(no_cuda))]
pub fn to_dtype(&self, target: DType) -> Self {
if self.dtype == target {
return self.clone();
}
assert!(
matches!(self.device(), Device::Cuda(_)),
"to_dtype requires a CUDA tensor"
);
assert!(self.is_contiguous(), "to_dtype requires contiguous tensor");
let n = self.numel() as i32;
let out = Tensor::zeros(&self.shape, target, self.device());
match (self.dtype, target) {
(DType::F32, DType::BF16) => unsafe {
xtrain_cuda::ffi::launch_cast_f32_to_bf16(
self.data_ptr() as *const f32,
out.data_ptr() as *mut std::ffi::c_void,
n,
std::ptr::null_mut(),
);
},
(DType::BF16, DType::F32) => unsafe {
xtrain_cuda::ffi::launch_cast_bf16_to_f32(
self.data_ptr() as *const std::ffi::c_void,
out.data_ptr() as *mut f32,
n,
std::ptr::null_mut(),
);
},
(a, b) => panic!("unsupported cast {a} -> {b}"),
}
out
}
// --- Host data access (CPU only) ---
/// Typed read-only view of the data. Requires a contiguous CPU tensor.
@@ -136,7 +175,10 @@ impl Tensor {
/// GPU. Available only when CUDA was compiled in (`not(no_cuda)`).
#[cfg(not(no_cuda))]
pub fn scale(&self, alpha: f32) -> Self {
assert_eq!(self.dtype, DType::F32, "scale only supports F32 in T2");
assert!(
matches!(self.dtype, DType::F32 | DType::BF16),
"scale supports F32/BF16"
);
assert!(self.is_contiguous(), "scale requires contiguous tensor");
assert!(
matches!(self.device(), Device::Cuda(_)),
@@ -144,14 +186,27 @@ impl Tensor {
);
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
unsafe {
xtrain_cuda::ffi::launch_scale_f32(
self.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
alpha,
self.numel() as i32,
std::ptr::null_mut(), // default stream
);
let n = self.numel() as i32;
match self.dtype {
DType::F32 => unsafe {
xtrain_cuda::ffi::launch_scale_f32(
self.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
alpha,
n,
std::ptr::null_mut(), // default stream
);
},
DType::BF16 => unsafe {
xtrain_cuda::ffi::launch_scale_bf16(
self.data_ptr() as *const std::ffi::c_void,
out.data_ptr() as *mut std::ffi::c_void,
alpha,
n,
std::ptr::null_mut(),
);
},
_ => unreachable!(),
}
out
}
@@ -165,8 +220,11 @@ impl Tensor {
/// on the same GPU. Available only when CUDA is compiled in.
#[cfg(not(no_cuda))]
pub fn matmul(&self, other: &Tensor) -> Self {
assert_eq!(self.dtype, DType::F32, "matmul only supports F32");
assert_eq!(other.dtype, DType::F32, "matmul only supports F32");
assert_eq!(self.dtype, other.dtype, "matmul dtype mismatch");
assert!(
matches!(self.dtype, DType::F32 | DType::BF16),
"matmul supports F32/BF16"
);
assert_eq!(self.ndim(), 2, "matmul requires 2D lhs");
assert_eq!(other.ndim(), 2, "matmul requires 2D rhs");
assert_eq!(
@@ -187,19 +245,36 @@ impl Tensor {
let m = self.shape[0];
let k = self.shape[1];
let n = other.shape[1];
let out = Tensor::zeros(&[m, n], DType::F32, self.device());
xtrain_cuda::cublas::sgemm(
false,
false,
m,
n,
k,
1.0,
self.data_ptr() as *const f32,
other.data_ptr() as *const f32,
0.0,
out.data_ptr() as *mut f32,
);
let out = Tensor::zeros(&[m, n], self.dtype, self.device());
match self.dtype {
// fp32 path — unchanged (bit-identical to T7/T10/T11).
DType::F32 => xtrain_cuda::cublas::sgemm(
false,
false,
m,
n,
k,
1.0,
self.data_ptr() as *const f32,
other.data_ptr() as *const f32,
0.0,
out.data_ptr() as *mut f32,
),
// bf16 path — GemmEx, bf16 in/out, fp32 accumulation.
DType::BF16 => xtrain_cuda::cublas::gemm_ex(
false,
false,
m,
n,
k,
1.0,
self.data_ptr() as *const std::ffi::c_void,
other.data_ptr() as *const std::ffi::c_void,
0.0,
out.data_ptr() as *mut std::ffi::c_void,
),
_ => unreachable!(),
}
out
}
@@ -207,9 +282,15 @@ impl Tensor {
/// self[i,j]`. Requires a contiguous F32 CUDA tensor.
#[cfg(not(no_cuda))]
pub fn transpose_2d(&self) -> Self {
assert_eq!(self.dtype, DType::F32, "transpose only supports F32");
assert_eq!(self.ndim(), 2, "transpose_2d requires 2D tensor");
assert!(self.is_contiguous(), "transpose requires contiguous tensor");
if self.dtype == DType::BF16 {
return self
.to_dtype(DType::F32)
.transpose_2d()
.to_dtype(DType::BF16);
}
assert_eq!(self.dtype, DType::F32, "transpose supports F32/BF16");
assert!(
matches!(self.device(), Device::Cuda(_)),
"transpose requires a CUDA tensor"
@@ -245,35 +326,69 @@ impl Tensor {
assert_eq!(dc.shape[0], a.shape[0], "dC rows != A rows (M)");
assert_eq!(dc.shape[1], b.shape[1], "dC cols != B cols (N)");
assert_eq!(a.dtype, b.dtype, "matmul_backward dtype mismatch");
assert_eq!(a.dtype, dc.dtype, "matmul_backward dtype mismatch");
let (m, k, n) = (a.shape[0], a.shape[1], b.shape[1]);
let dt = a.dtype;
// dA[M,K] = dC[M,N] · Bᵀ (B stored [K,N], transposed by cuBLAS)
let da = Tensor::zeros(&[m, k], DType::F32, a.device());
xtrain_cuda::cublas::sgemm(
false,
true,
m,
k,
n,
1.0,
dc.data_ptr() as *const f32,
b.data_ptr() as *const f32,
0.0,
da.data_ptr() as *mut f32,
);
let da = Tensor::zeros(&[m, k], dt, a.device());
// dB[K,N] = Aᵀ · dC[M,N] (A stored [M,K], transposed by cuBLAS)
let db = Tensor::zeros(&[k, n], DType::F32, a.device());
xtrain_cuda::cublas::sgemm(
true,
false,
k,
n,
m,
1.0,
a.data_ptr() as *const f32,
dc.data_ptr() as *const f32,
0.0,
db.data_ptr() as *mut f32,
);
let db = Tensor::zeros(&[k, n], dt, a.device());
match dt {
DType::F32 => {
xtrain_cuda::cublas::sgemm(
false,
true,
m,
k,
n,
1.0,
dc.data_ptr() as *const f32,
b.data_ptr() as *const f32,
0.0,
da.data_ptr() as *mut f32,
);
xtrain_cuda::cublas::sgemm(
true,
false,
k,
n,
m,
1.0,
a.data_ptr() as *const f32,
dc.data_ptr() as *const f32,
0.0,
db.data_ptr() as *mut f32,
);
}
DType::BF16 => {
xtrain_cuda::cublas::gemm_ex(
false,
true,
m,
k,
n,
1.0,
dc.data_ptr() as *const std::ffi::c_void,
b.data_ptr() as *const std::ffi::c_void,
0.0,
da.data_ptr() as *mut std::ffi::c_void,
);
xtrain_cuda::cublas::gemm_ex(
true,
false,
k,
n,
m,
1.0,
a.data_ptr() as *const std::ffi::c_void,
dc.data_ptr() as *const std::ffi::c_void,
0.0,
db.data_ptr() as *mut std::ffi::c_void,
);
}
_ => panic!("matmul_backward supports F32/BF16"),
}
(da, db)
}
@@ -287,15 +402,28 @@ impl Tensor {
#[cfg(not(no_cuda))]
pub fn add(&self, other: &Tensor) -> Self {
self.check_binary(other, "add");
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_add_f32(
self.data_ptr() as *const f32,
other.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
self.numel() as i32,
std::ptr::null_mut(),
);
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
let n = self.numel() as i32;
match self.dtype {
DType::F32 => unsafe {
xtrain_cuda::ffi::launch_add_f32(
self.data_ptr() as *const f32,
other.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
n,
std::ptr::null_mut(),
);
},
DType::BF16 => unsafe {
xtrain_cuda::ffi::launch_add_bf16(
self.data_ptr() as *const std::ffi::c_void,
other.data_ptr() as *const std::ffi::c_void,
out.data_ptr() as *mut std::ffi::c_void,
n,
std::ptr::null_mut(),
);
},
_ => unreachable!(),
}
out
}
@@ -304,15 +432,28 @@ impl Tensor {
#[cfg(not(no_cuda))]
pub fn mul(&self, other: &Tensor) -> Self {
self.check_binary(other, "mul");
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_mul_f32(
self.data_ptr() as *const f32,
other.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
self.numel() as i32,
std::ptr::null_mut(),
);
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
let n = self.numel() as i32;
match self.dtype {
DType::F32 => unsafe {
xtrain_cuda::ffi::launch_mul_f32(
self.data_ptr() as *const f32,
other.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
n,
std::ptr::null_mut(),
);
},
DType::BF16 => unsafe {
xtrain_cuda::ffi::launch_mul_bf16(
self.data_ptr() as *const std::ffi::c_void,
other.data_ptr() as *const std::ffi::c_void,
out.data_ptr() as *mut std::ffi::c_void,
n,
std::ptr::null_mut(),
);
},
_ => unreachable!(),
}
out
}
@@ -324,17 +465,31 @@ impl Tensor {
assert_eq!(self.ndim(), 2, "add_bias requires 2D input");
assert_eq!(bias.ndim(), 1, "bias must be 1D");
assert_eq!(self.shape[1], bias.shape[0], "bias len != cols");
assert_eq!(self.dtype, bias.dtype, "add_bias dtype mismatch");
let (rows, cols) = (self.shape[0], self.shape[1]);
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_add_bias_f32(
self.data_ptr() as *const f32,
bias.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
rows as i32,
cols as i32,
std::ptr::null_mut(),
);
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
match self.dtype {
DType::F32 => unsafe {
xtrain_cuda::ffi::launch_add_bias_f32(
self.data_ptr() as *const f32,
bias.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
rows as i32,
cols as i32,
std::ptr::null_mut(),
);
},
DType::BF16 => unsafe {
xtrain_cuda::ffi::launch_add_bias_bf16(
self.data_ptr() as *const std::ffi::c_void,
bias.data_ptr() as *const std::ffi::c_void,
out.data_ptr() as *mut std::ffi::c_void,
rows as i32,
cols as i32,
std::ptr::null_mut(),
);
},
_ => panic!("add_bias supports F32/BF16"),
}
out
}
@@ -346,15 +501,27 @@ impl Tensor {
pub fn sum_rows(&self) -> Self {
assert_eq!(self.ndim(), 2, "sum_rows requires 2D input");
let (rows, cols) = (self.shape[0], self.shape[1]);
let out = Tensor::zeros(&[cols], DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_sum_rows_f32(
self.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
rows as i32,
cols as i32,
std::ptr::null_mut(),
);
let out = Tensor::zeros(&[cols], self.dtype, self.device());
match self.dtype {
DType::F32 => unsafe {
xtrain_cuda::ffi::launch_sum_rows_f32(
self.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
rows as i32,
cols as i32,
std::ptr::null_mut(),
);
},
DType::BF16 => unsafe {
xtrain_cuda::ffi::launch_sum_rows_bf16(
self.data_ptr() as *const std::ffi::c_void,
out.data_ptr() as *mut std::ffi::c_void,
rows as i32,
cols as i32,
std::ptr::null_mut(),
);
},
_ => panic!("sum_rows supports F32/BF16"),
}
out
}
@@ -367,6 +534,14 @@ impl Tensor {
assert_eq!(self.ndim(), 2, "rms_norm requires 2D input");
assert_eq!(gamma.ndim(), 1, "gamma must be 1D");
assert_eq!(self.shape[1], gamma.shape[0], "gamma len != cols");
// bf16: compute the reduction in fp32 (standard AMP), downcast y back to
// bf16. inv_rms stays fp32 (the cache the fp32 backward kernel consumes).
if self.dtype == DType::BF16 {
let (y, inv_rms) = self
.to_dtype(DType::F32)
.rms_norm(&gamma.to_dtype(DType::F32), eps);
return (y.to_dtype(DType::BF16), inv_rms);
}
let (rows, cols) = (self.shape[0], self.shape[1]);
let y = Tensor::zeros(&self.shape, DType::F32, self.device());
let inv_rms = Tensor::zeros(&[rows], DType::F32, self.device());
@@ -394,6 +569,17 @@ impl Tensor {
dy: &Tensor,
inv_rms: &Tensor,
) -> (Tensor, Tensor) {
// bf16: upcast (x, gamma, dy) to fp32, run the fp32 backward, downcast the
// grads back to bf16 (inv_rms is already the fp32 cache).
if x.dtype == DType::BF16 {
let (dx, dgamma) = Tensor::rms_norm_backward(
&x.to_dtype(DType::F32),
&gamma.to_dtype(DType::F32),
&dy.to_dtype(DType::F32),
inv_rms,
);
return (dx.to_dtype(DType::BF16), dgamma.to_dtype(DType::BF16));
}
let (rows, cols) = (x.shape[0], x.shape[1]);
let dx = Tensor::zeros(&[rows, cols], DType::F32, x.device());
let dgamma = Tensor::zeros(&[cols], DType::F32, x.device());
@@ -424,15 +610,30 @@ impl Tensor {
/// SiLU forward: `y = x * sigmoid(x)`, elementwise.
#[cfg(not(no_cuda))]
pub fn silu(&self) -> Self {
assert_eq!(self.dtype, DType::F32, "silu only supports F32");
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_silu_f32(
self.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
self.numel() as i32,
std::ptr::null_mut(),
);
assert!(
matches!(self.dtype, DType::F32 | DType::BF16),
"silu supports F32/BF16"
);
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
let n = self.numel() as i32;
match self.dtype {
DType::F32 => unsafe {
xtrain_cuda::ffi::launch_silu_f32(
self.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
n,
std::ptr::null_mut(),
);
},
DType::BF16 => unsafe {
xtrain_cuda::ffi::launch_silu_bf16(
self.data_ptr() as *const std::ffi::c_void,
out.data_ptr() as *mut std::ffi::c_void,
n,
std::ptr::null_mut(),
);
},
_ => unreachable!(),
}
out
}
@@ -441,15 +642,28 @@ impl Tensor {
/// Inputs are the forward `x` and upstream `dy`.
#[cfg(not(no_cuda))]
pub fn silu_backward(x: &Tensor, dy: &Tensor) -> Self {
let dx = Tensor::zeros(&x.shape, DType::F32, x.device());
unsafe {
xtrain_cuda::ffi::launch_silu_dx_f32(
x.data_ptr() as *const f32,
dy.data_ptr() as *const f32,
dx.data_ptr() as *mut f32,
x.numel() as i32,
std::ptr::null_mut(),
);
let dx = Tensor::zeros(&x.shape, x.dtype, x.device());
let n = x.numel() as i32;
match x.dtype {
DType::F32 => unsafe {
xtrain_cuda::ffi::launch_silu_dx_f32(
x.data_ptr() as *const f32,
dy.data_ptr() as *const f32,
dx.data_ptr() as *mut f32,
n,
std::ptr::null_mut(),
);
},
DType::BF16 => unsafe {
xtrain_cuda::ffi::launch_silu_dx_bf16(
x.data_ptr() as *const std::ffi::c_void,
dy.data_ptr() as *const std::ffi::c_void,
dx.data_ptr() as *mut std::ffi::c_void,
n,
std::ptr::null_mut(),
);
},
_ => panic!("silu_backward supports F32/BF16"),
}
dx
}
@@ -468,6 +682,12 @@ impl Tensor {
period > 0 && tokens % period == 0,
"tokens must be a multiple of period"
);
if self.dtype == DType::BF16 {
return self
.to_dtype(DType::F32)
.rope(theta, period)
.to_dtype(DType::BF16);
}
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_rope_f32(
@@ -488,6 +708,10 @@ impl Tensor {
/// orthogonal map, so it needs no cached forward values, only `theta`/`period`.
#[cfg(not(no_cuda))]
pub fn rope_backward(dy: &Tensor, theta: f32, period: usize) -> Self {
if dy.dtype == DType::BF16 {
return Tensor::rope_backward(&dy.to_dtype(DType::F32), theta, period)
.to_dtype(DType::BF16);
}
let (tokens, heads, head_dim) = (dy.shape[0], dy.shape[1], dy.shape[2]);
let dx = Tensor::zeros(&dy.shape, DType::F32, dy.device());
unsafe {
@@ -509,6 +733,9 @@ impl Tensor {
#[cfg(not(no_cuda))]
pub fn softmax(&self) -> Self {
assert_eq!(self.ndim(), 2, "softmax requires 2D input");
if self.dtype == DType::BF16 {
return self.to_dtype(DType::F32).softmax().to_dtype(DType::BF16);
}
let (rows, cols) = (self.shape[0], self.shape[1]);
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
@@ -527,6 +754,10 @@ impl Tensor {
/// Inputs are the forward output `y` and upstream `dy`.
#[cfg(not(no_cuda))]
pub fn softmax_backward(y: &Tensor, dy: &Tensor) -> Self {
if y.dtype == DType::BF16 {
return Tensor::softmax_backward(&y.to_dtype(DType::F32), &dy.to_dtype(DType::F32))
.to_dtype(DType::BF16);
}
let (rows, cols) = (y.shape[0], y.shape[1]);
let dx = Tensor::zeros(&y.shape, DType::F32, y.device());
unsafe {
@@ -550,6 +781,11 @@ impl Tensor {
assert_eq!(self.ndim(), 2, "cross_entropy requires 2D logits");
assert_eq!(target.dtype, DType::I32, "target must be I32");
assert_eq!(target.numel(), self.shape[0], "one target per row");
// CE math (log-sum-exp) is fp32 (probs/loss cached fp32). The model casts
// logits→fp32 before CE; this guard keeps the op robust to bf16 logits.
if self.dtype == DType::BF16 {
return self.to_dtype(DType::F32).cross_entropy(target);
}
let (rows, cols) = (self.shape[0], self.shape[1]);
let probs = Tensor::zeros(&self.shape, DType::F32, self.device());
let loss = Tensor::zeros(&[rows], DType::F32, self.device());
@@ -658,9 +894,15 @@ impl Tensor {
/// own backward is the same op (swap a,b).
#[cfg(not(no_cuda))]
pub fn transpose_3d01(&self) -> Self {
assert_eq!(self.dtype, DType::F32, "transpose_3d01 only supports F32");
assert_eq!(self.ndim(), 3, "transpose_3d01 requires a 3D tensor");
assert!(self.is_contiguous(), "transpose_3d01 requires contiguous");
if self.dtype == DType::BF16 {
return self
.to_dtype(DType::F32)
.transpose_3d01()
.to_dtype(DType::BF16);
}
assert_eq!(self.dtype, DType::F32, "transpose_3d01 supports F32/BF16");
let (a, b, c) = (self.shape[0], self.shape[1], self.shape[2]);
let out = Tensor::zeros(&[b, a, c], DType::F32, self.device());
unsafe {
@@ -689,54 +931,59 @@ impl Tensor {
assert_eq!(self.ndim(), 3, "attention Q must be [bh,seq,head_dim]");
assert_eq!(self.shape(), k.shape(), "Q/K shape mismatch");
assert_eq!(self.shape(), v.shape(), "Q/V shape mismatch");
assert_eq!(self.dtype, k.dtype, "Q/K dtype mismatch");
assert_eq!(self.dtype, v.dtype, "Q/V dtype mismatch");
let (bh, seq, hd) = (self.shape[0], self.shape[1], self.shape[2]);
let dev = self.device();
let dt = self.dtype;
// scores[bh,seq,seq] = Q[bh,seq,hd] · Kᵀ[bh,hd,seq]
let scores = Tensor::zeros(&[bh, seq, seq], DType::F32, dev);
xtrain_cuda::cublas::sgemm_strided_batched(
// scores[bh,seq,seq] = Q[bh,seq,hd] · Kᵀ[bh,hd,seq] (GEMM in self dtype)
let scores = Tensor::zeros(&[bh, seq, seq], dt, dev);
strided_batched_gemm(
dt,
false,
true,
seq,
seq,
hd,
1.0,
self.data_ptr() as *const f32,
self.data_ptr(),
seq * hd,
k.data_ptr() as *const f32,
k.data_ptr(),
seq * hd,
0.0,
scores.data_ptr() as *mut f32,
scores.data_ptr(),
seq * seq,
bh,
);
// probs = softmax(causal(scores · scale)), one block per [bh·seq] row.
let probs = Tensor::zeros(&[bh, seq, seq], DType::F32, dev);
// probs = softmax(causal(scores · scale)). Softmax math is fp32 (stable);
// for bf16 we upcast scores → f32 → kernel → downcast probs back to bf16
// (so the cached probs activation is half-size). One block per [bh·seq] row.
let scores_f32 = scores.to_dtype(DType::F32);
let probs_f32 = Tensor::zeros(&[bh, seq, seq], DType::F32, dev);
unsafe {
xtrain_cuda::ffi::launch_softmax_causal_f32(
scores.data_ptr() as *const f32,
probs.data_ptr() as *mut f32,
scores_f32.data_ptr() as *const f32,
probs_f32.data_ptr() as *mut f32,
(bh * seq) as i32,
seq as i32,
scale,
std::ptr::null_mut(),
);
}
let probs = probs_f32.to_dtype(dt);
// out[bh,seq,hd] = probs[bh,seq,seq] · V[bh,seq,hd]
let out = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
xtrain_cuda::cublas::sgemm_strided_batched(
let out = Tensor::zeros(&[bh, seq, hd], dt, dev);
strided_batched_gemm(
dt,
false,
false,
seq,
hd,
seq,
1.0,
probs.data_ptr() as *const f32,
probs.data_ptr(),
seq * seq,
v.data_ptr() as *const f32,
v.data_ptr(),
seq * hd,
0.0,
out.data_ptr() as *mut f32,
out.data_ptr(),
seq * hd,
bh,
);
@@ -764,45 +1011,44 @@ impl Tensor {
) -> (Tensor, Tensor, Tensor) {
let (bh, seq, hd) = (q.shape[0], q.shape[1], q.shape[2]);
let dev = q.device();
let dt = q.dtype;
// dP[bh,seq,seq] = dOut[bh,seq,hd] · Vᵀ[bh,hd,seq]
let dp = Tensor::zeros(&[bh, seq, seq], DType::F32, dev);
xtrain_cuda::cublas::sgemm_strided_batched(
let dp = Tensor::zeros(&[bh, seq, seq], dt, dev);
strided_batched_gemm(
dt,
false,
true,
seq,
seq,
hd,
1.0,
dout.data_ptr() as *const f32,
dout.data_ptr(),
seq * hd,
v.data_ptr() as *const f32,
v.data_ptr(),
seq * hd,
0.0,
dp.data_ptr() as *mut f32,
dp.data_ptr(),
seq * seq,
bh,
);
// dV[bh,seq,hd] = Pᵀ[bh,seq,seq] · dOut[bh,seq,hd]
let dv = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
xtrain_cuda::cublas::sgemm_strided_batched(
let dv = Tensor::zeros(&[bh, seq, hd], dt, dev);
strided_batched_gemm(
dt,
true,
false,
seq,
hd,
seq,
1.0,
probs.data_ptr() as *const f32,
probs.data_ptr(),
seq * seq,
dout.data_ptr() as *const f32,
dout.data_ptr(),
seq * hd,
0.0,
dv.data_ptr() as *mut f32,
dv.data_ptr(),
seq * hd,
bh,
);
// dScores = softmax Jacobian (per row) applied to dP, then ×scale.
// Reuse the row-wise softmax backward over the flattened [bh·seq, seq].
// softmax_backward + scale are dtype-aware (fp32 math inside for bf16).
let dscores = Tensor::softmax_backward(
&probs.reshape(&[bh * seq, seq]),
&dp.reshape(&[bh * seq, seq]),
@@ -810,38 +1056,36 @@ impl Tensor {
.reshape(&[bh, seq, seq]);
let dscores = dscores.scale(scale);
// dQ[bh,seq,hd] = dScores[bh,seq,seq] · K[bh,seq,hd]
let dq = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
xtrain_cuda::cublas::sgemm_strided_batched(
let dq = Tensor::zeros(&[bh, seq, hd], dt, dev);
strided_batched_gemm(
dt,
false,
false,
seq,
hd,
seq,
1.0,
dscores.data_ptr() as *const f32,
dscores.data_ptr(),
seq * seq,
k.data_ptr() as *const f32,
k.data_ptr(),
seq * hd,
0.0,
dq.data_ptr() as *mut f32,
dq.data_ptr(),
seq * hd,
bh,
);
// dK[bh,seq,hd] = dScoresᵀ[bh,seq,seq] · Q[bh,seq,hd]
let dk = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
xtrain_cuda::cublas::sgemm_strided_batched(
let dk = Tensor::zeros(&[bh, seq, hd], dt, dev);
strided_batched_gemm(
dt,
true,
false,
seq,
hd,
seq,
1.0,
dscores.data_ptr() as *const f32,
dscores.data_ptr(),
seq * seq,
q.data_ptr() as *const f32,
q.data_ptr(),
seq * hd,
0.0,
dk.data_ptr() as *mut f32,
dk.data_ptr(),
seq * hd,
bh,
);
@@ -853,9 +1097,15 @@ impl Tensor {
/// (`[B,S,nh,hd] <-> [B,nh,S,hd]`). Its own backward is the same op (swap b,c).
#[cfg(not(no_cuda))]
pub fn transpose_4d12(&self) -> Self {
assert_eq!(self.dtype, DType::F32, "transpose_4d12 only supports F32");
assert_eq!(self.ndim(), 4, "transpose_4d12 requires a 4D tensor");
assert!(self.is_contiguous(), "transpose_4d12 requires contiguous");
if self.dtype == DType::BF16 {
return self
.to_dtype(DType::F32)
.transpose_4d12()
.to_dtype(DType::BF16);
}
assert_eq!(self.dtype, DType::F32, "transpose_4d12 supports F32/BF16");
let (a, b, c, d) = (self.shape[0], self.shape[1], self.shape[2], self.shape[3]);
let out = Tensor::zeros(&[a, c, b, d], DType::F32, self.device());
unsafe {
@@ -875,8 +1125,11 @@ impl Tensor {
// Shared validation for same-shape binary elementwise ops.
#[cfg(not(no_cuda))]
fn check_binary(&self, other: &Tensor, op: &str) {
assert_eq!(self.dtype, DType::F32, "{op} only supports F32");
assert_eq!(other.dtype, DType::F32, "{op} only supports F32");
assert!(
matches!(self.dtype, DType::F32 | DType::BF16),
"{op} supports F32/BF16"
);
assert_eq!(self.dtype, other.dtype, "{op} dtype mismatch");
assert_eq!(self.shape(), other.shape(), "{op} shape mismatch");
assert_eq!(self.device(), other.device(), "{op} device mismatch");
assert!(
@@ -886,6 +1139,64 @@ impl Tensor {
}
}
/// Dispatch a strided-batched GEMM on `dt`: fp32 → `sgemm_strided_batched`,
/// bf16 → `gemm_ex_strided_batched` (bf16 in/out, fp32 accum). Pointers are the
/// raw `data_ptr()` bytes of contiguous same-dtype tensors. `alpha=1, beta=0`.
/// The fp32 path is bit-identical to the inlined T10 call it replaces.
#[cfg(not(no_cuda))]
#[allow(clippy::too_many_arguments)]
fn strided_batched_gemm(
dt: DType,
trans_a: bool,
trans_b: bool,
m: usize,
n: usize,
k: usize,
a: *const u8,
stride_a: usize,
b: *const u8,
stride_b: usize,
c: *const u8,
stride_c: usize,
batch: usize,
) {
match dt {
DType::F32 => xtrain_cuda::cublas::sgemm_strided_batched(
trans_a,
trans_b,
m,
n,
k,
1.0,
a as *const f32,
stride_a,
b as *const f32,
stride_b,
0.0,
c as *mut f32,
stride_c,
batch,
),
DType::BF16 => xtrain_cuda::cublas::gemm_ex_strided_batched(
trans_a,
trans_b,
m,
n,
k,
1.0,
a as *const std::ffi::c_void,
stride_a,
b as *const std::ffi::c_void,
stride_b,
0.0,
c as *mut std::ffi::c_void,
stride_c,
batch,
),
_ => panic!("strided_batched_gemm supports F32/BF16"),
}
}
impl std::fmt::Debug for Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(