Design doc for the T7 fp32-preserving speedups: cuBLAS matmul fwd/bwd (row-major⟺col-major layout), GPU AdamW + GPU grad-norm (no per-step param/grad roundtrip), drop per-op sync + device memset. Includes the verification table (regression suite green + tok/s 2770→8220 ~3x), the deferred bf16/recompute follow-up rationale, and the T8 all-reduce note. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
10 KiB
Phase T7: Performance — Design Document
Goal
T6 把真训练打通了(TinyStories,loss 10.83→3.43,采样连贯),但吞吐只有 ~2800 tok/s。 T7 的目标是在不牺牲数值正确性的前提下把训练显著加速——上来先把 fp32 路径里那些纯开销榨干, 再视情况上 bf16 / 激活重计算。
按 xtrain.md T7 note 的优先级,保 fp32 数值不回归的三步是 must-have:
- matmul fwd/bwd 走 cuBLAS —— 前向 + 两路反向(
dA=dC·Bᵀ、dB=Aᵀ·dC)全切 cuBLASSgemm。fp32,等价于换了求和顺序的同一个 GEMM,所以正确性自动保住。 - GPU 侧优化器 + grad clip —— 干掉每步把全部参数/梯度 GPU↔host 往返的开销:AdamW 的 m/v 状态搬到 device、update 走 kernel;global grad-norm 用 device reduction,只把那一个标量取回 host。
- stream / 减 sync —— 不再每个 op 之后都
cudaDeviceSynchronize:default stream 上 kernel 本就顺序执行,host 读数据又都走 stream-ordered 的cudaMemcpy,per-op sync 是纯开销,全删。
不做(本 Phase 范围外):分布式数据并行 / NCCL all-reduce(T8)、导出回流 xserv(T9)。
降级出口:bf16 混合精度(④)/ 激活重计算(⑤)改数值、牵动每个 kernel,且本 model 太小(dim=32)属 latency-bound、bf16 tensor-core 收益有限——按 xtrain.md 的 escape hatch,①②③ 交付并实测加速后,④⑤ 记为 follow-up,不把正确性留在半截状态。详见末节。
Module Layout
csrc/ops/optim.cu # 新:GPU AdamW step + global grad sumsq reduce + in-place scale
crates/xtrain-cuda/
├── src/cublas.rs # 新:持久化 cuBLAS handle + row-major sgemm(含转置位)
├── src/ffi.rs # +CUBLAS_OP_T、+optim.cu 的三个 launch_*、+cudaMemset
├── src/memory.rs # GpuBuffer::memset(device 置零,免 H2D 零拷贝)
├── src/lib.rs # pub mod cublas(not(no_cuda))
└── build.rs # +optim.cu
crates/xtrain-tensor/
├── src/tensor.rs # matmul/matmul_backward 改走 cublas::sgemm;删 21 处 per-op sync
└── src/storage.rs # device zeros 改用 memset
crates/xtrain-optim/
├── src/lib.rs # +GpuAdamW(m/v on device,in-place update);host AdamW 留作参考
├── Cargo.toml # xtrain-cuda 升为常规依赖(GpuAdamW 要发 kernel)
└── tests/adamw_gpu.rs # 新:GPU AdamW 对 host 参考逐位一致
crates/xtrain-train/
├── src/clip.rs # +clip_grad_norm_gpu(device reduce + in-place rescale);host 版留作参考
└── src/train_loop.rs # 改用 GpuAdamW + clip_grad_norm_gpu
Key Design Decisions
① cuBLAS matmul(row-major ⟺ col-major)
cuBLAS 是 列主序,我们的张量是 行主序。一个行主序 [r,c]、leading dim = c 的矩阵交给 cuBLAS,
被读作它的转置(列主序 [c,r])。要拿到行主序结果 C[m,n] = opA(A)·opB(B),就让 cuBLAS 算它的列主序转置
Cᵀ[n,m] = opB(B)ᵀ·opA(A)ᵀ——Cᵀ 列主序的字节布局正好就是 C 行主序。
cublas::sgemm(trans_a, trans_b, m, n, k, …) 落地为:第一参 = B(op = trans_b ? N : T),第二参 = A(op = trans_a ? N : T),尺寸 (m=n, n=m, k=k),lda/ldb/ldc = 各自存储态行主序的列数:
let lda = if trans_a { m } else { k }; // A 存 [m,k] 或 [k,m]
let ldb = if trans_b { k } else { n }; // B 存 [k,n] 或 [n,k]
let ldc = n; // Cᵀ 是 [n,m] 列主序 ld=n(== 行主序 C[m,n])
trans_a=trans_b=false 这一支与 T3 测试里的 cuBLAS oracle 逐参数一致(同样 OP_N、交换顺序、m=N/n=M/k=K),所以前向天然对得上。
反向用 cuBLAS 的转置位省两个 transpose kernel:T3 版 matmul_backward 是 dc.matmul(b.transpose_2d()) + a.transpose_2d().matmul(dc),要起两个 transpose kernel + 两次分配。T7 直接:
dA[M,K] = dC[M,N] · Bᵀ → sgemm(trans_a=false, trans_b=true, m=M,n=K,k=N, a=dC, b=B)
dB[K,N] = Aᵀ · dC[M,N] → sgemm(trans_a=true, trans_b=false, m=K,n=N,k=M, a=A, b=dC)
为什么不回归:全程 fp32,cuBLAS 与手写 tiled kernel 算的是同一个 GEMM,只差求和顺序的 rounding。 所以 T3「fwd 对 cuBLAS / bwd 对 finite-diff」的容差不变,下游 autograd grad-check、PyTorch 对拍也不变。
handle 持久化:cuBLAS handle 创建很贵(T3 oracle 每次调用都 create/destroy)。改为 每线程缓存一个 handle,进程生命周期内复用(thread_local! + RefCell<Option<CublasHandle>>)。
② GPU AdamW + GPU grad-norm(去掉每步全参往返)
T6 的瓶颈之一:AdamW::step 把每个参数的 value + grad 全 D2H 拉回 host、host 上跑 AdamW、再 H2D 写回;clip_grad_norm 同理把全部 grad 拉回 host 算范数。3.26M 参数 × 每步两趟 = 大量 PCIe 往返 + 同步。
GpuAdamW:m/v 矩状态以「每参一对 device Tensor」常驻显存,update 是一个 in-place kernel——读参数的 .grad()、原地改写参数 buffer(参数 leaf 的 storage 是 Arc 共享,原地写对所有 clone 可见,leaf 身份跨步稳定,无需 set_value):
m ← β1·m + (1−β1)·g ; v ← β2·v + (1−β2)·g²
p ← p − lr·( (m/bc1) / (√(v/bc2) + ε) + wd·p ) bc1/bc2 = 1−βᵗ(host 传入)
数学与 host AdamW::step_host 逐字对应;host AdamW 原样保留作 PyTorch 对拍的参考,新增 adamw_gpu 测试拿同一组 params/grads 把 GPU 结果对 host 参考逐位比(实测 max abs err = 0)。
clip_grad_norm_gpu:sumsq_accum kernel 对每个 grad 做 block-reduce 后 atomicAdd 到一个 device 标量;只把这一个标量取回 host 求 sqrt、算 clip factor,再用 scale_inplace kernel 原地把每个 grad 乘 pre_scale·factor。整步只回传 1 个 float,不再拉全部 grad。
③ stream / 减 sync
每个 tensor op 之前 Tensor::zeros 分配输出、之后 cudaDeviceSynchronize——两处都是隐藏开销:
- per-op sync 全删(21 处):default stream 上 kernel 顺序执行;任何 host 读数据都走
to_device(Cpu)→ 阻塞且 stream-ordered 的cudaMemcpy,自然等齐前面的 kernel。所以 op 后那次显式 sync 对正确性纯属多余(只是把异步 kernel 错误提前暴露,可接受地推迟到下一次 sync/memcpy)。 - device zeros 改
cudaMemset:原来每个 op 输出都用「host 零 buffer + 阻塞 H2D memcpy」置零,那次 H2D 本身就是个 per-op 同步点 + 一次拷贝;换成 device 端cudaMemset(default stream 上异步,不串行化 stream)。
once-per-step 的 sync(clip 取范数前、AdamW step 末尾)保留——量级是每步一次,非每 op。
CUDA-graph capture 是 optional bonus,本 Phase 未做。
验证方法
两道闸都要过:
A. 数值不回归(fp32 容差不变,全绿)——dash5 实跑:
| 测试 | 闸 | 结果 |
|---|---|---|
| T3 GEMM(fwd vs cuBLAS / bwd vs finite-diff) | rel-err 容差不变 | 5/5 ok |
| T4 autograd grad-check(每 op finite-diff) | ≤2e-2 不变 | 12/12 ok |
| T5 结构 grad-check + overfit + PyTorch 对拍 | overfit 27/27、logits relerr、21 参梯度 rtol 不变 | overfit 2.821→0.004 (27/27);parity logits relerr 1.5e-4、21 grads OK |
| T6 AdamW vs torch + checkpoint round-trip | 轨迹/终参 rtol 不变、逐位一致 | AdamW relerr 4.6e-6;ckpt logit diff 0.0 |
| T7 GPU AdamW vs host 参考 | 逐位一致 | max abs err 0.0 |
B. 吞吐提升——同 model/config(dim 32、4 层、vocab 50257、seq 64、batch 8、~3.26M 参),60 步计时取稳态:
| 步骤 | tok/s | 备注 |
|---|---|---|
| baseline (T6) | ~2770 | 起点 |
| ① cuBLAS matmul | ~3310 | matmul 非主瓶颈(model 小、latency-bound) |
| ② GPU AdamW + grad-norm | ~4070 | 去掉每步全参 GPU↔host 往返 |
| ③ drop per-op sync + memset | ~8220 | 删 21 处 per-op sync 是大头 |
端到端(real_training 800 步,新快路):~8500 tok/s 稳态,loss 10.81→3.90(avg10),采样
Once upon a time, there was a little girl named Lily. She was very happy to play with her mom.——
收敛与 T6 fp32 同轨。
净加速 ~3×,零数值回归。
④⑤ Follow-up(本 Phase 未做,记给后续)
- ④ bf16 混合精度(fp32 master):matmul/激活走 bf16、optimizer 持 fp32 master 拷贝。本 model dim=32 太小、属 launch/latency-bound,bf16 tensor-core 算力收益有限,唯一够大的
lm_head [64,32]@[32,50257]主要吃带宽;且 bf16 改数值、要单独加宽容差 + 重验收敛,风险/收益此规模下不划算。等模型放大或上 T8 多卡再做更值。 - ⑤ 激活重计算:反向重算 block 激活省显存。当前单序列、显存不紧,优先级低。
两者按 escape hatch 推迟,①②③ 的 fp32 加速已完整交付且全测绿。
T8 衔接(数据并行 all-reduce)
T7 之后梯度常驻 device(.grad() 是 device tensor),优化器 update 也全在 device——这正好对接 T8 的 NCCL 数据并行:
- 各 rank 本地
backward后,梯度已在显存里,直接对params的.grad()张量 all-reduce(无需先拉回 host)。 - all-reduce 取 均值后,每 rank 各自跑
GpuAdamW.step——因为各 rank 梯度一致、优化器状态从相同 init 同步演化,参数自然保持一致(无需再同步参数)。 - grad clip 的 global-norm 在 all-reduce 之后算:
clip_grad_norm_gpu的sumsq_accum已是 device reduction,多卡只需把那个标量再 all-reduce 一次(或对已 all-reduce 的梯度本地算,因梯度已一致,结果天然相同)。