autodiff: bf16 mixed-precision path (fp32 master via cast op)
Tensor ops dispatch on dtype: fp32 branch unchanged (bit-identical), bf16 branch routes matmul/attention through GemmEx and elementwise through the bf16 kernels. Norm/softmax/RoPE/cross-entropy upcast to fp32 around the existing fp32 kernels (standard AMP: reductions/loss fp32, matmuls bf16). Transposes route bf16 through fp32 (pure layout). New autodiff `cast` op is the AMP bridge: forward downcasts a fp32 master leaf to bf16 for the matmul; backward upcasts the bf16 grad back to fp32. So the fp32 leaf accumulates an fp32 grad and AdamW / clip / DDP all-reduce stay fp32 and completely unchanged. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -13,7 +13,27 @@
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use crate::tape::Var;
|
||||
use xtrain_tensor::Tensor;
|
||||
use xtrain_tensor::{DType, Tensor};
|
||||
|
||||
/// dtype cast as an autograd node (Phase T12 — the AMP bridge between fp32 master
|
||||
/// weights / fp32 reductions and the bf16 compute stream). Forward casts `x` to
|
||||
/// `target`; **backward casts the upstream grad back to `x`'s dtype**. So a fp32
|
||||
/// master-weight leaf fed through `cast(w, BF16)` into a bf16 matmul accumulates
|
||||
/// an **fp32** grad — AdamW / clip / DDP all-reduce stay fp32, untouched.
|
||||
pub fn cast(x: &Var, target: DType) -> Var {
|
||||
let src = x.value().dtype();
|
||||
if src == target {
|
||||
return x.clone();
|
||||
}
|
||||
let out = x.value().to_dtype(target);
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![x.clone()],
|
||||
Box::new(move |d, parents| {
|
||||
Var::push_grad(&parents[0], d.to_dtype(src));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// `C = A @ B` (2D). Backward: `dA = dC @ Bᵀ`, `dB = Aᵀ @ dC`.
|
||||
pub fn matmul(a: &Var, b: &Var) -> Var {
|
||||
|
||||
@@ -107,6 +107,45 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
// --- dtype cast (Phase T12, bf16 mixed precision) ---
|
||||
|
||||
/// Cast between F32 and BF16 (the AMP bridge: fp32 master ↔ bf16 compute).
|
||||
/// Same dtype returns a cheap clone. Requires a contiguous CUDA tensor.
|
||||
/// I32 is not castable here (only used for token-id targets).
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn to_dtype(&self, target: DType) -> Self {
|
||||
if self.dtype == target {
|
||||
return self.clone();
|
||||
}
|
||||
assert!(
|
||||
matches!(self.device(), Device::Cuda(_)),
|
||||
"to_dtype requires a CUDA tensor"
|
||||
);
|
||||
assert!(self.is_contiguous(), "to_dtype requires contiguous tensor");
|
||||
let n = self.numel() as i32;
|
||||
let out = Tensor::zeros(&self.shape, target, self.device());
|
||||
match (self.dtype, target) {
|
||||
(DType::F32, DType::BF16) => unsafe {
|
||||
xtrain_cuda::ffi::launch_cast_f32_to_bf16(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut std::ffi::c_void,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
(DType::BF16, DType::F32) => unsafe {
|
||||
xtrain_cuda::ffi::launch_cast_bf16_to_f32(
|
||||
self.data_ptr() as *const std::ffi::c_void,
|
||||
out.data_ptr() as *mut f32,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
(a, b) => panic!("unsupported cast {a} -> {b}"),
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
// --- Host data access (CPU only) ---
|
||||
|
||||
/// Typed read-only view of the data. Requires a contiguous CPU tensor.
|
||||
@@ -136,7 +175,10 @@ impl Tensor {
|
||||
/// GPU. Available only when CUDA was compiled in (`not(no_cuda)`).
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn scale(&self, alpha: f32) -> Self {
|
||||
assert_eq!(self.dtype, DType::F32, "scale only supports F32 in T2");
|
||||
assert!(
|
||||
matches!(self.dtype, DType::F32 | DType::BF16),
|
||||
"scale supports F32/BF16"
|
||||
);
|
||||
assert!(self.is_contiguous(), "scale requires contiguous tensor");
|
||||
assert!(
|
||||
matches!(self.device(), Device::Cuda(_)),
|
||||
@@ -144,14 +186,27 @@ impl Tensor {
|
||||
);
|
||||
|
||||
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_scale_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
alpha,
|
||||
self.numel() as i32,
|
||||
std::ptr::null_mut(), // default stream
|
||||
);
|
||||
let n = self.numel() as i32;
|
||||
match self.dtype {
|
||||
DType::F32 => unsafe {
|
||||
xtrain_cuda::ffi::launch_scale_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
alpha,
|
||||
n,
|
||||
std::ptr::null_mut(), // default stream
|
||||
);
|
||||
},
|
||||
DType::BF16 => unsafe {
|
||||
xtrain_cuda::ffi::launch_scale_bf16(
|
||||
self.data_ptr() as *const std::ffi::c_void,
|
||||
out.data_ptr() as *mut std::ffi::c_void,
|
||||
alpha,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
out
|
||||
}
|
||||
@@ -165,8 +220,11 @@ impl Tensor {
|
||||
/// on the same GPU. Available only when CUDA is compiled in.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn matmul(&self, other: &Tensor) -> Self {
|
||||
assert_eq!(self.dtype, DType::F32, "matmul only supports F32");
|
||||
assert_eq!(other.dtype, DType::F32, "matmul only supports F32");
|
||||
assert_eq!(self.dtype, other.dtype, "matmul dtype mismatch");
|
||||
assert!(
|
||||
matches!(self.dtype, DType::F32 | DType::BF16),
|
||||
"matmul supports F32/BF16"
|
||||
);
|
||||
assert_eq!(self.ndim(), 2, "matmul requires 2D lhs");
|
||||
assert_eq!(other.ndim(), 2, "matmul requires 2D rhs");
|
||||
assert_eq!(
|
||||
@@ -187,19 +245,36 @@ impl Tensor {
|
||||
let m = self.shape[0];
|
||||
let k = self.shape[1];
|
||||
let n = other.shape[1];
|
||||
let out = Tensor::zeros(&[m, n], DType::F32, self.device());
|
||||
xtrain_cuda::cublas::sgemm(
|
||||
false,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
self.data_ptr() as *const f32,
|
||||
other.data_ptr() as *const f32,
|
||||
0.0,
|
||||
out.data_ptr() as *mut f32,
|
||||
);
|
||||
let out = Tensor::zeros(&[m, n], self.dtype, self.device());
|
||||
match self.dtype {
|
||||
// fp32 path — unchanged (bit-identical to T7/T10/T11).
|
||||
DType::F32 => xtrain_cuda::cublas::sgemm(
|
||||
false,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
self.data_ptr() as *const f32,
|
||||
other.data_ptr() as *const f32,
|
||||
0.0,
|
||||
out.data_ptr() as *mut f32,
|
||||
),
|
||||
// bf16 path — GemmEx, bf16 in/out, fp32 accumulation.
|
||||
DType::BF16 => xtrain_cuda::cublas::gemm_ex(
|
||||
false,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
self.data_ptr() as *const std::ffi::c_void,
|
||||
other.data_ptr() as *const std::ffi::c_void,
|
||||
0.0,
|
||||
out.data_ptr() as *mut std::ffi::c_void,
|
||||
),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
@@ -207,9 +282,15 @@ impl Tensor {
|
||||
/// self[i,j]`. Requires a contiguous F32 CUDA tensor.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn transpose_2d(&self) -> Self {
|
||||
assert_eq!(self.dtype, DType::F32, "transpose only supports F32");
|
||||
assert_eq!(self.ndim(), 2, "transpose_2d requires 2D tensor");
|
||||
assert!(self.is_contiguous(), "transpose requires contiguous tensor");
|
||||
if self.dtype == DType::BF16 {
|
||||
return self
|
||||
.to_dtype(DType::F32)
|
||||
.transpose_2d()
|
||||
.to_dtype(DType::BF16);
|
||||
}
|
||||
assert_eq!(self.dtype, DType::F32, "transpose supports F32/BF16");
|
||||
assert!(
|
||||
matches!(self.device(), Device::Cuda(_)),
|
||||
"transpose requires a CUDA tensor"
|
||||
@@ -245,35 +326,69 @@ impl Tensor {
|
||||
assert_eq!(dc.shape[0], a.shape[0], "dC rows != A rows (M)");
|
||||
assert_eq!(dc.shape[1], b.shape[1], "dC cols != B cols (N)");
|
||||
|
||||
assert_eq!(a.dtype, b.dtype, "matmul_backward dtype mismatch");
|
||||
assert_eq!(a.dtype, dc.dtype, "matmul_backward dtype mismatch");
|
||||
let (m, k, n) = (a.shape[0], a.shape[1], b.shape[1]);
|
||||
let dt = a.dtype;
|
||||
// dA[M,K] = dC[M,N] · Bᵀ (B stored [K,N], transposed by cuBLAS)
|
||||
let da = Tensor::zeros(&[m, k], DType::F32, a.device());
|
||||
xtrain_cuda::cublas::sgemm(
|
||||
false,
|
||||
true,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
1.0,
|
||||
dc.data_ptr() as *const f32,
|
||||
b.data_ptr() as *const f32,
|
||||
0.0,
|
||||
da.data_ptr() as *mut f32,
|
||||
);
|
||||
let da = Tensor::zeros(&[m, k], dt, a.device());
|
||||
// dB[K,N] = Aᵀ · dC[M,N] (A stored [M,K], transposed by cuBLAS)
|
||||
let db = Tensor::zeros(&[k, n], DType::F32, a.device());
|
||||
xtrain_cuda::cublas::sgemm(
|
||||
true,
|
||||
false,
|
||||
k,
|
||||
n,
|
||||
m,
|
||||
1.0,
|
||||
a.data_ptr() as *const f32,
|
||||
dc.data_ptr() as *const f32,
|
||||
0.0,
|
||||
db.data_ptr() as *mut f32,
|
||||
);
|
||||
let db = Tensor::zeros(&[k, n], dt, a.device());
|
||||
match dt {
|
||||
DType::F32 => {
|
||||
xtrain_cuda::cublas::sgemm(
|
||||
false,
|
||||
true,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
1.0,
|
||||
dc.data_ptr() as *const f32,
|
||||
b.data_ptr() as *const f32,
|
||||
0.0,
|
||||
da.data_ptr() as *mut f32,
|
||||
);
|
||||
xtrain_cuda::cublas::sgemm(
|
||||
true,
|
||||
false,
|
||||
k,
|
||||
n,
|
||||
m,
|
||||
1.0,
|
||||
a.data_ptr() as *const f32,
|
||||
dc.data_ptr() as *const f32,
|
||||
0.0,
|
||||
db.data_ptr() as *mut f32,
|
||||
);
|
||||
}
|
||||
DType::BF16 => {
|
||||
xtrain_cuda::cublas::gemm_ex(
|
||||
false,
|
||||
true,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
1.0,
|
||||
dc.data_ptr() as *const std::ffi::c_void,
|
||||
b.data_ptr() as *const std::ffi::c_void,
|
||||
0.0,
|
||||
da.data_ptr() as *mut std::ffi::c_void,
|
||||
);
|
||||
xtrain_cuda::cublas::gemm_ex(
|
||||
true,
|
||||
false,
|
||||
k,
|
||||
n,
|
||||
m,
|
||||
1.0,
|
||||
a.data_ptr() as *const std::ffi::c_void,
|
||||
dc.data_ptr() as *const std::ffi::c_void,
|
||||
0.0,
|
||||
db.data_ptr() as *mut std::ffi::c_void,
|
||||
);
|
||||
}
|
||||
_ => panic!("matmul_backward supports F32/BF16"),
|
||||
}
|
||||
(da, db)
|
||||
}
|
||||
|
||||
@@ -287,15 +402,28 @@ impl Tensor {
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn add(&self, other: &Tensor) -> Self {
|
||||
self.check_binary(other, "add");
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_add_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
other.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
self.numel() as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
|
||||
let n = self.numel() as i32;
|
||||
match self.dtype {
|
||||
DType::F32 => unsafe {
|
||||
xtrain_cuda::ffi::launch_add_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
other.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
DType::BF16 => unsafe {
|
||||
xtrain_cuda::ffi::launch_add_bf16(
|
||||
self.data_ptr() as *const std::ffi::c_void,
|
||||
other.data_ptr() as *const std::ffi::c_void,
|
||||
out.data_ptr() as *mut std::ffi::c_void,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
out
|
||||
}
|
||||
@@ -304,15 +432,28 @@ impl Tensor {
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn mul(&self, other: &Tensor) -> Self {
|
||||
self.check_binary(other, "mul");
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_mul_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
other.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
self.numel() as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
|
||||
let n = self.numel() as i32;
|
||||
match self.dtype {
|
||||
DType::F32 => unsafe {
|
||||
xtrain_cuda::ffi::launch_mul_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
other.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
DType::BF16 => unsafe {
|
||||
xtrain_cuda::ffi::launch_mul_bf16(
|
||||
self.data_ptr() as *const std::ffi::c_void,
|
||||
other.data_ptr() as *const std::ffi::c_void,
|
||||
out.data_ptr() as *mut std::ffi::c_void,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
out
|
||||
}
|
||||
@@ -324,17 +465,31 @@ impl Tensor {
|
||||
assert_eq!(self.ndim(), 2, "add_bias requires 2D input");
|
||||
assert_eq!(bias.ndim(), 1, "bias must be 1D");
|
||||
assert_eq!(self.shape[1], bias.shape[0], "bias len != cols");
|
||||
assert_eq!(self.dtype, bias.dtype, "add_bias dtype mismatch");
|
||||
let (rows, cols) = (self.shape[0], self.shape[1]);
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_add_bias_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
bias.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
|
||||
match self.dtype {
|
||||
DType::F32 => unsafe {
|
||||
xtrain_cuda::ffi::launch_add_bias_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
bias.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
DType::BF16 => unsafe {
|
||||
xtrain_cuda::ffi::launch_add_bias_bf16(
|
||||
self.data_ptr() as *const std::ffi::c_void,
|
||||
bias.data_ptr() as *const std::ffi::c_void,
|
||||
out.data_ptr() as *mut std::ffi::c_void,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
_ => panic!("add_bias supports F32/BF16"),
|
||||
}
|
||||
out
|
||||
}
|
||||
@@ -346,15 +501,27 @@ impl Tensor {
|
||||
pub fn sum_rows(&self) -> Self {
|
||||
assert_eq!(self.ndim(), 2, "sum_rows requires 2D input");
|
||||
let (rows, cols) = (self.shape[0], self.shape[1]);
|
||||
let out = Tensor::zeros(&[cols], DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_sum_rows_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
let out = Tensor::zeros(&[cols], self.dtype, self.device());
|
||||
match self.dtype {
|
||||
DType::F32 => unsafe {
|
||||
xtrain_cuda::ffi::launch_sum_rows_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
DType::BF16 => unsafe {
|
||||
xtrain_cuda::ffi::launch_sum_rows_bf16(
|
||||
self.data_ptr() as *const std::ffi::c_void,
|
||||
out.data_ptr() as *mut std::ffi::c_void,
|
||||
rows as i32,
|
||||
cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
_ => panic!("sum_rows supports F32/BF16"),
|
||||
}
|
||||
out
|
||||
}
|
||||
@@ -367,6 +534,14 @@ impl Tensor {
|
||||
assert_eq!(self.ndim(), 2, "rms_norm requires 2D input");
|
||||
assert_eq!(gamma.ndim(), 1, "gamma must be 1D");
|
||||
assert_eq!(self.shape[1], gamma.shape[0], "gamma len != cols");
|
||||
// bf16: compute the reduction in fp32 (standard AMP), downcast y back to
|
||||
// bf16. inv_rms stays fp32 (the cache the fp32 backward kernel consumes).
|
||||
if self.dtype == DType::BF16 {
|
||||
let (y, inv_rms) = self
|
||||
.to_dtype(DType::F32)
|
||||
.rms_norm(&gamma.to_dtype(DType::F32), eps);
|
||||
return (y.to_dtype(DType::BF16), inv_rms);
|
||||
}
|
||||
let (rows, cols) = (self.shape[0], self.shape[1]);
|
||||
let y = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
let inv_rms = Tensor::zeros(&[rows], DType::F32, self.device());
|
||||
@@ -394,6 +569,17 @@ impl Tensor {
|
||||
dy: &Tensor,
|
||||
inv_rms: &Tensor,
|
||||
) -> (Tensor, Tensor) {
|
||||
// bf16: upcast (x, gamma, dy) to fp32, run the fp32 backward, downcast the
|
||||
// grads back to bf16 (inv_rms is already the fp32 cache).
|
||||
if x.dtype == DType::BF16 {
|
||||
let (dx, dgamma) = Tensor::rms_norm_backward(
|
||||
&x.to_dtype(DType::F32),
|
||||
&gamma.to_dtype(DType::F32),
|
||||
&dy.to_dtype(DType::F32),
|
||||
inv_rms,
|
||||
);
|
||||
return (dx.to_dtype(DType::BF16), dgamma.to_dtype(DType::BF16));
|
||||
}
|
||||
let (rows, cols) = (x.shape[0], x.shape[1]);
|
||||
let dx = Tensor::zeros(&[rows, cols], DType::F32, x.device());
|
||||
let dgamma = Tensor::zeros(&[cols], DType::F32, x.device());
|
||||
@@ -424,15 +610,30 @@ impl Tensor {
|
||||
/// SiLU forward: `y = x * sigmoid(x)`, elementwise.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn silu(&self) -> Self {
|
||||
assert_eq!(self.dtype, DType::F32, "silu only supports F32");
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_silu_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
self.numel() as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
assert!(
|
||||
matches!(self.dtype, DType::F32 | DType::BF16),
|
||||
"silu supports F32/BF16"
|
||||
);
|
||||
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
|
||||
let n = self.numel() as i32;
|
||||
match self.dtype {
|
||||
DType::F32 => unsafe {
|
||||
xtrain_cuda::ffi::launch_silu_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
DType::BF16 => unsafe {
|
||||
xtrain_cuda::ffi::launch_silu_bf16(
|
||||
self.data_ptr() as *const std::ffi::c_void,
|
||||
out.data_ptr() as *mut std::ffi::c_void,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
out
|
||||
}
|
||||
@@ -441,15 +642,28 @@ impl Tensor {
|
||||
/// Inputs are the forward `x` and upstream `dy`.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn silu_backward(x: &Tensor, dy: &Tensor) -> Self {
|
||||
let dx = Tensor::zeros(&x.shape, DType::F32, x.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_silu_dx_f32(
|
||||
x.data_ptr() as *const f32,
|
||||
dy.data_ptr() as *const f32,
|
||||
dx.data_ptr() as *mut f32,
|
||||
x.numel() as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
let dx = Tensor::zeros(&x.shape, x.dtype, x.device());
|
||||
let n = x.numel() as i32;
|
||||
match x.dtype {
|
||||
DType::F32 => unsafe {
|
||||
xtrain_cuda::ffi::launch_silu_dx_f32(
|
||||
x.data_ptr() as *const f32,
|
||||
dy.data_ptr() as *const f32,
|
||||
dx.data_ptr() as *mut f32,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
DType::BF16 => unsafe {
|
||||
xtrain_cuda::ffi::launch_silu_dx_bf16(
|
||||
x.data_ptr() as *const std::ffi::c_void,
|
||||
dy.data_ptr() as *const std::ffi::c_void,
|
||||
dx.data_ptr() as *mut std::ffi::c_void,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
_ => panic!("silu_backward supports F32/BF16"),
|
||||
}
|
||||
dx
|
||||
}
|
||||
@@ -468,6 +682,12 @@ impl Tensor {
|
||||
period > 0 && tokens % period == 0,
|
||||
"tokens must be a multiple of period"
|
||||
);
|
||||
if self.dtype == DType::BF16 {
|
||||
return self
|
||||
.to_dtype(DType::F32)
|
||||
.rope(theta, period)
|
||||
.to_dtype(DType::BF16);
|
||||
}
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_rope_f32(
|
||||
@@ -488,6 +708,10 @@ impl Tensor {
|
||||
/// orthogonal map, so it needs no cached forward values, only `theta`/`period`.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn rope_backward(dy: &Tensor, theta: f32, period: usize) -> Self {
|
||||
if dy.dtype == DType::BF16 {
|
||||
return Tensor::rope_backward(&dy.to_dtype(DType::F32), theta, period)
|
||||
.to_dtype(DType::BF16);
|
||||
}
|
||||
let (tokens, heads, head_dim) = (dy.shape[0], dy.shape[1], dy.shape[2]);
|
||||
let dx = Tensor::zeros(&dy.shape, DType::F32, dy.device());
|
||||
unsafe {
|
||||
@@ -509,6 +733,9 @@ impl Tensor {
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn softmax(&self) -> Self {
|
||||
assert_eq!(self.ndim(), 2, "softmax requires 2D input");
|
||||
if self.dtype == DType::BF16 {
|
||||
return self.to_dtype(DType::F32).softmax().to_dtype(DType::BF16);
|
||||
}
|
||||
let (rows, cols) = (self.shape[0], self.shape[1]);
|
||||
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
unsafe {
|
||||
@@ -527,6 +754,10 @@ impl Tensor {
|
||||
/// Inputs are the forward output `y` and upstream `dy`.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn softmax_backward(y: &Tensor, dy: &Tensor) -> Self {
|
||||
if y.dtype == DType::BF16 {
|
||||
return Tensor::softmax_backward(&y.to_dtype(DType::F32), &dy.to_dtype(DType::F32))
|
||||
.to_dtype(DType::BF16);
|
||||
}
|
||||
let (rows, cols) = (y.shape[0], y.shape[1]);
|
||||
let dx = Tensor::zeros(&y.shape, DType::F32, y.device());
|
||||
unsafe {
|
||||
@@ -550,6 +781,11 @@ impl Tensor {
|
||||
assert_eq!(self.ndim(), 2, "cross_entropy requires 2D logits");
|
||||
assert_eq!(target.dtype, DType::I32, "target must be I32");
|
||||
assert_eq!(target.numel(), self.shape[0], "one target per row");
|
||||
// CE math (log-sum-exp) is fp32 (probs/loss cached fp32). The model casts
|
||||
// logits→fp32 before CE; this guard keeps the op robust to bf16 logits.
|
||||
if self.dtype == DType::BF16 {
|
||||
return self.to_dtype(DType::F32).cross_entropy(target);
|
||||
}
|
||||
let (rows, cols) = (self.shape[0], self.shape[1]);
|
||||
let probs = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
let loss = Tensor::zeros(&[rows], DType::F32, self.device());
|
||||
@@ -658,9 +894,15 @@ impl Tensor {
|
||||
/// own backward is the same op (swap a,b).
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn transpose_3d01(&self) -> Self {
|
||||
assert_eq!(self.dtype, DType::F32, "transpose_3d01 only supports F32");
|
||||
assert_eq!(self.ndim(), 3, "transpose_3d01 requires a 3D tensor");
|
||||
assert!(self.is_contiguous(), "transpose_3d01 requires contiguous");
|
||||
if self.dtype == DType::BF16 {
|
||||
return self
|
||||
.to_dtype(DType::F32)
|
||||
.transpose_3d01()
|
||||
.to_dtype(DType::BF16);
|
||||
}
|
||||
assert_eq!(self.dtype, DType::F32, "transpose_3d01 supports F32/BF16");
|
||||
let (a, b, c) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||
let out = Tensor::zeros(&[b, a, c], DType::F32, self.device());
|
||||
unsafe {
|
||||
@@ -689,54 +931,59 @@ impl Tensor {
|
||||
assert_eq!(self.ndim(), 3, "attention Q must be [bh,seq,head_dim]");
|
||||
assert_eq!(self.shape(), k.shape(), "Q/K shape mismatch");
|
||||
assert_eq!(self.shape(), v.shape(), "Q/V shape mismatch");
|
||||
assert_eq!(self.dtype, k.dtype, "Q/K dtype mismatch");
|
||||
assert_eq!(self.dtype, v.dtype, "Q/V dtype mismatch");
|
||||
let (bh, seq, hd) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||
let dev = self.device();
|
||||
let dt = self.dtype;
|
||||
|
||||
// scores[bh,seq,seq] = Q[bh,seq,hd] · Kᵀ[bh,hd,seq]
|
||||
let scores = Tensor::zeros(&[bh, seq, seq], DType::F32, dev);
|
||||
xtrain_cuda::cublas::sgemm_strided_batched(
|
||||
// scores[bh,seq,seq] = Q[bh,seq,hd] · Kᵀ[bh,hd,seq] (GEMM in self dtype)
|
||||
let scores = Tensor::zeros(&[bh, seq, seq], dt, dev);
|
||||
strided_batched_gemm(
|
||||
dt,
|
||||
false,
|
||||
true,
|
||||
seq,
|
||||
seq,
|
||||
hd,
|
||||
1.0,
|
||||
self.data_ptr() as *const f32,
|
||||
self.data_ptr(),
|
||||
seq * hd,
|
||||
k.data_ptr() as *const f32,
|
||||
k.data_ptr(),
|
||||
seq * hd,
|
||||
0.0,
|
||||
scores.data_ptr() as *mut f32,
|
||||
scores.data_ptr(),
|
||||
seq * seq,
|
||||
bh,
|
||||
);
|
||||
// probs = softmax(causal(scores · scale)), one block per [bh·seq] row.
|
||||
let probs = Tensor::zeros(&[bh, seq, seq], DType::F32, dev);
|
||||
// probs = softmax(causal(scores · scale)). Softmax math is fp32 (stable);
|
||||
// for bf16 we upcast scores → f32 → kernel → downcast probs back to bf16
|
||||
// (so the cached probs activation is half-size). One block per [bh·seq] row.
|
||||
let scores_f32 = scores.to_dtype(DType::F32);
|
||||
let probs_f32 = Tensor::zeros(&[bh, seq, seq], DType::F32, dev);
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_softmax_causal_f32(
|
||||
scores.data_ptr() as *const f32,
|
||||
probs.data_ptr() as *mut f32,
|
||||
scores_f32.data_ptr() as *const f32,
|
||||
probs_f32.data_ptr() as *mut f32,
|
||||
(bh * seq) as i32,
|
||||
seq as i32,
|
||||
scale,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
let probs = probs_f32.to_dtype(dt);
|
||||
// out[bh,seq,hd] = probs[bh,seq,seq] · V[bh,seq,hd]
|
||||
let out = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
xtrain_cuda::cublas::sgemm_strided_batched(
|
||||
let out = Tensor::zeros(&[bh, seq, hd], dt, dev);
|
||||
strided_batched_gemm(
|
||||
dt,
|
||||
false,
|
||||
false,
|
||||
seq,
|
||||
hd,
|
||||
seq,
|
||||
1.0,
|
||||
probs.data_ptr() as *const f32,
|
||||
probs.data_ptr(),
|
||||
seq * seq,
|
||||
v.data_ptr() as *const f32,
|
||||
v.data_ptr(),
|
||||
seq * hd,
|
||||
0.0,
|
||||
out.data_ptr() as *mut f32,
|
||||
out.data_ptr(),
|
||||
seq * hd,
|
||||
bh,
|
||||
);
|
||||
@@ -764,45 +1011,44 @@ impl Tensor {
|
||||
) -> (Tensor, Tensor, Tensor) {
|
||||
let (bh, seq, hd) = (q.shape[0], q.shape[1], q.shape[2]);
|
||||
let dev = q.device();
|
||||
let dt = q.dtype;
|
||||
|
||||
// dP[bh,seq,seq] = dOut[bh,seq,hd] · Vᵀ[bh,hd,seq]
|
||||
let dp = Tensor::zeros(&[bh, seq, seq], DType::F32, dev);
|
||||
xtrain_cuda::cublas::sgemm_strided_batched(
|
||||
let dp = Tensor::zeros(&[bh, seq, seq], dt, dev);
|
||||
strided_batched_gemm(
|
||||
dt,
|
||||
false,
|
||||
true,
|
||||
seq,
|
||||
seq,
|
||||
hd,
|
||||
1.0,
|
||||
dout.data_ptr() as *const f32,
|
||||
dout.data_ptr(),
|
||||
seq * hd,
|
||||
v.data_ptr() as *const f32,
|
||||
v.data_ptr(),
|
||||
seq * hd,
|
||||
0.0,
|
||||
dp.data_ptr() as *mut f32,
|
||||
dp.data_ptr(),
|
||||
seq * seq,
|
||||
bh,
|
||||
);
|
||||
// dV[bh,seq,hd] = Pᵀ[bh,seq,seq] · dOut[bh,seq,hd]
|
||||
let dv = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
xtrain_cuda::cublas::sgemm_strided_batched(
|
||||
let dv = Tensor::zeros(&[bh, seq, hd], dt, dev);
|
||||
strided_batched_gemm(
|
||||
dt,
|
||||
true,
|
||||
false,
|
||||
seq,
|
||||
hd,
|
||||
seq,
|
||||
1.0,
|
||||
probs.data_ptr() as *const f32,
|
||||
probs.data_ptr(),
|
||||
seq * seq,
|
||||
dout.data_ptr() as *const f32,
|
||||
dout.data_ptr(),
|
||||
seq * hd,
|
||||
0.0,
|
||||
dv.data_ptr() as *mut f32,
|
||||
dv.data_ptr(),
|
||||
seq * hd,
|
||||
bh,
|
||||
);
|
||||
// dScores = softmax Jacobian (per row) applied to dP, then ×scale.
|
||||
// Reuse the row-wise softmax backward over the flattened [bh·seq, seq].
|
||||
// softmax_backward + scale are dtype-aware (fp32 math inside for bf16).
|
||||
let dscores = Tensor::softmax_backward(
|
||||
&probs.reshape(&[bh * seq, seq]),
|
||||
&dp.reshape(&[bh * seq, seq]),
|
||||
@@ -810,38 +1056,36 @@ impl Tensor {
|
||||
.reshape(&[bh, seq, seq]);
|
||||
let dscores = dscores.scale(scale);
|
||||
// dQ[bh,seq,hd] = dScores[bh,seq,seq] · K[bh,seq,hd]
|
||||
let dq = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
xtrain_cuda::cublas::sgemm_strided_batched(
|
||||
let dq = Tensor::zeros(&[bh, seq, hd], dt, dev);
|
||||
strided_batched_gemm(
|
||||
dt,
|
||||
false,
|
||||
false,
|
||||
seq,
|
||||
hd,
|
||||
seq,
|
||||
1.0,
|
||||
dscores.data_ptr() as *const f32,
|
||||
dscores.data_ptr(),
|
||||
seq * seq,
|
||||
k.data_ptr() as *const f32,
|
||||
k.data_ptr(),
|
||||
seq * hd,
|
||||
0.0,
|
||||
dq.data_ptr() as *mut f32,
|
||||
dq.data_ptr(),
|
||||
seq * hd,
|
||||
bh,
|
||||
);
|
||||
// dK[bh,seq,hd] = dScoresᵀ[bh,seq,seq] · Q[bh,seq,hd]
|
||||
let dk = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
xtrain_cuda::cublas::sgemm_strided_batched(
|
||||
let dk = Tensor::zeros(&[bh, seq, hd], dt, dev);
|
||||
strided_batched_gemm(
|
||||
dt,
|
||||
true,
|
||||
false,
|
||||
seq,
|
||||
hd,
|
||||
seq,
|
||||
1.0,
|
||||
dscores.data_ptr() as *const f32,
|
||||
dscores.data_ptr(),
|
||||
seq * seq,
|
||||
q.data_ptr() as *const f32,
|
||||
q.data_ptr(),
|
||||
seq * hd,
|
||||
0.0,
|
||||
dk.data_ptr() as *mut f32,
|
||||
dk.data_ptr(),
|
||||
seq * hd,
|
||||
bh,
|
||||
);
|
||||
@@ -853,9 +1097,15 @@ impl Tensor {
|
||||
/// (`[B,S,nh,hd] <-> [B,nh,S,hd]`). Its own backward is the same op (swap b,c).
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn transpose_4d12(&self) -> Self {
|
||||
assert_eq!(self.dtype, DType::F32, "transpose_4d12 only supports F32");
|
||||
assert_eq!(self.ndim(), 4, "transpose_4d12 requires a 4D tensor");
|
||||
assert!(self.is_contiguous(), "transpose_4d12 requires contiguous");
|
||||
if self.dtype == DType::BF16 {
|
||||
return self
|
||||
.to_dtype(DType::F32)
|
||||
.transpose_4d12()
|
||||
.to_dtype(DType::BF16);
|
||||
}
|
||||
assert_eq!(self.dtype, DType::F32, "transpose_4d12 supports F32/BF16");
|
||||
let (a, b, c, d) = (self.shape[0], self.shape[1], self.shape[2], self.shape[3]);
|
||||
let out = Tensor::zeros(&[a, c, b, d], DType::F32, self.device());
|
||||
unsafe {
|
||||
@@ -875,8 +1125,11 @@ impl Tensor {
|
||||
// Shared validation for same-shape binary elementwise ops.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn check_binary(&self, other: &Tensor, op: &str) {
|
||||
assert_eq!(self.dtype, DType::F32, "{op} only supports F32");
|
||||
assert_eq!(other.dtype, DType::F32, "{op} only supports F32");
|
||||
assert!(
|
||||
matches!(self.dtype, DType::F32 | DType::BF16),
|
||||
"{op} supports F32/BF16"
|
||||
);
|
||||
assert_eq!(self.dtype, other.dtype, "{op} dtype mismatch");
|
||||
assert_eq!(self.shape(), other.shape(), "{op} shape mismatch");
|
||||
assert_eq!(self.device(), other.device(), "{op} device mismatch");
|
||||
assert!(
|
||||
@@ -886,6 +1139,64 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Dispatch a strided-batched GEMM on `dt`: fp32 → `sgemm_strided_batched`,
|
||||
/// bf16 → `gemm_ex_strided_batched` (bf16 in/out, fp32 accum). Pointers are the
|
||||
/// raw `data_ptr()` bytes of contiguous same-dtype tensors. `alpha=1, beta=0`.
|
||||
/// The fp32 path is bit-identical to the inlined T10 call it replaces.
|
||||
#[cfg(not(no_cuda))]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn strided_batched_gemm(
|
||||
dt: DType,
|
||||
trans_a: bool,
|
||||
trans_b: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
a: *const u8,
|
||||
stride_a: usize,
|
||||
b: *const u8,
|
||||
stride_b: usize,
|
||||
c: *const u8,
|
||||
stride_c: usize,
|
||||
batch: usize,
|
||||
) {
|
||||
match dt {
|
||||
DType::F32 => xtrain_cuda::cublas::sgemm_strided_batched(
|
||||
trans_a,
|
||||
trans_b,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
a as *const f32,
|
||||
stride_a,
|
||||
b as *const f32,
|
||||
stride_b,
|
||||
0.0,
|
||||
c as *mut f32,
|
||||
stride_c,
|
||||
batch,
|
||||
),
|
||||
DType::BF16 => xtrain_cuda::cublas::gemm_ex_strided_batched(
|
||||
trans_a,
|
||||
trans_b,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
a as *const std::ffi::c_void,
|
||||
stride_a,
|
||||
b as *const std::ffi::c_void,
|
||||
stride_b,
|
||||
0.0,
|
||||
c as *mut std::ffi::c_void,
|
||||
stride_c,
|
||||
batch,
|
||||
),
|
||||
_ => panic!("strided_batched_gemm supports F32/BF16"),
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Tensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
|
||||
Reference in New Issue
Block a user