Files
xtrain/crates/xtrain-cuda/src/ffi.rs
Gahow Wang 2c9b58cb3b post-train: M2b — batched KV-cache decode (G-way, token-identical)
The rollout long-pole fix deferred from M2a: decode the G samples of one prompt
in lockstep (one forward per step over the group → G× fewer kernel launches).

- rope_pos(x, positions[]): RoPE with a per-row absolute position (new forward-
  only kernel) — G rows share one decode position. Gate: == full rope for
  [0..n], == rope_at(P) per row for uniform P (bit-identical).
- generate_cached_batch: BatchKVCache [T, G·num_kv, hd] + batched decode_step.
  decode_attention is already batch-agnostic (bh = G·nh); repeat_kv(nh, batch=G)
  broadcasts per group. No finished-mask / ragged prompts yet (perf-only / next).
- Gate (tests/decode_batch.rs): all G greedy rows token-identical to the single-
  sequence decode (8 query / 2 kv heads → exercises repeat_kv batching).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 17:18:54 +08:00

613 lines
18 KiB
Rust

use std::ffi::c_void;
use std::os::raw::c_char;
pub type CudaStream = *mut c_void;
pub const CUDA_MEMCPY_H2D: i32 = 1;
pub const CUDA_MEMCPY_D2H: i32 = 2;
pub const CUDA_SUCCESS: i32 = 0;
pub const CUDA_ERROR_OUT_OF_MEMORY: i32 = 2;
unsafe extern "C" {
// --- Device ---
pub fn cudaGetDeviceCount(count: *mut i32) -> i32;
pub fn cudaSetDevice(device: i32) -> i32;
pub fn cudaGetDevice(device: *mut i32) -> i32;
pub fn cudaDeviceSynchronize() -> i32;
// --- Memory ---
pub fn cudaMalloc(devptr: *mut *mut u8, size: usize) -> i32;
pub fn cudaFree(devptr: *mut u8) -> i32;
pub fn cudaMemcpy(dst: *mut u8, src: *const u8, count: usize, kind: i32) -> i32;
pub fn cudaMemset(devptr: *mut u8, value: i32, count: usize) -> i32;
// --- Error ---
pub fn cudaGetErrorString(error: i32) -> *const c_char;
}
// GPU kernels compiled from csrc/ by build.rs. Only linked when CUDA is
// actually compiled (i.e. nvcc was present).
#[cfg(not(no_cuda))]
unsafe extern "C" {
// Vector-add smoke test (csrc/test/vecadd.cu).
pub fn launch_vecadd_f32(a: *const f32, b: *const f32, c: *mut f32, n: i32, stream: CudaStream);
// Elementwise scale: out[i] = in[i] * alpha (csrc/ops/elementwise.cu).
pub fn launch_scale_f32(
input: *const f32,
out: *mut f32,
alpha: f32,
n: i32,
stream: CudaStream,
);
// Tiled GEMM: C = A @ B, row-major F32. A:[M,K] B:[K,N] C:[M,N]
// (csrc/ops/gemm.cu).
pub fn launch_gemm_tiled_f32(
a: *const f32,
b: *const f32,
c: *mut f32,
m: i32,
n: i32,
k: i32,
stream: CudaStream,
);
// Out-of-place 2D transpose: out[j,i] = in[i,j]. in:[rows,cols] row-major,
// out:[cols,rows] row-major (csrc/ops/gemm.cu).
pub fn launch_transpose_f32(
input: *const f32,
out: *mut f32,
rows: i32,
cols: i32,
stream: CudaStream,
);
}
// Transformer / autograd op kernels (csrc/ops/nn.cu). Forward + backward for the
// ops the Phase T4 tape engine needs. All F32, row-major, contiguous.
#[cfg(not(no_cuda))]
unsafe extern "C" {
// Elementwise: out = a + b ; out = a * b.
pub fn launch_add_f32(a: *const f32, b: *const f32, out: *mut f32, n: i32, s: CudaStream);
pub fn launch_mul_f32(a: *const f32, b: *const f32, out: *mut f32, n: i32, s: CudaStream);
// Broadcast bias add: out[r,c] = x[r,c] + bias[c]. x:[rows,cols], bias:[cols].
pub fn launch_add_bias_f32(
x: *const f32,
bias: *const f32,
out: *mut f32,
rows: i32,
cols: i32,
s: CudaStream,
);
// Column-sum (over rows): dbias[c] = sum_r dout[r,c]. Bias backward.
pub fn launch_sum_rows_f32(
dout: *const f32,
dbias: *mut f32,
rows: i32,
cols: i32,
s: CudaStream,
);
// RMSNorm forward: writes y[rows,cols] and inv_rms[rows] (cached for bwd).
pub fn launch_rms_norm_f32(
x: *const f32,
gamma: *const f32,
y: *mut f32,
inv_rms: *mut f32,
rows: i32,
cols: i32,
eps: f32,
s: CudaStream,
);
pub fn launch_rms_norm_dx_f32(
x: *const f32,
gamma: *const f32,
dy: *const f32,
inv_rms: *const f32,
dx: *mut f32,
rows: i32,
cols: i32,
s: CudaStream,
);
pub fn launch_rms_norm_dgamma_f32(
x: *const f32,
dy: *const f32,
inv_rms: *const f32,
dgamma: *mut f32,
rows: i32,
cols: i32,
s: CudaStream,
);
// SiLU: y = x*sigmoid(x); backward dx.
pub fn launch_silu_f32(x: *const f32, y: *mut f32, n: i32, s: CudaStream);
pub fn launch_silu_dx_f32(x: *const f32, dy: *const f32, dx: *mut f32, n: i32, s: CudaStream);
// RoPE (rotate_half), x:[tokens,heads,head_dim], position = (token index %
// period). `period` = sequence length, so a flattened batch of sequences gets
// per-sequence positions; period == tokens reproduces the single-sequence case.
pub fn launch_rope_f32(
x: *const f32,
y: *mut f32,
tokens: i32,
heads: i32,
head_dim: i32,
theta: f32,
period: i32,
s: CudaStream,
);
// RoPE at an absolute position offset (KV-cache decode, forward only): row
// `tok`'s position is `pos0 + tok` (no modulo). For a single decode token
// (tokens == 1) the one row sits at absolute position `pos0`.
pub fn launch_rope_at_f32(
x: *const f32,
y: *mut f32,
tokens: i32,
heads: i32,
head_dim: i32,
theta: f32,
pos0: i32,
s: CudaStream,
);
// RoPE with a per-row absolute position (batched KV-cache decode, M2b): row
// `tok`'s position is `positions[tok]`. Forward only.
pub fn launch_rope_pos_f32(
x: *const f32,
positions: *const i32,
y: *mut f32,
tokens: i32,
heads: i32,
head_dim: i32,
theta: f32,
s: CudaStream,
);
// Per-row scale: y[r,c] = x[r,c] * s[r] (GRPO policy-gradient backward).
pub fn launch_scale_rows_f32(
x: *const f32,
s: *const f32,
y: *mut f32,
rows: i32,
cols: i32,
stream: CudaStream,
);
pub fn launch_rope_dx_f32(
dy: *const f32,
dx: *mut f32,
tokens: i32,
heads: i32,
head_dim: i32,
theta: f32,
period: i32,
s: CudaStream,
);
// Row-wise softmax + Jacobian backward.
pub fn launch_softmax_f32(x: *const f32, y: *mut f32, rows: i32, cols: i32, s: CudaStream);
pub fn launch_softmax_dx_f32(
y: *const f32,
dy: *const f32,
dx: *mut f32,
rows: i32,
cols: i32,
s: CudaStream,
);
// Cross-entropy: fwd writes probs[rows,cols] + per-row loss[rows];
// bwd dx = scale*(probs - onehot).
pub fn launch_cross_entropy_fwd_f32(
x: *const f32,
target: *const i32,
probs: *mut f32,
loss: *mut f32,
rows: i32,
cols: i32,
s: CudaStream,
);
pub fn launch_cross_entropy_dx_f32(
probs: *const f32,
target: *const i32,
dx: *mut f32,
rows: i32,
cols: i32,
scale: f32,
s: CudaStream,
);
}
// Structural ops for the tiny transformer (csrc/ops/model.cu): token embedding
// (gather fwd / scatter-add bwd) and a 3D axis-(0,1) transpose for the multi-head
// attention layout. F32 values, I32 ids, row-major contiguous.
#[cfg(not(no_cuda))]
unsafe extern "C" {
// Embedding: out[s,:] = table[ids[s], :]. table:[vocab,dim], ids:[seq] (I32).
pub fn launch_embedding_fwd_f32(
table: *const f32,
ids: *const i32,
out: *mut f32,
seq: i32,
dim: i32,
s: CudaStream,
);
// Scatter-add: dtable[ids[s],:] += dout[s,:] (dtable pre-zeroed; atomic).
pub fn launch_embedding_bwd_f32(
dout: *const f32,
ids: *const i32,
dtable: *mut f32,
seq: i32,
dim: i32,
s: CudaStream,
);
// 3D axis-(0,1) transpose: in:[a,b,c] -> out:[b,a,c]. out[j,i,k]=in[i,j,k].
pub fn launch_transpose_3d01_f32(
input: *const f32,
out: *mut f32,
a: i32,
b: i32,
c: i32,
s: CudaStream,
);
// 4D axis-(1,2) transpose: in:[a,b,c,d] -> out:[a,c,b,d]. out[i,k,j,l]=in[i,j,k,l].
pub fn launch_transpose_4d12_f32(
input: *const f32,
out: *mut f32,
a: i32,
b: i32,
c: i32,
d: i32,
s: CudaStream,
);
}
// Batched attention helper (csrc/ops/attention.cu): causal row-wise softmax over
// score rows [rows, seq] with query position = (row % seq); scales logits by
// `scale` (= 1/sqrt(head_dim)) and masks future columns to probability 0.
#[cfg(not(no_cuda))]
unsafe extern "C" {
pub fn launch_softmax_causal_f32(
x: *const f32,
y: *mut f32,
rows: i32,
seq: i32,
scale: f32,
s: CudaStream,
);
}
// Fused flash-attention (csrc/ops/flash_attention.cu, Phase T14). A SINGLE kernel
// each for forward/backward that streams over KV tiles with an online softmax and
// NEVER materializes the [bh,S,S] score matrix. Q/K/V/out are [bh,S,hd] row-major
// F32; the forward saves only the per-row logsumexp `l` ([bh*S], O(N)) for backward.
#[cfg(not(no_cuda))]
unsafe extern "C" {
// Forward: o[bh,S,hd] = softmax(causal(Q·Kᵀ·scale))·V, online over KV tiles.
// Also writes l[bh*S] = per-row logsumexp (saved for backward, not the scores).
#[allow(clippy::too_many_arguments)]
pub fn launch_flash_attention_fwd_f32(
q: *const f32,
k: *const f32,
v: *const f32,
o: *mut f32,
l: *mut f32,
bh: i32,
seq: i32,
hd: i32,
scale: f32,
s: CudaStream,
);
// Per-row D[i]=Σ_d dO[i,d]·O[i,d] over `rows`=bh*S rows of width `hd`. Must run
// before the backward kernel (which takes the precomputed D, not O).
pub fn launch_flash_attention_rowdot_f32(
d_o: *const f32,
o: *const f32,
d_d: *mut f32,
rows: i32,
hd: i32,
s: CudaStream,
);
// Backward: recomputes scores from Q/K/V + saved logsumexp `l` (NO cached probs)
// and the precomputed `d_d` (= D), produces dq/dk/dv. dq/dk/dv must be PRE-ZEROED
// (dk/dv are accumulated across query rows via atomicAdd).
#[allow(clippy::too_many_arguments)]
pub fn launch_flash_attention_bwd_f32(
q: *const f32,
k: *const f32,
v: *const f32,
d_o: *const f32,
l: *const f32,
d_d: *mut f32,
dq: *mut f32,
dk: *mut f32,
dv: *mut f32,
bh: i32,
seq: i32,
hd: i32,
scale: f32,
s: CudaStream,
);
}
// GQA repeat_kv head broadcast (csrc/ops/repeat_kv.cu, Phase T15). Expands a K/V
// tensor from [batch·num_kv, S, hd] to the full [batch·nh, S, hd] so the SDPA
// (composed or flash, both untouched) sees a full set of heads. Forward gathers
// (out head qh ← kv head qh/group, group = nh/num_kv); backward sums the `group`
// query heads sharing each kv head (deterministic, no atomics). All F32.
#[cfg(not(no_cuda))]
unsafe extern "C" {
// Forward: out[b·nh+qh] = in[b·num_kv + qh/group], per [S,hd] head block.
pub fn launch_repeat_kv_fwd_f32(
input: *const f32,
out: *mut f32,
batch: i32,
nh: i32,
num_kv: i32,
seq: i32,
hd: i32,
s: CudaStream,
);
// Backward: din[b·num_kv+kvh] = Σ_{r<group} dout[b·nh + kvh·group + r].
pub fn launch_repeat_kv_bwd_f32(
dout: *const f32,
din: *mut f32,
batch: i32,
nh: i32,
num_kv: i32,
seq: i32,
hd: i32,
s: CudaStream,
);
}
// GPU-side optimizer kernels (csrc/ops/optim.cu): AdamW step (m/v on device) and
// the global grad-norm reduction + in-place rescale (Phase T7).
#[cfg(not(no_cuda))]
unsafe extern "C" {
// One in-place AdamW step over a parameter tensor of `n` elements. `bc1`/`bc2`
// are the bias-correction denominators 1-beta^t.
#[allow(clippy::too_many_arguments)]
pub fn launch_adamw_step_f32(
p: *mut f32,
g: *const f32,
m: *mut f32,
v: *mut f32,
lr: f32,
b1: f32,
b2: f32,
eps: f32,
wd: f32,
bc1: f32,
bc2: f32,
n: i32,
s: CudaStream,
);
// acc += sum_i g[i]^2 (acc is one f32 on device, pre-zeroed). atomicAdd.
pub fn launch_sumsq_accum_f32(g: *const f32, acc: *mut f32, n: i32, s: CudaStream);
// In-place scalar scale: x[i] *= factor.
pub fn launch_scale_inplace_f32(x: *mut f32, factor: f32, n: i32, s: CudaStream);
}
// cuBLAS — the production GEMM backend (Phase T7) and the correctness oracle the
// T3 GEMM tests still compare against. Declared (and linked, see build.rs) only
// when CUDA is compiled in.
#[cfg(not(no_cuda))]
pub type CublasHandle = *mut c_void;
#[cfg(not(no_cuda))]
unsafe extern "C" {
pub fn cublasCreate_v2(handle: *mut CublasHandle) -> i32;
pub fn cublasDestroy_v2(handle: CublasHandle) -> i32;
pub fn cublasSgemm_v2(
handle: CublasHandle,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const f32,
a: *const f32,
lda: i32,
b: *const f32,
ldb: i32,
beta: *const f32,
c: *mut f32,
ldc: i32,
) -> i32;
#[allow(clippy::too_many_arguments)]
pub fn cublasSgemmStridedBatched(
handle: CublasHandle,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const f32,
a: *const f32,
lda: i32,
stride_a: i64,
b: *const f32,
ldb: i32,
stride_b: i64,
beta: *const f32,
c: *mut f32,
ldc: i32,
stride_c: i64,
batch_count: i32,
) -> i32;
}
#[cfg(not(no_cuda))]
pub const CUBLAS_OP_N: i32 = 0;
#[cfg(not(no_cuda))]
pub const CUBLAS_OP_T: i32 = 1;
// --- bf16 mixed precision (Phase T12) ---
//
// cudaDataType / cublasComputeType enum values (same as xserv's gemm.rs). The
// bf16 GEMM uses bf16 in/out with fp32 accumulation (CUBLAS_COMPUTE_32F).
#[cfg(not(no_cuda))]
pub const CUDA_R_32F: i32 = 0;
#[cfg(not(no_cuda))]
pub const CUDA_R_16BF: i32 = 14;
#[cfg(not(no_cuda))]
pub const CUBLAS_COMPUTE_32F: i32 = 68;
/// CUBLAS_GEMM_DEFAULT — let cuBLAS pick the algorithm.
#[cfg(not(no_cuda))]
pub const CUBLAS_GEMM_DEFAULT: i32 = -1;
#[cfg(not(no_cuda))]
unsafe extern "C" {
// General GEMM with explicit in/out + compute types (bf16 path). `alpha`/
// `beta` are fp32 host scalars (compute type is fp32). Pointers are void* so
// the same FFI serves bf16 / fp32.
#[allow(clippy::too_many_arguments)]
pub fn cublasGemmEx(
handle: CublasHandle,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const std::ffi::c_void,
a: *const std::ffi::c_void,
a_type: i32,
lda: i32,
b: *const std::ffi::c_void,
b_type: i32,
ldb: i32,
beta: *const std::ffi::c_void,
c: *mut std::ffi::c_void,
c_type: i32,
ldc: i32,
compute_type: i32,
algo: i32,
) -> i32;
#[allow(clippy::too_many_arguments)]
pub fn cublasGemmStridedBatchedEx(
handle: CublasHandle,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const std::ffi::c_void,
a: *const std::ffi::c_void,
a_type: i32,
lda: i32,
stride_a: i64,
b: *const std::ffi::c_void,
b_type: i32,
ldb: i32,
stride_b: i64,
beta: *const std::ffi::c_void,
c: *mut std::ffi::c_void,
c_type: i32,
ldc: i32,
stride_c: i64,
batch_count: i32,
compute_type: i32,
algo: i32,
) -> i32;
}
// bf16 cast + elementwise kernels (csrc/ops/cast.cu). Pointers are void* (bf16
// buffers); f32 sides are typed. The activation stream flows bf16; the math
// accumulates in fp32 inside each kernel.
#[cfg(not(no_cuda))]
unsafe extern "C" {
pub fn launch_cast_f32_to_bf16(input: *const f32, out: *mut c_void, n: i32, s: CudaStream);
pub fn launch_cast_bf16_to_f32(input: *const c_void, out: *mut f32, n: i32, s: CudaStream);
pub fn launch_add_bf16(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
s: CudaStream,
);
pub fn launch_mul_bf16(
a: *const c_void,
b: *const c_void,
out: *mut c_void,
n: i32,
s: CudaStream,
);
pub fn launch_scale_bf16(
input: *const c_void,
out: *mut c_void,
alpha: f32,
n: i32,
s: CudaStream,
);
pub fn launch_silu_bf16(x: *const c_void, y: *mut c_void, n: i32, s: CudaStream);
pub fn launch_silu_dx_bf16(
x: *const c_void,
dy: *const c_void,
dx: *mut c_void,
n: i32,
s: CudaStream,
);
pub fn launch_add_bias_bf16(
x: *const c_void,
bias: *const c_void,
out: *mut c_void,
rows: i32,
cols: i32,
s: CudaStream,
);
pub fn launch_sum_rows_bf16(
dout: *const c_void,
dbias: *mut c_void,
rows: i32,
cols: i32,
s: CudaStream,
);
}
// Dropout (Phase T18, csrc/ops/dropout.cu). A counter-based (stateless) RNG: the
// keep/drop decision for element `i` is `hash(seed, i)` — no global state, so a
// re-run with the same `seed` reproduces the same mask (compatible with T13
// activation recomputation). Forward writes `out = x ⊙ mask` and the fp32 `mask`
// buffer (mask[i] = (1/(1-p)) if kept else 0, the inverted-dropout scale);
// backward applies the SAME mask: dx = d ⊙ mask. fp32 + bf16 activation variants
// (mask is fp32 in both; the uniform is computed in fp32, dtype-independent).
#[cfg(not(no_cuda))]
unsafe extern "C" {
pub fn launch_dropout_fwd_f32(
x: *const f32,
out: *mut f32,
mask: *mut f32,
p: f32,
scale: f32,
seed: u64,
n: i32,
s: CudaStream,
);
pub fn launch_dropout_bwd_f32(
d: *const f32,
mask: *const f32,
dx: *mut f32,
n: i32,
s: CudaStream,
);
pub fn launch_dropout_fwd_bf16(
x: *const c_void,
out: *mut c_void,
mask: *mut f32,
p: f32,
scale: f32,
seed: u64,
n: i32,
s: CudaStream,
);
pub fn launch_dropout_bwd_bf16(
d: *const c_void,
mask: *const f32,
dx: *mut c_void,
n: i32,
s: CudaStream,
);
}