post-train: M2 — decode primitives (rope_at + decode_attention)

Two forward-only Tensor primitives the KV-cache decode engine is built on,
each gated by an isolated correctness test:

- rope_at(theta, pos0): RoPE at an absolute position (pos = pos0 + row, no
  modulo) for a single decode token, vs the training rope_k (pos = row %
  period) left untouched. New forward-only CUDA kernel, no training-path risk.
  Gate: bit-identical to the full-sequence rope's corresponding row.
- decode_attention(k, v, scale): single-query × cached-K/V SDPA, composed from
  the existing strided batched GEMM + plain (non-causal) softmax — no new
  kernel. Gate: equals the full causal attention's last query row (max |Δ| 6e-8).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 12:00:03 +08:00
parent 1574e21d89
commit c88e2ab88c
4 changed files with 245 additions and 0 deletions

View File

@@ -242,6 +242,33 @@ void launch_rope_f32(const float* x, float* y, int tokens, int heads,
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, period);
}
// RoPE at an absolute position offset (KV-cache decode-time, forward only). Same
// rotate_half as rope_k, but row `tok`'s position is `pos0 + tok` (no modulo) —
// a single new decode token sits at absolute position pos0. The training rope_k
// (position = tok % period) is left untouched, so this adds no training-path risk.
__global__ void rope_at_k(const float* x, float* y, int heads, int head_dim,
float theta, int pos0) {
int tok = blockIdx.x;
int head = blockIdx.y;
int half = head_dim / 2;
int i = threadIdx.x;
if (i >= half) return;
int pos = pos0 + tok;
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
float angle = (float)pos * freq;
float c = cosf(angle), sn = sinf(angle);
int base = (tok * heads + head) * head_dim;
float x0 = x[base + i], x1 = x[base + i + half];
y[base + i] = x0 * c - x1 * sn;
y[base + i + half] = x1 * c + x0 * sn;
}
void launch_rope_at_f32(const float* x, float* y, int tokens, int heads,
int head_dim, float theta, int pos0, void* s) {
dim3 grid(tokens, heads);
int blk = head_dim / 2;
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
}
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
float theta, int period) {
int tok = blockIdx.x;