perf: GPU AdamW + grad-norm
Eliminate the per-step GPU↔host roundtrip of every parameter/gradient. - optim.cu: adamw_step (m/v on device, in-place param update), sumsq_accum (block-reduced global grad sum-of-squares), scale_inplace. - GpuAdamW: device m/v state per param; step launches the kernel reading each param's .grad() and rewriting the param buffer in place — no host roundtrip. Host AdamW kept as the torch-parity reference. - clip_grad_norm_gpu: device sum-of-squares reduction (only the scalar norm comes back), in-place rescale of grads by pre_scale·clip_factor. - train_loop: use GpuAdamW + clip_grad_norm_gpu. - test: GPU AdamW vs host reference parity (max abs err < 1e-6). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -34,6 +34,7 @@ fn main() {
|
||||
.file("../../csrc/ops/gemm.cu")
|
||||
.file("../../csrc/ops/nn.cu")
|
||||
.file("../../csrc/ops/model.cu")
|
||||
.file("../../csrc/ops/optim.cu")
|
||||
.compile("xtrain_cuda_kernels");
|
||||
}
|
||||
|
||||
|
||||
@@ -212,6 +212,34 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
// GPU-side optimizer kernels (csrc/ops/optim.cu): AdamW step (m/v on device) and
|
||||
// the global grad-norm reduction + in-place rescale (Phase T7).
|
||||
#[cfg(not(no_cuda))]
|
||||
unsafe extern "C" {
|
||||
// One in-place AdamW step over a parameter tensor of `n` elements. `bc1`/`bc2`
|
||||
// are the bias-correction denominators 1-beta^t.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn launch_adamw_step_f32(
|
||||
p: *mut f32,
|
||||
g: *const f32,
|
||||
m: *mut f32,
|
||||
v: *mut f32,
|
||||
lr: f32,
|
||||
b1: f32,
|
||||
b2: f32,
|
||||
eps: f32,
|
||||
wd: f32,
|
||||
bc1: f32,
|
||||
bc2: f32,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// acc += sum_i g[i]^2 (acc is one f32 on device, pre-zeroed). atomicAdd.
|
||||
pub fn launch_sumsq_accum_f32(g: *const f32, acc: *mut f32, n: i32, s: CudaStream);
|
||||
// In-place scalar scale: x[i] *= factor.
|
||||
pub fn launch_scale_inplace_f32(x: *mut f32, factor: f32, n: i32, s: CudaStream);
|
||||
}
|
||||
|
||||
// cuBLAS — the production GEMM backend (Phase T7) and the correctness oracle the
|
||||
// T3 GEMM tests still compare against. Declared (and linked, see build.rs) only
|
||||
// when CUDA is compiled in.
|
||||
|
||||
@@ -113,7 +113,7 @@ impl AdamW {
|
||||
mod gpu {
|
||||
use super::AdamW;
|
||||
use xtrain_autodiff::tape::Var;
|
||||
use xtrain_tensor::{Device, Tensor};
|
||||
use xtrain_tensor::{DType, Device, Tensor};
|
||||
|
||||
impl AdamW {
|
||||
/// Apply one AdamW step to every parameter `Var`, using `lr` for this step
|
||||
@@ -125,6 +125,10 @@ mod gpu {
|
||||
///
|
||||
/// Does NOT zero grads — the caller does that (matching the GD-step
|
||||
/// template in the T5 overfit test).
|
||||
///
|
||||
/// This is the host-roundtrip reference path; training uses
|
||||
/// [`GpuAdamW`] (kernel, m/v on device). Both are checked against the
|
||||
/// torch parity in tests.
|
||||
pub fn step(&mut self, lr: f32, params: &[Var]) {
|
||||
let device = params[0].value().device();
|
||||
let shapes: Vec<Vec<usize>> =
|
||||
@@ -151,4 +155,93 @@ mod gpu {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// GPU AdamW (Phase T7): the optimizer state (m/v moments) lives on the device
|
||||
/// as one tensor pair per parameter, and the update runs as a CUDA kernel that
|
||||
/// reads each param's `.grad()` and rewrites the param buffer in place — no
|
||||
/// per-step GPU↔host roundtrip of params/grads. Same math as
|
||||
/// [`AdamW::step_host`] (the parity reference).
|
||||
pub struct GpuAdamW {
|
||||
beta1: f32,
|
||||
beta2: f32,
|
||||
eps: f32,
|
||||
weight_decay: f32,
|
||||
t: u64,
|
||||
/// Per-parameter (m, v) device buffers, sized lazily on first step.
|
||||
state: Vec<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl GpuAdamW {
|
||||
/// PyTorch-default betas/eps; you set lr (per-step) + weight decay.
|
||||
pub fn new(weight_decay: f32) -> Self {
|
||||
Self {
|
||||
beta1: 0.9,
|
||||
beta2: 0.999,
|
||||
eps: 1e-8,
|
||||
weight_decay,
|
||||
t: 0,
|
||||
state: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn step_count(&self) -> u64 {
|
||||
self.t
|
||||
}
|
||||
|
||||
/// One in-place AdamW step over every parameter `Var` at learning rate
|
||||
/// `lr`. Updates the param value buffer and the device m/v state via the
|
||||
/// `adamw_step_f32` kernel. Params are mutated in place, so the leaf `Var`
|
||||
/// identities stay stable across steps (no `set_value`). Does NOT zero
|
||||
/// grads — the caller does. A param without a grad is skipped this step.
|
||||
pub fn step(&mut self, lr: f32, params: &[Var]) {
|
||||
let device = params[0].value().device();
|
||||
if self.state.is_empty() {
|
||||
self.state = params
|
||||
.iter()
|
||||
.map(|p| {
|
||||
let shape = p.value().shape().to_vec();
|
||||
(
|
||||
Tensor::zeros(&shape, DType::F32, device),
|
||||
Tensor::zeros(&shape, DType::F32, device),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
assert_eq!(self.state.len(), params.len(), "param count changed");
|
||||
|
||||
self.t += 1;
|
||||
let bc1 = 1.0 - self.beta1.powi(self.t as i32);
|
||||
let bc2 = 1.0 - self.beta2.powi(self.t as i32);
|
||||
|
||||
for (p, (m, v)) in params.iter().zip(&self.state) {
|
||||
let g = match p.grad() {
|
||||
Some(g) => g,
|
||||
None => continue,
|
||||
};
|
||||
let pv = p.value();
|
||||
let n = pv.numel() as i32;
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_adamw_step_f32(
|
||||
pv.data_ptr() as *mut f32,
|
||||
g.data_ptr() as *const f32,
|
||||
m.data_ptr() as *mut f32,
|
||||
v.data_ptr() as *mut f32,
|
||||
lr,
|
||||
self.beta1,
|
||||
self.beta2,
|
||||
self.eps,
|
||||
self.weight_decay,
|
||||
bc1,
|
||||
bc2,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("adamw step sync failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
pub use gpu::GpuAdamW;
|
||||
|
||||
76
crates/xtrain-optim/tests/adamw_gpu.rs
Normal file
76
crates/xtrain-optim/tests/adamw_gpu.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
// GPU AdamW parity (Phase T7): the device-side AdamW kernel (m/v on device, no
|
||||
// host roundtrip) must produce the same update as the host reference
|
||||
// `AdamW::step_host` given identical params + grads across several steps with a
|
||||
// varying lr. This is the new correctness gate for the GPU optimizer; the host
|
||||
// path itself is already pinned to PyTorch by xtrain-train's adamw_parity test.
|
||||
//
|
||||
// Gated #![cfg(not(no_cuda))] (runs on dash5; needs a GPU to link + launch).
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_autodiff::tape::Var;
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_optim::{AdamW, GpuAdamW};
|
||||
use xtrain_tensor::{Device, Tensor};
|
||||
|
||||
fn grad(step: usize, idx: usize, j: usize) -> f32 {
|
||||
let s = (step * 13 + idx * 7 + j * 3) as f32;
|
||||
(s * 0.123).sin() * 0.5
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gpu_adamw_matches_host() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let dev = Device::Cuda(0);
|
||||
|
||||
let wd = 0.1f32;
|
||||
// Two params of different sizes (exercises per-param device state).
|
||||
let shapes: Vec<Vec<usize>> = vec![vec![2, 2], vec![3]];
|
||||
let init: Vec<Vec<f32>> = vec![vec![0.5, -1.0, 2.0, 0.0], vec![1.5, -0.25, 0.75]];
|
||||
|
||||
// GPU side: leaf Vars on device.
|
||||
let params: Vec<Var> = init
|
||||
.iter()
|
||||
.zip(&shapes)
|
||||
.map(|(d, s)| Var::leaf(Tensor::from_slice(d, s).to_device(dev)))
|
||||
.collect();
|
||||
let mut gpu_opt = GpuAdamW::new(wd);
|
||||
|
||||
// Host reference.
|
||||
let mut host_params = init.clone();
|
||||
let mut host_opt = AdamW::new(0.0, wd);
|
||||
|
||||
for step in 0..15 {
|
||||
let lr = 0.01 + 0.001 * step as f32; // varying lr
|
||||
let grads: Vec<Vec<f32>> = shapes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, s)| {
|
||||
let n: usize = s.iter().product();
|
||||
(0..n).map(|j| grad(step, idx, j)).collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Push grads onto the GPU Vars, run the device step, then clear.
|
||||
for (p, (g, s)) in params.iter().zip(grads.iter().zip(&shapes)) {
|
||||
p.zero_grad();
|
||||
Var::push_grad(p, Tensor::from_slice(g, s).to_device(dev));
|
||||
}
|
||||
gpu_opt.step(lr, ¶ms);
|
||||
for p in ¶ms {
|
||||
p.zero_grad();
|
||||
}
|
||||
|
||||
host_opt.step_host(lr, &mut host_params, &grads);
|
||||
}
|
||||
|
||||
let mut max_err = 0.0f32;
|
||||
for (p, hp) in params.iter().zip(&host_params) {
|
||||
let got = p.value().to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
for (a, b) in got.iter().zip(hp) {
|
||||
max_err = max_err.max((a - b).abs());
|
||||
}
|
||||
}
|
||||
println!("gpu vs host AdamW: max abs err = {max_err:.3e}");
|
||||
assert!(max_err < 1e-6, "GPU AdamW diverged from host: {max_err:e}");
|
||||
}
|
||||
@@ -73,8 +73,61 @@ mod gpu {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
mod gpu_norm {
|
||||
use super::clip_scale;
|
||||
use xtrain_autodiff::tape::Var;
|
||||
use xtrain_tensor::{DType, Device, Tensor};
|
||||
|
||||
/// GPU-side global-norm grad clip (Phase T7): compute the joint L2 norm of all
|
||||
/// `pre_scale`-applied grads with a device reduction, then rescale every grad
|
||||
/// in place by `pre_scale·clip_factor` — no per-step grad roundtrip to host
|
||||
/// (only the single scalar norm comes back). Returns the post-pre_scale total
|
||||
/// norm. Params without a grad contribute 0 and are skipped on rescale.
|
||||
pub fn clip_grad_norm_gpu(params: &[Var], max_norm: f32, pre_scale: f32) -> f32 {
|
||||
let device = params[0].value().device();
|
||||
// sum-of-squares of the RAW grads accumulated on device.
|
||||
let acc = Tensor::zeros(&[1], DType::F32, device);
|
||||
for p in params {
|
||||
if let Some(g) = p.grad() {
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_sumsq_accum_f32(
|
||||
g.data_ptr() as *const f32,
|
||||
acc.data_ptr() as *mut f32,
|
||||
g.numel() as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("grad-norm reduce sync failed");
|
||||
let raw_sumsq = acc.to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
// Norm of the pre_scale-applied grads = pre_scale · sqrt(raw_sumsq).
|
||||
let total = pre_scale * raw_sumsq.max(0.0).sqrt();
|
||||
let factor = pre_scale * clip_scale(total, max_norm);
|
||||
if (factor - 1.0).abs() >= f32::EPSILON {
|
||||
for p in params {
|
||||
if let Some(g) = p.grad() {
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_scale_inplace_f32(
|
||||
g.data_ptr() as *mut f32,
|
||||
factor,
|
||||
g.numel() as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
xtrain_cuda::device::synchronize().expect("grad rescale sync failed");
|
||||
}
|
||||
total
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
pub use gpu::clip_grad_norm;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub use gpu_norm::clip_grad_norm_gpu;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
@@ -13,11 +13,11 @@ use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
|
||||
use xtrain_model::{TinyTransformer, ids_tensor};
|
||||
use xtrain_optim::AdamW;
|
||||
use xtrain_optim::GpuAdamW;
|
||||
use xtrain_tensor::Device;
|
||||
|
||||
use crate::checkpoint;
|
||||
use crate::clip::clip_grad_norm;
|
||||
use crate::clip::clip_grad_norm_gpu;
|
||||
use crate::data::Corpus;
|
||||
use crate::schedule::LrSchedule;
|
||||
|
||||
@@ -47,7 +47,7 @@ pub fn train(
|
||||
cfg: &TrainConfig,
|
||||
) -> Vec<f32> {
|
||||
let params = model.params();
|
||||
let mut opt = AdamW::new(cfg.schedule.max_lr, cfg.weight_decay);
|
||||
let mut opt = GpuAdamW::new(cfg.weight_decay);
|
||||
let mut rng = cfg.seed;
|
||||
let mut losses = Vec::with_capacity(cfg.steps);
|
||||
let inv_batch = 1.0 / cfg.batch_size as f32;
|
||||
@@ -72,7 +72,7 @@ pub fn train(
|
||||
losses.push(step_loss);
|
||||
|
||||
// Average the summed grads (×1/batch) and clip to the global norm.
|
||||
let gnorm = clip_grad_norm(¶ms, cfg.max_grad_norm, inv_batch);
|
||||
let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, inv_batch);
|
||||
opt.step(lr, ¶ms);
|
||||
for p in ¶ms {
|
||||
p.zero_grad();
|
||||
|
||||
86
csrc/ops/optim.cu
Normal file
86
csrc/ops/optim.cu
Normal file
@@ -0,0 +1,86 @@
|
||||
// GPU-side optimizer kernels (Phase T7): AdamW parameter update and the
|
||||
// global grad-norm reduction + rescale. These eliminate the per-step GPU↔host
|
||||
// roundtrip of every parameter/gradient that the T6 host AdamW + host clip did.
|
||||
//
|
||||
// All F32, row-major, contiguous. The math mirrors xtrain-optim::AdamW::step_host
|
||||
// (the reference); bias correction is passed in as bc1/bc2 = 1 - beta^t.
|
||||
|
||||
#include <math.h>
|
||||
|
||||
extern "C" {
|
||||
|
||||
// One AdamW step over a single parameter tensor of `n` elements, in place.
|
||||
// m ← b1·m + (1-b1)·g
|
||||
// v ← b2·v + (1-b2)·g²
|
||||
// p ← p − lr·( (m/bc1) / (sqrt(v/bc2) + eps) + wd·p )
|
||||
// `m`/`v` are this parameter's moment buffers (persisted on device across steps).
|
||||
__global__ void adamw_step_f32(
|
||||
float* p, const float* g, float* m, float* v,
|
||||
float lr, float b1, float b2, float eps, float wd,
|
||||
float bc1, float bc2, int n
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= n) return;
|
||||
float gi = g[idx];
|
||||
float mi = b1 * m[idx] + (1.0f - b1) * gi;
|
||||
float vi = b2 * v[idx] + (1.0f - b2) * gi * gi;
|
||||
m[idx] = mi;
|
||||
v[idx] = vi;
|
||||
float mhat = mi / bc1;
|
||||
float vhat = vi / bc2;
|
||||
p[idx] -= lr * (mhat / (sqrtf(vhat) + eps) + wd * p[idx]);
|
||||
}
|
||||
|
||||
void launch_adamw_step_f32(
|
||||
float* p, const float* g, float* m, float* v,
|
||||
float lr, float b1, float b2, float eps, float wd,
|
||||
float bc1, float bc2, int n, void* stream
|
||||
) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
adamw_step_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
p, g, m, v, lr, b1, b2, eps, wd, bc1, bc2, n);
|
||||
}
|
||||
|
||||
// Accumulate sum-of-squares of one gradient tensor into *acc (a single f32 on
|
||||
// device, pre-zeroed by the caller). Block-reduces then one atomicAdd per block.
|
||||
__global__ void sumsq_accum_f32(const float* g, float* acc, int n) {
|
||||
__shared__ float shared[32];
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
float v = (tid < n) ? g[tid] * g[tid] : 0.0f;
|
||||
// block reduce
|
||||
int lane = threadIdx.x & 31;
|
||||
int warp = threadIdx.x >> 5;
|
||||
int nwarps = (blockDim.x + 31) >> 5;
|
||||
#pragma unroll
|
||||
for (int off = 16; off > 0; off >>= 1) v += __shfl_down_sync(0xffffffff, v, off);
|
||||
if (lane == 0) shared[warp] = v;
|
||||
__syncthreads();
|
||||
v = (threadIdx.x < nwarps) ? shared[threadIdx.x] : 0.0f;
|
||||
if (warp == 0) {
|
||||
#pragma unroll
|
||||
for (int off = 16; off > 0; off >>= 1) v += __shfl_down_sync(0xffffffff, v, off);
|
||||
if (lane == 0) atomicAdd(acc, v);
|
||||
}
|
||||
}
|
||||
|
||||
void launch_sumsq_accum_f32(const float* g, float* acc, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
sumsq_accum_f32<<<grid, block, 0, (cudaStream_t)stream>>>(g, acc, n);
|
||||
}
|
||||
|
||||
// Scale one tensor in place by a scalar (used to apply pre_scale·clip_factor to
|
||||
// each gradient). Same as scale_f32 but in place.
|
||||
__global__ void scale_inplace_f32(float* x, float factor, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) x[idx] *= factor;
|
||||
}
|
||||
|
||||
void launch_scale_inplace_f32(float* x, float factor, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
scale_inplace_f32<<<grid, block, 0, (cudaStream_t)stream>>>(x, factor, n);
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user