Files
xtrain/csrc/ops/optim.cu
Gahow Wang b0e397ca81 perf: GPU AdamW + grad-norm
Eliminate the per-step GPU↔host roundtrip of every parameter/gradient.

- optim.cu: adamw_step (m/v on device, in-place param update), sumsq_accum
  (block-reduced global grad sum-of-squares), scale_inplace.
- GpuAdamW: device m/v state per param; step launches the kernel reading
  each param's .grad() and rewriting the param buffer in place — no host
  roundtrip. Host AdamW kept as the torch-parity reference.
- clip_grad_norm_gpu: device sum-of-squares reduction (only the scalar norm
  comes back), in-place rescale of grads by pre_scale·clip_factor.
- train_loop: use GpuAdamW + clip_grad_norm_gpu.
- test: GPU AdamW vs host reference parity (max abs err < 1e-6).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 16:53:09 +08:00

87 lines
3.2 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// GPU-side optimizer kernels (Phase T7): AdamW parameter update and the
// global grad-norm reduction + rescale. These eliminate the per-step GPU↔host
// roundtrip of every parameter/gradient that the T6 host AdamW + host clip did.
//
// All F32, row-major, contiguous. The math mirrors xtrain-optim::AdamW::step_host
// (the reference); bias correction is passed in as bc1/bc2 = 1 - beta^t.
#include <math.h>
extern "C" {
// One AdamW step over a single parameter tensor of `n` elements, in place.
// m ← b1·m + (1-b1)·g
// v ← b2·v + (1-b2)·g²
// p ← p lr·( (m/bc1) / (sqrt(v/bc2) + eps) + wd·p )
// `m`/`v` are this parameter's moment buffers (persisted on device across steps).
__global__ void adamw_step_f32(
float* p, const float* g, float* m, float* v,
float lr, float b1, float b2, float eps, float wd,
float bc1, float bc2, int n
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= n) return;
float gi = g[idx];
float mi = b1 * m[idx] + (1.0f - b1) * gi;
float vi = b2 * v[idx] + (1.0f - b2) * gi * gi;
m[idx] = mi;
v[idx] = vi;
float mhat = mi / bc1;
float vhat = vi / bc2;
p[idx] -= lr * (mhat / (sqrtf(vhat) + eps) + wd * p[idx]);
}
void launch_adamw_step_f32(
float* p, const float* g, float* m, float* v,
float lr, float b1, float b2, float eps, float wd,
float bc1, float bc2, int n, void* stream
) {
int block = 256;
int grid = (n + block - 1) / block;
adamw_step_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
p, g, m, v, lr, b1, b2, eps, wd, bc1, bc2, n);
}
// Accumulate sum-of-squares of one gradient tensor into *acc (a single f32 on
// device, pre-zeroed by the caller). Block-reduces then one atomicAdd per block.
__global__ void sumsq_accum_f32(const float* g, float* acc, int n) {
__shared__ float shared[32];
int tid = blockIdx.x * blockDim.x + threadIdx.x;
float v = (tid < n) ? g[tid] * g[tid] : 0.0f;
// block reduce
int lane = threadIdx.x & 31;
int warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
#pragma unroll
for (int off = 16; off > 0; off >>= 1) v += __shfl_down_sync(0xffffffff, v, off);
if (lane == 0) shared[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? shared[threadIdx.x] : 0.0f;
if (warp == 0) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1) v += __shfl_down_sync(0xffffffff, v, off);
if (lane == 0) atomicAdd(acc, v);
}
}
void launch_sumsq_accum_f32(const float* g, float* acc, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;
sumsq_accum_f32<<<grid, block, 0, (cudaStream_t)stream>>>(g, acc, n);
}
// Scale one tensor in place by a scalar (used to apply pre_scale·clip_factor to
// each gradient). Same as scale_f32 but in place.
__global__ void scale_inplace_f32(float* x, float factor, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) x[idx] *= factor;
}
void launch_scale_inplace_f32(float* x, float factor, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;
scale_inplace_f32<<<grid, block, 0, (cudaStream_t)stream>>>(x, factor, n);
}
}