post-train: M2 — decode primitives (rope_at + decode_attention)
Two forward-only Tensor primitives the KV-cache decode engine is built on, each gated by an isolated correctness test: - rope_at(theta, pos0): RoPE at an absolute position (pos = pos0 + row, no modulo) for a single decode token, vs the training rope_k (pos = row % period) left untouched. New forward-only CUDA kernel, no training-path risk. Gate: bit-identical to the full-sequence rope's corresponding row. - decode_attention(k, v, scale): single-query × cached-K/V SDPA, composed from the existing strided batched GEMM + plain (non-causal) softmax — no new kernel. Gate: equals the full causal attention's last query row (max |Δ| 6e-8). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -242,6 +242,33 @@ void launch_rope_f32(const float* x, float* y, int tokens, int heads,
|
||||
rope_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, period);
|
||||
}
|
||||
|
||||
// RoPE at an absolute position offset (KV-cache decode-time, forward only). Same
|
||||
// rotate_half as rope_k, but row `tok`'s position is `pos0 + tok` (no modulo) —
|
||||
// a single new decode token sits at absolute position pos0. The training rope_k
|
||||
// (position = tok % period) is left untouched, so this adds no training-path risk.
|
||||
__global__ void rope_at_k(const float* x, float* y, int heads, int head_dim,
|
||||
float theta, int pos0) {
|
||||
int tok = blockIdx.x;
|
||||
int head = blockIdx.y;
|
||||
int half = head_dim / 2;
|
||||
int i = threadIdx.x;
|
||||
if (i >= half) return;
|
||||
int pos = pos0 + tok;
|
||||
float freq = powf(theta, -(float)(2 * i) / (float)head_dim);
|
||||
float angle = (float)pos * freq;
|
||||
float c = cosf(angle), sn = sinf(angle);
|
||||
int base = (tok * heads + head) * head_dim;
|
||||
float x0 = x[base + i], x1 = x[base + i + half];
|
||||
y[base + i] = x0 * c - x1 * sn;
|
||||
y[base + i + half] = x1 * c + x0 * sn;
|
||||
}
|
||||
void launch_rope_at_f32(const float* x, float* y, int tokens, int heads,
|
||||
int head_dim, float theta, int pos0, void* s) {
|
||||
dim3 grid(tokens, heads);
|
||||
int blk = head_dim / 2;
|
||||
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
|
||||
}
|
||||
|
||||
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
|
||||
float theta, int period) {
|
||||
int tok = blockIdx.x;
|
||||
|
||||
Reference in New Issue
Block a user