Files
xtrain/csrc/ops/nn.cu
Gahow Wang 3a3425960c post-train: M2c — device-side KV cache (cat_seq), profile-first bottleneck shift
Device-resident KV cache: keep K/V on the GPU as [bh,T,hd], grow by one token
per step via a new cat_seq kernel (concat along seq) — removes the M2a/M2b
per-layer host round-trip (to_cpu/from_slice/re-upload) AND the transpose_3d01.
Both single-seq and batched decode refactored to it; cache is Option<Tensor>
per layer (cleaner than the host Vec + rebuild).

Gates all hold: cat_seq == host concat; decode_kv single-seq + decode_batch
G-way both still TOKEN-IDENTICAL; GQA training path unaffected.

Honest measurement (the point): removing the host round-trip buys ~10% on pure
single-seq decode (133 → 147 tok/s @128) but does NOT move the GRPO step
(~8.5 s/step unchanged) — because after M2b batching the rollout is no longer
the step's bottleneck; the per-sample per_token_logp captures + the PG-update
forwards/backwards (model.forward, full-seq) now dominate. Measure-first lesson
(cf. T11/T17/M2a): the long pole shifted to the training-side forwards; the next
decode lever (ragged batched prefill) targets those, not the cache.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 17:38:16 +08:00

462 lines
19 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Forward + backward CUDA kernels for the transformer ops the autograd engine
// (Phase T4) needs: elementwise add/mul, broadcast bias add + its row-sum
// backward, RMSNorm, SiLU, RoPE, row-wise softmax, and cross-entropy.
//
// All F32, row-major, contiguous. Forward kernels mirror xserv
// (docs/04-transformer-kernels.md, docs/05-attention.md); the backward kernels
// are new (xserv is inference-only). Reduction helpers are inlined here so this
// file is self-contained (no shared header), matching the existing csrc/ layout.
#include <math.h>
extern "C" {
// --- Warp / block reductions (sum + max), block handles one row ---
__device__ __forceinline__ float warp_reduce_sum(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v += __shfl_down_sync(0xffffffff, v, off);
return v;
}
__device__ __forceinline__ float warp_reduce_max(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
return v;
}
__device__ __forceinline__ float block_reduce_sum(float v) {
__shared__ float shared[32];
int lane = threadIdx.x & 31;
int warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = warp_reduce_sum(v);
if (lane == 0) shared[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? shared[threadIdx.x] : 0.0f;
if (warp == 0) v = warp_reduce_sum(v);
// broadcast warp-0 lane-0 result to whole block
__shared__ float bcast;
if (threadIdx.x == 0) bcast = v;
__syncthreads();
return bcast;
}
__device__ __forceinline__ float block_reduce_max(float v) {
__shared__ float shared[32];
int lane = threadIdx.x & 31;
int warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = warp_reduce_max(v);
if (lane == 0) shared[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? shared[threadIdx.x] : -INFINITY;
if (warp == 0) v = warp_reduce_max(v);
__shared__ float bcast;
if (threadIdx.x == 0) bcast = v;
__syncthreads();
return bcast;
}
// =====================================================================
// Elementwise add / mul (same-shape)
// =====================================================================
__global__ void add_k(const float* a, const float* b, float* out, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) out[i] = a[i] + b[i];
}
void launch_add_f32(const float* a, const float* b, float* out, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
add_k<<<grid, blk, 0, (cudaStream_t)s>>>(a, b, out, n);
}
__global__ void mul_k(const float* a, const float* b, float* out, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) out[i] = a[i] * b[i];
}
void launch_mul_f32(const float* a, const float* b, float* out, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
mul_k<<<grid, blk, 0, (cudaStream_t)s>>>(a, b, out, n);
}
// =====================================================================
// Broadcast bias add: out[r,c] = x[r,c] + bias[c] (x:[rows,cols])
// Backward for bias is a column-sum (sum over rows): dbias[c] = sum_r dout[r,c].
// =====================================================================
__global__ void add_bias_k(const float* x, const float* bias, float* out,
int rows, int cols) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < rows * cols) out[i] = x[i] + bias[i % cols];
}
void launch_add_bias_f32(const float* x, const float* bias, float* out,
int rows, int cols, void* s) {
int n = rows * cols, blk = 256, grid = (n + blk - 1) / blk;
add_bias_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, bias, out, rows, cols);
}
// dbias[c] = sum_r dout[r,c]. One block per column, threads stride over rows.
__global__ void sum_rows_k(const float* dout, float* dbias, int rows, int cols) {
int col = blockIdx.x;
float acc = 0.0f;
for (int r = threadIdx.x; r < rows; r += blockDim.x)
acc += dout[r * cols + col];
acc = block_reduce_sum(acc);
if (threadIdx.x == 0) dbias[col] = acc;
}
void launch_sum_rows_f32(const float* dout, float* dbias, int rows, int cols, void* s) {
int blk = 256;
sum_rows_k<<<cols, blk, 0, (cudaStream_t)s>>>(dout, dbias, rows, cols);
}
// =====================================================================
// RMSNorm: y[r,c] = x[r,c] * inv_rms[r] * gamma[c], inv_rms = rsqrt(mean(x²)+eps)
// x:[rows,cols], gamma:[cols]. Forward also writes inv_rms[rows] for backward.
// =====================================================================
__global__ void rms_norm_k(const float* x, const float* gamma, float* y,
float* inv_rms, int rows, int cols, float eps) {
int r = blockIdx.x;
const float* xr = x + r * cols;
float* yr = y + r * cols;
float ss = 0.0f;
for (int c = threadIdx.x; c < cols; c += blockDim.x) ss += xr[c] * xr[c];
ss = block_reduce_sum(ss);
float ir = rsqrtf(ss / cols + eps);
if (threadIdx.x == 0) inv_rms[r] = ir;
for (int c = threadIdx.x; c < cols; c += blockDim.x)
yr[c] = xr[c] * ir * gamma[c];
}
void launch_rms_norm_f32(const float* x, const float* gamma, float* y,
float* inv_rms, int rows, int cols, float eps, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
rms_norm_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, gamma, y, inv_rms, rows, cols, eps);
}
// RMSNorm backward.
// Let g[c] = dy[r,c]*gamma[c], ir = inv_rms[r], n = cols.
// dx[r,c] = ir*g[c] - x[r,c]*ir³/n * sum_c(g[c]*x[r,c])
// dgamma[c] = sum_r dy[r,c] * x[r,c] * ir (accumulated across rows)
__global__ void rms_norm_dx_k(const float* x, const float* gamma, const float* dy,
const float* inv_rms, float* dx, int rows, int cols) {
int r = blockIdx.x;
const float* xr = x + r * cols;
const float* dyr = dy + r * cols;
float* dxr = dx + r * cols;
float ir = inv_rms[r];
float dot = 0.0f; // sum_c g[c]*x[c]
for (int c = threadIdx.x; c < cols; c += blockDim.x)
dot += dyr[c] * gamma[c] * xr[c];
dot = block_reduce_sum(dot);
float coeff = ir * ir * ir / (float)cols * dot;
for (int c = threadIdx.x; c < cols; c += blockDim.x)
dxr[c] = ir * dyr[c] * gamma[c] - xr[c] * coeff;
}
void launch_rms_norm_dx_f32(const float* x, const float* gamma, const float* dy,
const float* inv_rms, float* dx, int rows, int cols, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
rms_norm_dx_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, gamma, dy, inv_rms, dx, rows, cols);
}
// dgamma[c] = sum_r dy[r,c] * x[r,c] * inv_rms[r]. One block per column.
__global__ void rms_norm_dgamma_k(const float* x, const float* dy, const float* inv_rms,
float* dgamma, int rows, int cols) {
int col = blockIdx.x;
float acc = 0.0f;
for (int r = threadIdx.x; r < rows; r += blockDim.x)
acc += dy[r * cols + col] * x[r * cols + col] * inv_rms[r];
acc = block_reduce_sum(acc);
if (threadIdx.x == 0) dgamma[col] = acc;
}
void launch_rms_norm_dgamma_f32(const float* x, const float* dy, const float* inv_rms,
float* dgamma, int rows, int cols, void* s) {
int blk = 256;
rms_norm_dgamma_k<<<cols, blk, 0, (cudaStream_t)s>>>(x, dy, inv_rms, dgamma, rows, cols);
}
// =====================================================================
// SiLU: y = x * sigmoid(x). Backward: dx = dy * (sig + x*sig*(1-sig)).
// =====================================================================
__global__ void silu_k(const float* x, float* y, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) { float xv = x[i]; y[i] = xv / (1.0f + expf(-xv)); }
}
void launch_silu_f32(const float* x, float* y, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
silu_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, n);
}
__global__ void silu_dx_k(const float* x, const float* dy, float* dx, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
float xv = x[i];
float sig = 1.0f / (1.0f + expf(-xv));
dx[i] = dy[i] * (sig + xv * sig * (1.0f - sig));
}
}
void launch_silu_dx_f32(const float* x, const float* dy, float* dx, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
silu_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, dy, dx, n);
}
// =====================================================================
// RoPE (rotate_half layout). x:[tokens, heads, head_dim]; position = token index.
// y[i] = x[i]*cos - x[i+h]*sin
// y[i+h] = x[i+h]*cos + x[i]*sin (i in [0,half), h=half_dim)
// freq[i] = theta^(-2i/head_dim); angle = pos*freq[i].
// Backward is the inverse (transpose) rotation: apply +angle's transpose ≡ -angle.
// dx[i] = dy[i]*cos + dy[i+h]*sin
// dx[i+h] = dy[i+h]*cos - dy[i]*sin
// =====================================================================
// `period` is the sequence length: a flattened batch lays B sequences end to end
// along the `tokens` axis, so each token's RoPE position is its index WITHIN its
// own sequence, `tok % period`. With period == tokens (single sequence) this is
// the original position = row.
__global__ void rope_k(const float* x, float* y, int heads, int head_dim,
float theta, int period) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = tok % period;
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float x0 = x[base + i], x1 = x[base + i + half];
y[base + i] = x0 * c - x1 * sn;
y[base + i + half] = x1 * c + x0 * sn;
}
void launch_rope_f32(const float* x, float* y, int tokens, int heads,
int head_dim, float theta, int period, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, period);
}
// RoPE at an absolute position offset (KV-cache decode-time, forward only). Same
// rotate_half as rope_k, but row `tok`'s position is `pos0 + tok` (no modulo) —
// a single new decode token sits at absolute position pos0. The training rope_k
// (position = tok % period) is left untouched, so this adds no training-path risk.
__global__ void rope_at_k(const float* x, float* y, int heads, int head_dim,
float theta, int pos0) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = pos0 + tok;
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float x0 = x[base + i], x1 = x[base + i + half];
y[base + i] = x0 * c - x1 * sn;
y[base + i + half] = x1 * c + x0 * sn;
}
void launch_rope_at_f32(const float* x, float* y, int tokens, int heads,
int head_dim, float theta, int pos0, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
}
// RoPE with a PER-ROW absolute position (batched KV-cache decode, M2b): row `tok`'s
// position is `positions[tok]` (an i32 per token). For G-way batched decode all G
// rows share one decode position; for ragged batches each row carries its own.
// Forward only; the training rope_k is untouched.
__global__ void rope_pos_k(const float* x, const int* positions, float* y,
int heads, int head_dim, float theta) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = positions[tok];
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float x0 = x[base + i], x1 = x[base + i + half];
y[base + i] = x0 * c - x1 * sn;
y[base + i + half] = x1 * c + x0 * sn;
}
void launch_rope_pos_f32(const float* x, const int* positions, float* y,
int tokens, int heads, int head_dim, float theta, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_pos_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, positions, y, heads, head_dim, theta);
}
// Concatenate along the sequence (middle) dim: a:[bh,ta,hd], b:[bh,tb,hd] →
// out:[bh,ta+tb,hd] with out[:, :ta]=a, out[:, ta:]=b. The device-side KV-cache
// append (M2c): keeps K/V on the GPU and grows by one token per step, removing the
// host round-trip the M2a/M2b host cache paid. One block per bh row.
__global__ void cat_seq_k(const float* a, const float* b, float* out,
int ta_hd, int tb_hd) {
int i = blockIdx.x; // bh row
int o_hd = ta_hd + tb_hd;
const float* ar = a + (long)i * ta_hd;
const float* br = b + (long)i * tb_hd;
float* outr = out + (long)i * o_hd;
for (int j = threadIdx.x; j < ta_hd; j += blockDim.x) outr[j] = ar[j];
for (int j = threadIdx.x; j < tb_hd; j += blockDim.x) outr[ta_hd + j] = br[j];
}
void launch_cat_seq_f32(const float* a, const float* b, float* out,
int bh, int ta_hd, int tb_hd, void* s) {
cat_seq_k<<<bh, 256, 0, (cudaStream_t)s>>>(a, b, out, ta_hd, tb_hd);
}
// Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO
// (M4) policy-gradient backward, where each completion token's row of
// (probs onehot) is scaled by its own per-token coefficient.
__global__ void scale_rows_k(const float* x, const float* s, float* y,
int rows, int cols) {
int r = blockIdx.x;
float sr = s[r];
for (int c = threadIdx.x; c < cols; c += blockDim.x)
y[r * cols + c] = x[r * cols + c] * sr;
}
void launch_scale_rows_f32(const float* x, const float* s, float* y,
int rows, int cols, void* st) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
scale_rows_k<<<rows, blk, 0, (cudaStream_t)st>>>(x, s, y, rows, cols);
}
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
float theta, int period) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = tok % period;
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float d0 = dy[base + i], d1 = dy[base + i + half];
dx[base + i] = d0 * c + d1 * sn;
dx[base + i + half] = d1 * c - d0 * sn;
}
void launch_rope_dx_f32(const float* dy, float* dx, int tokens, int heads,
int head_dim, float theta, int period, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(dy, dx, heads, head_dim, theta, period);
}
// =====================================================================
// Row-wise safe softmax. x:[rows,cols] → y. Backward (Jacobian):
// dx[r,c] = y[r,c] * (dy[r,c] - sum_c'(dy[r,c']*y[r,c']))
// =====================================================================
__global__ void softmax_k(const float* x, float* y, int rows, int cols) {
int r = blockIdx.x;
const float* xr = x + r * cols;
float* yr = y + r * cols;
float m = -INFINITY;
for (int c = threadIdx.x; c < cols; c += blockDim.x) m = fmaxf(m, xr[c]);
m = block_reduce_max(m);
float sum = 0.0f;
for (int c = threadIdx.x; c < cols; c += blockDim.x) {
float e = expf(xr[c] - m);
yr[c] = e;
sum += e;
}
sum = block_reduce_sum(sum);
float inv = 1.0f / sum;
for (int c = threadIdx.x; c < cols; c += blockDim.x) yr[c] *= inv;
}
void launch_softmax_f32(const float* x, float* y, int rows, int cols, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
softmax_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, y, rows, cols);
}
__global__ void softmax_dx_k(const float* y, const float* dy, float* dx,
int rows, int cols) {
int r = blockIdx.x;
const float* yr = y + r * cols;
const float* dyr = dy + r * cols;
float* dxr = dx + r * cols;
float dot = 0.0f; // sum_c dy*y
for (int c = threadIdx.x; c < cols; c += blockDim.x) dot += dyr[c] * yr[c];
dot = block_reduce_sum(dot);
for (int c = threadIdx.x; c < cols; c += blockDim.x)
dxr[c] = yr[c] * (dyr[c] - dot);
}
void launch_softmax_dx_f32(const float* y, const float* dy, float* dx,
int rows, int cols, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
softmax_dx_k<<<rows, blk, 0, (cudaStream_t)s>>>(y, dy, dx, rows, cols);
}
// =====================================================================
// Cross-entropy over logits x:[rows,cols] with int target per row.
// Forward writes per-row loss[r] = -log(softmax(x)[target]) and the softmax
// probs[r,:] (cached for backward). Backward: dx[r,c] = (probs[r,c]-onehot)/rows
// (mean reduction; the *rows scale folds the 1/rows of mean loss into dx).
// =====================================================================
__global__ void cross_entropy_fwd_k(const float* x, const int* target,
float* probs, float* loss, int rows, int cols) {
int r = blockIdx.x;
const float* xr = x + r * cols;
float* pr = probs + r * cols;
float m = -INFINITY;
for (int c = threadIdx.x; c < cols; c += blockDim.x) m = fmaxf(m, xr[c]);
m = block_reduce_max(m);
float sum = 0.0f;
for (int c = threadIdx.x; c < cols; c += blockDim.x) {
float e = expf(xr[c] - m);
pr[c] = e;
sum += e;
}
sum = block_reduce_sum(sum);
float inv = 1.0f / sum;
for (int c = threadIdx.x; c < cols; c += blockDim.x) pr[c] *= inv;
if (threadIdx.x == 0) {
int t = target[r];
loss[r] = t < 0 ? 0.0f : -logf(pr[t]);
}
}
void launch_cross_entropy_fwd_f32(const float* x, const int* target,
float* probs, float* loss, int rows, int cols, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
cross_entropy_fwd_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, target, probs, loss, rows, cols);
}
// dx[r,c] = scale * (probs[r,c] - [c==target]). scale = upstream/rows.
__global__ void cross_entropy_dx_k(const float* probs, const int* target,
float* dx, int rows, int cols, float scale) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= rows * cols) return;
int r = i / cols, c = i % cols;
int t = target[r];
if (t < 0) {
dx[i] = 0.0f;
} else {
float g = probs[i] - (c == t ? 1.0f : 0.0f);
dx[i] = g * scale;
}
}
void launch_cross_entropy_dx_f32(const float* probs, const int* target,
float* dx, int rows, int cols, float scale, void* s) {
int n = rows * cols, blk = 256, grid = (n + blk - 1) / blk;
cross_entropy_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(probs, target, dx, rows, cols, scale);
}
} // extern "C"