Files
xtrain/csrc/ops/nn.cu
Gahow Wang 2c9b58cb3b post-train: M2b — batched KV-cache decode (G-way, token-identical)
The rollout long-pole fix deferred from M2a: decode the G samples of one prompt
in lockstep (one forward per step over the group → G× fewer kernel launches).

- rope_pos(x, positions[]): RoPE with a per-row absolute position (new forward-
  only kernel) — G rows share one decode position. Gate: == full rope for
  [0..n], == rope_at(P) per row for uniform P (bit-identical).
- generate_cached_batch: BatchKVCache [T, G·num_kv, hd] + batched decode_step.
  decode_attention is already batch-agnostic (bh = G·nh); repeat_kv(nh, batch=G)
  broadcasts per group. No finished-mask / ragged prompts yet (perf-only / next).
- Gate (tests/decode_batch.rs): all G greedy rows token-identical to the single-
  sequence decode (8 query / 2 kv heads → exercises repeat_kv batching).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 17:18:54 +08:00

443 lines
18 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Forward + backward CUDA kernels for the transformer ops the autograd engine
// (Phase T4) needs: elementwise add/mul, broadcast bias add + its row-sum
// backward, RMSNorm, SiLU, RoPE, row-wise softmax, and cross-entropy.
//
// All F32, row-major, contiguous. Forward kernels mirror xserv
// (docs/04-transformer-kernels.md, docs/05-attention.md); the backward kernels
// are new (xserv is inference-only). Reduction helpers are inlined here so this
// file is self-contained (no shared header), matching the existing csrc/ layout.
#include <math.h>
extern "C" {
// --- Warp / block reductions (sum + max), block handles one row ---
__device__ __forceinline__ float warp_reduce_sum(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v += __shfl_down_sync(0xffffffff, v, off);
return v;
}
__device__ __forceinline__ float warp_reduce_max(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
return v;
}
__device__ __forceinline__ float block_reduce_sum(float v) {
__shared__ float shared[32];
int lane = threadIdx.x & 31;
int warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = warp_reduce_sum(v);
if (lane == 0) shared[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? shared[threadIdx.x] : 0.0f;
if (warp == 0) v = warp_reduce_sum(v);
// broadcast warp-0 lane-0 result to whole block
__shared__ float bcast;
if (threadIdx.x == 0) bcast = v;
__syncthreads();
return bcast;
}
__device__ __forceinline__ float block_reduce_max(float v) {
__shared__ float shared[32];
int lane = threadIdx.x & 31;
int warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = warp_reduce_max(v);
if (lane == 0) shared[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? shared[threadIdx.x] : -INFINITY;
if (warp == 0) v = warp_reduce_max(v);
__shared__ float bcast;
if (threadIdx.x == 0) bcast = v;
__syncthreads();
return bcast;
}
// =====================================================================
// Elementwise add / mul (same-shape)
// =====================================================================
__global__ void add_k(const float* a, const float* b, float* out, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) out[i] = a[i] + b[i];
}
void launch_add_f32(const float* a, const float* b, float* out, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
add_k<<<grid, blk, 0, (cudaStream_t)s>>>(a, b, out, n);
}
__global__ void mul_k(const float* a, const float* b, float* out, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) out[i] = a[i] * b[i];
}
void launch_mul_f32(const float* a, const float* b, float* out, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
mul_k<<<grid, blk, 0, (cudaStream_t)s>>>(a, b, out, n);
}
// =====================================================================
// Broadcast bias add: out[r,c] = x[r,c] + bias[c] (x:[rows,cols])
// Backward for bias is a column-sum (sum over rows): dbias[c] = sum_r dout[r,c].
// =====================================================================
__global__ void add_bias_k(const float* x, const float* bias, float* out,
int rows, int cols) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < rows * cols) out[i] = x[i] + bias[i % cols];
}
void launch_add_bias_f32(const float* x, const float* bias, float* out,
int rows, int cols, void* s) {
int n = rows * cols, blk = 256, grid = (n + blk - 1) / blk;
add_bias_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, bias, out, rows, cols);
}
// dbias[c] = sum_r dout[r,c]. One block per column, threads stride over rows.
__global__ void sum_rows_k(const float* dout, float* dbias, int rows, int cols) {
int col = blockIdx.x;
float acc = 0.0f;
for (int r = threadIdx.x; r < rows; r += blockDim.x)
acc += dout[r * cols + col];
acc = block_reduce_sum(acc);
if (threadIdx.x == 0) dbias[col] = acc;
}
void launch_sum_rows_f32(const float* dout, float* dbias, int rows, int cols, void* s) {
int blk = 256;
sum_rows_k<<<cols, blk, 0, (cudaStream_t)s>>>(dout, dbias, rows, cols);
}
// =====================================================================
// RMSNorm: y[r,c] = x[r,c] * inv_rms[r] * gamma[c], inv_rms = rsqrt(mean(x²)+eps)
// x:[rows,cols], gamma:[cols]. Forward also writes inv_rms[rows] for backward.
// =====================================================================
__global__ void rms_norm_k(const float* x, const float* gamma, float* y,
float* inv_rms, int rows, int cols, float eps) {
int r = blockIdx.x;
const float* xr = x + r * cols;
float* yr = y + r * cols;
float ss = 0.0f;
for (int c = threadIdx.x; c < cols; c += blockDim.x) ss += xr[c] * xr[c];
ss = block_reduce_sum(ss);
float ir = rsqrtf(ss / cols + eps);
if (threadIdx.x == 0) inv_rms[r] = ir;
for (int c = threadIdx.x; c < cols; c += blockDim.x)
yr[c] = xr[c] * ir * gamma[c];
}
void launch_rms_norm_f32(const float* x, const float* gamma, float* y,
float* inv_rms, int rows, int cols, float eps, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
rms_norm_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, gamma, y, inv_rms, rows, cols, eps);
}
// RMSNorm backward.
// Let g[c] = dy[r,c]*gamma[c], ir = inv_rms[r], n = cols.
// dx[r,c] = ir*g[c] - x[r,c]*ir³/n * sum_c(g[c]*x[r,c])
// dgamma[c] = sum_r dy[r,c] * x[r,c] * ir (accumulated across rows)
__global__ void rms_norm_dx_k(const float* x, const float* gamma, const float* dy,
const float* inv_rms, float* dx, int rows, int cols) {
int r = blockIdx.x;
const float* xr = x + r * cols;
const float* dyr = dy + r * cols;
float* dxr = dx + r * cols;
float ir = inv_rms[r];
float dot = 0.0f; // sum_c g[c]*x[c]
for (int c = threadIdx.x; c < cols; c += blockDim.x)
dot += dyr[c] * gamma[c] * xr[c];
dot = block_reduce_sum(dot);
float coeff = ir * ir * ir / (float)cols * dot;
for (int c = threadIdx.x; c < cols; c += blockDim.x)
dxr[c] = ir * dyr[c] * gamma[c] - xr[c] * coeff;
}
void launch_rms_norm_dx_f32(const float* x, const float* gamma, const float* dy,
const float* inv_rms, float* dx, int rows, int cols, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
rms_norm_dx_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, gamma, dy, inv_rms, dx, rows, cols);
}
// dgamma[c] = sum_r dy[r,c] * x[r,c] * inv_rms[r]. One block per column.
__global__ void rms_norm_dgamma_k(const float* x, const float* dy, const float* inv_rms,
float* dgamma, int rows, int cols) {
int col = blockIdx.x;
float acc = 0.0f;
for (int r = threadIdx.x; r < rows; r += blockDim.x)
acc += dy[r * cols + col] * x[r * cols + col] * inv_rms[r];
acc = block_reduce_sum(acc);
if (threadIdx.x == 0) dgamma[col] = acc;
}
void launch_rms_norm_dgamma_f32(const float* x, const float* dy, const float* inv_rms,
float* dgamma, int rows, int cols, void* s) {
int blk = 256;
rms_norm_dgamma_k<<<cols, blk, 0, (cudaStream_t)s>>>(x, dy, inv_rms, dgamma, rows, cols);
}
// =====================================================================
// SiLU: y = x * sigmoid(x). Backward: dx = dy * (sig + x*sig*(1-sig)).
// =====================================================================
__global__ void silu_k(const float* x, float* y, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) { float xv = x[i]; y[i] = xv / (1.0f + expf(-xv)); }
}
void launch_silu_f32(const float* x, float* y, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
silu_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, n);
}
__global__ void silu_dx_k(const float* x, const float* dy, float* dx, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
float xv = x[i];
float sig = 1.0f / (1.0f + expf(-xv));
dx[i] = dy[i] * (sig + xv * sig * (1.0f - sig));
}
}
void launch_silu_dx_f32(const float* x, const float* dy, float* dx, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
silu_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, dy, dx, n);
}
// =====================================================================
// RoPE (rotate_half layout). x:[tokens, heads, head_dim]; position = token index.
// y[i] = x[i]*cos - x[i+h]*sin
// y[i+h] = x[i+h]*cos + x[i]*sin (i in [0,half), h=half_dim)
// freq[i] = theta^(-2i/head_dim); angle = pos*freq[i].
// Backward is the inverse (transpose) rotation: apply +angle's transpose ≡ -angle.
// dx[i] = dy[i]*cos + dy[i+h]*sin
// dx[i+h] = dy[i+h]*cos - dy[i]*sin
// =====================================================================
// `period` is the sequence length: a flattened batch lays B sequences end to end
// along the `tokens` axis, so each token's RoPE position is its index WITHIN its
// own sequence, `tok % period`. With period == tokens (single sequence) this is
// the original position = row.
__global__ void rope_k(const float* x, float* y, int heads, int head_dim,
float theta, int period) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = tok % period;
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float x0 = x[base + i], x1 = x[base + i + half];
y[base + i] = x0 * c - x1 * sn;
y[base + i + half] = x1 * c + x0 * sn;
}
void launch_rope_f32(const float* x, float* y, int tokens, int heads,
int head_dim, float theta, int period, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, period);
}
// RoPE at an absolute position offset (KV-cache decode-time, forward only). Same
// rotate_half as rope_k, but row `tok`'s position is `pos0 + tok` (no modulo) —
// a single new decode token sits at absolute position pos0. The training rope_k
// (position = tok % period) is left untouched, so this adds no training-path risk.
__global__ void rope_at_k(const float* x, float* y, int heads, int head_dim,
float theta, int pos0) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = pos0 + tok;
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float x0 = x[base + i], x1 = x[base + i + half];
y[base + i] = x0 * c - x1 * sn;
y[base + i + half] = x1 * c + x0 * sn;
}
void launch_rope_at_f32(const float* x, float* y, int tokens, int heads,
int head_dim, float theta, int pos0, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
}
// RoPE with a PER-ROW absolute position (batched KV-cache decode, M2b): row `tok`'s
// position is `positions[tok]` (an i32 per token). For G-way batched decode all G
// rows share one decode position; for ragged batches each row carries its own.
// Forward only; the training rope_k is untouched.
__global__ void rope_pos_k(const float* x, const int* positions, float* y,
int heads, int head_dim, float theta) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = positions[tok];
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float x0 = x[base + i], x1 = x[base + i + half];
y[base + i] = x0 * c - x1 * sn;
y[base + i + half] = x1 * c + x0 * sn;
}
void launch_rope_pos_f32(const float* x, const int* positions, float* y,
int tokens, int heads, int head_dim, float theta, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_pos_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, positions, y, heads, head_dim, theta);
}
// Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO
// (M4) policy-gradient backward, where each completion token's row of
// (probs onehot) is scaled by its own per-token coefficient.
__global__ void scale_rows_k(const float* x, const float* s, float* y,
int rows, int cols) {
int r = blockIdx.x;
float sr = s[r];
for (int c = threadIdx.x; c < cols; c += blockDim.x)
y[r * cols + c] = x[r * cols + c] * sr;
}
void launch_scale_rows_f32(const float* x, const float* s, float* y,
int rows, int cols, void* st) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
scale_rows_k<<<rows, blk, 0, (cudaStream_t)st>>>(x, s, y, rows, cols);
}
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
float theta, int period) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = tok % period;
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float d0 = dy[base + i], d1 = dy[base + i + half];
dx[base + i] = d0 * c + d1 * sn;
dx[base + i + half] = d1 * c - d0 * sn;
}
void launch_rope_dx_f32(const float* dy, float* dx, int tokens, int heads,
int head_dim, float theta, int period, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(dy, dx, heads, head_dim, theta, period);
}
// =====================================================================
// Row-wise safe softmax. x:[rows,cols] → y. Backward (Jacobian):
// dx[r,c] = y[r,c] * (dy[r,c] - sum_c'(dy[r,c']*y[r,c']))
// =====================================================================
__global__ void softmax_k(const float* x, float* y, int rows, int cols) {
int r = blockIdx.x;
const float* xr = x + r * cols;
float* yr = y + r * cols;
float m = -INFINITY;
for (int c = threadIdx.x; c < cols; c += blockDim.x) m = fmaxf(m, xr[c]);
m = block_reduce_max(m);
float sum = 0.0f;
for (int c = threadIdx.x; c < cols; c += blockDim.x) {
float e = expf(xr[c] - m);
yr[c] = e;
sum += e;
}
sum = block_reduce_sum(sum);
float inv = 1.0f / sum;
for (int c = threadIdx.x; c < cols; c += blockDim.x) yr[c] *= inv;
}
void launch_softmax_f32(const float* x, float* y, int rows, int cols, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
softmax_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, y, rows, cols);
}
__global__ void softmax_dx_k(const float* y, const float* dy, float* dx,
int rows, int cols) {
int r = blockIdx.x;
const float* yr = y + r * cols;
const float* dyr = dy + r * cols;
float* dxr = dx + r * cols;
float dot = 0.0f; // sum_c dy*y
for (int c = threadIdx.x; c < cols; c += blockDim.x) dot += dyr[c] * yr[c];
dot = block_reduce_sum(dot);
for (int c = threadIdx.x; c < cols; c += blockDim.x)
dxr[c] = yr[c] * (dyr[c] - dot);
}
void launch_softmax_dx_f32(const float* y, const float* dy, float* dx,
int rows, int cols, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
softmax_dx_k<<<rows, blk, 0, (cudaStream_t)s>>>(y, dy, dx, rows, cols);
}
// =====================================================================
// Cross-entropy over logits x:[rows,cols] with int target per row.
// Forward writes per-row loss[r] = -log(softmax(x)[target]) and the softmax
// probs[r,:] (cached for backward). Backward: dx[r,c] = (probs[r,c]-onehot)/rows
// (mean reduction; the *rows scale folds the 1/rows of mean loss into dx).
// =====================================================================
__global__ void cross_entropy_fwd_k(const float* x, const int* target,
float* probs, float* loss, int rows, int cols) {
int r = blockIdx.x;
const float* xr = x + r * cols;
float* pr = probs + r * cols;
float m = -INFINITY;
for (int c = threadIdx.x; c < cols; c += blockDim.x) m = fmaxf(m, xr[c]);
m = block_reduce_max(m);
float sum = 0.0f;
for (int c = threadIdx.x; c < cols; c += blockDim.x) {
float e = expf(xr[c] - m);
pr[c] = e;
sum += e;
}
sum = block_reduce_sum(sum);
float inv = 1.0f / sum;
for (int c = threadIdx.x; c < cols; c += blockDim.x) pr[c] *= inv;
if (threadIdx.x == 0) {
int t = target[r];
loss[r] = t < 0 ? 0.0f : -logf(pr[t]);
}
}
void launch_cross_entropy_fwd_f32(const float* x, const int* target,
float* probs, float* loss, int rows, int cols, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
cross_entropy_fwd_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, target, probs, loss, rows, cols);
}
// dx[r,c] = scale * (probs[r,c] - [c==target]). scale = upstream/rows.
__global__ void cross_entropy_dx_k(const float* probs, const int* target,
float* dx, int rows, int cols, float scale) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= rows * cols) return;
int r = i / cols, c = i % cols;
int t = target[r];
if (t < 0) {
dx[i] = 0.0f;
} else {
float g = probs[i] - (c == t ? 1.0f : 0.0f);
dx[i] = g * scale;
}
}
void launch_cross_entropy_dx_f32(const float* probs, const int* target,
float* dx, int rows, int cols, float scale, void* s) {
int n = rows * cols, blk = 256, grid = (n + blk - 1) / blk;
cross_entropy_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(probs, target, dx, rows, cols, scale);
}
} // extern "C"