Compare commits
26 Commits
31cc2bf745
...
4b6d3e0a79
| Author | SHA1 | Date | |
|---|---|---|---|
| 4b6d3e0a79 | |||
| c36cdf74d1 | |||
| f26db882e5 | |||
| 9e958cb0f9 | |||
| 80fafa1914 | |||
| e625aa05dd | |||
| 5eb27783f8 | |||
| 1fdd0c5002 | |||
| 6b8c1e4e0f | |||
| 8bd7db16e1 | |||
| b06b553f99 | |||
| abe5ceb913 | |||
| 7a03b0054a | |||
| d01fec6639 | |||
| 9064ced4c2 | |||
| d217f4fbd3 | |||
| 4d7b69f8d4 | |||
| 9b05f4f93f | |||
| c0f0b67510 | |||
| 80602099dc | |||
| f38beb0346 | |||
| 01fb22d114 | |||
| 5f3b81ac96 | |||
| 0e20821633 | |||
| 326a6fadfe | |||
| 65a2264227 |
11
README.md
11
README.md
@@ -50,9 +50,18 @@ Each phase: design doc + implementation + tests + a scoped commit (see [`docs/`]
|
||||
| **T11** | **device caching allocator** (fixes KI-5) | single-GPU 2.3×; **8-GPU 461K tok/s** |
|
||||
| **T12** | **bf16 mixed precision** (fp32 master, fixes KI-2) | dim768 OOM solved; −29% mem |
|
||||
| **T13** | **activation recompute** / checkpointing (fixes KI-3) | dim1024 fits; grads bit-identical |
|
||||
| **T14** | **fused flash-attention** kernel (online softmax, no materialized N×N; opt-in `--flash`) | peak mem −16%@1k / −23%@2k seq; flash==composed (grads/PyTorch) |
|
||||
| **T16** | **gradient accumulation** (`--accum-steps`; DDP all-reduces only at the boundary) | equiv to N× big batch (grad 3.8e-5); same effective-64 batch 27.7GB→7.2GB (−74%) |
|
||||
| **T18** | **dropout** (hand counter-based device RNG + mask, inverted scaling, train/eval switch) | fixed-seed grad-check; **p=0 bit-identical**; recompute-safe |
|
||||
|
||||
The four performance fixes (T10–T13) each removed a real bottleneck — see
|
||||
[`docs/known-issues.md`](docs/known-issues.md).
|
||||
[`docs/known-issues.md`](docs/known-issues.md). **Phase 2 (systems-stack depth, T14–)**
|
||||
revisits hand-writing deferred training-stack features: T14 = the fused
|
||||
flash-attention kernel ([`docs/13-flash-attention.md`](docs/13-flash-attention.md));
|
||||
T16 = micro-batch gradient accumulation ([`docs/15-grad-accum.md`](docs/15-grad-accum.md)),
|
||||
which decouples the effective batch from activation memory (memory tracks the micro-batch,
|
||||
not N×); T18 = dropout ([`docs/17-dropout.md`](docs/17-dropout.md), hand counter-based
|
||||
device RNG + mask, inverted scaling, train/eval switch).
|
||||
|
||||
## The scaling study — v0 → v8
|
||||
|
||||
|
||||
@@ -140,6 +140,31 @@ pub fn swiglu(gate: &Var, up: &Var) -> Var {
|
||||
mul(&silu(gate), up)
|
||||
}
|
||||
|
||||
/// Dropout (Phase T18). With probability `p` zero each element, scale the kept
|
||||
/// ones by `1/(1-p)` (inverted dropout — `E[out] == x`). The keep/drop mask is
|
||||
/// drawn by a counter-based RNG from `(seed, element index)`, so it is fully
|
||||
/// determined by `seed` (same `seed` ⇒ same mask: stable across the T13 recompute
|
||||
/// re-run, and held fixed across the ± perturbation of a finite-diff grad-check).
|
||||
/// Forward caches the per-element scale `mask`; **backward applies the same mask**
|
||||
/// (`dx = d ⊙ mask`), making dropout a fixed elementwise linear map of `x`.
|
||||
///
|
||||
/// `p == 0` is a no-op: returns `x.clone()` (no node added) so the default graph
|
||||
/// is bit-identical to the no-dropout path. eval-time identity is handled by the
|
||||
/// caller simply not invoking dropout (the model's train/eval switch).
|
||||
pub fn dropout(x: &Var, p: f32, seed: u64) -> Var {
|
||||
if p == 0.0 {
|
||||
return x.clone();
|
||||
}
|
||||
let (out, mask) = x.value().dropout(p, seed);
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![x.clone()],
|
||||
Box::new(move |d, parents| {
|
||||
Var::push_grad(&parents[0], Tensor::dropout_backward(d, &mask));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// RoPE (rotate_half) over `x:[tokens,heads,head_dim]` with per-sequence position
|
||||
/// `row % period` (`period` = sequence length; `period == tokens` for a single
|
||||
/// sequence). Orthogonal map, so the backward is the inverse rotation of `dy` — no
|
||||
@@ -325,6 +350,32 @@ pub fn attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
|
||||
)
|
||||
}
|
||||
|
||||
/// Fused FLASH causal scaled-dot-product attention (Phase T14). Same interface as
|
||||
/// [`attention`] (`q`,`k`,`v` each `[bh, seq, head_dim]`), but the forward is a
|
||||
/// SINGLE fused kernel with an online softmax over KV tiles — the `[bh,seq,seq]`
|
||||
/// score matrix is NEVER materialized, and backward caches only the per-row
|
||||
/// logsumexp (O(N)) instead of the whole probs (O(N²)). Mathematically the same
|
||||
/// SDPA, so it matches the composed [`attention`] within fp/bf16 tolerance.
|
||||
/// Opt-in via the model's `--flash` flag; the composed path stays the default.
|
||||
pub fn flash_attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
|
||||
let (out, lse) = q.value().flash_attention(&k.value(), &v.value(), scale);
|
||||
let out_cache = out.clone();
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![q.clone(), k.clone(), v.clone()],
|
||||
Box::new(move |dout, parents| {
|
||||
let q = parents[0].value();
|
||||
let k = parents[1].value();
|
||||
let v = parents[2].value();
|
||||
let (dq, dk, dv) =
|
||||
Tensor::flash_attention_backward(&q, &k, &v, &out_cache, &lse, dout, scale);
|
||||
Var::push_grad(&parents[0], dq);
|
||||
Var::push_grad(&parents[1], dk);
|
||||
Var::push_grad(&parents[2], dv);
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
|
||||
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
|
||||
/// scaled by the upstream scalar grad.
|
||||
|
||||
@@ -625,6 +625,247 @@ fn attention_batched_bwd() {
|
||||
);
|
||||
}
|
||||
|
||||
// ---- fused FLASH causal attention (the T14 op) ----
|
||||
// Same structure + dimensions as attention_batched_bwd (bh=2,seq=5,hd=6), but
|
||||
// exercises ops::flash_attention. Grad-check dq/dk/dv against finite-diff of
|
||||
// L=sum(W∘out). This is the SINGLE-tile regime (seq<FA_TILE=32), matching the
|
||||
// trusted composed grad-check's clean near-zero behavior; the MULTI-tile online-
|
||||
// softmax path (seq>FA_TILE) is validated against the already-grad-checked
|
||||
// composed backward by `flash_bwd_matches_composed_bwd` (seq=40) — sharper than
|
||||
// finite-diff, which is unreliable on the near-zero grad elements a long softmax
|
||||
// produces.
|
||||
#[test]
|
||||
fn flash_attention_batched_bwd() {
|
||||
require_gpu();
|
||||
let (bh, seq, hd) = (2, 5, 6);
|
||||
let n = bh * seq * hd;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
// Scale Q/K up so the softmax is non-uniform (sharper attention) → the dQ/dK
|
||||
// gradients are well-conditioned, not the near-zero saddle values a uniform
|
||||
// softmax produces (those make central finite-diff give spurious 0.0 / sign
|
||||
// flips that aren't backward bugs — cf. flash_bwd_matches_composed_bwd).
|
||||
let q_h: Vec<f32> = fill(n, 241).iter().map(|v| v * 2.5).collect();
|
||||
let k_h: Vec<f32> = fill(n, 242).iter().map(|v| v * 2.5).collect();
|
||||
let v_h = fill(n, 243);
|
||||
let w = fill(n, 244);
|
||||
|
||||
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
|
||||
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
|
||||
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
|
||||
let out = ops::flash_attention(&q, &k, &v, scale);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let dq = q.grad().unwrap().to_device(Device::Cpu);
|
||||
let dk = k.grad().unwrap().to_device(Device::Cpu);
|
||||
let dv = v.grad().unwrap().to_device(Device::Cpu);
|
||||
|
||||
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
|
||||
let qv = cuda(qh, &[bh, seq, hd]);
|
||||
let kv = cuda(kh, &[bh, seq, hd]);
|
||||
let vv = cuda(vh, &[bh, seq, hd]);
|
||||
let (o, _) = qv.flash_attention(&kv, &vv, scale);
|
||||
weighted_sum(&o, &w)
|
||||
};
|
||||
// Attention dQ/dK carry softmax curvature; for the small grad magnitudes here
|
||||
// a larger eps (2e-3) cuts the f32 rounding term (∝|L|/eps) that dominates the
|
||||
// O(eps²) truncation on a ~4e-4 grad. (dV is exactly linear → cfg_linear.)
|
||||
let cfg_attn = GradCheckConfig {
|
||||
eps: 2e-3,
|
||||
rel_tol: 3e-2,
|
||||
atol: 1e-3,
|
||||
};
|
||||
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
|
||||
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
|
||||
report(
|
||||
"flash dQ",
|
||||
&grad_check(&q_h, &[bh, seq, hd], &lq, dq.as_slice::<f32>(), cfg_attn),
|
||||
);
|
||||
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
|
||||
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
|
||||
report(
|
||||
"flash dK",
|
||||
&grad_check(&k_h, &[bh, seq, hd], &lk, dk.as_slice::<f32>(), cfg_attn),
|
||||
);
|
||||
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
|
||||
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
|
||||
report(
|
||||
"flash dV",
|
||||
&grad_check(
|
||||
&v_h,
|
||||
&[bh, seq, hd],
|
||||
&lv,
|
||||
dv.as_slice::<f32>(),
|
||||
cfg_linear(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// flash forward must equal the composed attention forward (same SDPA math).
|
||||
#[test]
|
||||
fn flash_matches_composed_fwd() {
|
||||
require_gpu();
|
||||
let (bh, seq, hd) = (2, 40, 16);
|
||||
let n = bh * seq * hd;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
let q = cuda(&fill(n, 341), &[bh, seq, hd]);
|
||||
let k = cuda(&fill(n, 342), &[bh, seq, hd]);
|
||||
let v = cuda(&fill(n, 343), &[bh, seq, hd]);
|
||||
let (oc, _) = q.attention(&k, &v, scale);
|
||||
let (of, _) = q.flash_attention(&k, &v, scale);
|
||||
let oc = oc.to_device(Device::Cpu);
|
||||
let of = of.to_device(Device::Cpu);
|
||||
let max_rel = oc
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.zip(of.as_slice::<f32>())
|
||||
.map(|(c, f)| (c - f).abs() / (c.abs() + 1e-6))
|
||||
.fold(0.0f32, f32::max);
|
||||
println!("flash-vs-composed fwd max rel: {max_rel:.3e}");
|
||||
assert!(
|
||||
max_rel < 1e-4,
|
||||
"flash fwd diverges from composed: {max_rel:.3e}"
|
||||
);
|
||||
}
|
||||
|
||||
// flash backward must equal the (already grad-checked) composed backward. This is
|
||||
// a sharper test than finite-diff: both share the trusted composed forward as the
|
||||
// reference, so it isolates the flash bwd dQ/dK/dV math from finite-diff noise on
|
||||
// near-zero gradient elements.
|
||||
#[test]
|
||||
fn flash_bwd_matches_composed_bwd() {
|
||||
require_gpu();
|
||||
let (bh, seq, hd) = (2, 40, 16);
|
||||
let n = bh * seq * hd;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
let q_h = fill(n, 441);
|
||||
let k_h = fill(n, 442);
|
||||
let v_h = fill(n, 443);
|
||||
let w = fill(n, 444);
|
||||
|
||||
let run = |flash: bool| -> (Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
|
||||
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
|
||||
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
|
||||
let out = if flash {
|
||||
ops::flash_attention(&q, &k, &v, scale)
|
||||
} else {
|
||||
ops::attention(&q, &k, &v, scale)
|
||||
};
|
||||
scalar_loss(&out, &w).backward();
|
||||
let g = |x: &Var| {
|
||||
x.grad()
|
||||
.unwrap()
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec()
|
||||
};
|
||||
(g(&q), g(&k), g(&v))
|
||||
};
|
||||
let (cq, ck, cv) = run(false);
|
||||
let (fq, fk, fv) = run(true);
|
||||
let maxrel = |a: &[f32], b: &[f32]| -> f32 {
|
||||
a.iter()
|
||||
.zip(b)
|
||||
.map(|(x, y)| (x - y).abs() / (x.abs() + y.abs() + 1e-4))
|
||||
.fold(0.0f32, f32::max)
|
||||
};
|
||||
let (rq, rk, rv) = (maxrel(&cq, &fq), maxrel(&ck, &fk), maxrel(&cv, &fv));
|
||||
println!("flash-vs-composed bwd max rel: dQ {rq:.3e} dK {rk:.3e} dV {rv:.3e}");
|
||||
assert!(rq < 2e-2, "dQ diverges: {rq:.3e}");
|
||||
assert!(rk < 2e-2, "dK diverges: {rk:.3e}");
|
||||
assert!(rv < 2e-2, "dV diverges: {rv:.3e}");
|
||||
}
|
||||
|
||||
// ---- dropout (Phase T18) ----
|
||||
//
|
||||
// Fixed-seed finite-diff grad-check. Under a fixed `seed` the mask is constant
|
||||
// (it depends only on (seed, index), NOT on x), so dropout is a fixed elementwise
|
||||
// linear map `out_i = c_i·x_i` and the central difference of L is differentiable:
|
||||
// the ± perturbation of each x_i sees the SAME mask. The forward function in the
|
||||
// closure calls `ops::dropout(x, p, SEED)` with the same SEED, so it reproduces
|
||||
// the same mask both times.
|
||||
#[test]
|
||||
fn dropout_bwd() {
|
||||
require_gpu();
|
||||
const SEED: u64 = 0xD120_FE5E;
|
||||
let p = 0.3f32;
|
||||
let (m, n) = (16, 12);
|
||||
let x_h = fill(m * n, 71);
|
||||
let w = fill(m * n, 72);
|
||||
|
||||
let x = Var::leaf(cuda(&x_h, &[m, n]));
|
||||
let out = ops::dropout(&x, p, SEED);
|
||||
scalar_loss(&out, &w).backward();
|
||||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||
|
||||
let wf = w.clone();
|
||||
let lx = move |v: &[f32], s: &[usize]| {
|
||||
let o = ops::dropout(&Var::leaf(cuda(v, s)), p, SEED);
|
||||
weighted_sum(&o.value(), &wf)
|
||||
};
|
||||
report(
|
||||
"dropout dX",
|
||||
&grad_check(&x_h, &[m, n], &lx, dx.as_slice::<f32>(), cfg_linear()),
|
||||
);
|
||||
}
|
||||
|
||||
// Inverted-dropout expectation + keep-rate check. Over a large tensor and a sweep
|
||||
// of seeds, the mean of dropout(x) tracks the mean of x (E[out] ≈ x, the inverted
|
||||
// 1/(1-p) scaling), and the kept fraction tracks 1-p (the RNG is ~Bernoulli).
|
||||
#[test]
|
||||
fn dropout_expectation_and_keep_rate() {
|
||||
require_gpu();
|
||||
let p = 0.25f32;
|
||||
let n = 200_000usize;
|
||||
let x_h = vec![1.0f32; n]; // mean(x) = 1 → mean(out) should ≈ 1
|
||||
let x = cuda(&x_h, &[n]);
|
||||
|
||||
let trials = 8;
|
||||
let mut mean_out_acc = 0.0f64;
|
||||
let mut keep_acc = 0.0f64;
|
||||
for t in 0..trials {
|
||||
let (out, mask) = x.dropout(p, 0x5EED_0000 + t as u64);
|
||||
let out_h = out.to_device(Device::Cpu);
|
||||
let mask_h = mask.to_device(Device::Cpu);
|
||||
let mean_out: f64 =
|
||||
out_h.as_slice::<f32>().iter().map(|&v| v as f64).sum::<f64>() / n as f64;
|
||||
let kept = mask_h.as_slice::<f32>().iter().filter(|&&m| m != 0.0).count();
|
||||
mean_out_acc += mean_out;
|
||||
keep_acc += kept as f64 / n as f64;
|
||||
}
|
||||
let mean_out = mean_out_acc / trials as f64;
|
||||
let keep_rate = keep_acc / trials as f64;
|
||||
println!(
|
||||
"dropout p={p}: E[out]={mean_out:.5} (input mean 1.0), keep_rate={keep_rate:.5} (1-p={:.3})",
|
||||
1.0 - p
|
||||
);
|
||||
assert!(
|
||||
(mean_out - 1.0).abs() < 0.01,
|
||||
"E[out] {mean_out} not ≈ input mean 1.0 (inverted scaling broken)"
|
||||
);
|
||||
assert!(
|
||||
(keep_rate - (1.0 - p) as f64).abs() < 0.01,
|
||||
"keep_rate {keep_rate} not ≈ 1-p {}",
|
||||
1.0 - p
|
||||
);
|
||||
}
|
||||
|
||||
// p=0 is a no-op (the op returns x.clone(), no node) → output is bit-identical to
|
||||
// x and its grad flows straight through (the default-graph regression guard at the
|
||||
// op level; the model-level bit-identity is in xtrain-model/tests/dropout.rs).
|
||||
#[test]
|
||||
fn dropout_p0_is_identity() {
|
||||
require_gpu();
|
||||
let (m, n) = (8, 5);
|
||||
let x_h = fill(m * n, 91);
|
||||
let x = cuda(&x_h, &[m, n]);
|
||||
let (out, _mask) = x.dropout(0.0, 12345);
|
||||
let out_h = out.to_device(Device::Cpu);
|
||||
for (a, b) in x_h.iter().zip(out_h.as_slice::<f32>()) {
|
||||
assert_eq!(*a, *b, "p=0 dropout must be identity");
|
||||
}
|
||||
}
|
||||
|
||||
// --- test helpers ---
|
||||
|
||||
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We
|
||||
|
||||
@@ -36,7 +36,9 @@ fn main() {
|
||||
.file("../../csrc/ops/model.cu")
|
||||
.file("../../csrc/ops/optim.cu")
|
||||
.file("../../csrc/ops/attention.cu")
|
||||
.file("../../csrc/ops/flash_attention.cu")
|
||||
.file("../../csrc/ops/cast.cu")
|
||||
.file("../../csrc/ops/dropout.cu")
|
||||
.compile("xtrain_cuda_kernels");
|
||||
}
|
||||
|
||||
|
||||
@@ -243,6 +243,59 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
// Fused flash-attention (csrc/ops/flash_attention.cu, Phase T14). A SINGLE kernel
|
||||
// each for forward/backward that streams over KV tiles with an online softmax and
|
||||
// NEVER materializes the [bh,S,S] score matrix. Q/K/V/out are [bh,S,hd] row-major
|
||||
// F32; the forward saves only the per-row logsumexp `l` ([bh*S], O(N)) for backward.
|
||||
#[cfg(not(no_cuda))]
|
||||
unsafe extern "C" {
|
||||
// Forward: o[bh,S,hd] = softmax(causal(Q·Kᵀ·scale))·V, online over KV tiles.
|
||||
// Also writes l[bh*S] = per-row logsumexp (saved for backward, not the scores).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn launch_flash_attention_fwd_f32(
|
||||
q: *const f32,
|
||||
k: *const f32,
|
||||
v: *const f32,
|
||||
o: *mut f32,
|
||||
l: *mut f32,
|
||||
bh: i32,
|
||||
seq: i32,
|
||||
hd: i32,
|
||||
scale: f32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// Per-row D[i]=Σ_d dO[i,d]·O[i,d] over `rows`=bh*S rows of width `hd`. Must run
|
||||
// before the backward kernel (which takes the precomputed D, not O).
|
||||
pub fn launch_flash_attention_rowdot_f32(
|
||||
d_o: *const f32,
|
||||
o: *const f32,
|
||||
d_d: *mut f32,
|
||||
rows: i32,
|
||||
hd: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// Backward: recomputes scores from Q/K/V + saved logsumexp `l` (NO cached probs)
|
||||
// and the precomputed `d_d` (= D), produces dq/dk/dv. dq/dk/dv must be PRE-ZEROED
|
||||
// (dk/dv are accumulated across query rows via atomicAdd).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn launch_flash_attention_bwd_f32(
|
||||
q: *const f32,
|
||||
k: *const f32,
|
||||
v: *const f32,
|
||||
d_o: *const f32,
|
||||
l: *const f32,
|
||||
d_d: *mut f32,
|
||||
dq: *mut f32,
|
||||
dk: *mut f32,
|
||||
dv: *mut f32,
|
||||
bh: i32,
|
||||
seq: i32,
|
||||
hd: i32,
|
||||
scale: f32,
|
||||
s: CudaStream,
|
||||
);
|
||||
}
|
||||
|
||||
// GPU-side optimizer kernels (csrc/ops/optim.cu): AdamW step (m/v on device) and
|
||||
// the global grad-norm reduction + in-place rescale (Phase T7).
|
||||
#[cfg(not(no_cuda))]
|
||||
@@ -447,3 +500,48 @@ unsafe extern "C" {
|
||||
s: CudaStream,
|
||||
);
|
||||
}
|
||||
|
||||
// Dropout (Phase T18, csrc/ops/dropout.cu). A counter-based (stateless) RNG: the
|
||||
// keep/drop decision for element `i` is `hash(seed, i)` — no global state, so a
|
||||
// re-run with the same `seed` reproduces the same mask (compatible with T13
|
||||
// activation recomputation). Forward writes `out = x ⊙ mask` and the fp32 `mask`
|
||||
// buffer (mask[i] = (1/(1-p)) if kept else 0, the inverted-dropout scale);
|
||||
// backward applies the SAME mask: dx = d ⊙ mask. fp32 + bf16 activation variants
|
||||
// (mask is fp32 in both; the uniform is computed in fp32, dtype-independent).
|
||||
#[cfg(not(no_cuda))]
|
||||
unsafe extern "C" {
|
||||
pub fn launch_dropout_fwd_f32(
|
||||
x: *const f32,
|
||||
out: *mut f32,
|
||||
mask: *mut f32,
|
||||
p: f32,
|
||||
scale: f32,
|
||||
seed: u64,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_dropout_bwd_f32(
|
||||
d: *const f32,
|
||||
mask: *const f32,
|
||||
dx: *mut f32,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_dropout_fwd_bf16(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
mask: *mut f32,
|
||||
p: f32,
|
||||
scale: f32,
|
||||
seed: u64,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_dropout_bwd_bf16(
|
||||
d: *const c_void,
|
||||
mask: *const f32,
|
||||
dx: *mut c_void,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -74,6 +74,10 @@ fn main() {
|
||||
// Optimization knobs (mirror bin/train).
|
||||
let steps: usize = flag(&args, "--steps", 100);
|
||||
let batch: usize = flag(&args, "--batch", 16);
|
||||
// Micro-batch gradient accumulation (Phase T16): effective global batch =
|
||||
// accum_steps × batch, all-reducing only at the accumulation boundary. Default
|
||||
// 1 = no accumulation (bit-identical to the pre-T16 DDP path).
|
||||
let accum_steps: usize = flag(&args, "--accum-steps", 1).max(1);
|
||||
let seq_len: usize = flag(&args, "--seq", 64);
|
||||
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
|
||||
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
|
||||
@@ -89,6 +93,9 @@ fn main() {
|
||||
// rank checkpoints its own forward/backward; exact grads, lower peak activation
|
||||
// memory (lets dim1024 batch32 fit). Opt-in; default off.
|
||||
let recompute = args.iter().any(|a| a == "--recompute");
|
||||
// Fused flash-attention (Phase T14): single fused SDPA kernel, online softmax,
|
||||
// no materialized [bh,S,S] scores. Opt-in; default off keeps the composed path.
|
||||
let flash = args.iter().any(|a| a == "--flash");
|
||||
let ckpt: Option<PathBuf> = args
|
||||
.iter()
|
||||
.position(|a| a == "--ckpt")
|
||||
@@ -147,6 +154,7 @@ fn main() {
|
||||
let dcfg = DdpConfig {
|
||||
seq_len,
|
||||
batch_size: batch,
|
||||
accum_steps,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr,
|
||||
@@ -164,8 +172,9 @@ fn main() {
|
||||
};
|
||||
|
||||
println!(
|
||||
"training: {steps} steps, seq {seq_len}, global batch {batch}, lr {max_lr:.1e}→{min_lr:.1e}, \
|
||||
eval every {eval_every}"
|
||||
"training: {steps} steps, seq {seq_len}, global batch {batch} × accum {accum_steps} = \
|
||||
effective global batch {}, lr {max_lr:.1e}→{min_lr:.1e}, eval every {eval_every}",
|
||||
batch * accum_steps
|
||||
);
|
||||
|
||||
if bf16 {
|
||||
@@ -174,6 +183,9 @@ fn main() {
|
||||
if recompute {
|
||||
println!("activation recompute: ON (per-block gradient checkpointing)");
|
||||
}
|
||||
if flash {
|
||||
println!("flash-attention: ON (fused SDPA kernel, no materialized scores)");
|
||||
}
|
||||
let results = launch(
|
||||
&devices,
|
||||
&train_corpus,
|
||||
@@ -187,6 +199,9 @@ fn main() {
|
||||
if recompute {
|
||||
m = m.with_recompute(true);
|
||||
}
|
||||
if flash {
|
||||
m = m.with_flash(true);
|
||||
}
|
||||
m
|
||||
},
|
||||
);
|
||||
|
||||
@@ -35,6 +35,13 @@ pub struct DdpConfig {
|
||||
pub seq_len: usize,
|
||||
/// Global batch size; must be divisible by the world size.
|
||||
pub batch_size: usize,
|
||||
/// Micro-batch gradient accumulation (Phase T16): each optimizer step
|
||||
/// accumulates grads over `accum_steps` micro-batches, giving an EFFECTIVE
|
||||
/// global batch of `accum_steps × batch_size`. The cross-rank all-reduce
|
||||
/// fires ONLY at the accumulation boundary (after the last micro-step) —
|
||||
/// intermediate micro-steps skip the NCCL collective entirely. `1` = no
|
||||
/// accumulation (bit-identical to the pre-T16 DDP path).
|
||||
pub accum_steps: usize,
|
||||
pub steps: usize,
|
||||
pub schedule: LrSchedule,
|
||||
pub weight_decay: f32,
|
||||
@@ -96,6 +103,7 @@ pub fn train_rank(
|
||||
// (sum across ranks, /world) then gives Σ_global/(world·b_local) = Σ_global/
|
||||
// B_global — already the global-batch mean — so the clip pre-scale is 1.0.
|
||||
let batch_local = cfg.batch_size / ctx.world;
|
||||
let accum = cfg.accum_steps.max(1);
|
||||
let start = Instant::now();
|
||||
let mut tokens_seen: u64 = 0;
|
||||
// Rank 0 owns the held-out eval + best-val checkpoint (params are identical
|
||||
@@ -105,36 +113,51 @@ pub fn train_rank(
|
||||
for step in 0..cfg.steps {
|
||||
let lr = cfg.schedule.lr(step);
|
||||
|
||||
// Draw the whole global batch from the shared RNG (same on every rank);
|
||||
// collect only this rank's shard (global index % world == rank) and run it
|
||||
// as ONE batched forward/backward. The union of shards == the single-GPU
|
||||
// batch; each rank's backward yields its local mean (Σ_local / b_local).
|
||||
let mut inputs = Vec::with_capacity(batch_local);
|
||||
let mut targets_v = Vec::with_capacity(batch_local);
|
||||
for i in 0..cfg.batch_size {
|
||||
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
||||
if i % ctx.world == ctx.rank {
|
||||
inputs.push(input);
|
||||
targets_v.push(target);
|
||||
// Accumulate grads over `accum` micro-batches, then ONE optimizer step
|
||||
// (Phase T16). Per micro-batch: draw the whole micro global batch from the
|
||||
// shared RNG (same on every rank), keep only this rank's shard (global index
|
||||
// % world == rank), run it as ONE batched forward/backward. Each micro-loss
|
||||
// is scaled by 1/accum before backward (the tape SUM-accumulates the scaled
|
||||
// grads across the `accum` micro-backwards) so the boundary grad equals a
|
||||
// single step over an `accum × batch_size` global batch. `accum == 1` skips
|
||||
// the scale → bit-identical to the pre-T16 DDP path. The cross-rank
|
||||
// all-reduce fires ONLY after the last micro-step (intermediate micro-steps
|
||||
// are local-only, no NCCL).
|
||||
let mut local_sum = 0.0f32; // Σ over micro of (local_mean · b_local)
|
||||
for _ in 0..accum {
|
||||
let mut inputs = Vec::with_capacity(batch_local);
|
||||
let mut targets_v = Vec::with_capacity(batch_local);
|
||||
for i in 0..cfg.batch_size {
|
||||
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
||||
if i % ctx.world == ctx.rank {
|
||||
inputs.push(input);
|
||||
targets_v.push(target);
|
||||
}
|
||||
}
|
||||
let ids = batched_ids_tensor(&inputs, device);
|
||||
let targets = batched_ids_tensor(&targets_v, device);
|
||||
let loss = model.loss_batched(&ids, &targets, batch_local);
|
||||
local_sum += read_scalar(&loss) * batch_local as f32; // local mean·b_local
|
||||
if accum == 1 {
|
||||
loss.backward();
|
||||
} else {
|
||||
xtrain_autodiff::ops::scale(&loss, 1.0 / accum as f32).backward();
|
||||
}
|
||||
tokens_seen += (batch_local * cfg.seq_len) as u64;
|
||||
}
|
||||
let ids = batched_ids_tensor(&inputs, device);
|
||||
let targets = batched_ids_tensor(&targets_v, device);
|
||||
let loss = model.loss_batched(&ids, &targets, batch_local);
|
||||
let local_mean = read_scalar(&loss); // Σ_local / b_local
|
||||
loss.backward();
|
||||
tokens_seen += (batch_local * cfg.seq_len) as u64;
|
||||
|
||||
// AllReduce(sum) + /world the grads → every rank holds Σ_global/B_global
|
||||
// (local means summed over ranks, /world = global mean). See note above.
|
||||
// Accumulation boundary: ONE AllReduce(sum) + /world over the accumulated
|
||||
// grads → every rank holds the effective-batch (accum·B_global) mean grad
|
||||
// (the per-micro 1/accum scaling is already baked into each backward; the
|
||||
// /world here is orthogonal to accum). Intermediate micro-steps issued NO
|
||||
// NCCL — only this single boundary collective per optimizer step.
|
||||
ctx.all_reduce_average_grads(¶ms);
|
||||
// Reported loss = global mean: sum the per-rank local sums (= mean·b_local)
|
||||
// across ranks, /B_global. With equal b_local this is mean over ranks.
|
||||
let step_loss =
|
||||
all_reduce_loss(ctx, local_mean * batch_local as f32) / cfg.batch_size as f32;
|
||||
// Reported loss = effective-batch mean: AllReduce(sum) the per-rank local
|
||||
// sums across ranks, /(accum·B_global).
|
||||
let step_loss = all_reduce_loss(ctx, local_sum) / (accum * cfg.batch_size) as f32;
|
||||
losses.push(step_loss);
|
||||
|
||||
// Grads are already the global-batch mean — just clip (pre-scale 1.0).
|
||||
// Grads are already the effective-batch mean — just clip (pre-scale 1.0).
|
||||
let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, 1.0);
|
||||
opt.step(lr, ¶ms);
|
||||
for p in ¶ms {
|
||||
|
||||
@@ -94,6 +94,7 @@ fn ddp_matches_single_gpu_and_params_consistent() {
|
||||
let dcfg = DdpConfig {
|
||||
seq_len: 32,
|
||||
batch_size: 8, // global; 4 per rank with world=2
|
||||
accum_steps: 1,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr: 3e-3,
|
||||
@@ -195,6 +196,127 @@ fn ddp_matches_single_gpu_and_params_consistent() {
|
||||
assert!(max_sdiff < 1e-2, "DDP params diverged from single-GPU");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ddp_with_accum_matches_single_gpu_big_batch() {
|
||||
// T16: DDP + gradient accumulation must match a single-GPU big-batch baseline
|
||||
// of the SAME effective batch. world=2, accum=2, per-rank micro-batch 2 →
|
||||
// effective global batch = world·accum·b_local = 2·2·2 = 8. Compared against a
|
||||
// single-GPU run with batch 8, accum 1 (the big-batch baseline). The all-reduce
|
||||
// fires only at the accumulation boundary (once per optimizer step, not per
|
||||
// micro-step) — enforced by the train_rank implementation; the load-bearing
|
||||
// gate here is that loss + final params still match the big-batch baseline.
|
||||
let world = 2usize;
|
||||
if device::device_count().unwrap_or(0) < world as i32 {
|
||||
eprintln!("skip: need >= {world} GPUs");
|
||||
return;
|
||||
}
|
||||
|
||||
let vocab = 64usize;
|
||||
let cfg = test_config(vocab);
|
||||
let corpus = synth_corpus(vocab, 4096);
|
||||
let steps = 20usize;
|
||||
let effective_batch = 8usize; // world(2) · accum(2) · b_local(2)
|
||||
let sched = LrSchedule {
|
||||
max_lr: 3e-3,
|
||||
min_lr: 3e-4,
|
||||
warmup: 3,
|
||||
total: steps,
|
||||
};
|
||||
|
||||
// Single-GPU big-batch baseline: world=1, accum=1, batch = effective_batch.
|
||||
let baseline_cfg = DdpConfig {
|
||||
seq_len: 32,
|
||||
batch_size: effective_batch,
|
||||
accum_steps: 1,
|
||||
steps,
|
||||
schedule: sched,
|
||||
weight_decay: 0.1,
|
||||
max_grad_norm: 1.0,
|
||||
log_every: 1_000_000,
|
||||
seed: 7,
|
||||
eval_every: 0,
|
||||
eval_batches: 0,
|
||||
ckpt_path: None,
|
||||
};
|
||||
let (single_losses, single_params) = run_single_gpu(cfg, &corpus, &baseline_cfg);
|
||||
|
||||
// DDP + accumulation: world=2, accum=2 → per-rank micro-batch = batch/world = 2.
|
||||
let ddp_cfg = DdpConfig {
|
||||
batch_size: effective_batch / 2, // per-step global batch; ×accum = effective
|
||||
accum_steps: 2,
|
||||
..baseline_cfg
|
||||
};
|
||||
let devices = [0u32, 1u32];
|
||||
let id = get_unique_id();
|
||||
let results: Vec<(Vec<f32>, Vec<Vec<f32>>)> = std::thread::scope(|s| {
|
||||
let handles: Vec<_> = devices
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(rank, &dev)| {
|
||||
let ddp_cfg = ddp_cfg.clone();
|
||||
let corpus = &corpus;
|
||||
s.spawn(move || {
|
||||
let ctx = DdpContext::init(rank, world, id, dev);
|
||||
let device = Device::Cuda(dev);
|
||||
let model = build_model(cfg, device);
|
||||
let res = train_rank(&ctx, &model, device, corpus, None, &ddp_cfg);
|
||||
let host = model
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| p.value().to_device(Device::Cpu).as_slice::<f32>().to_vec())
|
||||
.collect::<Vec<_>>();
|
||||
(res.losses, host)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
handles.into_iter().map(|h| h.join().unwrap()).collect()
|
||||
});
|
||||
|
||||
let (ddp_losses, ddp_p0) = &results[0];
|
||||
let (_, ddp_p1) = &results[1];
|
||||
|
||||
// (a) Loss trajectory matches the single-GPU big-batch baseline.
|
||||
let mut max_rel = 0.0f32;
|
||||
for (s, d) in single_losses.iter().zip(ddp_losses) {
|
||||
max_rel = max_rel.max((s - d).abs() / s.abs().max(1e-6));
|
||||
}
|
||||
println!(
|
||||
"DDP+accum(w2·a2·b2) vs single-GPU big-batch(8): single[last]={:.6} ddp[last]={:.6} max_rel={max_rel:.2e}",
|
||||
single_losses.last().unwrap(),
|
||||
ddp_losses.last().unwrap()
|
||||
);
|
||||
assert!(
|
||||
max_rel < 1e-3,
|
||||
"DDP+accum loss diverged from big-batch baseline: {max_rel:.3e}"
|
||||
);
|
||||
|
||||
// (b) Cross-rank parameter agreement (same KI-5 ULP tolerance as the base test).
|
||||
let mut max_pdiff = 0.0f32;
|
||||
for (a, b) in ddp_p0.iter().zip(ddp_p1) {
|
||||
for (x, y) in a.iter().zip(b) {
|
||||
max_pdiff = max_pdiff.max((x - y).abs());
|
||||
}
|
||||
}
|
||||
println!("DDP+accum cross-rank max |param diff| = {max_pdiff:.3e}");
|
||||
assert!(
|
||||
max_pdiff < 1e-6,
|
||||
"ranks' params drifted apart: {max_pdiff:.3e}"
|
||||
);
|
||||
|
||||
// (c) Final params match single-GPU big-batch within fp tolerance.
|
||||
let mut max_sdiff = 0.0f32;
|
||||
for (a, b) in ddp_p0.iter().zip(&single_params) {
|
||||
for (x, y) in a.iter().zip(b) {
|
||||
max_sdiff = max_sdiff.max((x - y).abs() / y.abs().max(1e-6));
|
||||
}
|
||||
}
|
||||
println!("DDP+accum vs single-GPU big-batch max rel |param diff| = {max_sdiff:.3e}");
|
||||
assert!(
|
||||
max_sdiff < 1e-2,
|
||||
"DDP+accum params diverged from big-batch baseline"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ddp_throughput_scaling() {
|
||||
let max_gpus = device::device_count().unwrap_or(0) as usize;
|
||||
@@ -230,6 +352,7 @@ fn ddp_throughput_scaling() {
|
||||
let dcfg = DdpConfig {
|
||||
seq_len,
|
||||
batch_size: per_gpu_batch * world,
|
||||
accum_steps: 1,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr: 1e-3,
|
||||
|
||||
@@ -20,6 +20,11 @@ pub struct Config {
|
||||
pub eps: f32,
|
||||
/// RoPE base frequency (theta).
|
||||
pub rope_theta: f32,
|
||||
/// Dropout probability `p` (Phase T18). Applied at the attention/MLP sub-block
|
||||
/// outputs (before each residual add) at TRAINING time, with inverted scaling
|
||||
/// `1/(1-p)`; disabled (identity) at eval. Default `0.0` = no dropout, and the
|
||||
/// forward graph is then bit-identical to the pre-T18 path.
|
||||
pub dropout: f32,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -36,6 +41,7 @@ impl Config {
|
||||
ffn_hidden: 64,
|
||||
eps: 1e-5,
|
||||
rope_theta: 10000.0,
|
||||
dropout: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,6 +66,7 @@ impl Config {
|
||||
ffn_hidden,
|
||||
eps: 1e-5,
|
||||
rope_theta: 10000.0,
|
||||
dropout: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use std::cell::Cell;
|
||||
|
||||
use crate::config::Config;
|
||||
use xtrain_autodiff::ops;
|
||||
use xtrain_autodiff::tape::Var;
|
||||
@@ -47,6 +49,27 @@ pub struct TinyTransformer {
|
||||
/// existing numerics are bit-identical; recompute is mathematically exact, so
|
||||
/// grads match the non-checkpointed path within fp tolerance.
|
||||
recompute: bool,
|
||||
/// Fused flash-attention (Phase T14). When `true`, the SDPA core runs through
|
||||
/// the hand-written single fused kernel ([`ops::flash_attention`]): online
|
||||
/// softmax over KV tiles, the `[bh,seq,seq]` score matrix NEVER materialized,
|
||||
/// backward caches only the O(N) logsumexp. Default `false` = the composed T10
|
||||
/// path (`cublasSgemmStridedBatched`×2 + causal-softmax kernel, O(N²) probs),
|
||||
/// so the default graph is unchanged. Mathematically the same SDPA → grads/loss
|
||||
/// match the composed path within fp/bf16 tolerance. Opt-in via `--flash`.
|
||||
use_flash: bool,
|
||||
/// Training mode for dropout (Phase T18). `true` → the attn/MLP sub-block
|
||||
/// outputs pass through `ops::dropout` (with `cfg.dropout` and a per-step,
|
||||
/// per-site seed); `false` (default) → dropout is identity (eval/sampling/
|
||||
/// export). `Cell` so `train()`/`eval()` flip it through `&self` (the forward
|
||||
/// takes `&self`). When `cfg.dropout == 0` this flag is irrelevant — the graph
|
||||
/// is bit-identical to the no-dropout path either way.
|
||||
training: Cell<bool>,
|
||||
/// Per-step dropout RNG seed (Phase T18). Bumped once at the start of each
|
||||
/// TRAINING forward so every step draws fresh masks; combined with the layer
|
||||
/// index + a per-site constant to give each dropout site its own seed. The RNG
|
||||
/// is counter-based, so re-running a checkpointed block's forward in backward
|
||||
/// (T13) reproduces the same seed → the same mask (recompute stays exact).
|
||||
step_seed: Cell<u64>,
|
||||
}
|
||||
|
||||
impl TinyTransformer {
|
||||
@@ -90,6 +113,9 @@ impl TinyTransformer {
|
||||
lm_head,
|
||||
compute_dtype: DType::F32,
|
||||
recompute: false,
|
||||
use_flash: false,
|
||||
training: Cell::new(false),
|
||||
step_seed: Cell::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,6 +153,43 @@ impl TinyTransformer {
|
||||
self.recompute
|
||||
}
|
||||
|
||||
/// Enable the fused flash-attention SDPA core (Phase T14). Builder-style and
|
||||
/// opt-in; default off keeps the composed T10 path (so the default graph is
|
||||
/// unchanged). On, the SDPA runs through [`ops::flash_attention`] — same SDPA
|
||||
/// math, online softmax, no materialized `[bh,seq,seq]` scores.
|
||||
pub fn with_flash(mut self, use_flash: bool) -> Self {
|
||||
self.use_flash = use_flash;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn use_flash(&self) -> bool {
|
||||
self.use_flash
|
||||
}
|
||||
|
||||
/// Switch to training mode (Phase T18): dropout (if `cfg.dropout > 0`) is
|
||||
/// active in subsequent forwards. The training loop calls this before stepping.
|
||||
pub fn train(&self) {
|
||||
self.training.set(true);
|
||||
}
|
||||
|
||||
/// Switch to eval mode (Phase T18): dropout is identity. Held-out eval,
|
||||
/// autoregressive sampling, and weight export all run in this mode (default).
|
||||
pub fn eval(&self) {
|
||||
self.training.set(false);
|
||||
}
|
||||
|
||||
pub fn is_training(&self) -> bool {
|
||||
self.training.get()
|
||||
}
|
||||
|
||||
/// Builder-style train/eval toggle (Phase T18) — handy for tests that want a
|
||||
/// model fixed in one mode. Equivalent to [`train`](Self::train) /
|
||||
/// [`eval`](Self::eval) but chains off `new(..)`.
|
||||
pub fn with_training(self, training: bool) -> Self {
|
||||
self.training.set(training);
|
||||
self
|
||||
}
|
||||
|
||||
/// All learnable parameters, in a stable order. The optimizer (a hand-written
|
||||
/// GD step in T5, AdamW in T6) iterates this; each holds its `.grad()` after
|
||||
/// `backward()`.
|
||||
@@ -176,28 +239,57 @@ impl TinyTransformer {
|
||||
);
|
||||
let seq = total / batch;
|
||||
|
||||
// Dropout (T18) is active only in training mode with p>0; otherwise it is
|
||||
// identity (`ops::dropout` no-ops at p==0). Bump the per-step seed ONCE per
|
||||
// training forward so each step draws fresh masks (counter-based RNG, so a
|
||||
// checkpointed block's recompute reproduces the same seed → same mask).
|
||||
let dropout_p = if self.training.get() {
|
||||
self.cfg.dropout
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
if dropout_p > 0.0 {
|
||||
self.step_seed.set(self.step_seed.get().wrapping_add(1));
|
||||
}
|
||||
let base_seed = self.step_seed.get();
|
||||
|
||||
// Embedding gathers from the fp32 master table; in bf16 mode cast the
|
||||
// activation stream to bf16 here (norms are cast to bf16 gammas too).
|
||||
let mut h = ops::embedding(&self.embed, ids); // [batch*seq, dim], fp32
|
||||
if self.compute_dtype == DType::BF16 {
|
||||
h = ops::cast(&h, DType::BF16);
|
||||
}
|
||||
for b in &self.blocks {
|
||||
for (li, b) in self.blocks.iter().enumerate() {
|
||||
// Per-layer dropout seed: a deterministic function of (base_seed,
|
||||
// layer index) — NOT a mutable counter — so the checkpoint recompute
|
||||
// (which re-derives it from the captured base_seed/li) gets the same
|
||||
// masks. The block derives its two per-site seeds from this.
|
||||
let block_seed = base_seed
|
||||
.wrapping_mul(0x100000001B3)
|
||||
.wrapping_add(li as u64);
|
||||
h = if self.recompute {
|
||||
// Activation recomputation (T13): run the whole block forward inside
|
||||
// `checkpoint` so its internal activations aren't kept on the tape;
|
||||
// the block forward is re-run in backward to recover the grads. The
|
||||
// segment fn captures only `Copy` config (no borrow of `self`) and
|
||||
// receives the block's params via the slice, in `block_params` order.
|
||||
let (cfg, cdt) = (self.cfg, self.compute_dtype);
|
||||
let seg = move |x: &Var, p: &[Var]| block_forward(cfg, cdt, batch, seq, x, p);
|
||||
// `flash` is captured too → the recompute segment also runs flash;
|
||||
// `dropout_p`/`block_seed` are captured so the recompute re-derives
|
||||
// the same per-site dropout masks (counter-based RNG, exact).
|
||||
let (cfg, cdt, flash) = (self.cfg, self.compute_dtype, self.use_flash);
|
||||
let seg = move |x: &Var, p: &[Var]| {
|
||||
block_forward(cfg, cdt, flash, batch, seq, dropout_p, block_seed, x, p)
|
||||
};
|
||||
xtrain_autodiff::checkpoint::checkpoint(seg, &h, &b.block_params())
|
||||
} else {
|
||||
block_forward(
|
||||
self.cfg,
|
||||
self.compute_dtype,
|
||||
self.use_flash,
|
||||
batch,
|
||||
seq,
|
||||
dropout_p,
|
||||
block_seed,
|
||||
&h,
|
||||
&b.block_params(),
|
||||
)
|
||||
@@ -275,25 +367,48 @@ fn norm_gamma(cdt: DType, gamma: &Var) -> Var {
|
||||
}
|
||||
|
||||
/// One transformer block's forward: pre-norm + multi-head causal attention +
|
||||
/// residual, then pre-norm + SwiGLU MLP + residual. Pure in `(cfg, cdt, batch,
|
||||
/// seq, input, params)` (no `&self`) so it can be the segment fn of
|
||||
/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13). `params` is
|
||||
/// the block's leaves in [`Block::block_params`] order.
|
||||
fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p: &[Var]) -> Var {
|
||||
/// (T18) dropout + residual, then pre-norm + SwiGLU MLP + dropout + residual.
|
||||
/// Attention runs the composed or fused-flash (T14) SDPA per `flash`. Pure in
|
||||
/// `(cfg, cdt, flash, batch, seq, dropout_p, block_seed, input, params)` (no
|
||||
/// `&self`, all `Copy`) so it can be the segment fn of
|
||||
/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13) — the
|
||||
/// recompute re-derives the same per-site seeds, so the dropout masks are
|
||||
/// reproduced bit-for-bit. `dropout_p == 0` makes `ops::dropout` a no-op (the
|
||||
/// graph is then identical to the pre-T18 path). `params` is the block's leaves in
|
||||
/// [`Block::block_params`] order.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn block_forward(
|
||||
cfg: Config,
|
||||
cdt: DType,
|
||||
flash: bool,
|
||||
batch: usize,
|
||||
seq: usize,
|
||||
dropout_p: f32,
|
||||
block_seed: u64,
|
||||
h: &Var,
|
||||
p: &[Var],
|
||||
) -> Var {
|
||||
let (attn_norm, wq, wk, wv) = (&p[0], &p[1], &p[2], &p[3]);
|
||||
let (q_norm, k_norm, wo) = (&p[4], &p[5], &p[6]);
|
||||
let (ffn_norm, w_gate, w_up, w_down) = (&p[7], &p[8], &p[9], &p[10]);
|
||||
|
||||
// --- Attention sub-block (pre-norm + residual) ---
|
||||
// Per-site dropout seeds (XOR a site constant into the block seed) so the two
|
||||
// residual-path dropouts draw independent masks within the same step/layer.
|
||||
let attn_seed = block_seed ^ 0x0A7700;
|
||||
let ffn_seed = block_seed ^ 0x0FF700;
|
||||
|
||||
// --- Attention sub-block (pre-norm + dropout + residual) ---
|
||||
let normed = ops::rms_norm(h, &norm_gamma(cdt, attn_norm), cfg.eps);
|
||||
let attn = attention(
|
||||
cfg, cdt, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo,
|
||||
cfg, cdt, flash, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo,
|
||||
);
|
||||
let attn = ops::dropout(&attn, dropout_p, attn_seed);
|
||||
let h = ops::add(h, &attn);
|
||||
|
||||
// --- MLP sub-block (pre-norm + residual) ---
|
||||
// --- MLP sub-block (pre-norm + dropout + residual) ---
|
||||
let normed = ops::rms_norm(&h, &norm_gamma(cdt, ffn_norm), cfg.eps);
|
||||
let mlp = swiglu_mlp(cdt, &normed, w_gate, w_up, w_down);
|
||||
let mlp = ops::dropout(&mlp, dropout_p, ffn_seed);
|
||||
ops::add(&h, &mlp)
|
||||
}
|
||||
|
||||
@@ -308,6 +423,7 @@ fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p:
|
||||
fn attention(
|
||||
cfg: Config,
|
||||
cdt: DType,
|
||||
flash: bool,
|
||||
batch: usize,
|
||||
seq: usize,
|
||||
x: &Var,
|
||||
@@ -351,9 +467,14 @@ fn attention(
|
||||
let k = to_bh(linear(cdt, x, wk), Some(k_norm));
|
||||
let v = to_bh(linear(cdt, x, wv), None);
|
||||
|
||||
// Fused batched causal SDPA over all B*nh (sequence,head) blocks at once
|
||||
// (2 batched GEMMs + 1 causal-softmax kernel; no per-head/per-seq loop).
|
||||
let out = ops::attention(&q, &k, &v, scale); // [B*nh, S, hd]
|
||||
// Causal SDPA over all B*nh (sequence,head) blocks. `flash` (T14) picks the
|
||||
// single fused flash kernel (online softmax, no materialized [bh,S,S] scores);
|
||||
// otherwise the composed T10 path (2 batched GEMMs + 1 causal-softmax kernel).
|
||||
let out = if flash {
|
||||
ops::flash_attention(&q, &k, &v, scale) // [B*nh, S, hd]
|
||||
} else {
|
||||
ops::attention(&q, &k, &v, scale) // [B*nh, S, hd]
|
||||
};
|
||||
|
||||
// Back to [B*S, dim]: [B*nh,S,hd] → [B,nh,S,hd] → transpose(1,2) →
|
||||
// [B,S,nh,hd] → [B*S, dim].
|
||||
|
||||
283
crates/xtrain-model/tests/dropout.rs
Normal file
283
crates/xtrain-model/tests/dropout.rs
Normal file
@@ -0,0 +1,283 @@
|
||||
// T18 dropout model-level gates.
|
||||
//
|
||||
// 1. p=0 bit-identical: a model built with cfg.dropout=0 (in either train or
|
||||
// eval mode) produces logits/loss/grads bit-for-bit identical to the same
|
||||
// model with no dropout field touched — the default forward graph is
|
||||
// unchanged (the regression guard).
|
||||
// 2. eval identity: with p>0 but eval mode, the forward equals the p=0 forward
|
||||
// bit-for-bit (dropout is OFF at eval).
|
||||
// 3. train vs eval differ: with p>0 and train mode, the forward differs from
|
||||
// eval (dropout actually does something) and grads are still finite.
|
||||
// 4. recompute compatibility: with p>0 + train + recompute, grads match the
|
||||
// non-recompute path (the counter-based seed reproduces the same mask on the
|
||||
// backward re-run — T13 stays exact even with dropout in the block).
|
||||
//
|
||||
// (The fixed-seed grad-check of the dropout op and the E[out]≈x / keep-rate check
|
||||
// live in xtrain-autodiff/tests/autograd.rs; p>0 training convergence is the
|
||||
// dash5 short run noted in docs/17-dropout.md.)
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
||||
use xtrain_tensor::{DType, Device};
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build(cfg: Config, device: Device) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||||
t.to_dtype(DType::F32)
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec()
|
||||
}
|
||||
|
||||
fn tiny_cfg(dropout: f32) -> Config {
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 16;
|
||||
cfg.n_layers = 4;
|
||||
cfg.dropout = dropout;
|
||||
cfg
|
||||
}
|
||||
|
||||
fn batch_data(cfg: &Config, device: Device) -> (xtrain_tensor::Tensor, xtrain_tensor::Tensor) {
|
||||
let (batch, seq) = (3usize, 6usize);
|
||||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect())
|
||||
.collect();
|
||||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect())
|
||||
.collect();
|
||||
(
|
||||
batched_ids_tensor(&seqs, device),
|
||||
batched_ids_tensor(&tgts, device),
|
||||
)
|
||||
}
|
||||
|
||||
fn require_gpu() -> Device {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
Device::Cuda(0)
|
||||
}
|
||||
|
||||
// Run forward+backward, return (logits, loss, per-param grads).
|
||||
fn fwd_bwd(
|
||||
m: &TinyTransformer,
|
||||
ids: &xtrain_tensor::Tensor,
|
||||
tgt: &xtrain_tensor::Tensor,
|
||||
batch: usize,
|
||||
) -> (Vec<f32>, f32, Vec<Vec<f32>>) {
|
||||
let logits = host(&m.forward_batched(ids, batch).value());
|
||||
let loss = m.loss_batched(ids, tgt, batch);
|
||||
let loss_val = host(&loss.value())[0];
|
||||
loss.backward();
|
||||
let grads: Vec<Vec<f32>> = m.params().iter().map(|p| host(&p.grad().unwrap())).collect();
|
||||
(logits, loss_val, grads)
|
||||
}
|
||||
|
||||
// --- Gate 3: p=0 is bit-identical to the no-dropout path (default graph). ---
|
||||
#[test]
|
||||
fn dropout_p0_bit_identical() {
|
||||
let device = require_gpu();
|
||||
let batch = 3;
|
||||
|
||||
// Reference: cfg.dropout default (0.0), never touched train/eval.
|
||||
let cfg0 = tiny_cfg(0.0);
|
||||
let (ids, tgt) = batch_data(&cfg0, device);
|
||||
let ref_m = build(cfg0, device);
|
||||
let (ref_logits, ref_loss, ref_grads) = fwd_bwd(&ref_m, &ids, &tgt, batch);
|
||||
|
||||
// p=0 in TRAINING mode: the seed bump is gated on p>0, the op no-ops at p==0,
|
||||
// so the graph must be byte-identical.
|
||||
let p0_train = build(tiny_cfg(0.0), device);
|
||||
p0_train.train();
|
||||
let (lt, lst, gt) = fwd_bwd(&p0_train, &ids, &tgt, batch);
|
||||
|
||||
assert_eq!(ref_logits, lt, "p=0 train logits not bit-identical");
|
||||
assert_eq!(ref_loss, lst, "p=0 train loss not bit-identical");
|
||||
for (i, (a, b)) in ref_grads.iter().zip(>).enumerate() {
|
||||
assert_eq!(a, b, "p=0 train grad[{i}] not bit-identical");
|
||||
}
|
||||
println!("p=0 (train) vs no-dropout: logits/loss/grads bit-identical ✅");
|
||||
}
|
||||
|
||||
// --- Gate 2: eval is exact identity (p>0 but eval mode == p=0). ---
|
||||
#[test]
|
||||
fn dropout_eval_is_identity() {
|
||||
let device = require_gpu();
|
||||
let batch = 3;
|
||||
let cfg = tiny_cfg(0.2);
|
||||
let (ids, tgt) = batch_data(&cfg, device);
|
||||
|
||||
// p=0 reference and a p=0.2 model held in eval — outputs must match bit-for-bit.
|
||||
let ref_m = build(tiny_cfg(0.0), device);
|
||||
let (ref_logits, ref_loss, ref_grads) = fwd_bwd(&ref_m, &ids, &tgt, batch);
|
||||
|
||||
let eval_m = build(cfg, device);
|
||||
eval_m.eval(); // explicit; also the default
|
||||
let (el, els, eg) = fwd_bwd(&eval_m, &ids, &tgt, batch);
|
||||
|
||||
assert_eq!(ref_logits, el, "eval (p>0) logits not identity");
|
||||
assert_eq!(ref_loss, els, "eval (p>0) loss not identity");
|
||||
for (i, (a, b)) in ref_grads.iter().zip(&eg).enumerate() {
|
||||
assert_eq!(a, b, "eval (p>0) grad[{i}] not identity");
|
||||
}
|
||||
println!("eval (p=0.2) == no-dropout: bit-identical (eval is identity) ✅");
|
||||
}
|
||||
|
||||
// --- Gate (train vs eval differ): with p>0 + train, dropout actually fires. ---
|
||||
#[test]
|
||||
fn dropout_train_differs_from_eval() {
|
||||
let device = require_gpu();
|
||||
let batch = 3;
|
||||
let cfg = tiny_cfg(0.3);
|
||||
let (ids, _tgt) = batch_data(&cfg, device);
|
||||
|
||||
let m = build(cfg, device);
|
||||
m.eval();
|
||||
let eval_logits = host(&m.forward_batched(&ids, batch).value());
|
||||
m.train();
|
||||
let train_logits = host(&m.forward_batched(&ids, batch).value());
|
||||
|
||||
let max_diff = eval_logits
|
||||
.iter()
|
||||
.zip(&train_logits)
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.fold(0.0f32, f32::max);
|
||||
assert!(
|
||||
max_diff > 1e-4 && train_logits.iter().all(|v| v.is_finite()),
|
||||
"train logits should differ from eval (dropout active) and be finite; max_diff={max_diff}"
|
||||
);
|
||||
println!("train vs eval logits max diff {max_diff:.4e} (dropout active in train) ✅");
|
||||
}
|
||||
|
||||
// --- Gate 4: p>0 + recompute grads match non-recompute (T13 stays exact). ---
|
||||
// The counter-based seed is a pure function of (step_seed, layer, site); the
|
||||
// checkpoint backward re-runs block_forward and re-derives the SAME seeds, so the
|
||||
// recomputed dropout masks match the forward — grads stay bit-identical.
|
||||
fn recompute_with_dropout(dtype: DType, grad_tol: f32) {
|
||||
let device = require_gpu();
|
||||
let batch = 3;
|
||||
let cfg = tiny_cfg(0.2);
|
||||
let (ids, tgt) = batch_data(&cfg, device);
|
||||
|
||||
// Both models: same init, train mode, p=0.2. step_seed starts at 0 and bumps
|
||||
// to 1 on the first training forward in BOTH, so they draw the same masks.
|
||||
let off = build(cfg, device).with_compute_dtype(dtype).with_training(true);
|
||||
let on = build(cfg, device)
|
||||
.with_compute_dtype(dtype)
|
||||
.with_recompute(true)
|
||||
.with_training(true);
|
||||
|
||||
let off_loss = off.loss_batched(&ids, &tgt, batch);
|
||||
off_loss.backward();
|
||||
let off_grads: Vec<Vec<f32>> = off.params().iter().map(|p| host(&p.grad().unwrap())).collect();
|
||||
|
||||
let on_loss = on.loss_batched(&ids, &tgt, batch);
|
||||
on_loss.backward();
|
||||
let on_grads: Vec<Vec<f32>> = on.params().iter().map(|p| host(&p.grad().unwrap())).collect();
|
||||
|
||||
let mut max_rel = 0.0f32;
|
||||
for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) {
|
||||
max_rel = max_rel.max((a - b).abs() / a.abs().max(1e-3));
|
||||
}
|
||||
println!("[{dtype:?}] dropout p=0.2 recompute on/off grad max rel = {max_rel:.3e}");
|
||||
assert!(
|
||||
max_rel < grad_tol,
|
||||
"[{dtype:?}] recompute grads diverged with dropout: {max_rel:.3e}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dropout_recompute_matches_fp32() {
|
||||
recompute_with_dropout(DType::F32, 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dropout_recompute_matches_bf16() {
|
||||
recompute_with_dropout(DType::BF16, 5e-3);
|
||||
}
|
||||
|
||||
// --- Cross-feature gate (Phase-2 integration): flash (T14) + dropout (T18)
|
||||
// together in the SAME model still grad-checks. Build two identical models, both
|
||||
// in train mode with p=0.2 (so dropout fires), one with `--flash` on, one off.
|
||||
// The dropout site seeds are a pure function of (step_seed, layer, site) and are
|
||||
// INDEPENDENT of flash, so both models draw the SAME masks on their first training
|
||||
// forward → the only difference is the SDPA reduction order. Assert logits/loss/
|
||||
// grads match within the flash-vs-composed tolerance and are finite. This is the
|
||||
// orthogonality check for the two Phase-2 features landing together.
|
||||
#[test]
|
||||
fn flash_plus_dropout_grad_check_fp32() {
|
||||
let device = require_gpu();
|
||||
let batch = 3;
|
||||
// seq=40 > FA_TILE=32 exercises flash's online-softmax tile-rescale path.
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 16;
|
||||
cfg.n_layers = 4;
|
||||
cfg.dropout = 0.2;
|
||||
let seq = 40usize;
|
||||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect())
|
||||
.collect();
|
||||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect())
|
||||
.collect();
|
||||
let ids = batched_ids_tensor(&seqs, device);
|
||||
let tgt = batched_ids_tensor(&tgts, device);
|
||||
|
||||
// Both: same init, train mode (dropout active), same step_seed progression →
|
||||
// identical masks; one composed SDPA, one flash SDPA.
|
||||
let off = build(cfg, device).with_training(true);
|
||||
let on = build(cfg, device).with_flash(true).with_training(true);
|
||||
|
||||
let (off_logits, off_loss, off_grads) = fwd_bwd(&off, &ids, &tgt, batch);
|
||||
let (on_logits, on_loss, on_grads) = fwd_bwd(&on, &ids, &tgt, batch);
|
||||
|
||||
assert!(
|
||||
on_logits.iter().all(|v| v.is_finite()) && on_grads.iter().flatten().all(|v| v.is_finite()),
|
||||
"flash+dropout produced non-finite logits/grads"
|
||||
);
|
||||
|
||||
let logit_rel = off_logits
|
||||
.iter()
|
||||
.zip(&on_logits)
|
||||
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
|
||||
.fold(0.0f32, f32::max);
|
||||
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
|
||||
let mut grad_rel = 0.0f32;
|
||||
for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) {
|
||||
grad_rel = grad_rel.max((a - b).abs() / a.abs().max(1e-3));
|
||||
}
|
||||
println!(
|
||||
"[F32] flash+dropout vs composed+dropout: loss rel {loss_rel:.2e}, \
|
||||
logits max rel {logit_rel:.2e}, grad max rel {grad_rel:.3e}"
|
||||
);
|
||||
// Same tolerances as the flash-vs-composed gate (flash.rs run_fp32): flash
|
||||
// differs from composed only by reduction order; dropout masks are identical.
|
||||
assert!(logit_rel < 1e-3, "[F32] flash+dropout logits diverged: {logit_rel:.2e}");
|
||||
assert!(loss_rel < 1e-3, "[F32] flash+dropout loss diverged: {loss_rel:.2e}");
|
||||
assert!(grad_rel < 2e-2, "[F32] flash+dropout grads diverged: {grad_rel:.3e}");
|
||||
}
|
||||
209
crates/xtrain-model/tests/flash.rs
Normal file
209
crates/xtrain-model/tests/flash.rs
Normal file
@@ -0,0 +1,209 @@
|
||||
// T14 flash-attention correctness gate: the fused flash SDPA core must match the
|
||||
// composed T10 path (cublasSgemmStridedBatched×2 + causal-softmax kernel) in
|
||||
// forward logits, loss, AND every parameter gradient — flash is the SAME SDPA
|
||||
// math (online softmax never materializes the [bh,S,S] scores), so it differs
|
||||
// from composed only by reduction order (in-kernel fp32 FMA vs cuBLAS, and the
|
||||
// dK/dV atomicAdd order in backward). This test makes that a closed on-GPU loop:
|
||||
//
|
||||
// build two identical models (same init), one with `--flash` on, one off, run
|
||||
// the SAME batched loss + backward on both, and assert
|
||||
// 1. the forward logits match within tolerance
|
||||
// 2. the loss matches
|
||||
// 3. EVERY parameter's grad matches within tolerance
|
||||
//
|
||||
// Parameterised over fp32 AND bf16 (T12). bf16 just adds the bf16 rounding band on
|
||||
// top — flash's bf16 path upcasts Q/K/V to fp32 for the kernel exactly like the
|
||||
// composed path's fp32 softmax, so the two are still the same softmax numerics.
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
||||
use xtrain_tensor::{DType, Device};
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build(cfg: Config, device: Device, dtype: DType, flash: bool) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
let m = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
});
|
||||
m.with_compute_dtype(dtype).with_flash(flash)
|
||||
}
|
||||
|
||||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||||
t.to_dtype(DType::F32)
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec()
|
||||
}
|
||||
|
||||
// fp32: same SDPA math, differs only by reduction order → tight per-element check.
|
||||
fn run_fp32(logit_tol: f32, grad_tol: f32) {
|
||||
let (off_logits, off_loss, off_grads, on_logits, on_loss, on_grads) = run_both(DType::F32);
|
||||
|
||||
let logit_rel = off_logits
|
||||
.iter()
|
||||
.zip(&on_logits)
|
||||
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
|
||||
.fold(0.0f32, f32::max);
|
||||
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
|
||||
println!(
|
||||
"[F32] flash on/off: loss {off_loss:.6}/{on_loss:.6} (rel {loss_rel:.2e}), \
|
||||
logits max rel {logit_rel:.2e}"
|
||||
);
|
||||
assert!(
|
||||
logit_rel < logit_tol,
|
||||
"[F32] logits diverged: {logit_rel:.2e}"
|
||||
);
|
||||
assert!(loss_rel < logit_tol, "[F32] loss diverged: {loss_rel:.2e}");
|
||||
|
||||
let mut max_grad_rel = 0.0f32;
|
||||
for (off_g, on_g) in off_grads.iter().zip(&on_grads) {
|
||||
for (a, b) in off_g.iter().zip(on_g) {
|
||||
max_grad_rel = max_grad_rel.max((a - b).abs() / a.abs().max(1e-3));
|
||||
}
|
||||
}
|
||||
println!("[F32] flash on/off: grad max rel err = {max_grad_rel:.3e}");
|
||||
assert!(
|
||||
max_grad_rel < grad_tol,
|
||||
"[F32] flash grads diverged from composed: {max_grad_rel:.3e}"
|
||||
);
|
||||
}
|
||||
|
||||
// bf16: ~2-3 decimal digits → robust comparison (mean + p99 with abs().max(1.0)
|
||||
// for logits, per-tensor scale-relative mean for grads), the same convention as
|
||||
// the repo's bf16.rs gate (per-element max-rel blows up on near-zero bf16 logits).
|
||||
fn run_bf16() {
|
||||
let (off_logits, off_loss, off_grads, on_logits, on_loss, on_grads) = run_both(DType::BF16);
|
||||
|
||||
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
|
||||
println!("[BF16] flash on/off: loss {off_loss:.5}/{on_loss:.5} (rel {loss_rel:.3e})");
|
||||
assert!(loss_rel < 2e-2, "[BF16] loss diverged: {loss_rel:.3e}");
|
||||
|
||||
let n = off_logits.len();
|
||||
let mut rels: Vec<f32> = off_logits
|
||||
.iter()
|
||||
.zip(&on_logits)
|
||||
.map(|(f, b)| (b - f).abs() / f.abs().max(1.0))
|
||||
.collect();
|
||||
rels.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let p99 = rels[(n as f32 * 0.99) as usize];
|
||||
let mean: f32 = rels.iter().sum::<f32>() / n as f32;
|
||||
println!("[BF16] flash on/off logits: mean rel {mean:.3e}, p99 rel {p99:.3e}");
|
||||
assert!(mean < 1e-2, "[BF16] logits mean rel too high: {mean:.3e}");
|
||||
assert!(p99 < 5e-2, "[BF16] logits p99 rel too high: {p99:.3e}");
|
||||
|
||||
let mut worst = 0.0f32;
|
||||
for (off_g, on_g) in off_grads.iter().zip(&on_grads) {
|
||||
let scale = off_g
|
||||
.iter()
|
||||
.map(|v| v.abs())
|
||||
.fold(0.0f32, f32::max)
|
||||
.max(1e-6);
|
||||
let mean_err: f32 = off_g
|
||||
.iter()
|
||||
.zip(on_g)
|
||||
.map(|(f, b)| (f - b).abs())
|
||||
.sum::<f32>()
|
||||
/ off_g.len() as f32
|
||||
/ scale;
|
||||
worst = worst.max(mean_err);
|
||||
}
|
||||
println!("[BF16] flash on/off grads: worst per-tensor scaled-mean err = {worst:.3e}");
|
||||
assert!(worst < 3e-2, "[BF16] flash grads diverged: {worst:.3e}");
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn run_both(dtype: DType) -> (Vec<f32>, f32, Vec<Vec<f32>>, Vec<f32>, f32, Vec<Vec<f32>>) {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
// seq=40 > FA_TILE=32 so the online-softmax tile-rescale path is exercised.
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 16;
|
||||
cfg.n_layers = 4;
|
||||
let batch = 3usize;
|
||||
let seq = 40usize;
|
||||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let ids = batched_ids_tensor(&seqs, device);
|
||||
let tgt = batched_ids_tensor(&tgts, device);
|
||||
|
||||
// --- flash OFF (composed reference) ---
|
||||
let off = build(cfg, device, dtype, false);
|
||||
let off_logits = host(&off.forward_batched(&ids, batch).value());
|
||||
let off_loss = off.loss_batched(&ids, &tgt, batch);
|
||||
let off_loss_val = host(&off_loss.value())[0];
|
||||
off_loss.backward();
|
||||
let off_grads: Vec<Vec<f32>> = off
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().expect("off grad")))
|
||||
.collect();
|
||||
|
||||
// --- flash ON ---
|
||||
let on = build(cfg, device, dtype, true);
|
||||
let on_logits = host(&on.forward_batched(&ids, batch).value());
|
||||
let on_loss = on.loss_batched(&ids, &tgt, batch);
|
||||
let on_loss_val = host(&on_loss.value())[0];
|
||||
on_loss.backward();
|
||||
let on_grads: Vec<Vec<f32>> = on
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().expect("on grad")))
|
||||
.collect();
|
||||
|
||||
(
|
||||
off_logits,
|
||||
off_loss_val,
|
||||
off_grads,
|
||||
on_logits,
|
||||
on_loss_val,
|
||||
on_grads,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_matches_composed_fp32() {
|
||||
// fp32: same SDPA math, differs only by reduction order (in-kernel fp32 FMA vs
|
||||
// cuBLAS, dK/dV atomicAdd order). Tight per-element check, not bit-exact.
|
||||
run_fp32(1e-3, 2e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_matches_composed_bf16() {
|
||||
// bf16 (T12 composition): bf16 rounding band on the fp32-softmax core; robust
|
||||
// (mean/p99/scaled-mean) comparison per the repo's bf16 convention.
|
||||
run_bf16();
|
||||
}
|
||||
@@ -67,7 +67,7 @@ fn dump_for_parity() {
|
||||
|
||||
// Same deterministic init as the overfit test.
|
||||
let mut seed = 1u64;
|
||||
let model = TinyTransformer::new(cfg, device, |shape| {
|
||||
let mut model = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
@@ -76,6 +76,14 @@ fn dump_for_parity() {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
});
|
||||
// T14: with XTRAIN_PARITY_FLASH set, dump from the fused flash-attention path.
|
||||
// flash is the SAME SDPA math, so the SAME parity.py PyTorch oracle is the
|
||||
// reference for both paths — running this once per path checks flash against
|
||||
// PyTorch at B>1 (forward logits + every parameter grad).
|
||||
if std::env::var("XTRAIN_PARITY_FLASH").is_ok() {
|
||||
model = model.with_flash(true);
|
||||
println!("parity: FLASH attention path");
|
||||
}
|
||||
|
||||
// config + ids
|
||||
{
|
||||
|
||||
@@ -668,6 +668,92 @@ impl Tensor {
|
||||
dx
|
||||
}
|
||||
|
||||
/// Dropout forward (Phase T18). Returns `(out, mask)` where, for each element
|
||||
/// `i`, a counter-based RNG draws `u = hash(seed, i) ∈ [0,1)` and keeps the
|
||||
/// element iff `u >= p`; kept elements are scaled by `1/(1-p)` (inverted
|
||||
/// dropout, so `E[out] == x`). `mask[i]` stores that per-element factor
|
||||
/// (`1/(1-p)` if kept, else `0`) for the backward to reuse — the same mask, so
|
||||
/// the op is a fixed elementwise scale w.r.t. `x` (and finite-diff-checkable).
|
||||
///
|
||||
/// The mask depends only on `(seed, i)`, NOT on `self`'s values, so a re-run
|
||||
/// with the same `seed` reproduces the same mask (T13 recompute stays exact).
|
||||
/// `mask` is always fp32 (the uniform is computed in fp32, dtype-independent);
|
||||
/// `out` matches `self`'s dtype. Requires `0 <= p < 1`.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn dropout(&self, p: f32, seed: u64) -> (Self, Self) {
|
||||
assert!(
|
||||
matches!(self.dtype, DType::F32 | DType::BF16),
|
||||
"dropout supports F32/BF16"
|
||||
);
|
||||
assert!((0.0..1.0).contains(&p), "dropout p must be in [0,1)");
|
||||
assert!(self.is_contiguous(), "dropout requires contiguous tensor");
|
||||
let scale = 1.0 / (1.0 - p);
|
||||
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
|
||||
let mask = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
let n = self.numel() as i32;
|
||||
match self.dtype {
|
||||
DType::F32 => unsafe {
|
||||
xtrain_cuda::ffi::launch_dropout_fwd_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
mask.data_ptr() as *mut f32,
|
||||
p,
|
||||
scale,
|
||||
seed,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
DType::BF16 => unsafe {
|
||||
xtrain_cuda::ffi::launch_dropout_fwd_bf16(
|
||||
self.data_ptr() as *const std::ffi::c_void,
|
||||
out.data_ptr() as *mut std::ffi::c_void,
|
||||
mask.data_ptr() as *mut f32,
|
||||
p,
|
||||
scale,
|
||||
seed,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
(out, mask)
|
||||
}
|
||||
|
||||
/// Dropout backward: `dx = d ⊙ mask` (the SAME `mask` the forward cached).
|
||||
/// `d` is the upstream grad (activation dtype); `mask` is the fp32 factor
|
||||
/// tensor from [`Self::dropout`]. Output matches `d`'s dtype.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn dropout_backward(d: &Tensor, mask: &Tensor) -> Self {
|
||||
assert_eq!(d.numel(), mask.numel(), "dropout_backward shape mismatch");
|
||||
assert_eq!(mask.dtype, DType::F32, "dropout mask must be F32");
|
||||
let dx = Tensor::zeros(&d.shape, d.dtype, d.device());
|
||||
let n = d.numel() as i32;
|
||||
match d.dtype {
|
||||
DType::F32 => unsafe {
|
||||
xtrain_cuda::ffi::launch_dropout_bwd_f32(
|
||||
d.data_ptr() as *const f32,
|
||||
mask.data_ptr() as *const f32,
|
||||
dx.data_ptr() as *mut f32,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
DType::BF16 => unsafe {
|
||||
xtrain_cuda::ffi::launch_dropout_bwd_bf16(
|
||||
d.data_ptr() as *const std::ffi::c_void,
|
||||
mask.data_ptr() as *const f32,
|
||||
dx.data_ptr() as *mut std::ffi::c_void,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
_ => panic!("dropout_backward supports F32/BF16"),
|
||||
}
|
||||
dx
|
||||
}
|
||||
|
||||
/// RoPE forward (rotate_half). `self`:[tokens,heads,head_dim]; each token's
|
||||
/// position is `row % period`. `period` = sequence length, so a flattened
|
||||
/// batch `[B*S,heads,head_dim]` gets per-sequence positions (pass `period=S`);
|
||||
@@ -1092,6 +1178,119 @@ impl Tensor {
|
||||
(dq, dk, dv)
|
||||
}
|
||||
|
||||
// --- Fused flash-attention (the T14 op) ---
|
||||
|
||||
/// Fused flash-attention forward (Phase T14). `self`=Q, `k`, `v` each
|
||||
/// `[bh, seq, head_dim]`, contiguous on one GPU. Computes, per batch element,
|
||||
/// `out = softmax(causal(Q·Kᵀ·scale))·V` in a SINGLE kernel that streams over
|
||||
/// KV tiles with an online softmax — the `[bh,seq,seq]` score matrix is NEVER
|
||||
/// materialized. Returns `(out, lse)` where `lse`:[bh,seq] (F32) is the per-row
|
||||
/// logsumexp cached for backward (O(N), vs the composed path's O(N²) probs).
|
||||
///
|
||||
/// The fused kernel is fp32; for bf16 we upcast Q/K/V → f32 → kernel → downcast
|
||||
/// `out` back to bf16 (same fp32-softmax policy as the composed [`attention`]),
|
||||
/// so flash and composed produce the same softmax numerics. `lse` stays fp32.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn flash_attention(&self, k: &Tensor, v: &Tensor, scale: f32) -> (Tensor, Tensor) {
|
||||
assert_eq!(
|
||||
self.ndim(),
|
||||
3,
|
||||
"flash_attention Q must be [bh,seq,head_dim]"
|
||||
);
|
||||
assert_eq!(self.shape(), k.shape(), "Q/K shape mismatch");
|
||||
assert_eq!(self.shape(), v.shape(), "Q/V shape mismatch");
|
||||
assert_eq!(self.dtype, k.dtype, "Q/K dtype mismatch");
|
||||
assert_eq!(self.dtype, v.dtype, "Q/V dtype mismatch");
|
||||
let (bh, seq, hd) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||
let dev = self.device();
|
||||
let dt = self.dtype;
|
||||
|
||||
let qf = self.to_dtype(DType::F32);
|
||||
let kf = k.to_dtype(DType::F32);
|
||||
let vf = v.to_dtype(DType::F32);
|
||||
let out_f32 = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
let lse = Tensor::zeros(&[bh, seq], DType::F32, dev);
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_flash_attention_fwd_f32(
|
||||
qf.data_ptr() as *const f32,
|
||||
kf.data_ptr() as *const f32,
|
||||
vf.data_ptr() as *const f32,
|
||||
out_f32.data_ptr() as *mut f32,
|
||||
lse.data_ptr() as *mut f32,
|
||||
bh as i32,
|
||||
seq as i32,
|
||||
hd as i32,
|
||||
scale,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
(out_f32.to_dtype(dt), lse)
|
||||
}
|
||||
|
||||
/// Backward of [`flash_attention`](Self::flash_attention). Inputs: forward
|
||||
/// `q`,`k`,`v`, the forward output `out`, the cached `lse`:[bh,seq], the upstream
|
||||
/// `dout`, and the same `scale`. Returns `(dq, dk, dv)`.
|
||||
///
|
||||
/// flash-style: NO cached probs. Recomputes scores from Q/K/V + `lse`, uses
|
||||
/// `D[i]=Σ dOᵢ·Oᵢ` to collapse the softmax Jacobian, streams KV in tiles. dQ is
|
||||
/// owned per query row; dK/dV are accumulated across rows (atomicAdd). Same
|
||||
/// fp32 kernel; bf16 callers get fp32 grads which the autograd `cast` op casts.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn flash_attention_backward(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
out: &Tensor,
|
||||
lse: &Tensor,
|
||||
dout: &Tensor,
|
||||
scale: f32,
|
||||
) -> (Tensor, Tensor, Tensor) {
|
||||
let (bh, seq, hd) = (q.shape[0], q.shape[1], q.shape[2]);
|
||||
let dev = q.device();
|
||||
let dt = q.dtype;
|
||||
|
||||
let qf = q.to_dtype(DType::F32);
|
||||
let kf = k.to_dtype(DType::F32);
|
||||
let vf = v.to_dtype(DType::F32);
|
||||
let of = out.to_dtype(DType::F32);
|
||||
let dof = dout.to_dtype(DType::F32);
|
||||
// D[i] = Σ_d dO[i,d]·O[i,d] (one scalar per query row, O(N)).
|
||||
let d = Tensor::zeros(&[bh, seq], DType::F32, dev);
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_flash_attention_rowdot_f32(
|
||||
dof.data_ptr() as *const f32,
|
||||
of.data_ptr() as *const f32,
|
||||
d.data_ptr() as *mut f32,
|
||||
(bh * seq) as i32,
|
||||
hd as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
// dq/dk/dv pre-zeroed (Tensor::zeros memsets); dk/dv accumulate via atomicAdd.
|
||||
let dq = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
let dk = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
let dv = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_flash_attention_bwd_f32(
|
||||
qf.data_ptr() as *const f32,
|
||||
kf.data_ptr() as *const f32,
|
||||
vf.data_ptr() as *const f32,
|
||||
dof.data_ptr() as *const f32,
|
||||
lse.data_ptr() as *const f32,
|
||||
d.data_ptr() as *mut f32,
|
||||
dq.data_ptr() as *mut f32,
|
||||
dk.data_ptr() as *mut f32,
|
||||
dv.data_ptr() as *mut f32,
|
||||
bh as i32,
|
||||
seq as i32,
|
||||
hd as i32,
|
||||
scale,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
(dq.to_dtype(dt), dk.to_dtype(dt), dv.to_dtype(dt))
|
||||
}
|
||||
|
||||
/// 4D axis-(1,2) transpose: `self`:[a,b,c,d] → [a,c,b,d],
|
||||
/// `out[i,k,j,l]=self[i,j,k,l]`. Lays out batched multi-head attention
|
||||
/// (`[B,S,nh,hd] <-> [B,nh,S,hd]`). Its own backward is the same op (swap b,c).
|
||||
|
||||
@@ -101,6 +101,10 @@ fn main() {
|
||||
// Optimization knobs.
|
||||
let steps: usize = flag(&args, "--steps", 2000);
|
||||
let batch_size: usize = flag(&args, "--batch", 8);
|
||||
// Micro-batch gradient accumulation (Phase T16): effective batch =
|
||||
// accum_steps × batch, at one micro-batch's activation-memory cost. Default 1
|
||||
// = no accumulation (bit-identical to the pre-T16 path).
|
||||
let accum_steps: usize = flag(&args, "--accum-steps", 1).max(1);
|
||||
let seq_len: usize = flag(&args, "--seq", 64);
|
||||
let max_lr: f32 = flag(&args, "--max-lr", 3e-3);
|
||||
let min_lr: f32 = flag(&args, "--min-lr", max_lr * 0.1);
|
||||
@@ -109,6 +113,10 @@ fn main() {
|
||||
let val_tokens: usize = flag(&args, "--val-tokens", 0);
|
||||
let eval_every: usize = flag(&args, "--eval-every", 0);
|
||||
let eval_batches: usize = flag(&args, "--eval-batches", 64);
|
||||
// Dropout (Phase T18): residual-path dropout prob, active at training time
|
||||
// only (inverted scaling), identity at eval/sampling/export. Default 0 = off
|
||||
// (forward graph bit-identical to the no-dropout path).
|
||||
let dropout: f32 = flag(&args, "--dropout", 0.0f32);
|
||||
// bf16 mixed precision (Phase T12): fp32 master weights, bf16 linears +
|
||||
// activations. Opt-in; default fp32 reproduces v0–v4 numerics.
|
||||
let bf16 = args.iter().any(|a| a == "--bf16");
|
||||
@@ -116,6 +124,9 @@ fn main() {
|
||||
// exact grads, lower peak activation memory (lets dim1024 batch32 fit). Opt-in;
|
||||
// default off stores every activation (unchanged numerics).
|
||||
let recompute = args.iter().any(|a| a == "--recompute");
|
||||
// Fused flash-attention (Phase T14): single fused SDPA kernel, online softmax,
|
||||
// no materialized [bh,S,S] scores. Opt-in; default off keeps the composed path.
|
||||
let flash = args.iter().any(|a| a == "--flash");
|
||||
let ckpt: PathBuf = PathBuf::from(
|
||||
args.iter()
|
||||
.position(|a| a == "--ckpt")
|
||||
@@ -149,7 +160,8 @@ fn main() {
|
||||
(corpus, None)
|
||||
};
|
||||
|
||||
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
|
||||
let mut cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
|
||||
cfg.dropout = dropout;
|
||||
println!(
|
||||
"model: dim {} layers {} heads {} head_dim {} ffn {} → core {:.3}M params \
|
||||
(+ embed/lm {:.2}M = {:.2}M total)",
|
||||
@@ -183,6 +195,13 @@ fn main() {
|
||||
model = model.with_recompute(true);
|
||||
println!("activation recompute: ON (per-block gradient checkpointing)");
|
||||
}
|
||||
if flash {
|
||||
model = model.with_flash(true);
|
||||
println!("flash-attention: ON (fused SDPA kernel, no materialized scores)");
|
||||
}
|
||||
if dropout > 0.0 {
|
||||
println!("dropout: ON (p={dropout}, residual-path, train-only inverted scaling)");
|
||||
}
|
||||
|
||||
// Eval-only mode: load a checkpoint and score it on the held-out val set, then
|
||||
// exit. Used to put an EXISTING model (e.g. v0) and a new one on the same
|
||||
@@ -201,6 +220,7 @@ fn main() {
|
||||
let tcfg = TrainConfig {
|
||||
seq_len,
|
||||
batch_size,
|
||||
accum_steps,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr,
|
||||
@@ -219,10 +239,13 @@ fn main() {
|
||||
};
|
||||
|
||||
println!(
|
||||
"training: {} steps, seq {}, batch {}, lr {:.1e}→{:.1e}, eval every {}",
|
||||
"training: {} steps, seq {}, batch {} × accum {} = effective batch {}, \
|
||||
lr {:.1e}→{:.1e}, eval every {}",
|
||||
tcfg.steps,
|
||||
tcfg.seq_len,
|
||||
tcfg.batch_size,
|
||||
tcfg.accum_steps,
|
||||
tcfg.batch_size * tcfg.accum_steps,
|
||||
tcfg.schedule.max_lr,
|
||||
tcfg.schedule.min_lr,
|
||||
tcfg.eval_every
|
||||
|
||||
@@ -27,6 +27,12 @@ use crate::schedule::LrSchedule;
|
||||
pub struct TrainConfig {
|
||||
pub seq_len: usize,
|
||||
pub batch_size: usize,
|
||||
/// Micro-batch gradient accumulation (Phase T16): each optimizer step
|
||||
/// accumulates grads over `accum_steps` micro-batches of `batch_size`
|
||||
/// sequences, giving an EFFECTIVE batch of `accum_steps × batch_size` at the
|
||||
/// activation-memory cost of a single micro-batch. `1` = no accumulation
|
||||
/// (bit-identical to the pre-T16 path).
|
||||
pub accum_steps: usize,
|
||||
pub steps: usize,
|
||||
pub schedule: LrSchedule,
|
||||
pub weight_decay: f32,
|
||||
@@ -74,28 +80,47 @@ pub fn train(
|
||||
// Best-val checkpointing only kicks in when we actually evaluate.
|
||||
let track_best = valid.is_some() && cfg.eval_every > 0;
|
||||
|
||||
let accum = cfg.accum_steps.max(1);
|
||||
for step in 0..cfg.steps {
|
||||
let lr = cfg.schedule.lr(step);
|
||||
|
||||
// Sample `batch_size` sequences and run them as ONE batched forward/
|
||||
// backward. The CE mean over all batch*seq rows is the batch-mean loss, so
|
||||
// backward already yields the batch-mean gradient (clip pre-scale = 1.0).
|
||||
let mut inputs = Vec::with_capacity(cfg.batch_size);
|
||||
let mut targets_v = Vec::with_capacity(cfg.batch_size);
|
||||
for _ in 0..cfg.batch_size {
|
||||
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
||||
inputs.push(input);
|
||||
targets_v.push(target);
|
||||
// Accumulate grads over `accum` micro-batches of `batch_size` sequences,
|
||||
// then take ONE optimizer step (Phase T16). Each micro-batch is ONE batched
|
||||
// forward/backward; its loss is the CE mean over batch*seq rows, so backward
|
||||
// yields that micro-batch's mean grad. To make the SUM over `accum` micro-
|
||||
// batches equal a single step over an `accum × batch` batch, each micro-loss
|
||||
// is scaled by 1/accum before backward (the tape SUM-accumulates the scaled
|
||||
// grads). `accum == 1` skips the scale entirely → bit-identical to pre-T16.
|
||||
let mut step_loss_sum = 0.0f32;
|
||||
// Training mode → dropout active (T18; no-op when cfg.dropout == 0). Set
|
||||
// each step so it is restored after a periodic eval flips to eval mode.
|
||||
// Each micro-step's forward bumps the per-step seed → fresh masks.
|
||||
model.train();
|
||||
for _ in 0..accum {
|
||||
let mut inputs = Vec::with_capacity(cfg.batch_size);
|
||||
let mut targets_v = Vec::with_capacity(cfg.batch_size);
|
||||
for _ in 0..cfg.batch_size {
|
||||
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
||||
inputs.push(input);
|
||||
targets_v.push(target);
|
||||
}
|
||||
let ids = batched_ids_tensor(&inputs, device);
|
||||
let targets = batched_ids_tensor(&targets_v, device);
|
||||
let loss = model.loss_batched(&ids, &targets, cfg.batch_size);
|
||||
step_loss_sum += read_scalar(&loss);
|
||||
if accum == 1 {
|
||||
loss.backward();
|
||||
} else {
|
||||
xtrain_autodiff::ops::scale(&loss, 1.0 / accum as f32).backward();
|
||||
}
|
||||
tokens_seen += (cfg.batch_size * cfg.seq_len) as u64;
|
||||
}
|
||||
let ids = batched_ids_tensor(&inputs, device);
|
||||
let targets = batched_ids_tensor(&targets_v, device);
|
||||
let loss = model.loss_batched(&ids, &targets, cfg.batch_size);
|
||||
let step_loss = read_scalar(&loss);
|
||||
loss.backward();
|
||||
tokens_seen += (cfg.batch_size * cfg.seq_len) as u64;
|
||||
// Reported loss = mean over the effective batch = mean of the raw micro
|
||||
// losses (each is itself a micro-batch mean of equal size).
|
||||
let step_loss = step_loss_sum / accum as f32;
|
||||
losses.push(step_loss);
|
||||
|
||||
// Backward already produced the batch-mean gradient — just clip it.
|
||||
// Backward already produced the effective-batch mean gradient — just clip.
|
||||
let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, 1.0);
|
||||
opt.step(lr, ¶ms);
|
||||
for p in ¶ms {
|
||||
@@ -169,6 +194,8 @@ pub fn eval_loss(
|
||||
if valid.len() <= seq + 1 {
|
||||
return f32::NAN;
|
||||
}
|
||||
// Eval mode → dropout is identity (T18).
|
||||
model.eval();
|
||||
let n_win = (valid.len() - 1) / seq; // disjoint windows that fit
|
||||
let batches = batches.max(1).min(n_win.max(1));
|
||||
let stride = (n_win / batches).max(1);
|
||||
|
||||
294
crates/xtrain-train/tests/grad_accum.rs
Normal file
294
crates/xtrain-train/tests/grad_accum.rs
Normal file
@@ -0,0 +1,294 @@
|
||||
// T16 gradient-accumulation correctness gates.
|
||||
//
|
||||
// Gradient accumulation is mathematically EXACT: accumulating the grads of N
|
||||
// micro-batches of B sequences (each micro-loss scaled by 1/N before backward,
|
||||
// the tape SUM-accumulating) equals a single step over one N·B-sequence batch.
|
||||
// This file makes that a closed loop on-GPU, plus the accum_steps=1 bit-identity
|
||||
// regression guard.
|
||||
//
|
||||
// 1. accum_equiv_big_batch: same init, same N·B sequences in the same order.
|
||||
// Path A = ONE batched loss over all N·B (the big-batch baseline). Path B =
|
||||
// N micro-backwards of B each, scale(1/N), tape SUM. Assert loss and EVERY
|
||||
// parameter grad match within fp tolerance (only the summation order differs,
|
||||
// like the T8 DDP-vs-single-GPU and T13 recompute gates).
|
||||
// 2. accum1_bit_identical: accum_steps=1 must reproduce the no-accum path
|
||||
// bit-for-bit (the implementation skips the ×1/1 scale entirely) — every
|
||||
// parameter grad max|Δ| == 0.0.
|
||||
// 3. accum_train_converges: drive the real `train()` loop with accum and assert
|
||||
// the per-step effective-batch loss trace tracks a big-batch baseline (errors
|
||||
// stay bounded over many AdamW steps, not just one).
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_autodiff::ops;
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
||||
use xtrain_tensor::Device;
|
||||
use xtrain_train::data::Corpus;
|
||||
use xtrain_train::schedule::LrSchedule;
|
||||
use xtrain_train::{TrainConfig, train};
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build(cfg: Config, device: Device) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||||
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||||
}
|
||||
|
||||
// `n` deterministic (seq, target) pairs for the equivalence tests.
|
||||
fn make_seqs(n: usize, seq: usize, vocab: usize) -> (Vec<Vec<i32>>, Vec<Vec<i32>>) {
|
||||
let seqs = (0..n)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 7 + i * 3 + 1) % vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let tgts = (0..n)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 5 + i * 2 + 2) % vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
(seqs, tgts)
|
||||
}
|
||||
|
||||
// Run one big-batch forward/backward over all `seqs` and return the grads.
|
||||
fn big_batch_grads(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
seqs: &[Vec<i32>],
|
||||
tgts: &[Vec<i32>],
|
||||
) -> (f32, Vec<Vec<f32>>) {
|
||||
let n = seqs.len();
|
||||
let ids = batched_ids_tensor(seqs, device);
|
||||
let tgt = batched_ids_tensor(tgts, device);
|
||||
let loss = model.loss_batched(&ids, &tgt, n);
|
||||
let loss_val = host(&loss.value())[0];
|
||||
loss.backward();
|
||||
let grads = model
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().expect("grad")))
|
||||
.collect();
|
||||
(loss_val, grads)
|
||||
}
|
||||
|
||||
// Accumulate over `accum` micro-batches of `b` sequences (drawn in order from the
|
||||
// flat `seqs`/`tgts`), scaling each micro-loss by 1/accum before backward; the
|
||||
// tape SUM-accumulates. Returns the mean of the raw micro losses + accumulated grads.
|
||||
fn accum_grads(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
seqs: &[Vec<i32>],
|
||||
tgts: &[Vec<i32>],
|
||||
accum: usize,
|
||||
b: usize,
|
||||
scale: bool,
|
||||
) -> (f32, Vec<Vec<f32>>) {
|
||||
let mut loss_sum = 0.0f32;
|
||||
for m in 0..accum {
|
||||
let s = &seqs[m * b..(m + 1) * b];
|
||||
let t = &tgts[m * b..(m + 1) * b];
|
||||
let ids = batched_ids_tensor(s, device);
|
||||
let tgt = batched_ids_tensor(t, device);
|
||||
let loss = model.loss_batched(&ids, &tgt, b);
|
||||
loss_sum += host(&loss.value())[0];
|
||||
if scale {
|
||||
ops::scale(&loss, 1.0 / accum as f32).backward();
|
||||
} else {
|
||||
loss.backward(); // accum==1 bit-identity path
|
||||
}
|
||||
}
|
||||
let grads = model
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().expect("grad")))
|
||||
.collect();
|
||||
(loss_sum / accum as f32, grads)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accum_equiv_big_batch() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 16;
|
||||
cfg.n_layers = 3;
|
||||
let b = 2usize; // micro-batch
|
||||
let accum = 4usize; // → effective batch 8
|
||||
let seq = 6usize;
|
||||
let (seqs, tgts) = make_seqs(b * accum, seq, cfg.vocab);
|
||||
|
||||
// Big-batch baseline (accum_steps=1, batch = b·accum).
|
||||
let big = build(cfg, device);
|
||||
let (big_loss, big_grads) = big_batch_grads(&big, device, &seqs, &tgts);
|
||||
|
||||
// Accumulated (accum micro-batches of b, scale 1/accum).
|
||||
let acc = build(cfg, device);
|
||||
let (acc_loss, acc_grads) = accum_grads(&acc, device, &seqs, &tgts, accum, b, true);
|
||||
|
||||
let loss_rel = (big_loss - acc_loss).abs() / big_loss.abs().max(1e-4);
|
||||
let mut max_grad_rel = 0.0f32;
|
||||
for (bg, ag) in big_grads.iter().zip(&acc_grads) {
|
||||
for (x, y) in bg.iter().zip(ag) {
|
||||
max_grad_rel = max_grad_rel.max((x - y).abs() / x.abs().max(1e-3));
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"accum=={accum}×b{b} vs big-batch{}: loss {big_loss:.6}/{acc_loss:.6} (rel {loss_rel:.2e}), \
|
||||
grad max rel {max_grad_rel:.3e}",
|
||||
b * accum
|
||||
);
|
||||
// fp summation order differs (big batch sums b·accum rows once; accum sums per
|
||||
// micro then across micros) → tight fp tol, same convention as T13 recompute.
|
||||
assert!(loss_rel < 1e-5, "loss diverged: {loss_rel:.2e}");
|
||||
assert!(
|
||||
max_grad_rel < 1e-4,
|
||||
"accum grads diverged from big batch: {max_grad_rel:.3e}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accum1_bit_identical() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 16;
|
||||
cfg.n_layers = 3;
|
||||
let b = 4usize;
|
||||
let seq = 6usize;
|
||||
let (seqs, tgts) = make_seqs(b, seq, cfg.vocab);
|
||||
|
||||
// No-accum reference: one batched loss + backward (the pre-T16 path).
|
||||
let reference = build(cfg, device);
|
||||
let (_, ref_grads) = big_batch_grads(&reference, device, &seqs, &tgts);
|
||||
|
||||
// accum_steps=1 path: the loop runs ONE micro-batch and (by design) skips the
|
||||
// ×1/1 scale → must be byte-for-byte identical to the reference backward.
|
||||
let accum1 = build(cfg, device);
|
||||
let (_, a1_grads) = accum_grads(&accum1, device, &seqs, &tgts, 1, b, false);
|
||||
|
||||
let mut max_abs = 0.0f32;
|
||||
for (r, a) in ref_grads.iter().zip(&a1_grads) {
|
||||
for (x, y) in r.iter().zip(a) {
|
||||
max_abs = max_abs.max((x - y).abs());
|
||||
}
|
||||
}
|
||||
println!("accum_steps=1 vs no-accum: grad max |Δ| = {max_abs:.3e}");
|
||||
assert_eq!(
|
||||
max_abs, 0.0,
|
||||
"accum_steps=1 not bit-identical to no-accum: {max_abs:.3e}"
|
||||
);
|
||||
}
|
||||
|
||||
// A self-contained synthetic corpus (no tokenizer / data file needed).
|
||||
fn synth_corpus(vocab: usize, n_tokens: usize) -> Corpus {
|
||||
Corpus {
|
||||
tokens: (0..n_tokens)
|
||||
.map(|i| (i * 7 + 3) as i32 % vocab as i32)
|
||||
.collect(),
|
||||
vocab_size: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accum_train_converges() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let vocab = 64usize;
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = vocab;
|
||||
cfg.n_layers = 2;
|
||||
let corpus = synth_corpus(vocab, 4096);
|
||||
let steps = 20usize;
|
||||
let seq = 32usize;
|
||||
|
||||
// Same per-step RNG stream + effective batch 8 either way: the big-batch run
|
||||
// (accum=1, batch=8) and the accumulated run (accum=4, batch=2) draw the SAME
|
||||
// 8 sequences per step in the same order, so the per-step loss/grads — and thus
|
||||
// the whole AdamW trajectory — track within fp tolerance.
|
||||
let sched = LrSchedule {
|
||||
max_lr: 3e-3,
|
||||
min_lr: 3e-4,
|
||||
warmup: 3,
|
||||
total: steps,
|
||||
};
|
||||
let base = |batch, accum| TrainConfig {
|
||||
seq_len: seq,
|
||||
batch_size: batch,
|
||||
accum_steps: accum,
|
||||
steps,
|
||||
schedule: sched.clone(),
|
||||
weight_decay: 0.1,
|
||||
max_grad_norm: 1.0,
|
||||
log_every: 1_000_000,
|
||||
ckpt_path: None,
|
||||
ckpt_every: 0,
|
||||
eval_every: 0,
|
||||
eval_batches: 0,
|
||||
seed: 7,
|
||||
};
|
||||
|
||||
let big_model = build(cfg, device);
|
||||
let big = train(&big_model, device, &corpus, None, &base(8, 1)).train_losses;
|
||||
|
||||
let acc_model = build(cfg, device);
|
||||
let acc = train(&acc_model, device, &corpus, None, &base(2, 4)).train_losses;
|
||||
|
||||
let mut max_rel = 0.0f32;
|
||||
for (x, y) in big.iter().zip(&acc) {
|
||||
max_rel = max_rel.max((x - y).abs() / x.abs().max(1e-6));
|
||||
}
|
||||
// Final params should also stay close (errors don't blow up over the run).
|
||||
let mut max_pdiff = 0.0f32;
|
||||
for (p, q) in big_model.params().iter().zip(&acc_model.params()) {
|
||||
for (x, y) in host(&p.value()).iter().zip(host(&q.value())) {
|
||||
max_pdiff = max_pdiff.max((x - y).abs() / x.abs().max(1e-6));
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"accum(4×2) vs big(8) over {steps} steps: loss[last] {:.6}/{:.6} max_rel {max_rel:.2e}, \
|
||||
final param max rel {max_pdiff:.2e}",
|
||||
big.last().unwrap(),
|
||||
acc.last().unwrap()
|
||||
);
|
||||
assert!(
|
||||
max_rel < 1e-3,
|
||||
"accum loss trajectory diverged: {max_rel:.3e}"
|
||||
);
|
||||
assert!(
|
||||
max_pdiff < 1e-2,
|
||||
"accum final params diverged: {max_pdiff:.3e}"
|
||||
);
|
||||
}
|
||||
@@ -84,6 +84,7 @@ fn trains_on_tinystories() {
|
||||
let tcfg = TrainConfig {
|
||||
seq_len: 64,
|
||||
batch_size: 8,
|
||||
accum_steps: 1,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr: 3e-3,
|
||||
|
||||
109
csrc/ops/dropout.cu
Normal file
109
csrc/ops/dropout.cu
Normal file
@@ -0,0 +1,109 @@
|
||||
// Dropout kernels (Phase T18).
|
||||
//
|
||||
// A counter-based (stateless) RNG: the keep/drop decision for element `i` is a
|
||||
// pure function of (seed, i) — no global RNG state is advanced. This is what
|
||||
// makes dropout compatible with activation recomputation (T13): when a
|
||||
// checkpointed block re-runs its forward in backward, the SAME seed regenerates
|
||||
// the SAME mask, so the recomputed activations / grads stay bit-identical to the
|
||||
// forward (no mask drift).
|
||||
//
|
||||
// Inverted dropout: at training time kept elements are scaled by 1/(1-p) so the
|
||||
// expectation E[out] == x (no inference-time rescale needed; eval is identity,
|
||||
// handled in Rust by simply not calling dropout).
|
||||
//
|
||||
// key = seed ^ (i * GOLDEN)
|
||||
// h = splitmix64(key) // a few rounds of xorshift/multiply
|
||||
// u = (h >> 40) / 2^24 in [0,1) // 24-bit uniform
|
||||
// keep = u >= p // Bernoulli(keep = 1-p)
|
||||
// out = keep ? x * scale : 0 // scale = 1/(1-p)
|
||||
// mask = keep ? scale : 0 // cached for backward (dx = d * mask)
|
||||
//
|
||||
// fp32 + bf16 variants: bf16 loads/stores half-size activations but the uniform
|
||||
// `u` is always computed in fp32, so the mask distribution is identical across
|
||||
// dtypes (drop decisions don't depend on bf16 rounding). The mask buffer is fp32
|
||||
// in both cases (it stores `scale` or 0 — exactly representable, tiny relative to
|
||||
// the activation, reused only elementwise in backward).
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <stdint.h>
|
||||
|
||||
extern "C" {
|
||||
|
||||
// splitmix64: cheap, well-mixed counter hash. Maps a 64-bit counter to a 64-bit
|
||||
// pseudo-random output; we only need the high bits for a uniform.
|
||||
__device__ __forceinline__ uint64_t splitmix64(uint64_t x) {
|
||||
x += 0x9E3779B97F4A7C15ULL;
|
||||
x = (x ^ (x >> 30)) * 0xBF58476D1CE4E5B9ULL;
|
||||
x = (x ^ (x >> 27)) * 0x94D049BB133111EBULL;
|
||||
return x ^ (x >> 31);
|
||||
}
|
||||
|
||||
// Uniform [0,1) for element i under `seed`, computed in fp32 (dtype-independent).
|
||||
__device__ __forceinline__ float dropout_uniform(uint64_t seed, int i) {
|
||||
uint64_t key = seed ^ ((uint64_t)i * 0x9E3779B97F4A7C15ULL);
|
||||
uint64_t h = splitmix64(key);
|
||||
// Top 24 bits → [0,1) with 2^-24 resolution.
|
||||
return (float)(h >> 40) * (1.0f / 16777216.0f); // 1/2^24
|
||||
}
|
||||
|
||||
// fp32 forward: out[i] = keep ? x[i]*scale : 0 ; mask[i] = keep ? scale : 0.
|
||||
__global__ void dropout_fwd_f32_k(const float* x, float* out, float* mask,
|
||||
float p, float scale, uint64_t seed, int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) {
|
||||
float keep = (dropout_uniform(seed, i) >= p) ? scale : 0.0f;
|
||||
mask[i] = keep;
|
||||
out[i] = x[i] * keep;
|
||||
}
|
||||
}
|
||||
void launch_dropout_fwd_f32(const float* x, float* out, float* mask, float p,
|
||||
float scale, uint64_t seed, int n, void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
dropout_fwd_f32_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, out, mask, p, scale,
|
||||
seed, n);
|
||||
}
|
||||
|
||||
// Backward applies the SAME cached mask elementwise: dx[i] = d[i] * mask[i].
|
||||
__global__ void dropout_bwd_f32_k(const float* d, const float* mask, float* dx,
|
||||
int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) dx[i] = d[i] * mask[i];
|
||||
}
|
||||
void launch_dropout_bwd_f32(const float* d, const float* mask, float* dx, int n,
|
||||
void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
dropout_bwd_f32_k<<<grid, blk, 0, (cudaStream_t)s>>>(d, mask, dx, n);
|
||||
}
|
||||
|
||||
// bf16 forward: activation is bf16; mask is fp32 (stores `scale` or 0). Uniform
|
||||
// is fp32, so the mask matches the fp32 path bit-for-bit (same drop decisions).
|
||||
__global__ void dropout_fwd_bf16_k(const __nv_bfloat16* x, __nv_bfloat16* out,
|
||||
float* mask, float p, float scale,
|
||||
uint64_t seed, int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) {
|
||||
float keep = (dropout_uniform(seed, i) >= p) ? scale : 0.0f;
|
||||
mask[i] = keep;
|
||||
out[i] = __float2bfloat16(__bfloat162float(x[i]) * keep);
|
||||
}
|
||||
}
|
||||
void launch_dropout_fwd_bf16(const void* x, void* out, float* mask, float p,
|
||||
float scale, uint64_t seed, int n, void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
dropout_fwd_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, mask, p, scale, seed, n);
|
||||
}
|
||||
|
||||
__global__ void dropout_bwd_bf16_k(const __nv_bfloat16* d, const float* mask,
|
||||
__nv_bfloat16* dx, int n) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < n) dx[i] = __float2bfloat16(__bfloat162float(d[i]) * mask[i]);
|
||||
}
|
||||
void launch_dropout_bwd_bf16(const void* d, const float* mask, void* dx, int n,
|
||||
void* s) {
|
||||
int blk = 256, grid = (n + blk - 1) / blk;
|
||||
dropout_bwd_bf16_k<<<grid, blk, 0, (cudaStream_t)s>>>(
|
||||
(const __nv_bfloat16*)d, mask, (__nv_bfloat16*)dx, n);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
281
csrc/ops/flash_attention.cu
Normal file
281
csrc/ops/flash_attention.cu
Normal file
@@ -0,0 +1,281 @@
|
||||
// Hand-written fused flash-attention (Phase T14).
|
||||
//
|
||||
// The T10 composed SDPA path is 3 launches that MATERIALIZE the [bh,S,S] score
|
||||
// matrix: cublasSgemmStridedBatched (Q·Kᵀ) → causal-softmax kernel (writes the
|
||||
// whole probs) → cublasSgemmStridedBatched (P·V), and backward caches that whole
|
||||
// probs. flash-attention NEVER materializes N×N: a single fused kernel streams
|
||||
// over KV tiles with an ONLINE softmax (running max/sum + rescaled V accumulator),
|
||||
// so peak attention activation drops from O(S²) to O(S·hd) (= the output itself).
|
||||
//
|
||||
// Layout (matches the T10 op): Q/K/V/out are [bh, S, hd] row-major contiguous,
|
||||
// bh = batch·n_heads. The query's position within its sequence is the row index
|
||||
// within its [S,hd] block (so the flat row's qpos = (row % S) is automatic here —
|
||||
// we index per (bh, row)). CAUSAL: a query at position i attends to keys j ≤ i.
|
||||
// `scale` (= 1/sqrt(hd)) is folded into the logits before the max/exp.
|
||||
//
|
||||
// All F32, contiguous. (bf16 callers upcast Q/K/V → f32 on the Rust side and
|
||||
// downcast the f32 out, mirroring the composed path's fp32 softmax policy, so the
|
||||
// kernel only ever sees fp32.) Reduction helpers are inlined (self-contained file,
|
||||
// matching the csrc/ layout).
|
||||
//
|
||||
// Parallelisation: grid = bh*S, one block per query row; blockDim.x threads
|
||||
// cooperate. Forward keeps m (running max), l (running sum), acc[hd] (rescaled
|
||||
// V accumulator) in shared memory, streams KV in tiles of BK. Backward recomputes
|
||||
// scores from Q/K/V + the saved logsumexp L[bh,S] (NO cached probs), uses
|
||||
// D[i]=Σ dOᵢ·Oᵢ to collapse the softmax Jacobian, and atomicAdds dK/dV (which are
|
||||
// accumulated across query rows).
|
||||
|
||||
#include <math.h>
|
||||
|
||||
extern "C" {
|
||||
|
||||
__device__ __forceinline__ float fa_warp_sum(float v) {
|
||||
#pragma unroll
|
||||
for (int off = 16; off > 0; off >>= 1)
|
||||
v += __shfl_down_sync(0xffffffff, v, off);
|
||||
return v;
|
||||
}
|
||||
__device__ __forceinline__ float fa_warp_max(float v) {
|
||||
#pragma unroll
|
||||
for (int off = 16; off > 0; off >>= 1)
|
||||
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
|
||||
return v;
|
||||
}
|
||||
__device__ __forceinline__ float fa_block_sum(float v) {
|
||||
__shared__ float sh[32];
|
||||
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
|
||||
int nwarps = (blockDim.x + 31) >> 5;
|
||||
v = fa_warp_sum(v);
|
||||
if (lane == 0) sh[warp] = v;
|
||||
__syncthreads();
|
||||
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : 0.0f;
|
||||
if (warp == 0) v = fa_warp_sum(v);
|
||||
__shared__ float bc;
|
||||
if (threadIdx.x == 0) bc = v;
|
||||
__syncthreads();
|
||||
return bc;
|
||||
}
|
||||
__device__ __forceinline__ float fa_block_max(float v) {
|
||||
__shared__ float sh[32];
|
||||
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
|
||||
int nwarps = (blockDim.x + 31) >> 5;
|
||||
v = fa_warp_max(v);
|
||||
if (lane == 0) sh[warp] = v;
|
||||
__syncthreads();
|
||||
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : -INFINITY;
|
||||
if (warp == 0) v = fa_warp_max(v);
|
||||
__shared__ float bc;
|
||||
if (threadIdx.x == 0) bc = v;
|
||||
__syncthreads();
|
||||
return bc;
|
||||
}
|
||||
|
||||
#define FA_TILE 32 // KV tile width (columns streamed per step)
|
||||
|
||||
// One block per (bh-row, query-position). Computes out[bh, i, :] and L[bh, i] via
|
||||
// an online softmax that streams the keys in tiles of FA_TILE — the [S,S] score
|
||||
// row is never stored, only the per-tile partials flow through shared memory.
|
||||
__global__ void flash_attn_fwd_k(const float* Q, const float* K, const float* V,
|
||||
float* O, float* L, int seq, int hd, float scale) {
|
||||
int row = blockIdx.x; // global query row over bh*S
|
||||
int b = row / seq; // which (batch,head) block
|
||||
int i = row % seq; // query position within the sequence (causal limit)
|
||||
int t = threadIdx.x;
|
||||
int nthreads = blockDim.x;
|
||||
|
||||
const float* q = Q + (size_t)row * hd;
|
||||
const float* kb = K + (size_t)b * seq * hd; // this block's keys [seq,hd]
|
||||
const float* vb = V + (size_t)b * seq * hd; // this block's values[seq,hd]
|
||||
|
||||
// Q row in shared memory (reused every tile); acc accumulator over hd.
|
||||
extern __shared__ float smem[];
|
||||
float* sq = smem; // [hd]
|
||||
float* acc = smem + hd; // [hd]
|
||||
for (int d = t; d < hd; d += nthreads) {
|
||||
sq[d] = q[d];
|
||||
acc[d] = 0.0f;
|
||||
}
|
||||
__shared__ float m_run, l_run;
|
||||
if (t == 0) { m_run = -INFINITY; l_run = 0.0f; }
|
||||
__syncthreads();
|
||||
|
||||
int valid = i + 1; // causal: attend to keys [0, i]
|
||||
for (int j0 = 0; j0 < valid; j0 += FA_TILE) {
|
||||
int tile = min(FA_TILE, valid - j0);
|
||||
// Each thread computes whole logits for a strided subset of the tile's
|
||||
// columns: s = scale * (q · k_j). hd is small (≤128) so the per-thread
|
||||
// dot loop is cheap; this avoids a block-reduce per column.
|
||||
__shared__ float s_tile[FA_TILE];
|
||||
for (int c = t; c < tile; c += nthreads) {
|
||||
const float* kj = kb + (size_t)(j0 + c) * hd;
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < hd; ++d) dot += sq[d] * kj[d];
|
||||
s_tile[c] = dot * scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Tile max, then online rescale of (m, l, acc).
|
||||
float tmax = -INFINITY;
|
||||
for (int c = t; c < tile; c += nthreads) tmax = fmaxf(tmax, s_tile[c]);
|
||||
tmax = fa_block_max(tmax);
|
||||
|
||||
__shared__ float m_new, corr;
|
||||
if (t == 0) {
|
||||
float mn = fmaxf(m_run, tmax);
|
||||
corr = (m_run == -INFINITY) ? 0.0f : expf(m_run - mn); // rescale old state
|
||||
m_new = mn;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Overwrite s_tile with the softmax weights p = exp(s - m_new) ONCE per
|
||||
// column (instead of recomputing expf inside the per-dim V loop, which
|
||||
// would cost hd× the transcendentals). Sum them for l.
|
||||
float lsum = 0.0f;
|
||||
for (int c = t; c < tile; c += nthreads) {
|
||||
float p = expf(s_tile[c] - m_new);
|
||||
s_tile[c] = p;
|
||||
lsum += p;
|
||||
}
|
||||
lsum = fa_block_sum(lsum);
|
||||
|
||||
// Rescale old accumulator + add this tile's p·V (p cached in s_tile).
|
||||
// Each thread owns a strided subset of hd; loops over the tile columns.
|
||||
for (int d = t; d < hd; d += nthreads) {
|
||||
float a = acc[d] * corr;
|
||||
for (int c = 0; c < tile; ++c)
|
||||
a += s_tile[c] * vb[(size_t)(j0 + c) * hd + d];
|
||||
acc[d] = a;
|
||||
}
|
||||
if (t == 0) {
|
||||
l_run = l_run * corr + lsum;
|
||||
m_run = m_new;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// out = acc / l ; L = m + log(l) (logsumexp, saved for backward).
|
||||
float inv = 1.0f / l_run;
|
||||
for (int d = t; d < hd; d += nthreads) O[(size_t)row * hd + d] = acc[d] * inv;
|
||||
if (t == 0) L[row] = m_run + logf(l_run);
|
||||
}
|
||||
|
||||
void launch_flash_attention_fwd_f32(const float* q, const float* k, const float* v,
|
||||
float* o, float* l, int bh, int seq, int hd,
|
||||
float scale, void* s) {
|
||||
int blk = hd < 1024 ? hd : 1024;
|
||||
if (blk < 32) blk = 32;
|
||||
size_t shmem = (size_t)2 * hd * sizeof(float); // sq[hd] + acc[hd]
|
||||
flash_attn_fwd_k<<<bh * seq, blk, shmem, (cudaStream_t)s>>>(q, k, v, o, l, seq, hd, scale);
|
||||
}
|
||||
|
||||
// Per-row D[i] = Σ_d dO[i,d] · O[i,d]. One block per row (bh*S rows). Used to
|
||||
// collapse the softmax Jacobian in backward (Σ_j P_ij dP_ij = dOᵢ·Oᵢ).
|
||||
__global__ void flash_attn_rowdot_k(const float* dO, const float* O, float* D, int hd) {
|
||||
int row = blockIdx.x;
|
||||
int t = threadIdx.x;
|
||||
const float* d = dO + (size_t)row * hd;
|
||||
const float* o = O + (size_t)row * hd;
|
||||
float v = 0.0f;
|
||||
for (int c = t; c < hd; c += blockDim.x) v += d[c] * o[c];
|
||||
v = fa_block_sum(v);
|
||||
if (t == 0) D[row] = v;
|
||||
}
|
||||
|
||||
// Backward: one block per query row i. Recomputes scores from Q/K/V + the saved
|
||||
// logsumexp L (NO cached probs), streams KV in tiles. dQ accumulates locally (this
|
||||
// row owns it). dK/dV are accumulated ACROSS query rows so they atomicAdd into the
|
||||
// shared global buffers (pre-zeroed by the caller).
|
||||
// p_ij = exp(Qᵢ·Kⱼ·scale - L[i]) ; dp_ij = dOᵢ·Vⱼ ;
|
||||
// ds_ij = p_ij·(dp_ij - D[i])·scale
|
||||
// dQᵢ += Σ_j ds_ij·Kⱼ ; dKⱼ += ds_ij·Qᵢ ; dVⱼ += p_ij·dOᵢ
|
||||
__global__ void flash_attn_bwd_k(const float* Q, const float* K, const float* V,
|
||||
const float* dO, const float* L, const float* D,
|
||||
float* dQ, float* dK, float* dV,
|
||||
int seq, int hd, float scale) {
|
||||
int row = blockIdx.x;
|
||||
int b = row / seq;
|
||||
int i = row % seq;
|
||||
int t = threadIdx.x;
|
||||
int nthreads = blockDim.x;
|
||||
|
||||
const float* q = Q + (size_t)row * hd;
|
||||
const float* doi = dO + (size_t)row * hd;
|
||||
const float* kb = K + (size_t)b * seq * hd;
|
||||
const float* vb = V + (size_t)b * seq * hd;
|
||||
float* dkb = dK + (size_t)b * seq * hd;
|
||||
float* dvb = dV + (size_t)b * seq * hd;
|
||||
|
||||
extern __shared__ float smem[];
|
||||
float* sq = smem; // [hd] Qᵢ
|
||||
float* sdo = smem + hd; // [hd] dOᵢ
|
||||
float* dqa = smem + 2*hd; // [hd] dQᵢ accumulator
|
||||
for (int d = t; d < hd; d += nthreads) {
|
||||
sq[d] = q[d];
|
||||
sdo[d] = doi[d];
|
||||
dqa[d] = 0.0f;
|
||||
}
|
||||
__shared__ float Li, Di;
|
||||
if (t == 0) { Li = L[row]; Di = D[row]; }
|
||||
__syncthreads();
|
||||
|
||||
int valid = i + 1;
|
||||
for (int j0 = 0; j0 < valid; j0 += FA_TILE) {
|
||||
int tile = min(FA_TILE, valid - j0);
|
||||
// Phase 1: per-column ds[c] and p[c] (the column owner does the dots).
|
||||
__shared__ float s_ds[FA_TILE];
|
||||
__shared__ float s_p[FA_TILE];
|
||||
for (int c = t; c < tile; c += nthreads) {
|
||||
const float* kj = kb + (size_t)(j0 + c) * hd;
|
||||
const float* vj = vb + (size_t)(j0 + c) * hd;
|
||||
float sdot = 0.0f, dpdot = 0.0f;
|
||||
for (int d = 0; d < hd; ++d) {
|
||||
sdot += sq[d] * kj[d];
|
||||
dpdot += sdo[d] * vj[d];
|
||||
}
|
||||
float p = expf(sdot * scale - Li);
|
||||
s_p[c] = p;
|
||||
s_ds[c] = p * (dpdot - Di) * scale;
|
||||
}
|
||||
__syncthreads();
|
||||
// Phase 2: dV_j += p·dOᵢ ; dK_j += ds·Qᵢ (accumulated across rows → atomic).
|
||||
// Spread the tile×hd atomics over ALL threads (was serial in the column
|
||||
// owner) — flatten (c,d) so every thread issues a balanced share.
|
||||
for (int idx = t; idx < tile * hd; idx += nthreads) {
|
||||
int c = idx / hd, d = idx % hd;
|
||||
size_t off = (size_t)(j0 + c) * hd + d;
|
||||
atomicAdd(&dvb[off], s_p[c] * sdo[d]);
|
||||
atomicAdd(&dkb[off], s_ds[c] * sq[d]);
|
||||
}
|
||||
// dQᵢ += Σ_c ds[c] · K_{j0+c} (this row owns dQ — no atomic).
|
||||
for (int d = t; d < hd; d += nthreads) {
|
||||
float a = 0.0f;
|
||||
for (int c = 0; c < tile; ++c)
|
||||
a += s_ds[c] * kb[(size_t)(j0 + c) * hd + d];
|
||||
dqa[d] += a;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
for (int d = t; d < hd; d += nthreads) dQ[(size_t)row * hd + d] = dqa[d];
|
||||
}
|
||||
|
||||
void launch_flash_attention_bwd_f32(const float* q, const float* k, const float* v,
|
||||
const float* d_o, const float* l, float* d_d,
|
||||
float* dq, float* dk, float* dv,
|
||||
int bh, int seq, int hd, float scale, void* s) {
|
||||
int blk = hd < 1024 ? hd : 1024;
|
||||
if (blk < 32) blk = 32;
|
||||
// d_d is the pre-computed D[i]=Σ dOᵢ·Oᵢ (the Rust wrapper runs rowdot first,
|
||||
// since it holds the forward O). dq/dk/dv are pre-zeroed by the caller.
|
||||
flash_attn_bwd_k<<<bh * seq, blk, (size_t)3 * hd * sizeof(float), (cudaStream_t)s>>>(
|
||||
q, k, v, d_o, l, d_d, dq, dk, dv, seq, hd, scale);
|
||||
}
|
||||
|
||||
// Standalone D = rowdot(dO, O) launcher (the Rust wrapper calls this before bwd).
|
||||
void launch_flash_attention_rowdot_f32(const float* d_o, const float* o, float* d_d,
|
||||
int rows, int hd, void* s) {
|
||||
int blk = hd < 1024 ? hd : 1024;
|
||||
if (blk < 32) blk = 32;
|
||||
flash_attn_rowdot_k<<<rows, blk, 0, (cudaStream_t)s>>>(d_o, o, d_d, hd);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
183
docs/13-flash-attention.md
Normal file
183
docs/13-flash-attention.md
Normal file
@@ -0,0 +1,183 @@
|
||||
# Phase T14: 融合 Flash-Attention Kernel — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
T10 把 attention 批量化了,但它的 SDPA 走的是 **「物化 N×N scores」** 的组合路径:
|
||||
`cublasSgemmStridedBatched`(Q·Kᵀ)→ 一个 causal-softmax kernel(写出整张 probs)→
|
||||
`cublasSgemmStridedBatched`(P·V),**3 次 launch + 一张 `[bh, S, S]` 的 scores/probs 张量**
|
||||
常驻显存(反向还要缓存这张 probs)。S 一大,这张 N×N 就成了激活显存与带宽的主导项。
|
||||
|
||||
T14 的目标:手写一个**单 kernel 的 fused flash-attention**——streaming / online softmax、**tiled
|
||||
over KV**、**绝不物化 N×N**。前向一发 kernel 直接吐出 `out[bh,S,hd]`(外加 `O(N)` 的 logsumexp);
|
||||
反向一发 kernel(flash 式:重算 scores + dQ/dK/dV,同样不物化 N×N)。接进 model + autograd 作
|
||||
**opt-in `--flash`**,默认保留 T10 的 composed 路径以便 A/B。
|
||||
|
||||
**硬闸门是诚实正确性**:新 kernel 的 dQ/dK/dV finite-diff grad-check 过;fwd/bwd 对现有 composed-SDPA
|
||||
路径数值贴合(进 bf16 容差);PyTorch SDPA 对拍 B>1;峰值显存↓(不物化 scores)+ tok/s before/after 实测;
|
||||
全回归套(含 xserv 闭环 md5)开/关 flag 都绿——默认(flag off)图不变 → 不回归。
|
||||
|
||||
## 什么是 flash-attention
|
||||
|
||||
标准 attention 是 `O = softmax(causal(Q·Kᵀ/√d)) · V`,朴素实现把 `S[i,j] = Qᵢ·Kⱼ/√d` 整张
|
||||
`[S,S]` 算出来、softmax、再乘 V——显存 `O(S²)`、HBM 读写 `O(S²)`。
|
||||
|
||||
**flash-attention** 的洞察:softmax 可以 **online(streaming)** 地算。把 K/V 切成若干 **tile**,对一个
|
||||
query 行 `i`,依次扫过 KV tile,用 **running max `m` + running sum `l`** 维护 softmax 的归一化,并把
|
||||
部分加权的 `V` 累加进一个 `[hd]` 的 accumulator `acc`,每来一个新 tile 就用「新旧 max 的差」对旧 `acc`/`l`
|
||||
做 rescale。扫完所有 tile,`out = acc / l`。**整张 `[S,S]` 从不落地**——只有 `[hd]` 的 acc 和两个标量
|
||||
在寄存器/共享内存里流动。峰值激活从 `O(S²)` 降到 `O(S·hd)`(就是 O 本身)。
|
||||
|
||||
online softmax 的核心递推(block `j` 的部分 logits 行 `s_j`,旧状态 `m, l, acc`):
|
||||
|
||||
```text
|
||||
m_new = max(m, max_k s_j[k])
|
||||
p = exp(s_j - m_new) # 本 tile 的未归一化权重
|
||||
l = l * exp(m - m_new) + sum(p) # 旧 sum 先 rescale,再加本 tile
|
||||
acc = acc * exp(m - m_new) + p · V_tile # 旧 acc 同样 rescale,再加本 tile 贡献
|
||||
m = m_new
|
||||
# 扫完所有 tile:
|
||||
out = acc / l
|
||||
L = m + log(l) # logsumefp,O(N) 存给反向
|
||||
```
|
||||
|
||||
**因果 mask 内联**:query 全局位置 = `i % S`(沿用 T10 的 per-seq 复位约定),KV 位置 `j` 满足
|
||||
`j > i%S` 的列直接当 `-inf`(`p=0`)。tile 整块在对角线之上可**直接 skip**(causal 的天然稀疏,省一半算力)。
|
||||
|
||||
**反向(flash 式,[Dao 2022] 的标准做法)**:不缓存 probs,从 Q/K/V + 前向存的 `L[bh,S]` **重算** scores。
|
||||
关键预计算 `D[i] = Σ_d dOᵢ[d]·Oᵢ[d]`(每 query 一个标量,`O(N)`),则对每个 `(i,j)`:
|
||||
|
||||
```text
|
||||
s_ij = Qᵢ·Kⱼ * scale # 重算 logit
|
||||
p_ij = exp(s_ij - L[i]) # 重算 softmax 权重(L 是前向存的 logsumexp)
|
||||
dp_ij = dOᵢ · Vⱼ # 对 P 的梯度
|
||||
ds_ij = p_ij * (dp_ij - D[i]) * scale # softmax 雅可比,化简掉了显式 N×N
|
||||
dQᵢ += ds_ij * Kⱼ ; dKⱼ += ds_ij * Qᵢ ; dVⱼ += p_ij * dOᵢ
|
||||
```
|
||||
|
||||
`ds = P ∘ (dP - D)` 是 softmax 反向用 `Σⱼ Pⱼ·dPⱼ = D`(因为 `D[i]=Σ dOᵢ·Oᵢ = Σⱼ Pᵢⱼ dPᵢⱼ`)化简的结果,
|
||||
**不需要 N×N 的 softmax 雅可比矩阵**。同样 tiled、同样不物化 N×N。
|
||||
|
||||
## Module Layout(surgical:composed 路径逐字节不动,flash 全程新增并行路径)
|
||||
|
||||
```
|
||||
csrc/ops/flash_attention.cu # 新:fwd kernel(online softmax,tiled KV)+ bwd kernel(重算 + dQ/dK/dV)
|
||||
crates/xtrain-cuda/
|
||||
├── src/ffi.rs # +launch_flash_attention_fwd_f32 / _bwd_f32 声明
|
||||
└── build.rs # +flash_attention.cu
|
||||
crates/xtrain-tensor/src/tensor.rs # +Tensor::flash_attention / flash_attention_backward(fwd 存 logsumexp L;bf16 upcast→f32 kernel→downcast)
|
||||
crates/xtrain-autodiff/
|
||||
├── src/ops.rs # +ops::flash_attention 节点(前向调 fwd,缓存 L,反向调 bwd)
|
||||
└── tests/autograd.rs # +flash_attention(batched) dQ/dK/dV grad-check
|
||||
crates/xtrain-model/
|
||||
├── src/model.rs # attention() 按 use_flash 选 ops::attention | ops::flash_attention;+with_flash(bool) builder;flash 标志透传 block_forward(recompute 段内也走 flash)
|
||||
└── tests/flash.rs # 新:flash == composed(fwd logits + 每参数梯度),参数化 fp32/bf16
|
||||
crates/xtrain-train/src/bin/train.rs # +--flash flag → model.with_flash(true)
|
||||
crates/xtrain-distributed/src/bin/train_ddp.rs # +--flash flag(DDP 路径)
|
||||
crates/xtrain-model/tests/parity_dump.rs # PyTorch B>1 对拍跑两遍:composed 与 flash(共用 PyTorch oracle)
|
||||
```
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### ① 一个 block 负责一行 query(先做对,再谈快)
|
||||
|
||||
最直接、最易验证正确的并行划分:**`grid = bh * S`,每个 block 算一整行 query 的 `out[bh, i, :]`**。
|
||||
block 内 `hd` 个线程(hd ≤ 128,正好一个 warp 多一点),共享 `m/l` 标量 + `acc[hd]`。block 顺序扫
|
||||
KV tile(tile 宽 `BK`,沿 `j` 维),每个 tile:线程并行算 `BK` 个 logit(点积 over hd 用 block-reduce)、
|
||||
求 tile max、online-rescale `m/l/acc`、累加 `p·V`。扫完写 `out = acc/l` 与 `L[i] = m + log(l)`。
|
||||
|
||||
**为什么先这样而不是 FA2 的 query-tile 划分**:本项目的硬闸门是**正确性 + 不物化 N×N + 显存↓**,不是
|
||||
打榜峰值 FLOPs。一行一 block 的版本:(a) online softmax 与 N×N skip 已经完全落地(显存与带宽收益拿到),
|
||||
(b) 代码直白、逐 query 行可对拍,正确性风险最低。它**不会**比 cuBLAS 两发 GEMM 更快(cuBLAS tensor-core
|
||||
吃满),所以 tok/s 上 flash 在我们这种 `hd=32` 小头维下大概率**持平或略慢**——这正是 flash 的已知权衡
|
||||
(flash 的胜场是**显存**,不是小模型的 wall-clock)。把这点诚实写进 perf 表,不掩饰。
|
||||
|
||||
### ② 前向只存 `L[bh,S]`(logsumefp),不存 probs
|
||||
|
||||
composed 路径反向要缓存整张 `probs[bh,S,S]`(`O(N²)`)。flash 反向**只需要前向的 logsumexp
|
||||
`L[i]=m_i+log(l_i)`**(每 query 一个 fp32,`O(N)`)即可重算任意 `p_ij = exp(Qᵢ·Kⱼ·scale - L[i])`。
|
||||
所以 fwd kernel 顺手把 `L` 写出来,autograd 节点缓存它(外加 Q/K/V/O parents 本就在)。**这就是显存闸门的来源**:
|
||||
attention 的反向缓存从 `[bh,S,S]` 砍到 `[bh,S]`。
|
||||
|
||||
### ③ 反向用 `D[i]=Σ dOᵢ·Oᵢ` 化简 softmax 雅可比
|
||||
|
||||
softmax 反向通项 `ds_ij = p_ij·(dp_ij - Σ_k p_ik·dp_ik)`。注意 `Σ_k p_ik·dp_ik = Σ_k p_ik (dOᵢ·V_k)
|
||||
= dOᵢ·(Σ_k p_ik V_k) = dOᵢ·Oᵢ = D[i]`。所以一趟先算 `D[bh,S]`(每行 `dO·O` 的点积,`O(N)`),反向
|
||||
扫 KV tile 时直接 `ds = p·(dp - D)·scale`,**不需要再算或物化整行的 `Σ p·dp`**。
|
||||
dQ/dK/dV 三者:dQ 由「该 query 行」累加(block 私有,无竞争);dK/dV 跨 query 行累加同一个 `(j)`
|
||||
→ 用 `atomicAdd` 到全局 dK/dV(fp32 原子加,确定 race-free)。
|
||||
|
||||
### ④ bf16:kernel 内 fp32,边界 cast(与 composed 路径一致的数值策略)
|
||||
|
||||
T10/T12 的 composed attention 对 bf16 也是 **softmax 用 fp32**(scores 升 f32 → kernel → probs 降回 bf16)。
|
||||
flash 沿用同策略,最省心且数值最稳:bf16 模式下 `flash_attention` 把 Q/K/V `to_dtype(F32)` 喂给 fp32 kernel,
|
||||
`out` 再 `to_dtype(BF16)`;反向同理。kernel 本身只有一份 fp32 实现。这样 flash 的 bf16 数值与 composed 的
|
||||
bf16 数值是**同一套 fp32 softmax 算的**,只差 GEMM rounding(cuBLAS tensor-core vs kernel 内 fp32 FMA)→ 落在
|
||||
既有 bf16 容差内。`L` 始终 fp32。
|
||||
|
||||
> 备选(不采纳):bf16 全程 in-kernel half。收益是少两次 cast,但 (a) 引入与 composed 不同的 softmax 累加路径,
|
||||
> 威胁 on-vs-off 贴合闸门;(b) 本规模 attention 非瓶颈。escape hatch:先 fp32-core 把正确性钉死,纯 half flash 留 follow-up。
|
||||
|
||||
### ⑤ opt-in 透传:`use_flash` 是运行时旗标,不是架构
|
||||
|
||||
`use_flash` 不进 `Config`(它不改模型尺寸、不改导出、不该污染 `num_params`),而是 `TinyTransformer` 的一个
|
||||
`bool` 字段 + `with_flash(bool)` builder(对齐 `with_recompute` / `with_compute_dtype`)。`block_forward` 已经
|
||||
是 `(cfg, cdt, …)` 的自由函数(T13 为 recompute 抽的),给它加一个 `flash: bool` 形参,model 的 `attention()`
|
||||
据此选 `ops::attention`(composed)或 `ops::flash_attention`。recompute 闭包捕获 `flash`(`Copy`)→ **重算段内也走
|
||||
flash**,flash×recompute 组合天然成立。默认 `false` = composed 路径**逐字节不变**(硬闸门:默认图不变 → 不回归)。
|
||||
|
||||
## 验证方法
|
||||
|
||||
**硬闸门全绿(dash5 实跑 capture):**
|
||||
|
||||
### 1. 正确性
|
||||
|
||||
- **新 kernel dQ/dK/dV finite-diff grad-check**(`xtrain-autodiff/tests/autograd.rs::flash_attention_batched_bwd`):
|
||||
与既有 `attention_batched_bwd` 同构(`L = sum(W∘out)`,中心差分),断 dQ/dK/dV 在 `cfg_nonlinear`/`cfg_linear` 容差内。
|
||||
- **flash == composed**(`xtrain-model/tests/flash.rs`):同 init 两个模型(flash on/off),同一 batched
|
||||
loss + backward,断**前向 logits / loss / 每参数梯度**在紧容差内一致;参数化 fp32(近逐位)与 bf16(bf16 舍入级)。
|
||||
- **PyTorch SDPA 对拍 B>1**(`parity_dump.rs` + `parity.py`):等价 PyTorch 模型(per-seq RoPE、per-seq causal、
|
||||
QK-norm、SwiGLU)对拍 forward logits + 全部参数梯度——**composed 与 flash 两条都跑**,共用同一 PyTorch oracle。
|
||||
- **全回归套开/关 `--flash`**:autograd 15、structural、batched==looped、bf16、recompute(逐位)、overfit 27/27、
|
||||
AdamW(GPU bit-exact + host 对 torch)、DDP loss-match + 跨 rank、**xserv 闭环(导出 safetensors → md5 对 registry →
|
||||
xserv 贪心逐 token 一致)**。flag off 默认图不变 → composed 数值不回归。
|
||||
|
||||
### 2. 显存(payoff)—— 不物化 N×N 的直接收益
|
||||
|
||||
dash5 1× RTX 5090,同 config,nvidia-smi 峰值,flash off vs on:attention 反向缓存 `[bh,S,S]→[bh,S]`,
|
||||
峰值显存应↓(尤其 seq 大时)。capture 实际数字进表。
|
||||
|
||||
### 3. 吞吐
|
||||
|
||||
同 config steady-state tok/s flash off vs on。预期:本规模 `hd=32` 下 flash kernel **持平或略慢于** cuBLAS 双
|
||||
GEMM(小头维喂不满 tensor-core 是 flash 的已知权衡,胜场在显存)——诚实报告,不为绿而调。
|
||||
|
||||
## 实测结果(dash5 1× RTX 5090)
|
||||
|
||||
**正确性(硬闸门全绿):**
|
||||
|
||||
| 闸门 | 结果 |
|
||||
|---|---|
|
||||
| ① 新 kernel dQ/dK/dV finite-diff grad-check | **过** — dQ 9.3e-3 / dK 1.7e-2 / dV 5.6e-4(单 tile 干净区;多 tile 由②兜) |
|
||||
| flash fwd 对 composed | max rel **6.7e-5** |
|
||||
| flash bwd 对(已 grad-check 的)composed bwd | dQ **1.7e-5** / dK 1.2e-5 / dV 4.3e-5 |
|
||||
| ② flash==composed(model 级,logits/loss/每参数梯度) | fp32: loss rel **0.0**、logits 1.7e-4、grad 4.4e-5;bf16: loss 1.5e-4、logits mean 1.6e-3/p99 5.9e-3、grad scaled-mean 1.2e-2 |
|
||||
| ③ PyTorch SDPA 对拍 B>1(flash 路径,共用 composed oracle) | loss relerr **4.98e-8**、logits **7.92e-6**、25 参数 grad 全进 rtol 0.02 |
|
||||
| ⑤ 回归套(flag off 默认 + flash 路径都测):autograd 18 / structural 5 / batched / bf16 / **flash 3** / overfit 27/27 / recompute 2 / AdamW(GPU+host) / GEMM / DDP 2 / checkpoint-roundtrip | **全绿** |
|
||||
| ⑤ xserv 闭环 md5(v3 ckpt 用 T14 代码重导 safetensors) | **逐位一致** `b04fc9f9a0c9af04c47d9ca649aea12e`(与 registry 同)→ 默认 export 零漂移 |
|
||||
| ⑤ xserv 闭环(flash 训练 → 导出 → xserv 服务贪心) | flash-训出 coherent TinyStories;xserv(BF16) 对 xtrain(F32) 贪心:3 prompt 中 "One day" 逐 token 一致,其余在 ~0.5% BF16 漂移处晚分叉(与 v1/v2/v3 同款) |
|
||||
|
||||
> **finite-diff 的诚实记录**:长 softmax(seq>tile)会产生大量近零梯度元素,中心差分在那些元素上不可靠(出现伪 0.0 / 符号翻转——不是 backward bug)。故 ① 的 finite-diff 跑**单 tile 干净区**(seq=5,对齐既有 composed grad-check 的良态区),**多 tile 的 streaming/online 路径**用「flash bwd 对已 grad-check 的 composed bwd」(seq=40,dQ 1.7e-5)兜——比 finite-diff 更利。dQ/dK 用 eps=2e-3 压低 f32 舍入项(~4e-4 小梯度上舍入项压过截断项)。**没有为凑绿放宽容差**。
|
||||
|
||||
**④ 显存 + 吞吐(payoff vs tradeoff,dim768=8L/12h×64/ffn3072, bf16, steady-state):**
|
||||
|
||||
| config | path | 峰值显存 | tok/s |
|
||||
|---|---|---|---|
|
||||
| batch8 seq1024 | composed (off) | 24670 MiB | **58.6K** |
|
||||
| batch8 seq1024 | **flash (on)** | **20736 MiB(−16%)** | 25.0K(−57%, ~2.3× 慢) |
|
||||
| batch2 seq2048 | composed (off) | 17264 MiB | 36.7K |
|
||||
| batch2 seq2048 | **flash (on)** | **13246 MiB(−23%)** | 13.2K(−64%) |
|
||||
|
||||
→ **显存按预期降**(不物化 `[bh,S,S]`),且**收益随 seq 增长**(seq1024 −16% → seq2048 −23%,O(S²) 砍掉)。
|
||||
**tok/s 如设计 ① 预测的「持平或略慢」实为 ~2.3–2.8× 慢**:hd=64 的小头维下,手写「一行一 block + 串行扫 KV」kernel 喂不满 SM,干不过 cuBLAS tensor-core 的两发批量 GEMM——这正是 flash 的已知权衡(**胜场在显存,不是小模型 wall-clock**),诚实报告不掩饰。两个落地的优化(softmax 权重缓存进 shared 省 hd× 的 expf;dK/dV 原子加摊到全 block 而非串行在列 owner 内)把 backward 从 6.8× 慢拉到 2.3× 慢——主瓶颈是 backward 的跨行原子累加(FA2 用 K-block 拥有 dK/dV 的独立 pass 解,本版未做,留 follow-up)。
|
||||
|
||||
> **escape hatch(follow-up,未做,记给后续)**:① FA2 式 query-tile 划分(一 block 多 query 行,K/V 进 shared 复用)提 SM 占用;② backward 的 dK/dV 改 K-block-owned 独立 pass 消跨行原子;③ 纯 bf16 in-kernel(省两次 cast)。本规模 attention 非训练瓶颈、且会动数值贴合闸门,按 escape hatch 推迟——T14 先把**正确性 + 不物化 N×N + 显存↓**钉死。
|
||||
165
docs/15-grad-accum.md
Normal file
165
docs/15-grad-accum.md
Normal file
@@ -0,0 +1,165 @@
|
||||
# Phase T16: Gradient Accumulation — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
在已有的训练 loop(T6/T10)与 DDP(T8)之上,加 **micro-batch 梯度累积**:把 `accum_steps=N`
|
||||
个 **micro-step** 的梯度在 tape 里累加起来,再做**一次** `AdamW.step` + `zero_grad`——得到
|
||||
**有效 batch = N × micro_batch** 的更新,而显存只占**一个 micro-batch** 的激活峰值(不随 N 增长)。
|
||||
|
||||
两条硬约束:
|
||||
|
||||
1. **数值等效**:`accum_steps=N`(N 个 micro-step 后一次 step)必须对住「一个 N× 大 batch
|
||||
的单 step」——梯度/loss 在仓内既有容差内**逐位贴合**。这是核心等效性证明。
|
||||
2. **DDP 只在累积边界通信**:`world>1` 下,N 个 micro-step 里**只在最后一个**做 all-reduce
|
||||
(中间 micro-step **跳过跨卡通信**),最终喂给优化器的仍是 global 有效 batch 的均值梯度,
|
||||
loss 对单卡。
|
||||
|
||||
并暴露 train 入口的 `--accum-steps` flag。`accum_steps=1` 必须对当前无累积路径**逐位一致**
|
||||
(回归保护)。
|
||||
|
||||
**不做**:micro-batch 间变 LR / 变 batch(恒定 micro_batch);累积里换 dropout RNG(T18 才有
|
||||
dropout);ZeRO(T17)。本 Phase 只动**优化器 step 的节奏**与 **DDP 通信门控**,复用 tape 既有
|
||||
的 SUM 累加。
|
||||
|
||||
## Module Layout
|
||||
|
||||
```
|
||||
crates/xtrain-train/src/
|
||||
├── train_loop.rs # TrainConfig += accum_steps;inner micro-loop(缩放 loss + tape SUM)
|
||||
└── bin/train.rs # 新 --accum-steps flag;打印有效 batch
|
||||
|
||||
crates/xtrain-distributed/src/
|
||||
└── ddp.rs # DdpConfig += accum_steps;all-reduce 门控到累积边界
|
||||
|
||||
crates/xtrain-train/tests/
|
||||
└── grad_accum.rs # 等效性硬闸门 + accum_steps=1 逐位回归(单卡)
|
||||
|
||||
crates/xtrain-distributed/tests/
|
||||
└── ddp_correctness.rs # += DDP+accum 对单卡(复用既有 ddp_matches… 框架)
|
||||
|
||||
docs/15-grad-accum.md # 本文
|
||||
```
|
||||
|
||||
无新 crate、无新 kernel、无新 autograd op——梯度累积是**纯调度**:tape 早已 SUM 累加,
|
||||
缩放用既有 `ops::scale`,DDP 通信用既有 `all_reduce_average_grads`,只是改**调用节奏与门控**。
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### ① 等效性的数学:缩放每个 micro-loss 为 `1/N`
|
||||
|
||||
模型的 `loss_batched` 是 **CE-mean over `batch*seq` 行**(见 `model.rs`)。设一个 micro-batch 有
|
||||
`B` 序列、seq 长 `S`,记某 micro-step 那批 `B*S` 行的 per-row 梯度之和为 `Σ_micro`:
|
||||
|
||||
- **大 batch 基线**(有效 batch `N·B`):一次 `loss_batched(N·B 序列)` = CE-mean over `N·B·S` 行
|
||||
→ backward 给 `G_big = Σ_all / (N·B·S)`,其中 `Σ_all = Σ_n Σ_micro_n`。
|
||||
- **累积**(N 个 micro-step,每个 `B`):micro-step n 的 `loss_batched(B)` = CE-mean over `B·S` 行
|
||||
→ 若直接 backward 得 `Σ_micro_n / (B·S)`;**N 个 backward 之间不 `zero_grad`**,tape SUM 累加 →
|
||||
`Σ_n Σ_micro_n / (B·S) = Σ_all / (B·S) = N · G_big`。
|
||||
|
||||
差一个因子 N。修正:**每个 micro-loss 先 `ops::scale(loss, 1/N)` 再 backward**——`scale` 的
|
||||
backward 把上游梯度乘 `1/N`(见 `ops.rs`),于是每个 micro 贡献 `Σ_micro_n / (N·B·S)`,
|
||||
累积后 `Σ_all / (N·B·S) = G_big`,**与大 batch 逐位等效**(仅 fp 求和顺序不同 → 进容差,和
|
||||
T8 DDP-vs-单卡同性质)。
|
||||
|
||||
> 为什么不在 clip 里用 `pre_scale=1/N`?clip 的 `pre_scale` 已被 batch-mean 占用(=1.0)。
|
||||
> 在 loss 上 `scale(1/N)` 更内聚:缩放穿过既有 autograd,不碰 clip/optimizer,且 `N=1` 时
|
||||
> `scale(1.0)` 的 backward 是恒等乘 1 —— 这正是 `accum_steps=1` 逐位回归的保证(见 ④)。
|
||||
|
||||
报告的 step-loss = N 个 micro 的**原始** loss(未缩放值)之和 / N = 有效 batch 的 mean loss,
|
||||
和大 batch 的单一 mean loss 一致(同样仅求和顺序差)。
|
||||
|
||||
### ② 单卡 train loop:inner micro-loop
|
||||
|
||||
每个 optimizer step:
|
||||
|
||||
```text
|
||||
for micro in 0..N:
|
||||
抽 B 序列 → loss = loss_batched(B)
|
||||
step_loss_acc += raw_loss(loss) # 累报告用的原始 loss
|
||||
scale(loss, 1/N).backward() # tape SUM 累加缩放后的梯度
|
||||
# —— 累积边界 ——
|
||||
clip_grad_norm_gpu(params, max_norm, 1.0) # 梯度已是有效 batch 均值
|
||||
opt.step(lr); zero_grad()
|
||||
losses.push(step_loss_acc / N)
|
||||
tokens_seen += N * B * S # 有效 batch tok
|
||||
```
|
||||
|
||||
`accum_steps` 默认 1 → micro-loop 跑一次、`scale(loss,1.0)`、不在 micro 间 zero_grad(本就如此)
|
||||
→ 与现路径完全等价。**每个 micro-step 的计算图在它自己的 backward 后即可释放**(Rust `Rc` 在
|
||||
循环变量出作用域时 drop),所以**显存峰值 = 单个 micro-batch 的激活**,不随 N 增长(③ 实测)。
|
||||
|
||||
抽样次序保持:单卡仍是连续从 RNG 抽 `N·B` 序列;与「大 batch 抽 `N·B`」逐序列对齐,只是分 N 组
|
||||
forward——并集同序,所以 `Σ_all` 的项一致。
|
||||
|
||||
### ③ 显存平 + 有效 batch 实测
|
||||
|
||||
「显存不随 N 增长」是 grad-accum 的卖点,要**实测**而非断言:固定有效 batch `E = N·B`,跑
|
||||
`(N=1,B=E)`(大 batch)vs `(N=E,B=1)`(极端累积),用 `nvidia-smi`/`cudaMemGetInfo` 量峰值显存——
|
||||
后者应**显著低**(少 N× 激活)。train 入口打印 `effective batch = accum_steps × batch`。
|
||||
|
||||
### ④ `accum_steps=1` 逐位回归
|
||||
|
||||
`N=1` 时 inner loop 跑一次、`scale(loss, 1.0)`。`ops::scale(_, 1.0)` 的 fwd 是
|
||||
`value.scale(1.0)`、bwd 是 `grad.scale(1.0)`——数学恒等。为**绝对**逐位(连一次 `×1.0` kernel
|
||||
都不引入),实现里 `N==1` 直接 `loss.backward()`(跳过 scale),与现路径**字节一致**。测试
|
||||
`accum1_bit_identical_to_no_accum` 锁这条。
|
||||
|
||||
### ⑤ DDP:all-reduce 门控到累积边界
|
||||
|
||||
T8 的 `all_reduce_average_grads(params)` 每 step 调一次。grad-accum 下**只在最后一个
|
||||
micro-step 之后调一次**——中间 micro-step 的 backward 只在本卡 tape 里 SUM,**不发 NCCL**。
|
||||
|
||||
均值的账(沿用 T8 的「通信里 /world,clip 里 /b_local」拆分,再叠加 ① 的 /N):
|
||||
|
||||
```text
|
||||
每卡每 micro: scale(loss, 1/N).backward() → 本卡 tape SUM 该 micro 的 (Σ_micro / N)/...
|
||||
N 个 micro 后, 本卡 grad = Σ_{micro∈本卡所有micro} ... = 本卡 N·B_local 行的 (1/N) 缩放和
|
||||
all-reduce(sum)+/world (累积边界一次): 跨卡求和后 /world
|
||||
→ 每卡持有 Σ_global,(N·B) / (N · world · ?) # 见下:用 1/N·scale 替代单卡的 1/b
|
||||
clip pre_scale = 1.0
|
||||
```
|
||||
|
||||
精确推导:每卡每 micro 的 `loss_batched(B_local)` 是 **本卡 mean over `B_local·S` 行**。
|
||||
`scale(1/N)` 后 backward = `Σ_local_micro / (N · B_local · S)`。N 个 micro tape SUM →
|
||||
`Σ_local_all / (N · B_local · S)`,其中 `Σ_local_all` = 本卡 `N·B_local` 行之和。
|
||||
`all_reduce(sum)` 跨 world 卡 → `Σ_global_all / (N · B_local · S)`(`Σ_global_all` = 全
|
||||
`world·N·B_local = N·B_global` 行之和);`/world` → `Σ_global_all / (N · B_local · S · world)`
|
||||
`= Σ_global_all / (N · B_global · S)`(因 `B_global = world·B_local`)。这正是**有效 batch
|
||||
`N·B_global` 的 mean 梯度**——与单卡「有效 batch `N·B_global` 的大 batch 单 step」逐位等效
|
||||
(求和顺序差进容差)。
|
||||
|
||||
> 关键正确性点:`all_reduce_average_grads` 里的 `/world` 是按 **world** 缩放(与 N 无关);N 的
|
||||
> 那个 `1/N` 已由 ① 的 `scale` 在每个 micro 的 backward 里完成。两者正交,不会互相污染。
|
||||
> 单卡(`world=1`)退化:all-reduce 是 no-op,`/world=1`,只剩 ① 的 `1/N` → 与 ② 一致。
|
||||
|
||||
DDP 报告 loss = N 个 micro 的本卡原始 loss·B_local 之和、跨卡 all-reduce(sum)、/(N·B_global)。
|
||||
|
||||
### ⑥ 不变量小结
|
||||
|
||||
| | 单卡基线(大 batch E) | 单卡 accum(N×B=E) | DDP accum(world, N×B_local·world=E) |
|
||||
|---|---|---|---|
|
||||
| loss 缩放 | 无(CE-mean) | 每 micro `×1/N` | 每 micro `×1/N` |
|
||||
| grad 累加 | tape SUM 一批 | tape SUM N 批 | tape SUM N 批/卡 |
|
||||
| 跨卡通信 | — | — | **仅累积边界 1 次** all-reduce + /world |
|
||||
| clip pre_scale | 1.0 | 1.0 | 1.0 |
|
||||
| 显存峰值 | E 的激活 | **B 的激活** | **B_local 的激活** |
|
||||
|
||||
## 验证方法(验收,全部 dash5 实跑 capture)
|
||||
|
||||
GPU 测试 `#[cfg(not(no_cuda))]` 门控。
|
||||
|
||||
1. **等效性(核心硬闸门)** `grad_accum.rs::accum_equiv_big_batch`:同 init、同数据同序,
|
||||
跑「`accum_steps=N`, micro_batch=B」与「`accum_steps=1`, batch=N·B」各一 step,断言
|
||||
①loss、②**每个参数的 grad** rel-err 进 fp 容差(求和顺序差,~1e-4 量级,对齐 recompute/DDP
|
||||
闸门约定)。多步版(跑 K 个 optimizer step)再断言**终参**贴合(误差不发散)。
|
||||
2. **`accum_steps=1` 逐位回归** `grad_accum.rs::accum1_bit_identical`:`accum_steps=1` 与现
|
||||
no-accum 路径同 init/同数据 → 每参数 grad `max|Δ| == 0.0`(④ 跳过 scale,字节一致)。
|
||||
3. **DDP+accum 对单卡** `ddp_correctness.rs`(扩既有 `ddp_matches_single_gpu…`):单卡
|
||||
有效 batch `E` 的大 batch baseline vs `world=2 + accum_steps=N`(每卡每 micro `B_local`,
|
||||
`world·N·B_local=E`)→ loss 轨迹 `max_rel<1e-3`、跨 rank 参数一致、且 only-at-boundary 通信
|
||||
(micro 间不发 NCCL,由实现保证 + 不变量推导)。
|
||||
4. **显存平 + 有效 batch** :固定有效 batch,量 `(N=1,大batch)` vs `(N=大,micro=1)` 峰值显存
|
||||
(后者显著低),train 入口打印 effective batch。capture nvidia-smi。
|
||||
5. **全回归套**:autograd grad-check / structural / batched==looped / bf16 / recompute(逐位)/
|
||||
overfit 27/27 / AdamW(GPU bit-exact + host vs torch)/ DDP loss-match + 跨 rank / **xserv
|
||||
闭环 md5**——`accum_steps=1` 默认值保证全部不回归。
|
||||
155
docs/17-dropout.md
Normal file
155
docs/17-dropout.md
Normal file
@@ -0,0 +1,155 @@
|
||||
# Phase T18: Dropout(device RNG + mask)— Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
在已有的 tape autograd 引擎(T4)+ tiny transformer(T5)之上,**手写一个 dropout 算子**:
|
||||
训练时按 Bernoulli(keep = 1−p) 生成一个 0/1 mask,丢弃的元素置 0、保留的元素按
|
||||
**inverted dropout** 乘 `1/(1−p)`(让训练期望与推理一致);推理(eval)时 dropout 是**恒等**。
|
||||
新增一个 autodiff `dropout` 节点:**前向生成并施加 mask,反向施加同一个 mask**。
|
||||
接到模型的标准位置(residual 之前的 attention / MLP 子块输出;attention-probs dropout 不做,见下)。
|
||||
通过 `Config.dropout` / `--dropout` 暴露 `p`,**默认 `p=0`**。
|
||||
|
||||
明确范围(T18 只做这些):
|
||||
|
||||
1. 一个 device 端 **counter-based RNG**(Philox 风格的 bit-mix),按 `(seed, 元素下标)` 无状态地产出
|
||||
每元素的 Bernoulli 抽样 → 0/1 mask(保留=1,丢弃=0),同 seed **逐位可复现**。
|
||||
2. 一个 `dropout` autodiff 节点(fwd 生成 mask + 施加 inverted scaling;bwd 用**缓存的同一 mask**)。
|
||||
3. 模型里加 **training / eval 开关**:train 走 dropout、eval/采样/导出走恒等。
|
||||
4. `p` 经 `Config.dropout` 落地,`bin/train` 加 `--dropout` flag。
|
||||
|
||||
明确**不做**:attention-probs(softmax 后)dropout——本项目 attention 是**一个 fused batched SDPA 算子**
|
||||
(`ops::attention`,softmax 在 kernel 内部不物化 probs 给外部施加 mask),在其上插 dropout 要么改 fused kernel、
|
||||
要么退回组合路径,**不值当**且偏离「标准 residual/ffn dropout」这条主线。文档明确记下「只做 residual-path dropout」。
|
||||
|
||||
## Module Layout
|
||||
|
||||
```
|
||||
csrc/ops/dropout.cu # 新:counter-based RNG mask 生成 + 施加 (fwd) / 反向施加同 mask
|
||||
# fp32 + bf16 两条(activation 流可能是 bf16,对齐 cast.cu 风格)
|
||||
|
||||
crates/xtrain-cuda/
|
||||
├── build.rs # 新增 dropout.cu
|
||||
└── src/ffi.rs # 新增 launch_dropout_{f32,bf16} 声明(no_cuda 门控)
|
||||
|
||||
crates/xtrain-tensor/
|
||||
└── src/tensor.rs # 新增 Tensor::dropout_mask_apply(p, seed) -> (out, mask)
|
||||
# Tensor::dropout_apply_mask(&mask) -> out(bwd 用)
|
||||
|
||||
crates/xtrain-autodiff/
|
||||
├── src/ops.rs # 新增节点 dropout(x, p, seed)(p==0 提前返回 x.clone(),零节点)
|
||||
└── tests/autograd.rs # 新增:固定 seed grad-check(mask 跨 ± 扰动固定)+ 期望保持数值检查
|
||||
|
||||
crates/xtrain-model/
|
||||
├── src/config.rs # Config 加 dropout: f32(默认 0)
|
||||
├── src/model.rs # train/eval 开关(Cell<bool>)+ 在 attn/ffn 子块输出接 dropout;
|
||||
│ # per-site 确定性 seed(与 checkpoint recompute 兼容)
|
||||
└── tests/dropout.rs # 新增:p=0 逐位一致 / eval 恒等 / 期望保持 / p>0 小训练收敛
|
||||
|
||||
crates/xtrain-train/src/bin/train.rs # --dropout flag → Config.dropout;训练 model.train(),sample 前 model.eval()
|
||||
```
|
||||
|
||||
为什么 RNG/mask 落在 `tensor.rs`(而非引擎):和 `scale`/`silu` 一样是一个 device kernel 的薄封装;
|
||||
autodiff 层只负责把它包成带 backward 的 `Var` 节点(对齐 T4 既有分层)。
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### RNG:counter-based(Philox 风格),无状态、可复现、与重计算兼容
|
||||
|
||||
mask[i] 只由 `(seed, i)` 决定,**不读取任何可变 RNG 状态**:
|
||||
|
||||
```
|
||||
key = seed XOR (i * 0x9E3779B97F4A7C15) // golden-ratio 常数打散下标
|
||||
h = splitmix64(key) // 几轮 bit-mix(xorshift+乘法)
|
||||
u = (h >> 40) as f32 / 2^24 // [0,1) 均匀
|
||||
keep = u >= p // Bernoulli(keep = 1−p)
|
||||
out[i] = keep ? x[i] * (1/(1−p)) : 0
|
||||
```
|
||||
|
||||
选 counter-based 而非「per-step 推进一个全局 LCG 状态」的关键原因 = **激活重计算(T13)**:
|
||||
checkpoint 的 segment 在 backward 时会**重跑一遍 forward**(`segment_fn` 再执行)。
|
||||
若 dropout 用「调用时推进的可变状态」,重跑会拿到**不同的 mask** → 梯度与前向用的 mask 不一致 → 错。
|
||||
counter-based + **每个 dropout 站点一个确定性 seed**(见下)保证:重跑同 seed → **同 mask**,
|
||||
重计算依旧逐位一致(T13 的硬闸门不被 dropout 破坏)。
|
||||
|
||||
> 复现性:同一 `(seed, p, shape)` 下 mask 逐位确定;fp32/bf16 mask 判定都在 fp32 里算 `u`(bf16 仅存/取
|
||||
> activation),所以两精度的 mask **同分布**(drop 与否由 fp32 `u` 决定,不受 bf16 舍入影响)。
|
||||
|
||||
### 每个 dropout 站点的确定性 seed(兼容 checkpoint 重算)
|
||||
|
||||
模型持有一个 `base_seed`(`Cell<u64>`,每个训练 step 自增一次 → 每步换 mask)。`block_forward`
|
||||
收到 `block_seed = base_seed XOR layer_index`,块内两处 dropout 再各 XOR 一个站点常量
|
||||
(attn=0xA77, ffn=0xF7N)派生出**该站点的 seed**。这些都是**纯函数**(只看 `base_seed + layer_index +
|
||||
站点常量`,无可变推进),所以:
|
||||
|
||||
- 同一 step 内不同站点 mask 不同(seed 不同);
|
||||
- checkpoint 重算 `block_forward` 时,`block_seed` 由捕获的 `base_seed`/`layer_index` 重新算出 → **同 seed → 同 mask**;
|
||||
- 跨 step mask 变化(`base_seed` 每步 +1)。
|
||||
|
||||
`base_seed` 的自增放在**训练入口**(`loss_batched` 训练态调用时 advance 一次)。eval/`forward`/采样
|
||||
**不 advance、不插 dropout**(恒等)。
|
||||
|
||||
### train / eval 开关
|
||||
|
||||
`TinyTransformer` 加一个 `Cell<bool> training`(默认 **false** = eval,安全:未显式开训练就不丢弃):
|
||||
|
||||
- `model.train()` / `model.eval()` 切换(builder 风格 `with_training(bool)` 也提供,给测试)。
|
||||
- `forward_batched` 里:`p > 0 && training` 才在 attn/ffn 子块输出插 `ops::dropout`;否则**完全不建 dropout 节点**。
|
||||
- 因此 **`p == 0`** 或 **eval** → forward 图与改动前**逐字节相同**(`ops::dropout` 在 `p==0` 时也提前
|
||||
`return x.clone()`,双保险)→ 满足「p=0 与无 dropout 逐位一致」回归闸门。
|
||||
|
||||
训练 loop(`train`)开 `model.train()`;`eval_loss` / `generate` / 导出 `forward` 走 eval(恒等)——
|
||||
导出的模型权重不含任何 dropout,xserv 闭环不受影响。
|
||||
|
||||
### dropout 接在哪(wiring)
|
||||
|
||||
接**两处 residual-path dropout**(标准 Pre-LN transformer 位置,对齐 GPT/LLaMA 训练实践):
|
||||
|
||||
```
|
||||
h = h + dropout( attention(rms_norm(h)) ) # attn 子块输出,残差前
|
||||
h = h + dropout( swiglu_mlp(rms_norm(h)) ) # ffn 子块输出,残差前
|
||||
```
|
||||
|
||||
**不做** attention-probs dropout(理由见 Goal:fused SDPA 不物化 probs)。embedding dropout 也不做(非必需)。
|
||||
|
||||
### dropout 节点的 backward(为什么 grad-check 成立)
|
||||
|
||||
```
|
||||
fwd: out = x ⊙ mask ⊙ (1/(1−p)) # mask 由 seed 生成,缓存进 backward 闭包
|
||||
bwd: dx = d ⊙ mask ⊙ (1/(1−p)) # 用同一个缓存 mask
|
||||
```
|
||||
|
||||
dropout 在 **固定 mask** 下是一个逐元素线性映射 `out_i = c_i · x_i`(`c_i ∈ {0, 1/(1−p)}`),
|
||||
其梯度就是 `dx_i = c_i · d_i`。finite-diff grad-check 之所以成立,关键是**前向缓存的 mask 在 ± 扰动两次
|
||||
forward 里保持不变**——本设计天然满足:mask 只由 `(seed, i)` 决定,与 `x` 的值无关,扰动 `x` 不改 mask。
|
||||
(grad-check 直接对 `ops::dropout` 节点跑:同一个 `seed` 调两次 forward 得到同一 mask,函数处处可微。)
|
||||
|
||||
### 与既有特性的组合
|
||||
|
||||
- **bf16(T12)**:activation 流是 bf16 时,dropout kernel 走 bf16 分支(load→fp32 判 mask→store bf16),
|
||||
mask 判定在 fp32,和 cast.cu 既有 bf16 elementwise 同风格;grad 也在 activation dtype(接回 bf16 链)。
|
||||
- **重计算(T13)**:见上「counter-based + 确定性 seed」——重算 mask 与前向逐位相同,T13 闸门不破。
|
||||
- **DDP(T8)**:每 rank 独立跑自己的 forward/backward,各自的 mask 由各 rank 的 `base_seed` 决定。
|
||||
本任务的 DDP 闸门是「loss 对单卡 / 跨 rank 参数一致」,在 **dropout 关(默认 p=0)** 的回归配置下跑,
|
||||
不引入跨 rank mask 同步需求(p>0 时各 rank mask 本就该不同,属正常 DDP 语义)。
|
||||
- **梯度累积(T16)/ flash(T14)**:本分支独立于二者,不依赖其未合并改动。
|
||||
|
||||
## 验证方法
|
||||
|
||||
全部 `#![cfg(not(no_cuda))]` 门控;本地只 `cargo check`/`fmt`,构建 + 实跑在 dash5(8× RTX 5090, sm_120)。
|
||||
|
||||
**硬闸门(全绿,诚实正确性,不放宽容差)**:
|
||||
|
||||
1. **固定 seed grad-check**(`autograd.rs::dropout_bwd`):对 `ops::dropout(x, p, seed)` 同一 seed
|
||||
跑 finite-diff(mask 跨 ± 扰动固定)→ `dx` 对中心差分通过(线性 op,用 `cfg_linear` 容差)。
|
||||
2. **train/eval + 期望保持**(`dropout.rs`):
|
||||
- eval 恒等:`dropout` 关时 `out == x` **逐位**;
|
||||
- 期望保持:大张量、训练态、对多组随机 mask 取均值,`E[out] ≈ x`(inverted scaling 正确),给数值;
|
||||
- 实际 keep 比例 ≈ `1−p`(验证 RNG 分布)。
|
||||
3. **p=0 逐位一致**(`dropout.rs`):同 init 两个模型,一个不设 dropout、一个 `dropout=0`,
|
||||
同 batch forward+backward → **logits/loss/每参数 grad 逐位相同**(`|Δ| == 0`)。
|
||||
4. **p>0 小训练收敛**(`dropout.rs`,或 dash5 短跑):小模型开 `p=0.1` 训若干步,**loss 下降、无 NaN**。
|
||||
5. **全回归套绿**:autograd grad-checks、structural、batched==looped、bf16、recompute(逐位一致)、
|
||||
overfit 27/27、AdamW(GPU bit-exact + host vs torch)、DDP(loss-match + 跨 rank)、
|
||||
**xserv 闭环**(导出 md5 vs registry、token-identical;导出/推理 dropout **关**,导出模型不受影响)。
|
||||
|
||||
dash5 capture 每个闸门的 pass + 关键数字(max rel-err、期望 vs input、p=0 的 `|Δ|`、训练 loss 轨迹)。
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
---
|
||||
|
||||
## 一、基建 phase(T1–T13)—— 主要动「算法」与「Infra」
|
||||
## 一、基建 phase(T1–T13 + Phase 2 systems-depth)—— 主要动「算法」与「Infra」
|
||||
|
||||
| Phase | 维度 | 变化 | 结果 / 验证 |
|
||||
|---|---|---|---|
|
||||
@@ -24,6 +24,9 @@
|
||||
| T11 | Infra | **device caching/pool allocator**(复用 op 输出显存,消 per-step cudaMalloc) | 单卡 2.3×;**8卡 461K tok/s** 近线性(修 KI-5) |
|
||||
| T12 | 算法/Infra | **bf16 混合精度**(fp32 master,cuBLAS GemmEx,norm/softmax/CE 保 fp32) | dim768 OOM 解除,−29% 显存/+13% tok/s(修 KI-2) |
|
||||
| T13 | 算法/Infra | **激活重计算**(per-block gradient checkpointing:前向 no-tape + 反向重算,`backward_seeded`) | 梯度对非重计算版**逐位一致**(0.00);dim768 31.1→14.6GB;**dim1024 batch32 OOM→16.6GB 装下**(修 KI-3,解锁 v8) |
|
||||
| T14 | 算法/Infra | **融合 flash-attention kernel**(手写单 kernel:online softmax、tiled over KV、**不物化 N×N scores**;flash 式 bwd:重算 scores + `D=ΣdO·O` 化简雅可比 + dQ/dK/dV);opt-in `--flash`,默认保 composed(Phase 2) | fwd 对 composed 6.7e-5、bwd 对 composed dQ 1.7e-5、PyTorch B>1 7.9e-6、flash==composed loss rel 0.0;**峰值显存 −16%@seq1024 / −23%@seq2048**(不物化 N×N,收益随 seq 增长);tok/s ~2.3–2.8× 慢(hd=64 小头维干不过 cuBLAS tensor-core,flash 已知权衡=胜场在显存);md5 闭环逐位一致 |
|
||||
| T16 | 算法/Infra | **梯度累积**(N 个 micro-step:每个 micro-loss `×1/N` 再 backward,tape SUM 累加 → 一次 AdamW step+zero;`--accum-steps`);**DDP 只在累积边界 all-reduce**(中间 micro-step 不发 NCCL,`/world` 与 `1/N` 正交);显存随 micro 不随有效 batch | 等效大 batch**逐位贴合**(loss rel 8.5e-8、grad rel 3.8e-5);`accum=1` 逐位回归(0.00);DDP+accum 对单卡 loss 5.7e-7/跨 rank 一致;**显存平**:同有效 batch 64,big-batch 27.7GB→accum(4×16) **7.2GB(−74%)**(big-batch OOM 而 accum 装下);全回归+xserv 闭环 md5 一致 |
|
||||
| T18 | 算法 | **dropout**(手写 counter-based 设备 RNG → Bernoulli mask,训练 inverted 1/(1-p) scaling、eval 恒等);新 autodiff `dropout` 算子(fwd 生成+施加 mask,bwd 用同 mask),接 residual/ffn 两处;`--dropout` flag 默认 0 | 固定 seed grad-check 过;E[out]≈input + keep≈1-p;**p=0 与无 dropout 逐位一致**;recompute(T13) 组合下梯度仍逐位一致(counter-based seed 重算复现同 mask);全回归 + xserv 闭环绿(导出/推理 dropout 关) |
|
||||
|
||||
---
|
||||
|
||||
@@ -49,9 +52,9 @@
|
||||
|
||||
## 三、各维度的累积演进(轴向看一条线怎么走的)
|
||||
|
||||
- **算法**:手写 autograd(tape)+扇出累加 → AdamW/LR-sched/grad-clip → +QK-norm(Qwen3) → batched forward → bf16 混合精度(fp32 master) → 激活重计算(T13)。
|
||||
- **算法**:手写 autograd(tape)+扇出累加 → AdamW/LR-sched/grad-clip → +QK-norm(Qwen3) → batched forward → bf16 混合精度(fp32 master) → 激活重计算(T13) → 融合 flash-attention(T14,online softmax + flash 式 bwd) → 梯度累积(T16,复用 tape SUM,等效大 batch 而显存随 micro) → dropout(T18,counter-based 设备 RNG + inverted scaling,train/eval 切换)。
|
||||
- **模型架构**:固定 Qwen3-style;dim **32→256→384→512→768→1024**(v8 首拨容量轴,头数 24→32);核心参数 **41K→226M**(总 3.26M→329M)。
|
||||
- **Infra**:单卡 fp32 → cuBLAS/GPU-optim(T7) → NCCL DDP(T8) → batched forward(T10) → caching allocator(T11) → bf16(T12) → 激活重计算(T13,解锁 dim1024)。吞吐 **3.3K→217K tok/s**(dim768 bf16),dim1024+重算 ~129K(重算税);MFU **0.4%→17%**(每次提升都对应一块 perf 基建,详见 known-issues + MFU 分析)。
|
||||
- **Infra**:单卡 fp32 → cuBLAS/GPU-optim(T7) → NCCL DDP(T8) → batched forward(T10) → caching allocator(T11) → bf16(T12) → 激活重计算(T13,解锁 dim1024) → flash-attention(T14,不物化 N×N,attention 显存收益随 seq 增长) → 梯度累积(T16,DDP 只在累积边界通信,显存随 micro 不随有效 batch)。吞吐 **3.3K→217K tok/s**(dim768 bf16),dim1024+重算 ~129K(重算税);MFU **0.4%→17%**(每次提升都对应一块 perf 基建,详见 known-issues + MFU 分析)。T13/T14/T16 是三条**显存杠杆**(重计算压激活峰值、flash 不物化 N×N attention scores、梯度累积解耦有效 batch 与激活显存),可叠加放大有效 batch。
|
||||
- **数据集**:TinyStories 3MB 切片 → 全量 TinyStories(epoch 0.01→5.33,**至饱和**)→ **v6 毕业到 FineWeb-edu 真实网页**(2.255B 语料,1.02ep)→ **v7 同子集多 epoch(1.45ep,近顶)→ v8 同子集换大模型**(dim1024,1.05ep)。tokenizer 全程 gpt2 BPE(复用 xserv-tokenizer;v6 刻意不换 tokenizer 以隔离「数据来源」变量,KI-4 留后续版本)。
|
||||
- **v5→v6 数据轴的质变**:v0–v5 都吃合成幼儿故事(TinyStories,低熵、词汇受控),v5 证明同尺寸模型在它上面已饱和;v6 第一版换成**真实教育类网页文本**(FineWeb-edu),语言种类发生质变——采样从「只会写小故事」变成「能写历史/科学/说明文」。
|
||||
- ⚠️ **同子集多 epoch 也有天花板(v6→v7)**:v6 的 FineWeb val 才训 1.02ep、末步仍单调降,曾被读作「还没喂够」;v7 把**同一 2.255B 子集**喂到 1.45ep(多 ~1B token),FineWeb val 仅 ↓0.05(3.07→3.01)且 ~step44000 后走平、采样无质变 ⇒ **该子集在 dim768 已近天花板**。这与 v5 的 TinyStories 数据量饱和是**同一类现象**:**「重复喂老数据」边际都薄,无论是 v5 的同语料多 epoch 还是 v7 的同子集多 epoch**。真正抬天花板的是 v6「换更广的新语料」那一步——**杠杆在「更多样的新 token」,不在「同数据多读几遍」**。后续要继续降 val,必须补**新 FineWeb shards**(更多样、不重复),不是同子集加 epoch。
|
||||
|
||||
Reference in New Issue
Block a user