Eliminate the per-step GPU↔host roundtrip of every parameter/gradient. - optim.cu: adamw_step (m/v on device, in-place param update), sumsq_accum (block-reduced global grad sum-of-squares), scale_inplace. - GpuAdamW: device m/v state per param; step launches the kernel reading each param's .grad() and rewriting the param buffer in place — no host roundtrip. Host AdamW kept as the torch-parity reference. - clip_grad_norm_gpu: device sum-of-squares reduction (only the scalar norm comes back), in-place rescale of grads by pre_scale·clip_factor. - train_loop: use GpuAdamW + clip_grad_norm_gpu. - test: GPU AdamW vs host reference parity (max abs err < 1e-6). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
87 lines
3.2 KiB
Plaintext
87 lines
3.2 KiB
Plaintext
// GPU-side optimizer kernels (Phase T7): AdamW parameter update and the
|
||
// global grad-norm reduction + rescale. These eliminate the per-step GPU↔host
|
||
// roundtrip of every parameter/gradient that the T6 host AdamW + host clip did.
|
||
//
|
||
// All F32, row-major, contiguous. The math mirrors xtrain-optim::AdamW::step_host
|
||
// (the reference); bias correction is passed in as bc1/bc2 = 1 - beta^t.
|
||
|
||
#include <math.h>
|
||
|
||
extern "C" {
|
||
|
||
// One AdamW step over a single parameter tensor of `n` elements, in place.
|
||
// m ← b1·m + (1-b1)·g
|
||
// v ← b2·v + (1-b2)·g²
|
||
// p ← p − lr·( (m/bc1) / (sqrt(v/bc2) + eps) + wd·p )
|
||
// `m`/`v` are this parameter's moment buffers (persisted on device across steps).
|
||
__global__ void adamw_step_f32(
|
||
float* p, const float* g, float* m, float* v,
|
||
float lr, float b1, float b2, float eps, float wd,
|
||
float bc1, float bc2, int n
|
||
) {
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx >= n) return;
|
||
float gi = g[idx];
|
||
float mi = b1 * m[idx] + (1.0f - b1) * gi;
|
||
float vi = b2 * v[idx] + (1.0f - b2) * gi * gi;
|
||
m[idx] = mi;
|
||
v[idx] = vi;
|
||
float mhat = mi / bc1;
|
||
float vhat = vi / bc2;
|
||
p[idx] -= lr * (mhat / (sqrtf(vhat) + eps) + wd * p[idx]);
|
||
}
|
||
|
||
void launch_adamw_step_f32(
|
||
float* p, const float* g, float* m, float* v,
|
||
float lr, float b1, float b2, float eps, float wd,
|
||
float bc1, float bc2, int n, void* stream
|
||
) {
|
||
int block = 256;
|
||
int grid = (n + block - 1) / block;
|
||
adamw_step_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||
p, g, m, v, lr, b1, b2, eps, wd, bc1, bc2, n);
|
||
}
|
||
|
||
// Accumulate sum-of-squares of one gradient tensor into *acc (a single f32 on
|
||
// device, pre-zeroed by the caller). Block-reduces then one atomicAdd per block.
|
||
__global__ void sumsq_accum_f32(const float* g, float* acc, int n) {
|
||
__shared__ float shared[32];
|
||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||
float v = (tid < n) ? g[tid] * g[tid] : 0.0f;
|
||
// block reduce
|
||
int lane = threadIdx.x & 31;
|
||
int warp = threadIdx.x >> 5;
|
||
int nwarps = (blockDim.x + 31) >> 5;
|
||
#pragma unroll
|
||
for (int off = 16; off > 0; off >>= 1) v += __shfl_down_sync(0xffffffff, v, off);
|
||
if (lane == 0) shared[warp] = v;
|
||
__syncthreads();
|
||
v = (threadIdx.x < nwarps) ? shared[threadIdx.x] : 0.0f;
|
||
if (warp == 0) {
|
||
#pragma unroll
|
||
for (int off = 16; off > 0; off >>= 1) v += __shfl_down_sync(0xffffffff, v, off);
|
||
if (lane == 0) atomicAdd(acc, v);
|
||
}
|
||
}
|
||
|
||
void launch_sumsq_accum_f32(const float* g, float* acc, int n, void* stream) {
|
||
int block = 256;
|
||
int grid = (n + block - 1) / block;
|
||
sumsq_accum_f32<<<grid, block, 0, (cudaStream_t)stream>>>(g, acc, n);
|
||
}
|
||
|
||
// Scale one tensor in place by a scalar (used to apply pre_scale·clip_factor to
|
||
// each gradient). Same as scale_f32 but in place.
|
||
__global__ void scale_inplace_f32(float* x, float factor, int n) {
|
||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx < n) x[idx] *= factor;
|
||
}
|
||
|
||
void launch_scale_inplace_f32(float* x, float factor, int n, void* stream) {
|
||
int block = 256;
|
||
int grid = (n + block - 1) / block;
|
||
scale_inplace_f32<<<grid, block, 0, (cudaStream_t)stream>>>(x, factor, n);
|
||
}
|
||
|
||
}
|