dropout: device RNG kernel + Tensor fwd/bwd (T18)

csrc/ops/dropout.cu: counter-based RNG (splitmix64 over seed^index) → fp32
uniform → Bernoulli(keep=1-p); fwd writes out=x⊙mask + an fp32 mask buffer
(per-element 1/(1-p) or 0); bwd applies the same mask (dx=d⊙mask). fp32 + bf16
activation variants (mask fp32 in both; uniform is dtype-independent so masks
match across precisions). Stateless → re-run with same seed = same mask (T13
recompute-safe). Registered in build.rs + FFI decls.

Tensor::dropout(p,seed)->(out,mask) and Tensor::dropout_backward(d,mask) wrap the
launches (contiguous F32/BF16, default stream, per-op sync via the kernels).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-18 00:05:18 +08:00
parent 6b8c1e4e0f
commit 1fdd0c5002
4 changed files with 241 additions and 0 deletions

View File

@@ -37,6 +37,7 @@ fn main() {
.file("../../csrc/ops/optim.cu")
.file("../../csrc/ops/attention.cu")
.file("../../csrc/ops/cast.cu")
.file("../../csrc/ops/dropout.cu")
.compile("xtrain_cuda_kernels");
}

View File

@@ -447,3 +447,48 @@ unsafe extern "C" {
s: CudaStream,
);
}
// Dropout (Phase T18, csrc/ops/dropout.cu). A counter-based (stateless) RNG: the
// keep/drop decision for element `i` is `hash(seed, i)` — no global state, so a
// re-run with the same `seed` reproduces the same mask (compatible with T13
// activation recomputation). Forward writes `out = x ⊙ mask` and the fp32 `mask`
// buffer (mask[i] = (1/(1-p)) if kept else 0, the inverted-dropout scale);
// backward applies the SAME mask: dx = d ⊙ mask. fp32 + bf16 activation variants
// (mask is fp32 in both; the uniform is computed in fp32, dtype-independent).
#[cfg(not(no_cuda))]
unsafe extern "C" {
pub fn launch_dropout_fwd_f32(
x: *const f32,
out: *mut f32,
mask: *mut f32,
p: f32,
scale: f32,
seed: u64,
n: i32,
s: CudaStream,
);
pub fn launch_dropout_bwd_f32(
d: *const f32,
mask: *const f32,
dx: *mut f32,
n: i32,
s: CudaStream,
);
pub fn launch_dropout_fwd_bf16(
x: *const c_void,
out: *mut c_void,
mask: *mut f32,
p: f32,
scale: f32,
seed: u64,
n: i32,
s: CudaStream,
);
pub fn launch_dropout_bwd_bf16(
d: *const c_void,
mask: *const f32,
dx: *mut c_void,
n: i32,
s: CudaStream,
);
}

View File

@@ -668,6 +668,92 @@ impl Tensor {
dx
}
/// Dropout forward (Phase T18). Returns `(out, mask)` where, for each element
/// `i`, a counter-based RNG draws `u = hash(seed, i) ∈ [0,1)` and keeps the
/// element iff `u >= p`; kept elements are scaled by `1/(1-p)` (inverted
/// dropout, so `E[out] == x`). `mask[i]` stores that per-element factor
/// (`1/(1-p)` if kept, else `0`) for the backward to reuse — the same mask, so
/// the op is a fixed elementwise scale w.r.t. `x` (and finite-diff-checkable).
///
/// The mask depends only on `(seed, i)`, NOT on `self`'s values, so a re-run
/// with the same `seed` reproduces the same mask (T13 recompute stays exact).
/// `mask` is always fp32 (the uniform is computed in fp32, dtype-independent);
/// `out` matches `self`'s dtype. Requires `0 <= p < 1`.
#[cfg(not(no_cuda))]
pub fn dropout(&self, p: f32, seed: u64) -> (Self, Self) {
assert!(
matches!(self.dtype, DType::F32 | DType::BF16),
"dropout supports F32/BF16"
);
assert!((0.0..1.0).contains(&p), "dropout p must be in [0,1)");
assert!(self.is_contiguous(), "dropout requires contiguous tensor");
let scale = 1.0 / (1.0 - p);
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
let mask = Tensor::zeros(&self.shape, DType::F32, self.device());
let n = self.numel() as i32;
match self.dtype {
DType::F32 => unsafe {
xtrain_cuda::ffi::launch_dropout_fwd_f32(
self.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
mask.data_ptr() as *mut f32,
p,
scale,
seed,
n,
std::ptr::null_mut(),
);
},
DType::BF16 => unsafe {
xtrain_cuda::ffi::launch_dropout_fwd_bf16(
self.data_ptr() as *const std::ffi::c_void,
out.data_ptr() as *mut std::ffi::c_void,
mask.data_ptr() as *mut f32,
p,
scale,
seed,
n,
std::ptr::null_mut(),
);
},
_ => unreachable!(),
}
(out, mask)
}
/// Dropout backward: `dx = d ⊙ mask` (the SAME `mask` the forward cached).
/// `d` is the upstream grad (activation dtype); `mask` is the fp32 factor
/// tensor from [`Self::dropout`]. Output matches `d`'s dtype.
#[cfg(not(no_cuda))]
pub fn dropout_backward(d: &Tensor, mask: &Tensor) -> Self {
assert_eq!(d.numel(), mask.numel(), "dropout_backward shape mismatch");
assert_eq!(mask.dtype, DType::F32, "dropout mask must be F32");
let dx = Tensor::zeros(&d.shape, d.dtype, d.device());
let n = d.numel() as i32;
match d.dtype {
DType::F32 => unsafe {
xtrain_cuda::ffi::launch_dropout_bwd_f32(
d.data_ptr() as *const f32,
mask.data_ptr() as *const f32,
dx.data_ptr() as *mut f32,
n,
std::ptr::null_mut(),
);
},
DType::BF16 => unsafe {
xtrain_cuda::ffi::launch_dropout_bwd_bf16(
d.data_ptr() as *const std::ffi::c_void,
mask.data_ptr() as *const f32,
dx.data_ptr() as *mut std::ffi::c_void,
n,
std::ptr::null_mut(),
);
},
_ => panic!("dropout_backward supports F32/BF16"),
}
dx
}
/// RoPE forward (rotate_half). `self`:[tokens,heads,head_dim]; each token's
/// position is `row % period`. `period` = sequence length, so a flattened
/// batch `[B*S,heads,head_dim]` gets per-sequence positions (pass `period=S`);

109
csrc/ops/dropout.cu Normal file
View File

@@ -0,0 +1,109 @@
// Dropout kernels (Phase T18).
//
// A counter-based (stateless) RNG: the keep/drop decision for element `i` is a
// pure function of (seed, i) — no global RNG state is advanced. This is what
// makes dropout compatible with activation recomputation (T13): when a
// checkpointed block re-runs its forward in backward, the SAME seed regenerates
// the SAME mask, so the recomputed activations / grads stay bit-identical to the
// forward (no mask drift).
//
// Inverted dropout: at training time kept elements are scaled by 1/(1-p) so the
// expectation E[out] == x (no inference-time rescale needed; eval is identity,
// handled in Rust by simply not calling dropout).
//
// key = seed ^ (i * GOLDEN)
// h = splitmix64(key) // a few rounds of xorshift/multiply
// u = (h >> 40) / 2^24 in [0,1) // 24-bit uniform
// keep = u >= p // Bernoulli(keep = 1-p)
// out = keep ? x * scale : 0 // scale = 1/(1-p)
// mask = keep ? scale : 0 // cached for backward (dx = d * mask)
//
// fp32 + bf16 variants: bf16 loads/stores half-size activations but the uniform
// `u` is always computed in fp32, so the mask distribution is identical across
// dtypes (drop decisions don't depend on bf16 rounding). The mask buffer is fp32
// in both cases (it stores `scale` or 0 — exactly representable, tiny relative to
// the activation, reused only elementwise in backward).
#include <cuda_bf16.h>
#include <stdint.h>
extern "C" {
// splitmix64: cheap, well-mixed counter hash. Maps a 64-bit counter to a 64-bit
// pseudo-random output; we only need the high bits for a uniform.
__device__ __forceinline__ uint64_t splitmix64(uint64_t x) {
x += 0x9E3779B97F4A7C15ULL;
x = (x ^ (x >> 30)) * 0xBF58476D1CE4E5B9ULL;
x = (x ^ (x >> 27)) * 0x94D049BB133111EBULL;
return x ^ (x >> 31);
}
// Uniform [0,1) for element i under `seed`, computed in fp32 (dtype-independent).
__device__ __forceinline__ float dropout_uniform(uint64_t seed, int i) {
uint64_t key = seed ^ ((uint64_t)i * 0x9E3779B97F4A7C15ULL);
uint64_t h = splitmix64(key);
// Top 24 bits → [0,1) with 2^-24 resolution.
return (float)(h >> 40) * (1.0f / 16777216.0f); // 1/2^24
}
// fp32 forward: out[i] = keep ? x[i]*scale : 0 ; mask[i] = keep ? scale : 0.
__global__ void dropout_fwd_f32_k(const float* x, float* out, float* mask,
float p, float scale, uint64_t seed, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
float keep = (dropout_uniform(seed, i) >= p) ? scale : 0.0f;
mask[i] = keep;
out[i] = x[i] * keep;
}
}
void launch_dropout_fwd_f32(const float* x, float* out, float* mask, float p,
float scale, uint64_t seed, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
dropout_fwd_f32_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, out, mask, p, scale,
seed, n);
}
// Backward applies the SAME cached mask elementwise: dx[i] = d[i] * mask[i].
__global__ void dropout_bwd_f32_k(const float* d, const float* mask, float* dx,
int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) dx[i] = d[i] * mask[i];
}
void launch_dropout_bwd_f32(const float* d, const float* mask, float* dx, int n,
void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
dropout_bwd_f32_k<<<grid, blk, 0, (cudaStream_t)s>>>(d, mask, dx, n);
}
// bf16 forward: activation is bf16; mask is fp32 (stores `scale` or 0). Uniform
// is fp32, so the mask matches the fp32 path bit-for-bit (same drop decisions).
__global__ void dropout_fwd_bf16_k(const __nv_bfloat16* x, __nv_bfloat16* out,
float* mask, float p, float scale,
uint64_t seed, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
float keep = (dropout_uniform(seed, i) >= p) ? scale : 0.0f;
mask[i] = keep;
out[i] = __float2bfloat16(__bfloat162float(x[i]) * keep);
}
}
void launch_dropout_fwd_bf16(const void* x, void* out, float* mask, float p,
float scale, uint64_t seed, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
dropout_fwd_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, mask, p, scale, seed, n);
}
__global__ void dropout_bwd_bf16_k(const __nv_bfloat16* d, const float* mask,
__nv_bfloat16* dx, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) dx[i] = __float2bfloat16(__bfloat162float(d[i]) * mask[i]);
}
void launch_dropout_bwd_bf16(const void* d, const float* mask, void* dx, int n,
void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
dropout_bwd_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
(const __nv_bfloat16*)d, mask, (__nv_bfloat16*)dx, n);
}
} // extern "C"