// Forward + backward CUDA kernels for the transformer ops the autograd engine // (Phase T4) needs: elementwise add/mul, broadcast bias add + its row-sum // backward, RMSNorm, SiLU, RoPE, row-wise softmax, and cross-entropy. // // All F32, row-major, contiguous. Forward kernels mirror xserv // (docs/04-transformer-kernels.md, docs/05-attention.md); the backward kernels // are new (xserv is inference-only). Reduction helpers are inlined here so this // file is self-contained (no shared header), matching the existing csrc/ layout. #include extern "C" { // --- Warp / block reductions (sum + max), block handles one row --- __device__ __forceinline__ float warp_reduce_sum(float v) { #pragma unroll for (int off = 16; off > 0; off >>= 1) v += __shfl_down_sync(0xffffffff, v, off); return v; } __device__ __forceinline__ float warp_reduce_max(float v) { #pragma unroll for (int off = 16; off > 0; off >>= 1) v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off)); return v; } __device__ __forceinline__ float block_reduce_sum(float v) { __shared__ float shared[32]; int lane = threadIdx.x & 31; int warp = threadIdx.x >> 5; int nwarps = (blockDim.x + 31) >> 5; v = warp_reduce_sum(v); if (lane == 0) shared[warp] = v; __syncthreads(); v = (threadIdx.x < nwarps) ? shared[threadIdx.x] : 0.0f; if (warp == 0) v = warp_reduce_sum(v); // broadcast warp-0 lane-0 result to whole block __shared__ float bcast; if (threadIdx.x == 0) bcast = v; __syncthreads(); return bcast; } __device__ __forceinline__ float block_reduce_max(float v) { __shared__ float shared[32]; int lane = threadIdx.x & 31; int warp = threadIdx.x >> 5; int nwarps = (blockDim.x + 31) >> 5; v = warp_reduce_max(v); if (lane == 0) shared[warp] = v; __syncthreads(); v = (threadIdx.x < nwarps) ? shared[threadIdx.x] : -INFINITY; if (warp == 0) v = warp_reduce_max(v); __shared__ float bcast; if (threadIdx.x == 0) bcast = v; __syncthreads(); return bcast; } // ===================================================================== // Elementwise add / mul (same-shape) // ===================================================================== __global__ void add_k(const float* a, const float* b, float* out, int n) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) out[i] = a[i] + b[i]; } void launch_add_f32(const float* a, const float* b, float* out, int n, void* s) { int blk = 256, grid = (n + blk - 1) / blk; add_k<<>>(a, b, out, n); } __global__ void mul_k(const float* a, const float* b, float* out, int n) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) out[i] = a[i] * b[i]; } void launch_mul_f32(const float* a, const float* b, float* out, int n, void* s) { int blk = 256, grid = (n + blk - 1) / blk; mul_k<<>>(a, b, out, n); } // ===================================================================== // Broadcast bias add: out[r,c] = x[r,c] + bias[c] (x:[rows,cols]) // Backward for bias is a column-sum (sum over rows): dbias[c] = sum_r dout[r,c]. // ===================================================================== __global__ void add_bias_k(const float* x, const float* bias, float* out, int rows, int cols) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < rows * cols) out[i] = x[i] + bias[i % cols]; } void launch_add_bias_f32(const float* x, const float* bias, float* out, int rows, int cols, void* s) { int n = rows * cols, blk = 256, grid = (n + blk - 1) / blk; add_bias_k<<>>(x, bias, out, rows, cols); } // dbias[c] = sum_r dout[r,c]. One block per column, threads stride over rows. __global__ void sum_rows_k(const float* dout, float* dbias, int rows, int cols) { int col = blockIdx.x; float acc = 0.0f; for (int r = threadIdx.x; r < rows; r += blockDim.x) acc += dout[r * cols + col]; acc = block_reduce_sum(acc); if (threadIdx.x == 0) dbias[col] = acc; } void launch_sum_rows_f32(const float* dout, float* dbias, int rows, int cols, void* s) { int blk = 256; sum_rows_k<<>>(dout, dbias, rows, cols); } // ===================================================================== // RMSNorm: y[r,c] = x[r,c] * inv_rms[r] * gamma[c], inv_rms = rsqrt(mean(x²)+eps) // x:[rows,cols], gamma:[cols]. Forward also writes inv_rms[rows] for backward. // ===================================================================== __global__ void rms_norm_k(const float* x, const float* gamma, float* y, float* inv_rms, int rows, int cols, float eps) { int r = blockIdx.x; const float* xr = x + r * cols; float* yr = y + r * cols; float ss = 0.0f; for (int c = threadIdx.x; c < cols; c += blockDim.x) ss += xr[c] * xr[c]; ss = block_reduce_sum(ss); float ir = rsqrtf(ss / cols + eps); if (threadIdx.x == 0) inv_rms[r] = ir; for (int c = threadIdx.x; c < cols; c += blockDim.x) yr[c] = xr[c] * ir * gamma[c]; } void launch_rms_norm_f32(const float* x, const float* gamma, float* y, float* inv_rms, int rows, int cols, float eps, void* s) { int blk = cols < 1024 ? cols : 1024; if (blk < 32) blk = 32; rms_norm_k<<>>(x, gamma, y, inv_rms, rows, cols, eps); } // RMSNorm backward. // Let g[c] = dy[r,c]*gamma[c], ir = inv_rms[r], n = cols. // dx[r,c] = ir*g[c] - x[r,c]*ir³/n * sum_c(g[c]*x[r,c]) // dgamma[c] = sum_r dy[r,c] * x[r,c] * ir (accumulated across rows) __global__ void rms_norm_dx_k(const float* x, const float* gamma, const float* dy, const float* inv_rms, float* dx, int rows, int cols) { int r = blockIdx.x; const float* xr = x + r * cols; const float* dyr = dy + r * cols; float* dxr = dx + r * cols; float ir = inv_rms[r]; float dot = 0.0f; // sum_c g[c]*x[c] for (int c = threadIdx.x; c < cols; c += blockDim.x) dot += dyr[c] * gamma[c] * xr[c]; dot = block_reduce_sum(dot); float coeff = ir * ir * ir / (float)cols * dot; for (int c = threadIdx.x; c < cols; c += blockDim.x) dxr[c] = ir * dyr[c] * gamma[c] - xr[c] * coeff; } void launch_rms_norm_dx_f32(const float* x, const float* gamma, const float* dy, const float* inv_rms, float* dx, int rows, int cols, void* s) { int blk = cols < 1024 ? cols : 1024; if (blk < 32) blk = 32; rms_norm_dx_k<<>>(x, gamma, dy, inv_rms, dx, rows, cols); } // dgamma[c] = sum_r dy[r,c] * x[r,c] * inv_rms[r]. One block per column. __global__ void rms_norm_dgamma_k(const float* x, const float* dy, const float* inv_rms, float* dgamma, int rows, int cols) { int col = blockIdx.x; float acc = 0.0f; for (int r = threadIdx.x; r < rows; r += blockDim.x) acc += dy[r * cols + col] * x[r * cols + col] * inv_rms[r]; acc = block_reduce_sum(acc); if (threadIdx.x == 0) dgamma[col] = acc; } void launch_rms_norm_dgamma_f32(const float* x, const float* dy, const float* inv_rms, float* dgamma, int rows, int cols, void* s) { int blk = 256; rms_norm_dgamma_k<<>>(x, dy, inv_rms, dgamma, rows, cols); } // ===================================================================== // SiLU: y = x * sigmoid(x). Backward: dx = dy * (sig + x*sig*(1-sig)). // ===================================================================== __global__ void silu_k(const float* x, float* y, int n) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) { float xv = x[i]; y[i] = xv / (1.0f + expf(-xv)); } } void launch_silu_f32(const float* x, float* y, int n, void* s) { int blk = 256, grid = (n + blk - 1) / blk; silu_k<<>>(x, y, n); } __global__ void silu_dx_k(const float* x, const float* dy, float* dx, int n) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) { float xv = x[i]; float sig = 1.0f / (1.0f + expf(-xv)); dx[i] = dy[i] * (sig + xv * sig * (1.0f - sig)); } } void launch_silu_dx_f32(const float* x, const float* dy, float* dx, int n, void* s) { int blk = 256, grid = (n + blk - 1) / blk; silu_dx_k<<>>(x, dy, dx, n); } // ===================================================================== // RoPE (rotate_half layout). x:[tokens, heads, head_dim]; position = token index. // y[i] = x[i]*cos - x[i+h]*sin // y[i+h] = x[i+h]*cos + x[i]*sin (i in [0,half), h=half_dim) // freq[i] = theta^(-2i/head_dim); angle = pos*freq[i]. // Backward is the inverse (transpose) rotation: apply +angle's transpose ≡ -angle. // dx[i] = dy[i]*cos + dy[i+h]*sin // dx[i+h] = dy[i+h]*cos - dy[i]*sin // ===================================================================== // `period` is the sequence length: a flattened batch lays B sequences end to end // along the `tokens` axis, so each token's RoPE position is its index WITHIN its // own sequence, `tok % period`. With period == tokens (single sequence) this is // the original position = row. __global__ void rope_k(const float* x, float* y, int heads, int head_dim, float theta, int period) { int tok = blockIdx.x; int head = blockIdx.y; int half = head_dim / 2; int i = threadIdx.x; if (i >= half) return; int pos = tok % period; float freq = powf(theta, -(float)(2 * i) / (float)head_dim); float angle = (float)pos * freq; float c = cosf(angle), sn = sinf(angle); int base = (tok * heads + head) * head_dim; float x0 = x[base + i], x1 = x[base + i + half]; y[base + i] = x0 * c - x1 * sn; y[base + i + half] = x1 * c + x0 * sn; } void launch_rope_f32(const float* x, float* y, int tokens, int heads, int head_dim, float theta, int period, void* s) { dim3 grid(tokens, heads); int blk = head_dim / 2; rope_k<<>>(x, y, heads, head_dim, theta, period); } // RoPE at an absolute position offset (KV-cache decode-time, forward only). Same // rotate_half as rope_k, but row `tok`'s position is `pos0 + tok` (no modulo) — // a single new decode token sits at absolute position pos0. The training rope_k // (position = tok % period) is left untouched, so this adds no training-path risk. __global__ void rope_at_k(const float* x, float* y, int heads, int head_dim, float theta, int pos0) { int tok = blockIdx.x; int head = blockIdx.y; int half = head_dim / 2; int i = threadIdx.x; if (i >= half) return; int pos = pos0 + tok; float freq = powf(theta, -(float)(2 * i) / (float)head_dim); float angle = (float)pos * freq; float c = cosf(angle), sn = sinf(angle); int base = (tok * heads + head) * head_dim; float x0 = x[base + i], x1 = x[base + i + half]; y[base + i] = x0 * c - x1 * sn; y[base + i + half] = x1 * c + x0 * sn; } void launch_rope_at_f32(const float* x, float* y, int tokens, int heads, int head_dim, float theta, int pos0, void* s) { dim3 grid(tokens, heads); int blk = head_dim / 2; rope_at_k<<>>(x, y, heads, head_dim, theta, pos0); } // RoPE with a PER-ROW absolute position (batched KV-cache decode, M2b): row `tok`'s // position is `positions[tok]` (an i32 per token). For G-way batched decode all G // rows share one decode position; for ragged batches each row carries its own. // Forward only; the training rope_k is untouched. __global__ void rope_pos_k(const float* x, const int* positions, float* y, int heads, int head_dim, float theta) { int tok = blockIdx.x; int head = blockIdx.y; int half = head_dim / 2; int i = threadIdx.x; if (i >= half) return; int pos = positions[tok]; float freq = powf(theta, -(float)(2 * i) / (float)head_dim); float angle = (float)pos * freq; float c = cosf(angle), sn = sinf(angle); int base = (tok * heads + head) * head_dim; float x0 = x[base + i], x1 = x[base + i + half]; y[base + i] = x0 * c - x1 * sn; y[base + i + half] = x1 * c + x0 * sn; } void launch_rope_pos_f32(const float* x, const int* positions, float* y, int tokens, int heads, int head_dim, float theta, void* s) { dim3 grid(tokens, heads); int blk = head_dim / 2; rope_pos_k<<>>(x, positions, y, heads, head_dim, theta); } // Concatenate along the sequence (middle) dim: a:[bh,ta,hd], b:[bh,tb,hd] → // out:[bh,ta+tb,hd] with out[:, :ta]=a, out[:, ta:]=b. The device-side KV-cache // append (M2c): keeps K/V on the GPU and grows by one token per step, removing the // host round-trip the M2a/M2b host cache paid. One block per bh row. __global__ void cat_seq_k(const float* a, const float* b, float* out, int ta_hd, int tb_hd) { int i = blockIdx.x; // bh row int o_hd = ta_hd + tb_hd; const float* ar = a + (long)i * ta_hd; const float* br = b + (long)i * tb_hd; float* outr = out + (long)i * o_hd; for (int j = threadIdx.x; j < ta_hd; j += blockDim.x) outr[j] = ar[j]; for (int j = threadIdx.x; j < tb_hd; j += blockDim.x) outr[ta_hd + j] = br[j]; } void launch_cat_seq_f32(const float* a, const float* b, float* out, int bh, int ta_hd, int tb_hd, void* s) { cat_seq_k<<>>(a, b, out, ta_hd, tb_hd); } // Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO // (M4) policy-gradient backward, where each completion token's row of // (probs − onehot) is scaled by its own per-token coefficient. __global__ void scale_rows_k(const float* x, const float* s, float* y, int rows, int cols) { int r = blockIdx.x; float sr = s[r]; for (int c = threadIdx.x; c < cols; c += blockDim.x) y[r * cols + c] = x[r * cols + c] * sr; } void launch_scale_rows_f32(const float* x, const float* s, float* y, int rows, int cols, void* st) { int blk = cols < 1024 ? cols : 1024; if (blk < 32) blk = 32; scale_rows_k<<>>(x, s, y, rows, cols); } __global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim, float theta, int period) { int tok = blockIdx.x; int head = blockIdx.y; int half = head_dim / 2; int i = threadIdx.x; if (i >= half) return; int pos = tok % period; float freq = powf(theta, -(float)(2 * i) / (float)head_dim); float angle = (float)pos * freq; float c = cosf(angle), sn = sinf(angle); int base = (tok * heads + head) * head_dim; float d0 = dy[base + i], d1 = dy[base + i + half]; dx[base + i] = d0 * c + d1 * sn; dx[base + i + half] = d1 * c - d0 * sn; } void launch_rope_dx_f32(const float* dy, float* dx, int tokens, int heads, int head_dim, float theta, int period, void* s) { dim3 grid(tokens, heads); int blk = head_dim / 2; rope_dx_k<<>>(dy, dx, heads, head_dim, theta, period); } // ===================================================================== // Row-wise safe softmax. x:[rows,cols] → y. Backward (Jacobian): // dx[r,c] = y[r,c] * (dy[r,c] - sum_c'(dy[r,c']*y[r,c'])) // ===================================================================== __global__ void softmax_k(const float* x, float* y, int rows, int cols) { int r = blockIdx.x; const float* xr = x + r * cols; float* yr = y + r * cols; float m = -INFINITY; for (int c = threadIdx.x; c < cols; c += blockDim.x) m = fmaxf(m, xr[c]); m = block_reduce_max(m); float sum = 0.0f; for (int c = threadIdx.x; c < cols; c += blockDim.x) { float e = expf(xr[c] - m); yr[c] = e; sum += e; } sum = block_reduce_sum(sum); float inv = 1.0f / sum; for (int c = threadIdx.x; c < cols; c += blockDim.x) yr[c] *= inv; } void launch_softmax_f32(const float* x, float* y, int rows, int cols, void* s) { int blk = cols < 1024 ? cols : 1024; if (blk < 32) blk = 32; softmax_k<<>>(x, y, rows, cols); } __global__ void softmax_dx_k(const float* y, const float* dy, float* dx, int rows, int cols) { int r = blockIdx.x; const float* yr = y + r * cols; const float* dyr = dy + r * cols; float* dxr = dx + r * cols; float dot = 0.0f; // sum_c dy*y for (int c = threadIdx.x; c < cols; c += blockDim.x) dot += dyr[c] * yr[c]; dot = block_reduce_sum(dot); for (int c = threadIdx.x; c < cols; c += blockDim.x) dxr[c] = yr[c] * (dyr[c] - dot); } void launch_softmax_dx_f32(const float* y, const float* dy, float* dx, int rows, int cols, void* s) { int blk = cols < 1024 ? cols : 1024; if (blk < 32) blk = 32; softmax_dx_k<<>>(y, dy, dx, rows, cols); } // ===================================================================== // Cross-entropy over logits x:[rows,cols] with int target per row. // Forward writes per-row loss[r] = -log(softmax(x)[target]) and the softmax // probs[r,:] (cached for backward). Backward: dx[r,c] = (probs[r,c]-onehot)/rows // (mean reduction; the *rows scale folds the 1/rows of mean loss into dx). // ===================================================================== __global__ void cross_entropy_fwd_k(const float* x, const int* target, float* probs, float* loss, int rows, int cols) { int r = blockIdx.x; const float* xr = x + r * cols; float* pr = probs + r * cols; float m = -INFINITY; for (int c = threadIdx.x; c < cols; c += blockDim.x) m = fmaxf(m, xr[c]); m = block_reduce_max(m); float sum = 0.0f; for (int c = threadIdx.x; c < cols; c += blockDim.x) { float e = expf(xr[c] - m); pr[c] = e; sum += e; } sum = block_reduce_sum(sum); float inv = 1.0f / sum; for (int c = threadIdx.x; c < cols; c += blockDim.x) pr[c] *= inv; if (threadIdx.x == 0) { int t = target[r]; loss[r] = t < 0 ? 0.0f : -logf(pr[t]); } } void launch_cross_entropy_fwd_f32(const float* x, const int* target, float* probs, float* loss, int rows, int cols, void* s) { int blk = cols < 1024 ? cols : 1024; if (blk < 32) blk = 32; cross_entropy_fwd_k<<>>(x, target, probs, loss, rows, cols); } // dx[r,c] = scale * (probs[r,c] - [c==target]). scale = upstream/rows. __global__ void cross_entropy_dx_k(const float* probs, const int* target, float* dx, int rows, int cols, float scale) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i >= rows * cols) return; int r = i / cols, c = i % cols; int t = target[r]; if (t < 0) { dx[i] = 0.0f; } else { float g = probs[i] - (c == t ? 1.0f : 0.0f); dx[i] = g * scale; } } void launch_cross_entropy_dx_f32(const float* probs, const int* target, float* dx, int rows, int cols, float scale, void* s) { int n = rows * cols, blk = 256, grid = (n + blk - 1) / blk; cross_entropy_dx_k<<>>(probs, target, dx, rows, cols, scale); } } // extern "C"