perf: GPU AdamW + grad-norm

Eliminate the per-step GPU↔host roundtrip of every parameter/gradient.

- optim.cu: adamw_step (m/v on device, in-place param update), sumsq_accum
  (block-reduced global grad sum-of-squares), scale_inplace.
- GpuAdamW: device m/v state per param; step launches the kernel reading
  each param's .grad() and rewriting the param buffer in place — no host
  roundtrip. Host AdamW kept as the torch-parity reference.
- clip_grad_norm_gpu: device sum-of-squares reduction (only the scalar norm
  comes back), in-place rescale of grads by pre_scale·clip_factor.
- train_loop: use GpuAdamW + clip_grad_norm_gpu.
- test: GPU AdamW vs host reference parity (max abs err < 1e-6).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 16:53:09 +08:00
parent 0e5c7d22e2
commit b0e397ca81
7 changed files with 342 additions and 5 deletions

View File

@@ -34,6 +34,7 @@ fn main() {
.file("../../csrc/ops/gemm.cu")
.file("../../csrc/ops/nn.cu")
.file("../../csrc/ops/model.cu")
.file("../../csrc/ops/optim.cu")
.compile("xtrain_cuda_kernels");
}

View File

@@ -212,6 +212,34 @@ unsafe extern "C" {
);
}
// GPU-side optimizer kernels (csrc/ops/optim.cu): AdamW step (m/v on device) and
// the global grad-norm reduction + in-place rescale (Phase T7).
#[cfg(not(no_cuda))]
unsafe extern "C" {
// One in-place AdamW step over a parameter tensor of `n` elements. `bc1`/`bc2`
// are the bias-correction denominators 1-beta^t.
#[allow(clippy::too_many_arguments)]
pub fn launch_adamw_step_f32(
p: *mut f32,
g: *const f32,
m: *mut f32,
v: *mut f32,
lr: f32,
b1: f32,
b2: f32,
eps: f32,
wd: f32,
bc1: f32,
bc2: f32,
n: i32,
s: CudaStream,
);
// acc += sum_i g[i]^2 (acc is one f32 on device, pre-zeroed). atomicAdd.
pub fn launch_sumsq_accum_f32(g: *const f32, acc: *mut f32, n: i32, s: CudaStream);
// In-place scalar scale: x[i] *= factor.
pub fn launch_scale_inplace_f32(x: *mut f32, factor: f32, n: i32, s: CudaStream);
}
// cuBLAS — the production GEMM backend (Phase T7) and the correctness oracle the
// T3 GEMM tests still compare against. Declared (and linked, see build.rs) only
// when CUDA is compiled in.

View File

@@ -113,7 +113,7 @@ impl AdamW {
mod gpu {
use super::AdamW;
use xtrain_autodiff::tape::Var;
use xtrain_tensor::{Device, Tensor};
use xtrain_tensor::{DType, Device, Tensor};
impl AdamW {
/// Apply one AdamW step to every parameter `Var`, using `lr` for this step
@@ -125,6 +125,10 @@ mod gpu {
///
/// Does NOT zero grads — the caller does that (matching the GD-step
/// template in the T5 overfit test).
///
/// This is the host-roundtrip reference path; training uses
/// [`GpuAdamW`] (kernel, m/v on device). Both are checked against the
/// torch parity in tests.
pub fn step(&mut self, lr: f32, params: &[Var]) {
let device = params[0].value().device();
let shapes: Vec<Vec<usize>> =
@@ -151,4 +155,93 @@ mod gpu {
}
}
}
/// GPU AdamW (Phase T7): the optimizer state (m/v moments) lives on the device
/// as one tensor pair per parameter, and the update runs as a CUDA kernel that
/// reads each param's `.grad()` and rewrites the param buffer in place — no
/// per-step GPU↔host roundtrip of params/grads. Same math as
/// [`AdamW::step_host`] (the parity reference).
pub struct GpuAdamW {
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
t: u64,
/// Per-parameter (m, v) device buffers, sized lazily on first step.
state: Vec<(Tensor, Tensor)>,
}
impl GpuAdamW {
/// PyTorch-default betas/eps; you set lr (per-step) + weight decay.
pub fn new(weight_decay: f32) -> Self {
Self {
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay,
t: 0,
state: Vec::new(),
}
}
pub fn step_count(&self) -> u64 {
self.t
}
/// One in-place AdamW step over every parameter `Var` at learning rate
/// `lr`. Updates the param value buffer and the device m/v state via the
/// `adamw_step_f32` kernel. Params are mutated in place, so the leaf `Var`
/// identities stay stable across steps (no `set_value`). Does NOT zero
/// grads — the caller does. A param without a grad is skipped this step.
pub fn step(&mut self, lr: f32, params: &[Var]) {
let device = params[0].value().device();
if self.state.is_empty() {
self.state = params
.iter()
.map(|p| {
let shape = p.value().shape().to_vec();
(
Tensor::zeros(&shape, DType::F32, device),
Tensor::zeros(&shape, DType::F32, device),
)
})
.collect();
}
assert_eq!(self.state.len(), params.len(), "param count changed");
self.t += 1;
let bc1 = 1.0 - self.beta1.powi(self.t as i32);
let bc2 = 1.0 - self.beta2.powi(self.t as i32);
for (p, (m, v)) in params.iter().zip(&self.state) {
let g = match p.grad() {
Some(g) => g,
None => continue,
};
let pv = p.value();
let n = pv.numel() as i32;
unsafe {
xtrain_cuda::ffi::launch_adamw_step_f32(
pv.data_ptr() as *mut f32,
g.data_ptr() as *const f32,
m.data_ptr() as *mut f32,
v.data_ptr() as *mut f32,
lr,
self.beta1,
self.beta2,
self.eps,
self.weight_decay,
bc1,
bc2,
n,
std::ptr::null_mut(),
);
}
}
xtrain_cuda::device::synchronize().expect("adamw step sync failed");
}
}
}
#[cfg(not(no_cuda))]
pub use gpu::GpuAdamW;

View File

@@ -0,0 +1,76 @@
// GPU AdamW parity (Phase T7): the device-side AdamW kernel (m/v on device, no
// host roundtrip) must produce the same update as the host reference
// `AdamW::step_host` given identical params + grads across several steps with a
// varying lr. This is the new correctness gate for the GPU optimizer; the host
// path itself is already pinned to PyTorch by xtrain-train's adamw_parity test.
//
// Gated #![cfg(not(no_cuda))] (runs on dash5; needs a GPU to link + launch).
#![cfg(not(no_cuda))]
use xtrain_autodiff::tape::Var;
use xtrain_cuda::device;
use xtrain_optim::{AdamW, GpuAdamW};
use xtrain_tensor::{Device, Tensor};
fn grad(step: usize, idx: usize, j: usize) -> f32 {
let s = (step * 13 + idx * 7 + j * 3) as f32;
(s * 0.123).sin() * 0.5
}
#[test]
fn gpu_adamw_matches_host() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let dev = Device::Cuda(0);
let wd = 0.1f32;
// Two params of different sizes (exercises per-param device state).
let shapes: Vec<Vec<usize>> = vec![vec![2, 2], vec![3]];
let init: Vec<Vec<f32>> = vec![vec![0.5, -1.0, 2.0, 0.0], vec![1.5, -0.25, 0.75]];
// GPU side: leaf Vars on device.
let params: Vec<Var> = init
.iter()
.zip(&shapes)
.map(|(d, s)| Var::leaf(Tensor::from_slice(d, s).to_device(dev)))
.collect();
let mut gpu_opt = GpuAdamW::new(wd);
// Host reference.
let mut host_params = init.clone();
let mut host_opt = AdamW::new(0.0, wd);
for step in 0..15 {
let lr = 0.01 + 0.001 * step as f32; // varying lr
let grads: Vec<Vec<f32>> = shapes
.iter()
.enumerate()
.map(|(idx, s)| {
let n: usize = s.iter().product();
(0..n).map(|j| grad(step, idx, j)).collect()
})
.collect();
// Push grads onto the GPU Vars, run the device step, then clear.
for (p, (g, s)) in params.iter().zip(grads.iter().zip(&shapes)) {
p.zero_grad();
Var::push_grad(p, Tensor::from_slice(g, s).to_device(dev));
}
gpu_opt.step(lr, &params);
for p in &params {
p.zero_grad();
}
host_opt.step_host(lr, &mut host_params, &grads);
}
let mut max_err = 0.0f32;
for (p, hp) in params.iter().zip(&host_params) {
let got = p.value().to_device(Device::Cpu).as_slice::<f32>().to_vec();
for (a, b) in got.iter().zip(hp) {
max_err = max_err.max((a - b).abs());
}
}
println!("gpu vs host AdamW: max abs err = {max_err:.3e}");
assert!(max_err < 1e-6, "GPU AdamW diverged from host: {max_err:e}");
}

View File

@@ -73,8 +73,61 @@ mod gpu {
}
}
#[cfg(not(no_cuda))]
mod gpu_norm {
use super::clip_scale;
use xtrain_autodiff::tape::Var;
use xtrain_tensor::{DType, Device, Tensor};
/// GPU-side global-norm grad clip (Phase T7): compute the joint L2 norm of all
/// `pre_scale`-applied grads with a device reduction, then rescale every grad
/// in place by `pre_scale·clip_factor` — no per-step grad roundtrip to host
/// (only the single scalar norm comes back). Returns the post-pre_scale total
/// norm. Params without a grad contribute 0 and are skipped on rescale.
pub fn clip_grad_norm_gpu(params: &[Var], max_norm: f32, pre_scale: f32) -> f32 {
let device = params[0].value().device();
// sum-of-squares of the RAW grads accumulated on device.
let acc = Tensor::zeros(&[1], DType::F32, device);
for p in params {
if let Some(g) = p.grad() {
unsafe {
xtrain_cuda::ffi::launch_sumsq_accum_f32(
g.data_ptr() as *const f32,
acc.data_ptr() as *mut f32,
g.numel() as i32,
std::ptr::null_mut(),
);
}
}
}
xtrain_cuda::device::synchronize().expect("grad-norm reduce sync failed");
let raw_sumsq = acc.to_device(Device::Cpu).as_slice::<f32>()[0];
// Norm of the pre_scale-applied grads = pre_scale · sqrt(raw_sumsq).
let total = pre_scale * raw_sumsq.max(0.0).sqrt();
let factor = pre_scale * clip_scale(total, max_norm);
if (factor - 1.0).abs() >= f32::EPSILON {
for p in params {
if let Some(g) = p.grad() {
unsafe {
xtrain_cuda::ffi::launch_scale_inplace_f32(
g.data_ptr() as *mut f32,
factor,
g.numel() as i32,
std::ptr::null_mut(),
);
}
}
}
xtrain_cuda::device::synchronize().expect("grad rescale sync failed");
}
total
}
}
#[cfg(not(no_cuda))]
pub use gpu::clip_grad_norm;
#[cfg(not(no_cuda))]
pub use gpu_norm::clip_grad_norm_gpu;
#[cfg(test)]
mod tests {

View File

@@ -13,11 +13,11 @@ use std::path::PathBuf;
use std::time::Instant;
use xtrain_model::{TinyTransformer, ids_tensor};
use xtrain_optim::AdamW;
use xtrain_optim::GpuAdamW;
use xtrain_tensor::Device;
use crate::checkpoint;
use crate::clip::clip_grad_norm;
use crate::clip::clip_grad_norm_gpu;
use crate::data::Corpus;
use crate::schedule::LrSchedule;
@@ -47,7 +47,7 @@ pub fn train(
cfg: &TrainConfig,
) -> Vec<f32> {
let params = model.params();
let mut opt = AdamW::new(cfg.schedule.max_lr, cfg.weight_decay);
let mut opt = GpuAdamW::new(cfg.weight_decay);
let mut rng = cfg.seed;
let mut losses = Vec::with_capacity(cfg.steps);
let inv_batch = 1.0 / cfg.batch_size as f32;
@@ -72,7 +72,7 @@ pub fn train(
losses.push(step_loss);
// Average the summed grads (×1/batch) and clip to the global norm.
let gnorm = clip_grad_norm(&params, cfg.max_grad_norm, inv_batch);
let gnorm = clip_grad_norm_gpu(&params, cfg.max_grad_norm, inv_batch);
opt.step(lr, &params);
for p in &params {
p.zero_grad();

86
csrc/ops/optim.cu Normal file
View File

@@ -0,0 +1,86 @@
// GPU-side optimizer kernels (Phase T7): AdamW parameter update and the
// global grad-norm reduction + rescale. These eliminate the per-step GPU↔host
// roundtrip of every parameter/gradient that the T6 host AdamW + host clip did.
//
// All F32, row-major, contiguous. The math mirrors xtrain-optim::AdamW::step_host
// (the reference); bias correction is passed in as bc1/bc2 = 1 - beta^t.
#include <math.h>
extern "C" {
// One AdamW step over a single parameter tensor of `n` elements, in place.
// m ← b1·m + (1-b1)·g
// v ← b2·v + (1-b2)·g²
// p ← p lr·( (m/bc1) / (sqrt(v/bc2) + eps) + wd·p )
// `m`/`v` are this parameter's moment buffers (persisted on device across steps).
__global__ void adamw_step_f32(
float* p, const float* g, float* m, float* v,
float lr, float b1, float b2, float eps, float wd,
float bc1, float bc2, int n
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= n) return;
float gi = g[idx];
float mi = b1 * m[idx] + (1.0f - b1) * gi;
float vi = b2 * v[idx] + (1.0f - b2) * gi * gi;
m[idx] = mi;
v[idx] = vi;
float mhat = mi / bc1;
float vhat = vi / bc2;
p[idx] -= lr * (mhat / (sqrtf(vhat) + eps) + wd * p[idx]);
}
void launch_adamw_step_f32(
float* p, const float* g, float* m, float* v,
float lr, float b1, float b2, float eps, float wd,
float bc1, float bc2, int n, void* stream
) {
int block = 256;
int grid = (n + block - 1) / block;
adamw_step_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
p, g, m, v, lr, b1, b2, eps, wd, bc1, bc2, n);
}
// Accumulate sum-of-squares of one gradient tensor into *acc (a single f32 on
// device, pre-zeroed by the caller). Block-reduces then one atomicAdd per block.
__global__ void sumsq_accum_f32(const float* g, float* acc, int n) {
__shared__ float shared[32];
int tid = blockIdx.x * blockDim.x + threadIdx.x;
float v = (tid < n) ? g[tid] * g[tid] : 0.0f;
// block reduce
int lane = threadIdx.x & 31;
int warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
#pragma unroll
for (int off = 16; off > 0; off >>= 1) v += __shfl_down_sync(0xffffffff, v, off);
if (lane == 0) shared[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? shared[threadIdx.x] : 0.0f;
if (warp == 0) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1) v += __shfl_down_sync(0xffffffff, v, off);
if (lane == 0) atomicAdd(acc, v);
}
}
void launch_sumsq_accum_f32(const float* g, float* acc, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;
sumsq_accum_f32<<<grid, block, 0, (cudaStream_t)stream>>>(g, acc, n);
}
// Scale one tensor in place by a scalar (used to apply pre_scale·clip_factor to
// each gradient). Same as scale_f32 but in place.
__global__ void scale_inplace_f32(float* x, float factor, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) x[idx] *= factor;
}
void launch_scale_inplace_f32(float* x, float factor, int n, void* stream) {
int block = 256;
int grid = (n + block - 1) / block;
scale_inplace_f32<<<grid, block, 0, (cudaStream_t)stream>>>(x, factor, n);
}
}