dropout: device RNG kernel + Tensor fwd/bwd (T18)
csrc/ops/dropout.cu: counter-based RNG (splitmix64 over seed^index) → fp32 uniform → Bernoulli(keep=1-p); fwd writes out=x⊙mask + an fp32 mask buffer (per-element 1/(1-p) or 0); bwd applies the same mask (dx=d⊙mask). fp32 + bf16 activation variants (mask fp32 in both; uniform is dtype-independent so masks match across precisions). Stateless → re-run with same seed = same mask (T13 recompute-safe). Registered in build.rs + FFI decls. Tensor::dropout(p,seed)->(out,mask) and Tensor::dropout_backward(d,mask) wrap the launches (contiguous F32/BF16, default stream, per-op sync via the kernels). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
109
csrc/ops/dropout.cu
Normal file
109
csrc/ops/dropout.cu
Normal file
@@ -0,0 +1,109 @@
|
||||
// Dropout kernels (Phase T18).
|
||||
//
|
||||
// A counter-based (stateless) RNG: the keep/drop decision for element `i` is a
|
||||
// pure function of (seed, i) — no global RNG state is advanced. This is what
|
||||
// makes dropout compatible with activation recomputation (T13): when a
|
||||
// checkpointed block re-runs its forward in backward, the SAME seed regenerates
|
||||
// the SAME mask, so the recomputed activations / grads stay bit-identical to the
|
||||
// forward (no mask drift).
|
||||
//
|
||||
// Inverted dropout: at training time kept elements are scaled by 1/(1-p) so the
|
||||
// expectation E[out] == x (no inference-time rescale needed; eval is identity,
|
||||
// handled in Rust by simply not calling dropout).
|
||||
//
|
||||
// key = seed ^ (i * GOLDEN)
|
||||
// h = splitmix64(key) // a few rounds of xorshift/multiply
|
||||
// u = (h >> 40) / 2^24 in [0,1) // 24-bit uniform
|
||||
// keep = u >= p // Bernoulli(keep = 1-p)
|
||||
// out = keep ? x * scale : 0 // scale = 1/(1-p)
|
||||
// mask = keep ? scale : 0 // cached for backward (dx = d * mask)
|
||||
//
|
||||
// fp32 + bf16 variants: bf16 loads/stores half-size activations but the uniform
|
||||
// `u` is always computed in fp32, so the mask distribution is identical across
|
||||
// dtypes (drop decisions don't depend on bf16 rounding). The mask buffer is fp32
|
||||
// in both cases (it stores `scale` or 0 — exactly representable, tiny relative to
|
||||
// the activation, reused only elementwise in backward).
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <stdint.h>
|
||||
|
||||
extern "C" {
|
||||
|
||||
// splitmix64: cheap, well-mixed counter hash. Maps a 64-bit counter to a 64-bit
|
||||
// pseudo-random output; we only need the high bits for a uniform.
|
||||
__device__ __forceinline__ uint64_t splitmix64(uint64_t x) {
|
||||
x += 0x9E3779B97F4A7C15ULL;
|
||||
x = (x ^ (x >> 30)) * 0xBF58476D1CE4E5B9ULL;
|
||||
x = (x ^ (x >> 27)) * 0x94D049BB133111EBULL;
|
||||
return x ^ (x >> 31);
|
||||
}
|
||||
|
||||
// Uniform [0,1) for element i under `seed`, computed in fp32 (dtype-independent).
|
||||
__device__ __forceinline__ float dropout_uniform(uint64_t seed, int i) {
|
||||
uint64_t key = seed ^ ((uint64_t)i * 0x9E3779B97F4A7C15ULL);
|
||||
uint64_t h = splitmix64(key);
|
||||
// Top 24 bits → [0,1) with 2^-24 resolution.
|
||||
return (float)(h >> 40) * (1.0f / 16777216.0f); // 1/2^24
|
||||
}
|
||||
|
||||
// fp32 forward: out[i] = keep ? x[i]*scale : 0 ; mask[i] = keep ? scale : 0.
|
||||
__global__ void dropout_fwd_f32_k(const float* x, float* out, float* mask,
|
||||
float p, float scale, uint64_t seed, int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) {
|
||||
float keep = (dropout_uniform(seed, i) >= p) ? scale : 0.0f;
|
||||
mask[i] = keep;
|
||||
out[i] = x[i] * keep;
|
||||
}
|
||||
}
|
||||
void launch_dropout_fwd_f32(const float* x, float* out, float* mask, float p,
|
||||
float scale, uint64_t seed, int n, void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
dropout_fwd_f32_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, out, mask, p, scale,
|
||||
seed, n);
|
||||
}
|
||||
|
||||
// Backward applies the SAME cached mask elementwise: dx[i] = d[i] * mask[i].
|
||||
__global__ void dropout_bwd_f32_k(const float* d, const float* mask, float* dx,
|
||||
int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) dx[i] = d[i] * mask[i];
|
||||
}
|
||||
void launch_dropout_bwd_f32(const float* d, const float* mask, float* dx, int n,
|
||||
void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
dropout_bwd_f32_k<<<grid, blk, 0, (cudaStream_t)s>>>(d, mask, dx, n);
|
||||
}
|
||||
|
||||
// bf16 forward: activation is bf16; mask is fp32 (stores `scale` or 0). Uniform
|
||||
// is fp32, so the mask matches the fp32 path bit-for-bit (same drop decisions).
|
||||
__global__ void dropout_fwd_bf16_k(const __nv_bfloat16* x, __nv_bfloat16* out,
|
||||
float* mask, float p, float scale,
|
||||
uint64_t seed, int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) {
|
||||
float keep = (dropout_uniform(seed, i) >= p) ? scale : 0.0f;
|
||||
mask[i] = keep;
|
||||
out[i] = __float2bfloat16(__bfloat162float(x[i]) * keep);
|
||||
}
|
||||
}
|
||||
void launch_dropout_fwd_bf16(const void* x, void* out, float* mask, float p,
|
||||
float scale, uint64_t seed, int n, void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
dropout_fwd_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, mask, p, scale, seed, n);
|
||||
}
|
||||
|
||||
__global__ void dropout_bwd_bf16_k(const __nv_bfloat16* d, const float* mask,
|
||||
__nv_bfloat16* dx, int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) dx[i] = __float2bfloat16(__bfloat162float(d[i]) * mask[i]);
|
||||
}
|
||||
void launch_dropout_bwd_bf16(const void* d, const float* mask, void* dx, int n,
|
||||
void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
dropout_bwd_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
||||
(const __nv_bfloat16*)d, mask, (__nv_bfloat16*)dx, n);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
Reference in New Issue
Block a user