The rollout long-pole fix deferred from M2a: decode the G samples of one prompt in lockstep (one forward per step over the group → G× fewer kernel launches). - rope_pos(x, positions[]): RoPE with a per-row absolute position (new forward- only kernel) — G rows share one decode position. Gate: == full rope for [0..n], == rope_at(P) per row for uniform P (bit-identical). - generate_cached_batch: BatchKVCache [T, G·num_kv, hd] + batched decode_step. decode_attention is already batch-agnostic (bh = G·nh); repeat_kv(nh, batch=G) broadcasts per group. No finished-mask / ragged prompts yet (perf-only / next). - Gate (tests/decode_batch.rs): all G greedy rows token-identical to the single- sequence decode (8 query / 2 kv heads → exercises repeat_kv batching). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
613 lines
18 KiB
Rust
613 lines
18 KiB
Rust
use std::ffi::c_void;
|
|
use std::os::raw::c_char;
|
|
|
|
pub type CudaStream = *mut c_void;
|
|
|
|
pub const CUDA_MEMCPY_H2D: i32 = 1;
|
|
pub const CUDA_MEMCPY_D2H: i32 = 2;
|
|
|
|
pub const CUDA_SUCCESS: i32 = 0;
|
|
pub const CUDA_ERROR_OUT_OF_MEMORY: i32 = 2;
|
|
|
|
unsafe extern "C" {
|
|
// --- Device ---
|
|
pub fn cudaGetDeviceCount(count: *mut i32) -> i32;
|
|
pub fn cudaSetDevice(device: i32) -> i32;
|
|
pub fn cudaGetDevice(device: *mut i32) -> i32;
|
|
pub fn cudaDeviceSynchronize() -> i32;
|
|
|
|
// --- Memory ---
|
|
pub fn cudaMalloc(devptr: *mut *mut u8, size: usize) -> i32;
|
|
pub fn cudaFree(devptr: *mut u8) -> i32;
|
|
pub fn cudaMemcpy(dst: *mut u8, src: *const u8, count: usize, kind: i32) -> i32;
|
|
pub fn cudaMemset(devptr: *mut u8, value: i32, count: usize) -> i32;
|
|
|
|
// --- Error ---
|
|
pub fn cudaGetErrorString(error: i32) -> *const c_char;
|
|
}
|
|
|
|
// GPU kernels compiled from csrc/ by build.rs. Only linked when CUDA is
|
|
// actually compiled (i.e. nvcc was present).
|
|
#[cfg(not(no_cuda))]
|
|
unsafe extern "C" {
|
|
// Vector-add smoke test (csrc/test/vecadd.cu).
|
|
pub fn launch_vecadd_f32(a: *const f32, b: *const f32, c: *mut f32, n: i32, stream: CudaStream);
|
|
|
|
// Elementwise scale: out[i] = in[i] * alpha (csrc/ops/elementwise.cu).
|
|
pub fn launch_scale_f32(
|
|
input: *const f32,
|
|
out: *mut f32,
|
|
alpha: f32,
|
|
n: i32,
|
|
stream: CudaStream,
|
|
);
|
|
|
|
// Tiled GEMM: C = A @ B, row-major F32. A:[M,K] B:[K,N] C:[M,N]
|
|
// (csrc/ops/gemm.cu).
|
|
pub fn launch_gemm_tiled_f32(
|
|
a: *const f32,
|
|
b: *const f32,
|
|
c: *mut f32,
|
|
m: i32,
|
|
n: i32,
|
|
k: i32,
|
|
stream: CudaStream,
|
|
);
|
|
|
|
// Out-of-place 2D transpose: out[j,i] = in[i,j]. in:[rows,cols] row-major,
|
|
// out:[cols,rows] row-major (csrc/ops/gemm.cu).
|
|
pub fn launch_transpose_f32(
|
|
input: *const f32,
|
|
out: *mut f32,
|
|
rows: i32,
|
|
cols: i32,
|
|
stream: CudaStream,
|
|
);
|
|
}
|
|
|
|
// Transformer / autograd op kernels (csrc/ops/nn.cu). Forward + backward for the
|
|
// ops the Phase T4 tape engine needs. All F32, row-major, contiguous.
|
|
#[cfg(not(no_cuda))]
|
|
unsafe extern "C" {
|
|
// Elementwise: out = a + b ; out = a * b.
|
|
pub fn launch_add_f32(a: *const f32, b: *const f32, out: *mut f32, n: i32, s: CudaStream);
|
|
pub fn launch_mul_f32(a: *const f32, b: *const f32, out: *mut f32, n: i32, s: CudaStream);
|
|
|
|
// Broadcast bias add: out[r,c] = x[r,c] + bias[c]. x:[rows,cols], bias:[cols].
|
|
pub fn launch_add_bias_f32(
|
|
x: *const f32,
|
|
bias: *const f32,
|
|
out: *mut f32,
|
|
rows: i32,
|
|
cols: i32,
|
|
s: CudaStream,
|
|
);
|
|
// Column-sum (over rows): dbias[c] = sum_r dout[r,c]. Bias backward.
|
|
pub fn launch_sum_rows_f32(
|
|
dout: *const f32,
|
|
dbias: *mut f32,
|
|
rows: i32,
|
|
cols: i32,
|
|
s: CudaStream,
|
|
);
|
|
|
|
// RMSNorm forward: writes y[rows,cols] and inv_rms[rows] (cached for bwd).
|
|
pub fn launch_rms_norm_f32(
|
|
x: *const f32,
|
|
gamma: *const f32,
|
|
y: *mut f32,
|
|
inv_rms: *mut f32,
|
|
rows: i32,
|
|
cols: i32,
|
|
eps: f32,
|
|
s: CudaStream,
|
|
);
|
|
pub fn launch_rms_norm_dx_f32(
|
|
x: *const f32,
|
|
gamma: *const f32,
|
|
dy: *const f32,
|
|
inv_rms: *const f32,
|
|
dx: *mut f32,
|
|
rows: i32,
|
|
cols: i32,
|
|
s: CudaStream,
|
|
);
|
|
pub fn launch_rms_norm_dgamma_f32(
|
|
x: *const f32,
|
|
dy: *const f32,
|
|
inv_rms: *const f32,
|
|
dgamma: *mut f32,
|
|
rows: i32,
|
|
cols: i32,
|
|
s: CudaStream,
|
|
);
|
|
|
|
// SiLU: y = x*sigmoid(x); backward dx.
|
|
pub fn launch_silu_f32(x: *const f32, y: *mut f32, n: i32, s: CudaStream);
|
|
pub fn launch_silu_dx_f32(x: *const f32, dy: *const f32, dx: *mut f32, n: i32, s: CudaStream);
|
|
|
|
// RoPE (rotate_half), x:[tokens,heads,head_dim], position = (token index %
|
|
// period). `period` = sequence length, so a flattened batch of sequences gets
|
|
// per-sequence positions; period == tokens reproduces the single-sequence case.
|
|
pub fn launch_rope_f32(
|
|
x: *const f32,
|
|
y: *mut f32,
|
|
tokens: i32,
|
|
heads: i32,
|
|
head_dim: i32,
|
|
theta: f32,
|
|
period: i32,
|
|
s: CudaStream,
|
|
);
|
|
// RoPE at an absolute position offset (KV-cache decode, forward only): row
|
|
// `tok`'s position is `pos0 + tok` (no modulo). For a single decode token
|
|
// (tokens == 1) the one row sits at absolute position `pos0`.
|
|
pub fn launch_rope_at_f32(
|
|
x: *const f32,
|
|
y: *mut f32,
|
|
tokens: i32,
|
|
heads: i32,
|
|
head_dim: i32,
|
|
theta: f32,
|
|
pos0: i32,
|
|
s: CudaStream,
|
|
);
|
|
// RoPE with a per-row absolute position (batched KV-cache decode, M2b): row
|
|
// `tok`'s position is `positions[tok]`. Forward only.
|
|
pub fn launch_rope_pos_f32(
|
|
x: *const f32,
|
|
positions: *const i32,
|
|
y: *mut f32,
|
|
tokens: i32,
|
|
heads: i32,
|
|
head_dim: i32,
|
|
theta: f32,
|
|
s: CudaStream,
|
|
);
|
|
// Per-row scale: y[r,c] = x[r,c] * s[r] (GRPO policy-gradient backward).
|
|
pub fn launch_scale_rows_f32(
|
|
x: *const f32,
|
|
s: *const f32,
|
|
y: *mut f32,
|
|
rows: i32,
|
|
cols: i32,
|
|
stream: CudaStream,
|
|
);
|
|
pub fn launch_rope_dx_f32(
|
|
dy: *const f32,
|
|
dx: *mut f32,
|
|
tokens: i32,
|
|
heads: i32,
|
|
head_dim: i32,
|
|
theta: f32,
|
|
period: i32,
|
|
s: CudaStream,
|
|
);
|
|
|
|
// Row-wise softmax + Jacobian backward.
|
|
pub fn launch_softmax_f32(x: *const f32, y: *mut f32, rows: i32, cols: i32, s: CudaStream);
|
|
pub fn launch_softmax_dx_f32(
|
|
y: *const f32,
|
|
dy: *const f32,
|
|
dx: *mut f32,
|
|
rows: i32,
|
|
cols: i32,
|
|
s: CudaStream,
|
|
);
|
|
|
|
// Cross-entropy: fwd writes probs[rows,cols] + per-row loss[rows];
|
|
// bwd dx = scale*(probs - onehot).
|
|
pub fn launch_cross_entropy_fwd_f32(
|
|
x: *const f32,
|
|
target: *const i32,
|
|
probs: *mut f32,
|
|
loss: *mut f32,
|
|
rows: i32,
|
|
cols: i32,
|
|
s: CudaStream,
|
|
);
|
|
pub fn launch_cross_entropy_dx_f32(
|
|
probs: *const f32,
|
|
target: *const i32,
|
|
dx: *mut f32,
|
|
rows: i32,
|
|
cols: i32,
|
|
scale: f32,
|
|
s: CudaStream,
|
|
);
|
|
}
|
|
|
|
// Structural ops for the tiny transformer (csrc/ops/model.cu): token embedding
|
|
// (gather fwd / scatter-add bwd) and a 3D axis-(0,1) transpose for the multi-head
|
|
// attention layout. F32 values, I32 ids, row-major contiguous.
|
|
#[cfg(not(no_cuda))]
|
|
unsafe extern "C" {
|
|
// Embedding: out[s,:] = table[ids[s], :]. table:[vocab,dim], ids:[seq] (I32).
|
|
pub fn launch_embedding_fwd_f32(
|
|
table: *const f32,
|
|
ids: *const i32,
|
|
out: *mut f32,
|
|
seq: i32,
|
|
dim: i32,
|
|
s: CudaStream,
|
|
);
|
|
// Scatter-add: dtable[ids[s],:] += dout[s,:] (dtable pre-zeroed; atomic).
|
|
pub fn launch_embedding_bwd_f32(
|
|
dout: *const f32,
|
|
ids: *const i32,
|
|
dtable: *mut f32,
|
|
seq: i32,
|
|
dim: i32,
|
|
s: CudaStream,
|
|
);
|
|
|
|
// 3D axis-(0,1) transpose: in:[a,b,c] -> out:[b,a,c]. out[j,i,k]=in[i,j,k].
|
|
pub fn launch_transpose_3d01_f32(
|
|
input: *const f32,
|
|
out: *mut f32,
|
|
a: i32,
|
|
b: i32,
|
|
c: i32,
|
|
s: CudaStream,
|
|
);
|
|
// 4D axis-(1,2) transpose: in:[a,b,c,d] -> out:[a,c,b,d]. out[i,k,j,l]=in[i,j,k,l].
|
|
pub fn launch_transpose_4d12_f32(
|
|
input: *const f32,
|
|
out: *mut f32,
|
|
a: i32,
|
|
b: i32,
|
|
c: i32,
|
|
d: i32,
|
|
s: CudaStream,
|
|
);
|
|
}
|
|
|
|
// Batched attention helper (csrc/ops/attention.cu): causal row-wise softmax over
|
|
// score rows [rows, seq] with query position = (row % seq); scales logits by
|
|
// `scale` (= 1/sqrt(head_dim)) and masks future columns to probability 0.
|
|
#[cfg(not(no_cuda))]
|
|
unsafe extern "C" {
|
|
pub fn launch_softmax_causal_f32(
|
|
x: *const f32,
|
|
y: *mut f32,
|
|
rows: i32,
|
|
seq: i32,
|
|
scale: f32,
|
|
s: CudaStream,
|
|
);
|
|
}
|
|
|
|
// Fused flash-attention (csrc/ops/flash_attention.cu, Phase T14). A SINGLE kernel
|
|
// each for forward/backward that streams over KV tiles with an online softmax and
|
|
// NEVER materializes the [bh,S,S] score matrix. Q/K/V/out are [bh,S,hd] row-major
|
|
// F32; the forward saves only the per-row logsumexp `l` ([bh*S], O(N)) for backward.
|
|
#[cfg(not(no_cuda))]
|
|
unsafe extern "C" {
|
|
// Forward: o[bh,S,hd] = softmax(causal(Q·Kᵀ·scale))·V, online over KV tiles.
|
|
// Also writes l[bh*S] = per-row logsumexp (saved for backward, not the scores).
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn launch_flash_attention_fwd_f32(
|
|
q: *const f32,
|
|
k: *const f32,
|
|
v: *const f32,
|
|
o: *mut f32,
|
|
l: *mut f32,
|
|
bh: i32,
|
|
seq: i32,
|
|
hd: i32,
|
|
scale: f32,
|
|
s: CudaStream,
|
|
);
|
|
// Per-row D[i]=Σ_d dO[i,d]·O[i,d] over `rows`=bh*S rows of width `hd`. Must run
|
|
// before the backward kernel (which takes the precomputed D, not O).
|
|
pub fn launch_flash_attention_rowdot_f32(
|
|
d_o: *const f32,
|
|
o: *const f32,
|
|
d_d: *mut f32,
|
|
rows: i32,
|
|
hd: i32,
|
|
s: CudaStream,
|
|
);
|
|
// Backward: recomputes scores from Q/K/V + saved logsumexp `l` (NO cached probs)
|
|
// and the precomputed `d_d` (= D), produces dq/dk/dv. dq/dk/dv must be PRE-ZEROED
|
|
// (dk/dv are accumulated across query rows via atomicAdd).
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn launch_flash_attention_bwd_f32(
|
|
q: *const f32,
|
|
k: *const f32,
|
|
v: *const f32,
|
|
d_o: *const f32,
|
|
l: *const f32,
|
|
d_d: *mut f32,
|
|
dq: *mut f32,
|
|
dk: *mut f32,
|
|
dv: *mut f32,
|
|
bh: i32,
|
|
seq: i32,
|
|
hd: i32,
|
|
scale: f32,
|
|
s: CudaStream,
|
|
);
|
|
}
|
|
|
|
// GQA repeat_kv head broadcast (csrc/ops/repeat_kv.cu, Phase T15). Expands a K/V
|
|
// tensor from [batch·num_kv, S, hd] to the full [batch·nh, S, hd] so the SDPA
|
|
// (composed or flash, both untouched) sees a full set of heads. Forward gathers
|
|
// (out head qh ← kv head qh/group, group = nh/num_kv); backward sums the `group`
|
|
// query heads sharing each kv head (deterministic, no atomics). All F32.
|
|
#[cfg(not(no_cuda))]
|
|
unsafe extern "C" {
|
|
// Forward: out[b·nh+qh] = in[b·num_kv + qh/group], per [S,hd] head block.
|
|
pub fn launch_repeat_kv_fwd_f32(
|
|
input: *const f32,
|
|
out: *mut f32,
|
|
batch: i32,
|
|
nh: i32,
|
|
num_kv: i32,
|
|
seq: i32,
|
|
hd: i32,
|
|
s: CudaStream,
|
|
);
|
|
// Backward: din[b·num_kv+kvh] = Σ_{r<group} dout[b·nh + kvh·group + r].
|
|
pub fn launch_repeat_kv_bwd_f32(
|
|
dout: *const f32,
|
|
din: *mut f32,
|
|
batch: i32,
|
|
nh: i32,
|
|
num_kv: i32,
|
|
seq: i32,
|
|
hd: i32,
|
|
s: CudaStream,
|
|
);
|
|
}
|
|
|
|
// GPU-side optimizer kernels (csrc/ops/optim.cu): AdamW step (m/v on device) and
|
|
// the global grad-norm reduction + in-place rescale (Phase T7).
|
|
#[cfg(not(no_cuda))]
|
|
unsafe extern "C" {
|
|
// One in-place AdamW step over a parameter tensor of `n` elements. `bc1`/`bc2`
|
|
// are the bias-correction denominators 1-beta^t.
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn launch_adamw_step_f32(
|
|
p: *mut f32,
|
|
g: *const f32,
|
|
m: *mut f32,
|
|
v: *mut f32,
|
|
lr: f32,
|
|
b1: f32,
|
|
b2: f32,
|
|
eps: f32,
|
|
wd: f32,
|
|
bc1: f32,
|
|
bc2: f32,
|
|
n: i32,
|
|
s: CudaStream,
|
|
);
|
|
// acc += sum_i g[i]^2 (acc is one f32 on device, pre-zeroed). atomicAdd.
|
|
pub fn launch_sumsq_accum_f32(g: *const f32, acc: *mut f32, n: i32, s: CudaStream);
|
|
// In-place scalar scale: x[i] *= factor.
|
|
pub fn launch_scale_inplace_f32(x: *mut f32, factor: f32, n: i32, s: CudaStream);
|
|
}
|
|
|
|
// cuBLAS — the production GEMM backend (Phase T7) and the correctness oracle the
|
|
// T3 GEMM tests still compare against. Declared (and linked, see build.rs) only
|
|
// when CUDA is compiled in.
|
|
#[cfg(not(no_cuda))]
|
|
pub type CublasHandle = *mut c_void;
|
|
|
|
#[cfg(not(no_cuda))]
|
|
unsafe extern "C" {
|
|
pub fn cublasCreate_v2(handle: *mut CublasHandle) -> i32;
|
|
pub fn cublasDestroy_v2(handle: CublasHandle) -> i32;
|
|
pub fn cublasSgemm_v2(
|
|
handle: CublasHandle,
|
|
transa: i32,
|
|
transb: i32,
|
|
m: i32,
|
|
n: i32,
|
|
k: i32,
|
|
alpha: *const f32,
|
|
a: *const f32,
|
|
lda: i32,
|
|
b: *const f32,
|
|
ldb: i32,
|
|
beta: *const f32,
|
|
c: *mut f32,
|
|
ldc: i32,
|
|
) -> i32;
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn cublasSgemmStridedBatched(
|
|
handle: CublasHandle,
|
|
transa: i32,
|
|
transb: i32,
|
|
m: i32,
|
|
n: i32,
|
|
k: i32,
|
|
alpha: *const f32,
|
|
a: *const f32,
|
|
lda: i32,
|
|
stride_a: i64,
|
|
b: *const f32,
|
|
ldb: i32,
|
|
stride_b: i64,
|
|
beta: *const f32,
|
|
c: *mut f32,
|
|
ldc: i32,
|
|
stride_c: i64,
|
|
batch_count: i32,
|
|
) -> i32;
|
|
}
|
|
|
|
#[cfg(not(no_cuda))]
|
|
pub const CUBLAS_OP_N: i32 = 0;
|
|
#[cfg(not(no_cuda))]
|
|
pub const CUBLAS_OP_T: i32 = 1;
|
|
|
|
// --- bf16 mixed precision (Phase T12) ---
|
|
//
|
|
// cudaDataType / cublasComputeType enum values (same as xserv's gemm.rs). The
|
|
// bf16 GEMM uses bf16 in/out with fp32 accumulation (CUBLAS_COMPUTE_32F).
|
|
#[cfg(not(no_cuda))]
|
|
pub const CUDA_R_32F: i32 = 0;
|
|
#[cfg(not(no_cuda))]
|
|
pub const CUDA_R_16BF: i32 = 14;
|
|
#[cfg(not(no_cuda))]
|
|
pub const CUBLAS_COMPUTE_32F: i32 = 68;
|
|
/// CUBLAS_GEMM_DEFAULT — let cuBLAS pick the algorithm.
|
|
#[cfg(not(no_cuda))]
|
|
pub const CUBLAS_GEMM_DEFAULT: i32 = -1;
|
|
|
|
#[cfg(not(no_cuda))]
|
|
unsafe extern "C" {
|
|
// General GEMM with explicit in/out + compute types (bf16 path). `alpha`/
|
|
// `beta` are fp32 host scalars (compute type is fp32). Pointers are void* so
|
|
// the same FFI serves bf16 / fp32.
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn cublasGemmEx(
|
|
handle: CublasHandle,
|
|
transa: i32,
|
|
transb: i32,
|
|
m: i32,
|
|
n: i32,
|
|
k: i32,
|
|
alpha: *const std::ffi::c_void,
|
|
a: *const std::ffi::c_void,
|
|
a_type: i32,
|
|
lda: i32,
|
|
b: *const std::ffi::c_void,
|
|
b_type: i32,
|
|
ldb: i32,
|
|
beta: *const std::ffi::c_void,
|
|
c: *mut std::ffi::c_void,
|
|
c_type: i32,
|
|
ldc: i32,
|
|
compute_type: i32,
|
|
algo: i32,
|
|
) -> i32;
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn cublasGemmStridedBatchedEx(
|
|
handle: CublasHandle,
|
|
transa: i32,
|
|
transb: i32,
|
|
m: i32,
|
|
n: i32,
|
|
k: i32,
|
|
alpha: *const std::ffi::c_void,
|
|
a: *const std::ffi::c_void,
|
|
a_type: i32,
|
|
lda: i32,
|
|
stride_a: i64,
|
|
b: *const std::ffi::c_void,
|
|
b_type: i32,
|
|
ldb: i32,
|
|
stride_b: i64,
|
|
beta: *const std::ffi::c_void,
|
|
c: *mut std::ffi::c_void,
|
|
c_type: i32,
|
|
ldc: i32,
|
|
stride_c: i64,
|
|
batch_count: i32,
|
|
compute_type: i32,
|
|
algo: i32,
|
|
) -> i32;
|
|
}
|
|
|
|
// bf16 cast + elementwise kernels (csrc/ops/cast.cu). Pointers are void* (bf16
|
|
// buffers); f32 sides are typed. The activation stream flows bf16; the math
|
|
// accumulates in fp32 inside each kernel.
|
|
#[cfg(not(no_cuda))]
|
|
unsafe extern "C" {
|
|
pub fn launch_cast_f32_to_bf16(input: *const f32, out: *mut c_void, n: i32, s: CudaStream);
|
|
pub fn launch_cast_bf16_to_f32(input: *const c_void, out: *mut f32, n: i32, s: CudaStream);
|
|
|
|
pub fn launch_add_bf16(
|
|
a: *const c_void,
|
|
b: *const c_void,
|
|
out: *mut c_void,
|
|
n: i32,
|
|
s: CudaStream,
|
|
);
|
|
pub fn launch_mul_bf16(
|
|
a: *const c_void,
|
|
b: *const c_void,
|
|
out: *mut c_void,
|
|
n: i32,
|
|
s: CudaStream,
|
|
);
|
|
pub fn launch_scale_bf16(
|
|
input: *const c_void,
|
|
out: *mut c_void,
|
|
alpha: f32,
|
|
n: i32,
|
|
s: CudaStream,
|
|
);
|
|
pub fn launch_silu_bf16(x: *const c_void, y: *mut c_void, n: i32, s: CudaStream);
|
|
pub fn launch_silu_dx_bf16(
|
|
x: *const c_void,
|
|
dy: *const c_void,
|
|
dx: *mut c_void,
|
|
n: i32,
|
|
s: CudaStream,
|
|
);
|
|
pub fn launch_add_bias_bf16(
|
|
x: *const c_void,
|
|
bias: *const c_void,
|
|
out: *mut c_void,
|
|
rows: i32,
|
|
cols: i32,
|
|
s: CudaStream,
|
|
);
|
|
pub fn launch_sum_rows_bf16(
|
|
dout: *const c_void,
|
|
dbias: *mut c_void,
|
|
rows: i32,
|
|
cols: i32,
|
|
s: CudaStream,
|
|
);
|
|
}
|
|
|
|
// Dropout (Phase T18, csrc/ops/dropout.cu). A counter-based (stateless) RNG: the
|
|
// keep/drop decision for element `i` is `hash(seed, i)` — no global state, so a
|
|
// re-run with the same `seed` reproduces the same mask (compatible with T13
|
|
// activation recomputation). Forward writes `out = x ⊙ mask` and the fp32 `mask`
|
|
// buffer (mask[i] = (1/(1-p)) if kept else 0, the inverted-dropout scale);
|
|
// backward applies the SAME mask: dx = d ⊙ mask. fp32 + bf16 activation variants
|
|
// (mask is fp32 in both; the uniform is computed in fp32, dtype-independent).
|
|
#[cfg(not(no_cuda))]
|
|
unsafe extern "C" {
|
|
pub fn launch_dropout_fwd_f32(
|
|
x: *const f32,
|
|
out: *mut f32,
|
|
mask: *mut f32,
|
|
p: f32,
|
|
scale: f32,
|
|
seed: u64,
|
|
n: i32,
|
|
s: CudaStream,
|
|
);
|
|
pub fn launch_dropout_bwd_f32(
|
|
d: *const f32,
|
|
mask: *const f32,
|
|
dx: *mut f32,
|
|
n: i32,
|
|
s: CudaStream,
|
|
);
|
|
pub fn launch_dropout_fwd_bf16(
|
|
x: *const c_void,
|
|
out: *mut c_void,
|
|
mask: *mut f32,
|
|
p: f32,
|
|
scale: f32,
|
|
seed: u64,
|
|
n: i32,
|
|
s: CudaStream,
|
|
);
|
|
pub fn launch_dropout_bwd_bf16(
|
|
d: *const c_void,
|
|
mask: *const f32,
|
|
dx: *mut c_void,
|
|
n: i32,
|
|
s: CudaStream,
|
|
);
|
|
}
|