Files
xtrain/csrc/ops/nn.cu
Gahow Wang fbf4ac2917 sft: assistant-only SFT (ignore-index CE) + chat-prompt greedy eval
Enable assistant-only supervised fine-tuning and a fixed chat-prompt eval path
used by the v12 SFT runs:

- cross_entropy ignores negative targets (-100 ignore-index), normalizing by
  valid rows instead of all rows; CUDA fwd/bwd skip t<0 (ops.rs, nn.cu).
- Corpus gains optional labels + load_sft_tsv_cached: two-column TSV is
  formatted as 'User: .. \nAssistant:' + answer + <|endoftext|>, prompt tokens
  masked to -100 while answer+EOS are supervised; i32 label cache alongside the
  u16 token cache; sample() retries windows that are fully masked; eval uses
  target_window so masking applies to val loss too (data.rs, train_loop.rs).
- train + train_ddp: --sft-tsv selects the TSV loader, --init-ckpt continues
  training from a base checkpoint.
- greedy_sample: --prompts-file/--prompt/--temperature for fixed chat-prompt
  generation eval.

Test fixtures updated for the new Corpus.labels field; dropout.rs carries
incidental rustfmt. Not rebuilt locally (no CUDA toolchain on this checkout);
correctness rests on the documented v12 base+SFT runs on the GPU box.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-29 16:19:02 +08:00

372 lines
15 KiB
Plaintext

// Forward + backward CUDA kernels for the transformer ops the autograd engine
// (Phase T4) needs: elementwise add/mul, broadcast bias add + its row-sum
// backward, RMSNorm, SiLU, RoPE, row-wise softmax, and cross-entropy.
//
// All F32, row-major, contiguous. Forward kernels mirror xserv
// (docs/04-transformer-kernels.md, docs/05-attention.md); the backward kernels
// are new (xserv is inference-only). Reduction helpers are inlined here so this
// file is self-contained (no shared header), matching the existing csrc/ layout.
#include <math.h>
extern "C" {
// --- Warp / block reductions (sum + max), block handles one row ---
__device__ __forceinline__ float warp_reduce_sum(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v += __shfl_down_sync(0xffffffff, v, off);
return v;
}
__device__ __forceinline__ float warp_reduce_max(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
return v;
}
__device__ __forceinline__ float block_reduce_sum(float v) {
__shared__ float shared[32];
int lane = threadIdx.x & 31;
int warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = warp_reduce_sum(v);
if (lane == 0) shared[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? shared[threadIdx.x] : 0.0f;
if (warp == 0) v = warp_reduce_sum(v);
// broadcast warp-0 lane-0 result to whole block
__shared__ float bcast;
if (threadIdx.x == 0) bcast = v;
__syncthreads();
return bcast;
}
__device__ __forceinline__ float block_reduce_max(float v) {
__shared__ float shared[32];
int lane = threadIdx.x & 31;
int warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = warp_reduce_max(v);
if (lane == 0) shared[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? shared[threadIdx.x] : -INFINITY;
if (warp == 0) v = warp_reduce_max(v);
__shared__ float bcast;
if (threadIdx.x == 0) bcast = v;
__syncthreads();
return bcast;
}
// =====================================================================
// Elementwise add / mul (same-shape)
// =====================================================================
__global__ void add_k(const float* a, const float* b, float* out, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) out[i] = a[i] + b[i];
}
void launch_add_f32(const float* a, const float* b, float* out, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
add_k<<<grid, blk, 0, (cudaStream_t)s>>>(a, b, out, n);
}
__global__ void mul_k(const float* a, const float* b, float* out, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) out[i] = a[i] * b[i];
}
void launch_mul_f32(const float* a, const float* b, float* out, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
mul_k<<<grid, blk, 0, (cudaStream_t)s>>>(a, b, out, n);
}
// =====================================================================
// Broadcast bias add: out[r,c] = x[r,c] + bias[c] (x:[rows,cols])
// Backward for bias is a column-sum (sum over rows): dbias[c] = sum_r dout[r,c].
// =====================================================================
__global__ void add_bias_k(const float* x, const float* bias, float* out,
int rows, int cols) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < rows * cols) out[i] = x[i] + bias[i % cols];
}
void launch_add_bias_f32(const float* x, const float* bias, float* out,
int rows, int cols, void* s) {
int n = rows * cols, blk = 256, grid = (n + blk - 1) / blk;
add_bias_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, bias, out, rows, cols);
}
// dbias[c] = sum_r dout[r,c]. One block per column, threads stride over rows.
__global__ void sum_rows_k(const float* dout, float* dbias, int rows, int cols) {
int col = blockIdx.x;
float acc = 0.0f;
for (int r = threadIdx.x; r < rows; r += blockDim.x)
acc += dout[r * cols + col];
acc = block_reduce_sum(acc);
if (threadIdx.x == 0) dbias[col] = acc;
}
void launch_sum_rows_f32(const float* dout, float* dbias, int rows, int cols, void* s) {
int blk = 256;
sum_rows_k<<<cols, blk, 0, (cudaStream_t)s>>>(dout, dbias, rows, cols);
}
// =====================================================================
// RMSNorm: y[r,c] = x[r,c] * inv_rms[r] * gamma[c], inv_rms = rsqrt(mean(x²)+eps)
// x:[rows,cols], gamma:[cols]. Forward also writes inv_rms[rows] for backward.
// =====================================================================
__global__ void rms_norm_k(const float* x, const float* gamma, float* y,
float* inv_rms, int rows, int cols, float eps) {
int r = blockIdx.x;
const float* xr = x + r * cols;
float* yr = y + r * cols;
float ss = 0.0f;
for (int c = threadIdx.x; c < cols; c += blockDim.x) ss += xr[c] * xr[c];
ss = block_reduce_sum(ss);
float ir = rsqrtf(ss / cols + eps);
if (threadIdx.x == 0) inv_rms[r] = ir;
for (int c = threadIdx.x; c < cols; c += blockDim.x)
yr[c] = xr[c] * ir * gamma[c];
}
void launch_rms_norm_f32(const float* x, const float* gamma, float* y,
float* inv_rms, int rows, int cols, float eps, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
rms_norm_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, gamma, y, inv_rms, rows, cols, eps);
}
// RMSNorm backward.
// Let g[c] = dy[r,c]*gamma[c], ir = inv_rms[r], n = cols.
// dx[r,c] = ir*g[c] - x[r,c]*ir³/n * sum_c(g[c]*x[r,c])
// dgamma[c] = sum_r dy[r,c] * x[r,c] * ir (accumulated across rows)
__global__ void rms_norm_dx_k(const float* x, const float* gamma, const float* dy,
const float* inv_rms, float* dx, int rows, int cols) {
int r = blockIdx.x;
const float* xr = x + r * cols;
const float* dyr = dy + r * cols;
float* dxr = dx + r * cols;
float ir = inv_rms[r];
float dot = 0.0f; // sum_c g[c]*x[c]
for (int c = threadIdx.x; c < cols; c += blockDim.x)
dot += dyr[c] * gamma[c] * xr[c];
dot = block_reduce_sum(dot);
float coeff = ir * ir * ir / (float)cols * dot;
for (int c = threadIdx.x; c < cols; c += blockDim.x)
dxr[c] = ir * dyr[c] * gamma[c] - xr[c] * coeff;
}
void launch_rms_norm_dx_f32(const float* x, const float* gamma, const float* dy,
const float* inv_rms, float* dx, int rows, int cols, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
rms_norm_dx_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, gamma, dy, inv_rms, dx, rows, cols);
}
// dgamma[c] = sum_r dy[r,c] * x[r,c] * inv_rms[r]. One block per column.
__global__ void rms_norm_dgamma_k(const float* x, const float* dy, const float* inv_rms,
float* dgamma, int rows, int cols) {
int col = blockIdx.x;
float acc = 0.0f;
for (int r = threadIdx.x; r < rows; r += blockDim.x)
acc += dy[r * cols + col] * x[r * cols + col] * inv_rms[r];
acc = block_reduce_sum(acc);
if (threadIdx.x == 0) dgamma[col] = acc;
}
void launch_rms_norm_dgamma_f32(const float* x, const float* dy, const float* inv_rms,
float* dgamma, int rows, int cols, void* s) {
int blk = 256;
rms_norm_dgamma_k<<<cols, blk, 0, (cudaStream_t)s>>>(x, dy, inv_rms, dgamma, rows, cols);
}
// =====================================================================
// SiLU: y = x * sigmoid(x). Backward: dx = dy * (sig + x*sig*(1-sig)).
// =====================================================================
__global__ void silu_k(const float* x, float* y, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) { float xv = x[i]; y[i] = xv / (1.0f + expf(-xv)); }
}
void launch_silu_f32(const float* x, float* y, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
silu_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, n);
}
__global__ void silu_dx_k(const float* x, const float* dy, float* dx, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
float xv = x[i];
float sig = 1.0f / (1.0f + expf(-xv));
dx[i] = dy[i] * (sig + xv * sig * (1.0f - sig));
}
}
void launch_silu_dx_f32(const float* x, const float* dy, float* dx, int n, void* s) {
int blk = 256, grid = (n + blk - 1) / blk;
silu_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, dy, dx, n);
}
// =====================================================================
// RoPE (rotate_half layout). x:[tokens, heads, head_dim]; position = token index.
// y[i] = x[i]*cos - x[i+h]*sin
// y[i+h] = x[i+h]*cos + x[i]*sin (i in [0,half), h=half_dim)
// freq[i] = theta^(-2i/head_dim); angle = pos*freq[i].
// Backward is the inverse (transpose) rotation: apply +angle's transpose ≡ -angle.
// dx[i] = dy[i]*cos + dy[i+h]*sin
// dx[i+h] = dy[i+h]*cos - dy[i]*sin
// =====================================================================
// `period` is the sequence length: a flattened batch lays B sequences end to end
// along the `tokens` axis, so each token's RoPE position is its index WITHIN its
// own sequence, `tok % period`. With period == tokens (single sequence) this is
// the original position = row.
__global__ void rope_k(const float* x, float* y, int heads, int head_dim,
float theta, int period) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = tok % period;
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float x0 = x[base + i], x1 = x[base + i + half];
y[base + i] = x0 * c - x1 * sn;
y[base + i + half] = x1 * c + x0 * sn;
}
void launch_rope_f32(const float* x, float* y, int tokens, int heads,
int head_dim, float theta, int period, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, period);
}
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
float theta, int period) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = tok % period;
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float d0 = dy[base + i], d1 = dy[base + i + half];
dx[base + i] = d0 * c + d1 * sn;
dx[base + i + half] = d1 * c - d0 * sn;
}
void launch_rope_dx_f32(const float* dy, float* dx, int tokens, int heads,
int head_dim, float theta, int period, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(dy, dx, heads, head_dim, theta, period);
}
// =====================================================================
// Row-wise safe softmax. x:[rows,cols] → y. Backward (Jacobian):
// dx[r,c] = y[r,c] * (dy[r,c] - sum_c'(dy[r,c']*y[r,c']))
// =====================================================================
__global__ void softmax_k(const float* x, float* y, int rows, int cols) {
int r = blockIdx.x;
const float* xr = x + r * cols;
float* yr = y + r * cols;
float m = -INFINITY;
for (int c = threadIdx.x; c < cols; c += blockDim.x) m = fmaxf(m, xr[c]);
m = block_reduce_max(m);
float sum = 0.0f;
for (int c = threadIdx.x; c < cols; c += blockDim.x) {
float e = expf(xr[c] - m);
yr[c] = e;
sum += e;
}
sum = block_reduce_sum(sum);
float inv = 1.0f / sum;
for (int c = threadIdx.x; c < cols; c += blockDim.x) yr[c] *= inv;
}
void launch_softmax_f32(const float* x, float* y, int rows, int cols, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
softmax_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, y, rows, cols);
}
__global__ void softmax_dx_k(const float* y, const float* dy, float* dx,
int rows, int cols) {
int r = blockIdx.x;
const float* yr = y + r * cols;
const float* dyr = dy + r * cols;
float* dxr = dx + r * cols;
float dot = 0.0f; // sum_c dy*y
for (int c = threadIdx.x; c < cols; c += blockDim.x) dot += dyr[c] * yr[c];
dot = block_reduce_sum(dot);
for (int c = threadIdx.x; c < cols; c += blockDim.x)
dxr[c] = yr[c] * (dyr[c] - dot);
}
void launch_softmax_dx_f32(const float* y, const float* dy, float* dx,
int rows, int cols, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
softmax_dx_k<<<rows, blk, 0, (cudaStream_t)s>>>(y, dy, dx, rows, cols);
}
// =====================================================================
// Cross-entropy over logits x:[rows,cols] with int target per row.
// Forward writes per-row loss[r] = -log(softmax(x)[target]) and the softmax
// probs[r,:] (cached for backward). Backward: dx[r,c] = (probs[r,c]-onehot)/rows
// (mean reduction; the *rows scale folds the 1/rows of mean loss into dx).
// =====================================================================
__global__ void cross_entropy_fwd_k(const float* x, const int* target,
float* probs, float* loss, int rows, int cols) {
int r = blockIdx.x;
const float* xr = x + r * cols;
float* pr = probs + r * cols;
float m = -INFINITY;
for (int c = threadIdx.x; c < cols; c += blockDim.x) m = fmaxf(m, xr[c]);
m = block_reduce_max(m);
float sum = 0.0f;
for (int c = threadIdx.x; c < cols; c += blockDim.x) {
float e = expf(xr[c] - m);
pr[c] = e;
sum += e;
}
sum = block_reduce_sum(sum);
float inv = 1.0f / sum;
for (int c = threadIdx.x; c < cols; c += blockDim.x) pr[c] *= inv;
if (threadIdx.x == 0) {
int t = target[r];
loss[r] = t < 0 ? 0.0f : -logf(pr[t]);
}
}
void launch_cross_entropy_fwd_f32(const float* x, const int* target,
float* probs, float* loss, int rows, int cols, void* s) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
cross_entropy_fwd_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, target, probs, loss, rows, cols);
}
// dx[r,c] = scale * (probs[r,c] - [c==target]). scale = upstream/rows.
__global__ void cross_entropy_dx_k(const float* probs, const int* target,
float* dx, int rows, int cols, float scale) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= rows * cols) return;
int r = i / cols, c = i % cols;
int t = target[r];
if (t < 0) {
dx[i] = 0.0f;
} else {
float g = probs[i] - (c == t ? 1.0f : 0.0f);
dx[i] = g * scale;
}
}
void launch_cross_entropy_dx_f32(const float* probs, const int* target,
float* dx, int rows, int cols, float scale, void* s) {
int n = rows * cols, blk = 256, grid = (n + blk - 1) / blk;
cross_entropy_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(probs, target, dx, rows, cols, scale);
}
} // extern "C"