dropout: device RNG kernel + Tensor fwd/bwd (T18)
csrc/ops/dropout.cu: counter-based RNG (splitmix64 over seed^index) → fp32 uniform → Bernoulli(keep=1-p); fwd writes out=x⊙mask + an fp32 mask buffer (per-element 1/(1-p) or 0); bwd applies the same mask (dx=d⊙mask). fp32 + bf16 activation variants (mask fp32 in both; uniform is dtype-independent so masks match across precisions). Stateless → re-run with same seed = same mask (T13 recompute-safe). Registered in build.rs + FFI decls. Tensor::dropout(p,seed)->(out,mask) and Tensor::dropout_backward(d,mask) wrap the launches (contiguous F32/BF16, default stream, per-op sync via the kernels). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -37,6 +37,7 @@ fn main() {
|
||||
.file("../../csrc/ops/optim.cu")
|
||||
.file("../../csrc/ops/attention.cu")
|
||||
.file("../../csrc/ops/cast.cu")
|
||||
.file("../../csrc/ops/dropout.cu")
|
||||
.compile("xtrain_cuda_kernels");
|
||||
}
|
||||
|
||||
|
||||
@@ -447,3 +447,48 @@ unsafe extern "C" {
|
||||
s: CudaStream,
|
||||
);
|
||||
}
|
||||
|
||||
// Dropout (Phase T18, csrc/ops/dropout.cu). A counter-based (stateless) RNG: the
|
||||
// keep/drop decision for element `i` is `hash(seed, i)` — no global state, so a
|
||||
// re-run with the same `seed` reproduces the same mask (compatible with T13
|
||||
// activation recomputation). Forward writes `out = x ⊙ mask` and the fp32 `mask`
|
||||
// buffer (mask[i] = (1/(1-p)) if kept else 0, the inverted-dropout scale);
|
||||
// backward applies the SAME mask: dx = d ⊙ mask. fp32 + bf16 activation variants
|
||||
// (mask is fp32 in both; the uniform is computed in fp32, dtype-independent).
|
||||
#[cfg(not(no_cuda))]
|
||||
unsafe extern "C" {
|
||||
pub fn launch_dropout_fwd_f32(
|
||||
x: *const f32,
|
||||
out: *mut f32,
|
||||
mask: *mut f32,
|
||||
p: f32,
|
||||
scale: f32,
|
||||
seed: u64,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_dropout_bwd_f32(
|
||||
d: *const f32,
|
||||
mask: *const f32,
|
||||
dx: *mut f32,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_dropout_fwd_bf16(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
mask: *mut f32,
|
||||
p: f32,
|
||||
scale: f32,
|
||||
seed: u64,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_dropout_bwd_bf16(
|
||||
d: *const c_void,
|
||||
mask: *const f32,
|
||||
dx: *mut c_void,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -668,6 +668,92 @@ impl Tensor {
|
||||
dx
|
||||
}
|
||||
|
||||
/// Dropout forward (Phase T18). Returns `(out, mask)` where, for each element
|
||||
/// `i`, a counter-based RNG draws `u = hash(seed, i) ∈ [0,1)` and keeps the
|
||||
/// element iff `u >= p`; kept elements are scaled by `1/(1-p)` (inverted
|
||||
/// dropout, so `E[out] == x`). `mask[i]` stores that per-element factor
|
||||
/// (`1/(1-p)` if kept, else `0`) for the backward to reuse — the same mask, so
|
||||
/// the op is a fixed elementwise scale w.r.t. `x` (and finite-diff-checkable).
|
||||
///
|
||||
/// The mask depends only on `(seed, i)`, NOT on `self`'s values, so a re-run
|
||||
/// with the same `seed` reproduces the same mask (T13 recompute stays exact).
|
||||
/// `mask` is always fp32 (the uniform is computed in fp32, dtype-independent);
|
||||
/// `out` matches `self`'s dtype. Requires `0 <= p < 1`.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn dropout(&self, p: f32, seed: u64) -> (Self, Self) {
|
||||
assert!(
|
||||
matches!(self.dtype, DType::F32 | DType::BF16),
|
||||
"dropout supports F32/BF16"
|
||||
);
|
||||
assert!((0.0..1.0).contains(&p), "dropout p must be in [0,1)");
|
||||
assert!(self.is_contiguous(), "dropout requires contiguous tensor");
|
||||
let scale = 1.0 / (1.0 - p);
|
||||
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
|
||||
let mask = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
let n = self.numel() as i32;
|
||||
match self.dtype {
|
||||
DType::F32 => unsafe {
|
||||
xtrain_cuda::ffi::launch_dropout_fwd_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
mask.data_ptr() as *mut f32,
|
||||
p,
|
||||
scale,
|
||||
seed,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
DType::BF16 => unsafe {
|
||||
xtrain_cuda::ffi::launch_dropout_fwd_bf16(
|
||||
self.data_ptr() as *const std::ffi::c_void,
|
||||
out.data_ptr() as *mut std::ffi::c_void,
|
||||
mask.data_ptr() as *mut f32,
|
||||
p,
|
||||
scale,
|
||||
seed,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
(out, mask)
|
||||
}
|
||||
|
||||
/// Dropout backward: `dx = d ⊙ mask` (the SAME `mask` the forward cached).
|
||||
/// `d` is the upstream grad (activation dtype); `mask` is the fp32 factor
|
||||
/// tensor from [`Self::dropout`]. Output matches `d`'s dtype.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn dropout_backward(d: &Tensor, mask: &Tensor) -> Self {
|
||||
assert_eq!(d.numel(), mask.numel(), "dropout_backward shape mismatch");
|
||||
assert_eq!(mask.dtype, DType::F32, "dropout mask must be F32");
|
||||
let dx = Tensor::zeros(&d.shape, d.dtype, d.device());
|
||||
let n = d.numel() as i32;
|
||||
match d.dtype {
|
||||
DType::F32 => unsafe {
|
||||
xtrain_cuda::ffi::launch_dropout_bwd_f32(
|
||||
d.data_ptr() as *const f32,
|
||||
mask.data_ptr() as *const f32,
|
||||
dx.data_ptr() as *mut f32,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
DType::BF16 => unsafe {
|
||||
xtrain_cuda::ffi::launch_dropout_bwd_bf16(
|
||||
d.data_ptr() as *const std::ffi::c_void,
|
||||
mask.data_ptr() as *const f32,
|
||||
dx.data_ptr() as *mut std::ffi::c_void,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
_ => panic!("dropout_backward supports F32/BF16"),
|
||||
}
|
||||
dx
|
||||
}
|
||||
|
||||
/// RoPE forward (rotate_half). `self`:[tokens,heads,head_dim]; each token's
|
||||
/// position is `row % period`. `period` = sequence length, so a flattened
|
||||
/// batch `[B*S,heads,head_dim]` gets per-sequence positions (pass `period=S`);
|
||||
|
||||
Reference in New Issue
Block a user