post-train: M4 — clipped_pg_loss + scale_rows (GRPO policy-gradient op)
The GRPO (M4) token-level loss op + the one primitive it needs: - scale_rows(x[r,c], s[r]): per-row scale (new ~5-line CUDA kernel). The clipped-PG backward scales each completion token's row of (probs − onehot) by its own per-token coefficient, which cross_entropy_backward's single scalar scale can't express. - clipped_pg_loss(logits, target, logp_old, logp_ref, A, eps, beta): per-token ρ_t = exp(logπθ_t − logp_old_t), L = −mean min(ρA, clip(ρ,1±ε)A) + β·mean KL (k3 estimator), masked to completion tokens. Backward reuses the CE machinery (probs − onehot) + scale_rows. Gates: grad-check the active PG path + the A=0 (KL-only) path; degenerate value checks ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -269,6 +269,23 @@ void launch_rope_at_f32(const float* x, float* y, int tokens, int heads,
|
||||
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
|
||||
}
|
||||
|
||||
// Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO
|
||||
// (M4) policy-gradient backward, where each completion token's row of
|
||||
// (probs − onehot) is scaled by its own per-token coefficient.
|
||||
__global__ void scale_rows_k(const float* x, const float* s, float* y,
|
||||
int rows, int cols) {
|
||||
int r = blockIdx.x;
|
||||
float sr = s[r];
|
||||
for (int c = threadIdx.x; c < cols; c += blockDim.x)
|
||||
y[r * cols + c] = x[r * cols + c] * sr;
|
||||
}
|
||||
void launch_scale_rows_f32(const float* x, const float* s, float* y,
|
||||
int rows, int cols, void* st) {
|
||||
int blk = cols < 1024 ? cols : 1024;
|
||||
if (blk < 32) blk = 32;
|
||||
scale_rows_k<<<rows, blk, 0, (cudaStream_t)st>>>(x, s, y, rows, cols);
|
||||
}
|
||||
|
||||
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
|
||||
float theta, int period) {
|
||||
int tok = blockIdx.x;
|
||||
|
||||
Reference in New Issue
Block a user