Device-resident KV cache: keep K/V on the GPU as [bh,T,hd], grow by one token per step via a new cat_seq kernel (concat along seq) — removes the M2a/M2b per-layer host round-trip (to_cpu/from_slice/re-upload) AND the transpose_3d01. Both single-seq and batched decode refactored to it; cache is Option<Tensor> per layer (cleaner than the host Vec + rebuild). Gates all hold: cat_seq == host concat; decode_kv single-seq + decode_batch G-way both still TOKEN-IDENTICAL; GQA training path unaffected. Honest measurement (the point): removing the host round-trip buys ~10% on pure single-seq decode (133 → 147 tok/s @128) but does NOT move the GRPO step (~8.5 s/step unchanged) — because after M2b batching the rollout is no longer the step's bottleneck; the per-sample per_token_logp captures + the PG-update forwards/backwards (model.forward, full-seq) now dominate. Measure-first lesson (cf. T11/T17/M2a): the long pole shifted to the training-side forwards; the next decode lever (ragged batched prefill) targets those, not the cache. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
462 lines
19 KiB
Plaintext
462 lines
19 KiB
Plaintext
// Forward + backward CUDA kernels for the transformer ops the autograd engine
|
||
// (Phase T4) needs: elementwise add/mul, broadcast bias add + its row-sum
|
||
// backward, RMSNorm, SiLU, RoPE, row-wise softmax, and cross-entropy.
|
||
//
|
||
// All F32, row-major, contiguous. Forward kernels mirror xserv
|
||
// (docs/04-transformer-kernels.md, docs/05-attention.md); the backward kernels
|
||
// are new (xserv is inference-only). Reduction helpers are inlined here so this
|
||
// file is self-contained (no shared header), matching the existing csrc/ layout.
|
||
|
||
#include <math.h>
|
||
|
||
extern "C" {
|
||
|
||
// --- Warp / block reductions (sum + max), block handles one row ---
|
||
|
||
__device__ __forceinline__ float warp_reduce_sum(float v) {
|
||
#pragma unroll
|
||
for (int off = 16; off > 0; off >>= 1)
|
||
v += __shfl_down_sync(0xffffffff, v, off);
|
||
return v;
|
||
}
|
||
|
||
__device__ __forceinline__ float warp_reduce_max(float v) {
|
||
#pragma unroll
|
||
for (int off = 16; off > 0; off >>= 1)
|
||
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
|
||
return v;
|
||
}
|
||
|
||
__device__ __forceinline__ float block_reduce_sum(float v) {
|
||
__shared__ float shared[32];
|
||
int lane = threadIdx.x & 31;
|
||
int warp = threadIdx.x >> 5;
|
||
int nwarps = (blockDim.x + 31) >> 5;
|
||
v = warp_reduce_sum(v);
|
||
if (lane == 0) shared[warp] = v;
|
||
__syncthreads();
|
||
v = (threadIdx.x < nwarps) ? shared[threadIdx.x] : 0.0f;
|
||
if (warp == 0) v = warp_reduce_sum(v);
|
||
// broadcast warp-0 lane-0 result to whole block
|
||
__shared__ float bcast;
|
||
if (threadIdx.x == 0) bcast = v;
|
||
__syncthreads();
|
||
return bcast;
|
||
}
|
||
|
||
__device__ __forceinline__ float block_reduce_max(float v) {
|
||
__shared__ float shared[32];
|
||
int lane = threadIdx.x & 31;
|
||
int warp = threadIdx.x >> 5;
|
||
int nwarps = (blockDim.x + 31) >> 5;
|
||
v = warp_reduce_max(v);
|
||
if (lane == 0) shared[warp] = v;
|
||
__syncthreads();
|
||
v = (threadIdx.x < nwarps) ? shared[threadIdx.x] : -INFINITY;
|
||
if (warp == 0) v = warp_reduce_max(v);
|
||
__shared__ float bcast;
|
||
if (threadIdx.x == 0) bcast = v;
|
||
__syncthreads();
|
||
return bcast;
|
||
}
|
||
|
||
// =====================================================================
|
||
// Elementwise add / mul (same-shape)
|
||
// =====================================================================
|
||
|
||
__global__ void add_k(const float* a, const float* b, float* out, int n) {
|
||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (i < n) out[i] = a[i] + b[i];
|
||
}
|
||
void launch_add_f32(const float* a, const float* b, float* out, int n, void* s) {
|
||
int blk = 256, grid = (n + blk - 1) / blk;
|
||
add_k<<<grid, blk, 0, (cudaStream_t)s>>>(a, b, out, n);
|
||
}
|
||
|
||
__global__ void mul_k(const float* a, const float* b, float* out, int n) {
|
||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (i < n) out[i] = a[i] * b[i];
|
||
}
|
||
void launch_mul_f32(const float* a, const float* b, float* out, int n, void* s) {
|
||
int blk = 256, grid = (n + blk - 1) / blk;
|
||
mul_k<<<grid, blk, 0, (cudaStream_t)s>>>(a, b, out, n);
|
||
}
|
||
|
||
// =====================================================================
|
||
// Broadcast bias add: out[r,c] = x[r,c] + bias[c] (x:[rows,cols])
|
||
// Backward for bias is a column-sum (sum over rows): dbias[c] = sum_r dout[r,c].
|
||
// =====================================================================
|
||
|
||
__global__ void add_bias_k(const float* x, const float* bias, float* out,
|
||
int rows, int cols) {
|
||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (i < rows * cols) out[i] = x[i] + bias[i % cols];
|
||
}
|
||
void launch_add_bias_f32(const float* x, const float* bias, float* out,
|
||
int rows, int cols, void* s) {
|
||
int n = rows * cols, blk = 256, grid = (n + blk - 1) / blk;
|
||
add_bias_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, bias, out, rows, cols);
|
||
}
|
||
|
||
// dbias[c] = sum_r dout[r,c]. One block per column, threads stride over rows.
|
||
__global__ void sum_rows_k(const float* dout, float* dbias, int rows, int cols) {
|
||
int col = blockIdx.x;
|
||
float acc = 0.0f;
|
||
for (int r = threadIdx.x; r < rows; r += blockDim.x)
|
||
acc += dout[r * cols + col];
|
||
acc = block_reduce_sum(acc);
|
||
if (threadIdx.x == 0) dbias[col] = acc;
|
||
}
|
||
void launch_sum_rows_f32(const float* dout, float* dbias, int rows, int cols, void* s) {
|
||
int blk = 256;
|
||
sum_rows_k<<<cols, blk, 0, (cudaStream_t)s>>>(dout, dbias, rows, cols);
|
||
}
|
||
|
||
// =====================================================================
|
||
// RMSNorm: y[r,c] = x[r,c] * inv_rms[r] * gamma[c], inv_rms = rsqrt(mean(x²)+eps)
|
||
// x:[rows,cols], gamma:[cols]. Forward also writes inv_rms[rows] for backward.
|
||
// =====================================================================
|
||
|
||
__global__ void rms_norm_k(const float* x, const float* gamma, float* y,
|
||
float* inv_rms, int rows, int cols, float eps) {
|
||
int r = blockIdx.x;
|
||
const float* xr = x + r * cols;
|
||
float* yr = y + r * cols;
|
||
float ss = 0.0f;
|
||
for (int c = threadIdx.x; c < cols; c += blockDim.x) ss += xr[c] * xr[c];
|
||
ss = block_reduce_sum(ss);
|
||
float ir = rsqrtf(ss / cols + eps);
|
||
if (threadIdx.x == 0) inv_rms[r] = ir;
|
||
for (int c = threadIdx.x; c < cols; c += blockDim.x)
|
||
yr[c] = xr[c] * ir * gamma[c];
|
||
}
|
||
void launch_rms_norm_f32(const float* x, const float* gamma, float* y,
|
||
float* inv_rms, int rows, int cols, float eps, void* s) {
|
||
int blk = cols < 1024 ? cols : 1024;
|
||
if (blk < 32) blk = 32;
|
||
rms_norm_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, gamma, y, inv_rms, rows, cols, eps);
|
||
}
|
||
|
||
// RMSNorm backward.
|
||
// Let g[c] = dy[r,c]*gamma[c], ir = inv_rms[r], n = cols.
|
||
// dx[r,c] = ir*g[c] - x[r,c]*ir³/n * sum_c(g[c]*x[r,c])
|
||
// dgamma[c] = sum_r dy[r,c] * x[r,c] * ir (accumulated across rows)
|
||
__global__ void rms_norm_dx_k(const float* x, const float* gamma, const float* dy,
|
||
const float* inv_rms, float* dx, int rows, int cols) {
|
||
int r = blockIdx.x;
|
||
const float* xr = x + r * cols;
|
||
const float* dyr = dy + r * cols;
|
||
float* dxr = dx + r * cols;
|
||
float ir = inv_rms[r];
|
||
float dot = 0.0f; // sum_c g[c]*x[c]
|
||
for (int c = threadIdx.x; c < cols; c += blockDim.x)
|
||
dot += dyr[c] * gamma[c] * xr[c];
|
||
dot = block_reduce_sum(dot);
|
||
float coeff = ir * ir * ir / (float)cols * dot;
|
||
for (int c = threadIdx.x; c < cols; c += blockDim.x)
|
||
dxr[c] = ir * dyr[c] * gamma[c] - xr[c] * coeff;
|
||
}
|
||
void launch_rms_norm_dx_f32(const float* x, const float* gamma, const float* dy,
|
||
const float* inv_rms, float* dx, int rows, int cols, void* s) {
|
||
int blk = cols < 1024 ? cols : 1024;
|
||
if (blk < 32) blk = 32;
|
||
rms_norm_dx_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, gamma, dy, inv_rms, dx, rows, cols);
|
||
}
|
||
|
||
// dgamma[c] = sum_r dy[r,c] * x[r,c] * inv_rms[r]. One block per column.
|
||
__global__ void rms_norm_dgamma_k(const float* x, const float* dy, const float* inv_rms,
|
||
float* dgamma, int rows, int cols) {
|
||
int col = blockIdx.x;
|
||
float acc = 0.0f;
|
||
for (int r = threadIdx.x; r < rows; r += blockDim.x)
|
||
acc += dy[r * cols + col] * x[r * cols + col] * inv_rms[r];
|
||
acc = block_reduce_sum(acc);
|
||
if (threadIdx.x == 0) dgamma[col] = acc;
|
||
}
|
||
void launch_rms_norm_dgamma_f32(const float* x, const float* dy, const float* inv_rms,
|
||
float* dgamma, int rows, int cols, void* s) {
|
||
int blk = 256;
|
||
rms_norm_dgamma_k<<<cols, blk, 0, (cudaStream_t)s>>>(x, dy, inv_rms, dgamma, rows, cols);
|
||
}
|
||
|
||
// =====================================================================
|
||
// SiLU: y = x * sigmoid(x). Backward: dx = dy * (sig + x*sig*(1-sig)).
|
||
// =====================================================================
|
||
|
||
__global__ void silu_k(const float* x, float* y, int n) {
|
||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (i < n) { float xv = x[i]; y[i] = xv / (1.0f + expf(-xv)); }
|
||
}
|
||
void launch_silu_f32(const float* x, float* y, int n, void* s) {
|
||
int blk = 256, grid = (n + blk - 1) / blk;
|
||
silu_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, n);
|
||
}
|
||
|
||
__global__ void silu_dx_k(const float* x, const float* dy, float* dx, int n) {
|
||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (i < n) {
|
||
float xv = x[i];
|
||
float sig = 1.0f / (1.0f + expf(-xv));
|
||
dx[i] = dy[i] * (sig + xv * sig * (1.0f - sig));
|
||
}
|
||
}
|
||
void launch_silu_dx_f32(const float* x, const float* dy, float* dx, int n, void* s) {
|
||
int blk = 256, grid = (n + blk - 1) / blk;
|
||
silu_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, dy, dx, n);
|
||
}
|
||
|
||
// =====================================================================
|
||
// RoPE (rotate_half layout). x:[tokens, heads, head_dim]; position = token index.
|
||
// y[i] = x[i]*cos - x[i+h]*sin
|
||
// y[i+h] = x[i+h]*cos + x[i]*sin (i in [0,half), h=half_dim)
|
||
// freq[i] = theta^(-2i/head_dim); angle = pos*freq[i].
|
||
// Backward is the inverse (transpose) rotation: apply +angle's transpose ≡ -angle.
|
||
// dx[i] = dy[i]*cos + dy[i+h]*sin
|
||
// dx[i+h] = dy[i+h]*cos - dy[i]*sin
|
||
// =====================================================================
|
||
|
||
// `period` is the sequence length: a flattened batch lays B sequences end to end
|
||
// along the `tokens` axis, so each token's RoPE position is its index WITHIN its
|
||
// own sequence, `tok % period`. With period == tokens (single sequence) this is
|
||
// the original position = row.
|
||
__global__ void rope_k(const float* x, float* y, int heads, int head_dim,
|
||
float theta, int period) {
|
||
int tok = blockIdx.x;
|
||
int head = blockIdx.y;
|
||
int half = head_dim / 2;
|
||
int i = threadIdx.x;
|
||
if (i >= half) return;
|
||
int pos = tok % period;
|
||
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
|
||
float angle = (float)pos * freq;
|
||
float c = cosf(angle), sn = sinf(angle);
|
||
int base = (tok * heads + head) * head_dim;
|
||
float x0 = x[base + i], x1 = x[base + i + half];
|
||
y[base + i] = x0 * c - x1 * sn;
|
||
y[base + i + half] = x1 * c + x0 * sn;
|
||
}
|
||
void launch_rope_f32(const float* x, float* y, int tokens, int heads,
|
||
int head_dim, float theta, int period, void* s) {
|
||
dim3 grid(tokens, heads);
|
||
int blk = head_dim / 2;
|
||
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, period);
|
||
}
|
||
|
||
// RoPE at an absolute position offset (KV-cache decode-time, forward only). Same
|
||
// rotate_half as rope_k, but row `tok`'s position is `pos0 + tok` (no modulo) —
|
||
// a single new decode token sits at absolute position pos0. The training rope_k
|
||
// (position = tok % period) is left untouched, so this adds no training-path risk.
|
||
__global__ void rope_at_k(const float* x, float* y, int heads, int head_dim,
|
||
float theta, int pos0) {
|
||
int tok = blockIdx.x;
|
||
int head = blockIdx.y;
|
||
int half = head_dim / 2;
|
||
int i = threadIdx.x;
|
||
if (i >= half) return;
|
||
int pos = pos0 + tok;
|
||
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
|
||
float angle = (float)pos * freq;
|
||
float c = cosf(angle), sn = sinf(angle);
|
||
int base = (tok * heads + head) * head_dim;
|
||
float x0 = x[base + i], x1 = x[base + i + half];
|
||
y[base + i] = x0 * c - x1 * sn;
|
||
y[base + i + half] = x1 * c + x0 * sn;
|
||
}
|
||
void launch_rope_at_f32(const float* x, float* y, int tokens, int heads,
|
||
int head_dim, float theta, int pos0, void* s) {
|
||
dim3 grid(tokens, heads);
|
||
int blk = head_dim / 2;
|
||
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
|
||
}
|
||
|
||
// RoPE with a PER-ROW absolute position (batched KV-cache decode, M2b): row `tok`'s
|
||
// position is `positions[tok]` (an i32 per token). For G-way batched decode all G
|
||
// rows share one decode position; for ragged batches each row carries its own.
|
||
// Forward only; the training rope_k is untouched.
|
||
__global__ void rope_pos_k(const float* x, const int* positions, float* y,
|
||
int heads, int head_dim, float theta) {
|
||
int tok = blockIdx.x;
|
||
int head = blockIdx.y;
|
||
int half = head_dim / 2;
|
||
int i = threadIdx.x;
|
||
if (i >= half) return;
|
||
int pos = positions[tok];
|
||
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
|
||
float angle = (float)pos * freq;
|
||
float c = cosf(angle), sn = sinf(angle);
|
||
int base = (tok * heads + head) * head_dim;
|
||
float x0 = x[base + i], x1 = x[base + i + half];
|
||
y[base + i] = x0 * c - x1 * sn;
|
||
y[base + i + half] = x1 * c + x0 * sn;
|
||
}
|
||
void launch_rope_pos_f32(const float* x, const int* positions, float* y,
|
||
int tokens, int heads, int head_dim, float theta, void* s) {
|
||
dim3 grid(tokens, heads);
|
||
int blk = head_dim / 2;
|
||
rope_pos_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, positions, y, heads, head_dim, theta);
|
||
}
|
||
|
||
// Concatenate along the sequence (middle) dim: a:[bh,ta,hd], b:[bh,tb,hd] →
|
||
// out:[bh,ta+tb,hd] with out[:, :ta]=a, out[:, ta:]=b. The device-side KV-cache
|
||
// append (M2c): keeps K/V on the GPU and grows by one token per step, removing the
|
||
// host round-trip the M2a/M2b host cache paid. One block per bh row.
|
||
__global__ void cat_seq_k(const float* a, const float* b, float* out,
|
||
int ta_hd, int tb_hd) {
|
||
int i = blockIdx.x; // bh row
|
||
int o_hd = ta_hd + tb_hd;
|
||
const float* ar = a + (long)i * ta_hd;
|
||
const float* br = b + (long)i * tb_hd;
|
||
float* outr = out + (long)i * o_hd;
|
||
for (int j = threadIdx.x; j < ta_hd; j += blockDim.x) outr[j] = ar[j];
|
||
for (int j = threadIdx.x; j < tb_hd; j += blockDim.x) outr[ta_hd + j] = br[j];
|
||
}
|
||
void launch_cat_seq_f32(const float* a, const float* b, float* out,
|
||
int bh, int ta_hd, int tb_hd, void* s) {
|
||
cat_seq_k<<<bh, 256, 0, (cudaStream_t)s>>>(a, b, out, ta_hd, tb_hd);
|
||
}
|
||
|
||
// Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO
|
||
// (M4) policy-gradient backward, where each completion token's row of
|
||
// (probs − onehot) is scaled by its own per-token coefficient.
|
||
__global__ void scale_rows_k(const float* x, const float* s, float* y,
|
||
int rows, int cols) {
|
||
int r = blockIdx.x;
|
||
float sr = s[r];
|
||
for (int c = threadIdx.x; c < cols; c += blockDim.x)
|
||
y[r * cols + c] = x[r * cols + c] * sr;
|
||
}
|
||
void launch_scale_rows_f32(const float* x, const float* s, float* y,
|
||
int rows, int cols, void* st) {
|
||
int blk = cols < 1024 ? cols : 1024;
|
||
if (blk < 32) blk = 32;
|
||
scale_rows_k<<<rows, blk, 0, (cudaStream_t)st>>>(x, s, y, rows, cols);
|
||
}
|
||
|
||
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
|
||
float theta, int period) {
|
||
int tok = blockIdx.x;
|
||
int head = blockIdx.y;
|
||
int half = head_dim / 2;
|
||
int i = threadIdx.x;
|
||
if (i >= half) return;
|
||
int pos = tok % period;
|
||
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
|
||
float angle = (float)pos * freq;
|
||
float c = cosf(angle), sn = sinf(angle);
|
||
int base = (tok * heads + head) * head_dim;
|
||
float d0 = dy[base + i], d1 = dy[base + i + half];
|
||
dx[base + i] = d0 * c + d1 * sn;
|
||
dx[base + i + half] = d1 * c - d0 * sn;
|
||
}
|
||
void launch_rope_dx_f32(const float* dy, float* dx, int tokens, int heads,
|
||
int head_dim, float theta, int period, void* s) {
|
||
dim3 grid(tokens, heads);
|
||
int blk = head_dim / 2;
|
||
rope_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(dy, dx, heads, head_dim, theta, period);
|
||
}
|
||
|
||
// =====================================================================
|
||
// Row-wise safe softmax. x:[rows,cols] → y. Backward (Jacobian):
|
||
// dx[r,c] = y[r,c] * (dy[r,c] - sum_c'(dy[r,c']*y[r,c']))
|
||
// =====================================================================
|
||
|
||
__global__ void softmax_k(const float* x, float* y, int rows, int cols) {
|
||
int r = blockIdx.x;
|
||
const float* xr = x + r * cols;
|
||
float* yr = y + r * cols;
|
||
float m = -INFINITY;
|
||
for (int c = threadIdx.x; c < cols; c += blockDim.x) m = fmaxf(m, xr[c]);
|
||
m = block_reduce_max(m);
|
||
float sum = 0.0f;
|
||
for (int c = threadIdx.x; c < cols; c += blockDim.x) {
|
||
float e = expf(xr[c] - m);
|
||
yr[c] = e;
|
||
sum += e;
|
||
}
|
||
sum = block_reduce_sum(sum);
|
||
float inv = 1.0f / sum;
|
||
for (int c = threadIdx.x; c < cols; c += blockDim.x) yr[c] *= inv;
|
||
}
|
||
void launch_softmax_f32(const float* x, float* y, int rows, int cols, void* s) {
|
||
int blk = cols < 1024 ? cols : 1024;
|
||
if (blk < 32) blk = 32;
|
||
softmax_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, y, rows, cols);
|
||
}
|
||
|
||
__global__ void softmax_dx_k(const float* y, const float* dy, float* dx,
|
||
int rows, int cols) {
|
||
int r = blockIdx.x;
|
||
const float* yr = y + r * cols;
|
||
const float* dyr = dy + r * cols;
|
||
float* dxr = dx + r * cols;
|
||
float dot = 0.0f; // sum_c dy*y
|
||
for (int c = threadIdx.x; c < cols; c += blockDim.x) dot += dyr[c] * yr[c];
|
||
dot = block_reduce_sum(dot);
|
||
for (int c = threadIdx.x; c < cols; c += blockDim.x)
|
||
dxr[c] = yr[c] * (dyr[c] - dot);
|
||
}
|
||
void launch_softmax_dx_f32(const float* y, const float* dy, float* dx,
|
||
int rows, int cols, void* s) {
|
||
int blk = cols < 1024 ? cols : 1024;
|
||
if (blk < 32) blk = 32;
|
||
softmax_dx_k<<<rows, blk, 0, (cudaStream_t)s>>>(y, dy, dx, rows, cols);
|
||
}
|
||
|
||
// =====================================================================
|
||
// Cross-entropy over logits x:[rows,cols] with int target per row.
|
||
// Forward writes per-row loss[r] = -log(softmax(x)[target]) and the softmax
|
||
// probs[r,:] (cached for backward). Backward: dx[r,c] = (probs[r,c]-onehot)/rows
|
||
// (mean reduction; the *rows scale folds the 1/rows of mean loss into dx).
|
||
// =====================================================================
|
||
|
||
__global__ void cross_entropy_fwd_k(const float* x, const int* target,
|
||
float* probs, float* loss, int rows, int cols) {
|
||
int r = blockIdx.x;
|
||
const float* xr = x + r * cols;
|
||
float* pr = probs + r * cols;
|
||
float m = -INFINITY;
|
||
for (int c = threadIdx.x; c < cols; c += blockDim.x) m = fmaxf(m, xr[c]);
|
||
m = block_reduce_max(m);
|
||
float sum = 0.0f;
|
||
for (int c = threadIdx.x; c < cols; c += blockDim.x) {
|
||
float e = expf(xr[c] - m);
|
||
pr[c] = e;
|
||
sum += e;
|
||
}
|
||
sum = block_reduce_sum(sum);
|
||
float inv = 1.0f / sum;
|
||
for (int c = threadIdx.x; c < cols; c += blockDim.x) pr[c] *= inv;
|
||
if (threadIdx.x == 0) {
|
||
int t = target[r];
|
||
loss[r] = t < 0 ? 0.0f : -logf(pr[t]);
|
||
}
|
||
}
|
||
void launch_cross_entropy_fwd_f32(const float* x, const int* target,
|
||
float* probs, float* loss, int rows, int cols, void* s) {
|
||
int blk = cols < 1024 ? cols : 1024;
|
||
if (blk < 32) blk = 32;
|
||
cross_entropy_fwd_k<<<rows, blk, 0, (cudaStream_t)s>>>(x, target, probs, loss, rows, cols);
|
||
}
|
||
|
||
// dx[r,c] = scale * (probs[r,c] - [c==target]). scale = upstream/rows.
|
||
__global__ void cross_entropy_dx_k(const float* probs, const int* target,
|
||
float* dx, int rows, int cols, float scale) {
|
||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (i >= rows * cols) return;
|
||
int r = i / cols, c = i % cols;
|
||
int t = target[r];
|
||
if (t < 0) {
|
||
dx[i] = 0.0f;
|
||
} else {
|
||
float g = probs[i] - (c == t ? 1.0f : 0.0f);
|
||
dx[i] = g * scale;
|
||
}
|
||
}
|
||
void launch_cross_entropy_dx_f32(const float* probs, const int* target,
|
||
float* dx, int rows, int cols, float scale, void* s) {
|
||
int n = rows * cols, blk = 256, grid = (n + blk - 1) / blk;
|
||
cross_entropy_dx_k<<<grid, blk, 0, (cudaStream_t)s>>>(probs, target, dx, rows, cols, scale);
|
||
}
|
||
|
||
} // extern "C"
|