Merge t14-flash-attention into main

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-18 00:35:46 +08:00
14 changed files with 1096 additions and 11 deletions

View File

@@ -50,9 +50,12 @@ Each phase: design doc + implementation + tests + a scoped commit (see [`docs/`]
| **T11** | **device caching allocator** (fixes KI-5) | single-GPU 2.3×; **8-GPU 461K tok/s** | | **T11** | **device caching allocator** (fixes KI-5) | single-GPU 2.3×; **8-GPU 461K tok/s** |
| **T12** | **bf16 mixed precision** (fp32 master, fixes KI-2) | dim768 OOM solved; 29% mem | | **T12** | **bf16 mixed precision** (fp32 master, fixes KI-2) | dim768 OOM solved; 29% mem |
| **T13** | **activation recompute** / checkpointing (fixes KI-3) | dim1024 fits; grads bit-identical | | **T13** | **activation recompute** / checkpointing (fixes KI-3) | dim1024 fits; grads bit-identical |
| **T14** | **fused flash-attention** kernel (online softmax, no materialized N×N; opt-in `--flash`) | peak mem 16%@1k / 23%@2k seq; flash==composed (grads/PyTorch) |
The four performance fixes (T10T13) each removed a real bottleneck — see The four performance fixes (T10T13) each removed a real bottleneck — see
[`docs/known-issues.md`](docs/known-issues.md). [`docs/known-issues.md`](docs/known-issues.md). **Phase 2 (systems-stack depth, T14)**
revisits hand-writing deferred training-stack features; T14 = the fused
flash-attention kernel ([`docs/13-flash-attention.md`](docs/13-flash-attention.md)).
## The scaling study — v0 → v8 ## The scaling study — v0 → v8

View File

@@ -325,6 +325,32 @@ pub fn attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
) )
} }
/// Fused FLASH causal scaled-dot-product attention (Phase T14). Same interface as
/// [`attention`] (`q`,`k`,`v` each `[bh, seq, head_dim]`), but the forward is a
/// SINGLE fused kernel with an online softmax over KV tiles — the `[bh,seq,seq]`
/// score matrix is NEVER materialized, and backward caches only the per-row
/// logsumexp (O(N)) instead of the whole probs (O(N²)). Mathematically the same
/// SDPA, so it matches the composed [`attention`] within fp/bf16 tolerance.
/// Opt-in via the model's `--flash` flag; the composed path stays the default.
pub fn flash_attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
let (out, lse) = q.value().flash_attention(&k.value(), &v.value(), scale);
let out_cache = out.clone();
Var::from_op(
out,
vec![q.clone(), k.clone(), v.clone()],
Box::new(move |dout, parents| {
let q = parents[0].value();
let k = parents[1].value();
let v = parents[2].value();
let (dq, dk, dv) =
Tensor::flash_attention_backward(&q, &k, &v, &out_cache, &lse, dout, scale);
Var::push_grad(&parents[0], dq);
Var::push_grad(&parents[1], dk);
Var::push_grad(&parents[2], dv);
}),
)
}
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per /// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`, /// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
/// scaled by the upstream scalar grad. /// scaled by the upstream scalar grad.

View File

@@ -625,6 +625,157 @@ fn attention_batched_bwd() {
); );
} }
// ---- fused FLASH causal attention (the T14 op) ----
// Same structure + dimensions as attention_batched_bwd (bh=2,seq=5,hd=6), but
// exercises ops::flash_attention. Grad-check dq/dk/dv against finite-diff of
// L=sum(W∘out). This is the SINGLE-tile regime (seq<FA_TILE=32), matching the
// trusted composed grad-check's clean near-zero behavior; the MULTI-tile online-
// softmax path (seq>FA_TILE) is validated against the already-grad-checked
// composed backward by `flash_bwd_matches_composed_bwd` (seq=40) — sharper than
// finite-diff, which is unreliable on the near-zero grad elements a long softmax
// produces.
#[test]
fn flash_attention_batched_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 5, 6);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
// Scale Q/K up so the softmax is non-uniform (sharper attention) → the dQ/dK
// gradients are well-conditioned, not the near-zero saddle values a uniform
// softmax produces (those make central finite-diff give spurious 0.0 / sign
// flips that aren't backward bugs — cf. flash_bwd_matches_composed_bwd).
let q_h: Vec<f32> = fill(n, 241).iter().map(|v| v * 2.5).collect();
let k_h: Vec<f32> = fill(n, 242).iter().map(|v| v * 2.5).collect();
let v_h = fill(n, 243);
let w = fill(n, 244);
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
let out = ops::flash_attention(&q, &k, &v, scale);
scalar_loss(&out, &w).backward();
let dq = q.grad().unwrap().to_device(Device::Cpu);
let dk = k.grad().unwrap().to_device(Device::Cpu);
let dv = v.grad().unwrap().to_device(Device::Cpu);
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
let qv = cuda(qh, &[bh, seq, hd]);
let kv = cuda(kh, &[bh, seq, hd]);
let vv = cuda(vh, &[bh, seq, hd]);
let (o, _) = qv.flash_attention(&kv, &vv, scale);
weighted_sum(&o, &w)
};
// Attention dQ/dK carry softmax curvature; for the small grad magnitudes here
// a larger eps (2e-3) cuts the f32 rounding term (∝|L|/eps) that dominates the
// O(eps²) truncation on a ~4e-4 grad. (dV is exactly linear → cfg_linear.)
let cfg_attn = GradCheckConfig {
eps: 2e-3,
rel_tol: 3e-2,
atol: 1e-3,
};
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
report(
"flash dQ",
&grad_check(&q_h, &[bh, seq, hd], &lq, dq.as_slice::<f32>(), cfg_attn),
);
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
report(
"flash dK",
&grad_check(&k_h, &[bh, seq, hd], &lk, dk.as_slice::<f32>(), cfg_attn),
);
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
report(
"flash dV",
&grad_check(
&v_h,
&[bh, seq, hd],
&lv,
dv.as_slice::<f32>(),
cfg_linear(),
),
);
}
// flash forward must equal the composed attention forward (same SDPA math).
#[test]
fn flash_matches_composed_fwd() {
require_gpu();
let (bh, seq, hd) = (2, 40, 16);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q = cuda(&fill(n, 341), &[bh, seq, hd]);
let k = cuda(&fill(n, 342), &[bh, seq, hd]);
let v = cuda(&fill(n, 343), &[bh, seq, hd]);
let (oc, _) = q.attention(&k, &v, scale);
let (of, _) = q.flash_attention(&k, &v, scale);
let oc = oc.to_device(Device::Cpu);
let of = of.to_device(Device::Cpu);
let max_rel = oc
.as_slice::<f32>()
.iter()
.zip(of.as_slice::<f32>())
.map(|(c, f)| (c - f).abs() / (c.abs() + 1e-6))
.fold(0.0f32, f32::max);
println!("flash-vs-composed fwd max rel: {max_rel:.3e}");
assert!(
max_rel < 1e-4,
"flash fwd diverges from composed: {max_rel:.3e}"
);
}
// flash backward must equal the (already grad-checked) composed backward. This is
// a sharper test than finite-diff: both share the trusted composed forward as the
// reference, so it isolates the flash bwd dQ/dK/dV math from finite-diff noise on
// near-zero gradient elements.
#[test]
fn flash_bwd_matches_composed_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 40, 16);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q_h = fill(n, 441);
let k_h = fill(n, 442);
let v_h = fill(n, 443);
let w = fill(n, 444);
let run = |flash: bool| -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
let out = if flash {
ops::flash_attention(&q, &k, &v, scale)
} else {
ops::attention(&q, &k, &v, scale)
};
scalar_loss(&out, &w).backward();
let g = |x: &Var| {
x.grad()
.unwrap()
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
};
(g(&q), g(&k), g(&v))
};
let (cq, ck, cv) = run(false);
let (fq, fk, fv) = run(true);
let maxrel = |a: &[f32], b: &[f32]| -> f32 {
a.iter()
.zip(b)
.map(|(x, y)| (x - y).abs() / (x.abs() + y.abs() + 1e-4))
.fold(0.0f32, f32::max)
};
let (rq, rk, rv) = (maxrel(&cq, &fq), maxrel(&ck, &fk), maxrel(&cv, &fv));
println!("flash-vs-composed bwd max rel: dQ {rq:.3e} dK {rk:.3e} dV {rv:.3e}");
assert!(rq < 2e-2, "dQ diverges: {rq:.3e}");
assert!(rk < 2e-2, "dK diverges: {rk:.3e}");
assert!(rv < 2e-2, "dV diverges: {rv:.3e}");
}
// --- test helpers --- // --- test helpers ---
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We // Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We

View File

@@ -36,6 +36,7 @@ fn main() {
.file("../../csrc/ops/model.cu") .file("../../csrc/ops/model.cu")
.file("../../csrc/ops/optim.cu") .file("../../csrc/ops/optim.cu")
.file("../../csrc/ops/attention.cu") .file("../../csrc/ops/attention.cu")
.file("../../csrc/ops/flash_attention.cu")
.file("../../csrc/ops/cast.cu") .file("../../csrc/ops/cast.cu")
.compile("xtrain_cuda_kernels"); .compile("xtrain_cuda_kernels");
} }

View File

@@ -243,6 +243,59 @@ unsafe extern "C" {
); );
} }
// Fused flash-attention (csrc/ops/flash_attention.cu, Phase T14). A SINGLE kernel
// each for forward/backward that streams over KV tiles with an online softmax and
// NEVER materializes the [bh,S,S] score matrix. Q/K/V/out are [bh,S,hd] row-major
// F32; the forward saves only the per-row logsumexp `l` ([bh*S], O(N)) for backward.
#[cfg(not(no_cuda))]
unsafe extern "C" {
// Forward: o[bh,S,hd] = softmax(causal(Q·Kᵀ·scale))·V, online over KV tiles.
// Also writes l[bh*S] = per-row logsumexp (saved for backward, not the scores).
#[allow(clippy::too_many_arguments)]
pub fn launch_flash_attention_fwd_f32(
q: *const f32,
k: *const f32,
v: *const f32,
o: *mut f32,
l: *mut f32,
bh: i32,
seq: i32,
hd: i32,
scale: f32,
s: CudaStream,
);
// Per-row D[i]=Σ_d dO[i,d]·O[i,d] over `rows`=bh*S rows of width `hd`. Must run
// before the backward kernel (which takes the precomputed D, not O).
pub fn launch_flash_attention_rowdot_f32(
d_o: *const f32,
o: *const f32,
d_d: *mut f32,
rows: i32,
hd: i32,
s: CudaStream,
);
// Backward: recomputes scores from Q/K/V + saved logsumexp `l` (NO cached probs)
// and the precomputed `d_d` (= D), produces dq/dk/dv. dq/dk/dv must be PRE-ZEROED
// (dk/dv are accumulated across query rows via atomicAdd).
#[allow(clippy::too_many_arguments)]
pub fn launch_flash_attention_bwd_f32(
q: *const f32,
k: *const f32,
v: *const f32,
d_o: *const f32,
l: *const f32,
d_d: *mut f32,
dq: *mut f32,
dk: *mut f32,
dv: *mut f32,
bh: i32,
seq: i32,
hd: i32,
scale: f32,
s: CudaStream,
);
}
// GPU-side optimizer kernels (csrc/ops/optim.cu): AdamW step (m/v on device) and // GPU-side optimizer kernels (csrc/ops/optim.cu): AdamW step (m/v on device) and
// the global grad-norm reduction + in-place rescale (Phase T7). // the global grad-norm reduction + in-place rescale (Phase T7).
#[cfg(not(no_cuda))] #[cfg(not(no_cuda))]

View File

@@ -89,6 +89,9 @@ fn main() {
// rank checkpoints its own forward/backward; exact grads, lower peak activation // rank checkpoints its own forward/backward; exact grads, lower peak activation
// memory (lets dim1024 batch32 fit). Opt-in; default off. // memory (lets dim1024 batch32 fit). Opt-in; default off.
let recompute = args.iter().any(|a| a == "--recompute"); let recompute = args.iter().any(|a| a == "--recompute");
// Fused flash-attention (Phase T14): single fused SDPA kernel, online softmax,
// no materialized [bh,S,S] scores. Opt-in; default off keeps the composed path.
let flash = args.iter().any(|a| a == "--flash");
let ckpt: Option<PathBuf> = args let ckpt: Option<PathBuf> = args
.iter() .iter()
.position(|a| a == "--ckpt") .position(|a| a == "--ckpt")
@@ -174,6 +177,9 @@ fn main() {
if recompute { if recompute {
println!("activation recompute: ON (per-block gradient checkpointing)"); println!("activation recompute: ON (per-block gradient checkpointing)");
} }
if flash {
println!("flash-attention: ON (fused SDPA kernel, no materialized scores)");
}
let results = launch( let results = launch(
&devices, &devices,
&train_corpus, &train_corpus,
@@ -187,6 +193,9 @@ fn main() {
if recompute { if recompute {
m = m.with_recompute(true); m = m.with_recompute(true);
} }
if flash {
m = m.with_flash(true);
}
m m
}, },
); );

View File

@@ -47,6 +47,14 @@ pub struct TinyTransformer {
/// existing numerics are bit-identical; recompute is mathematically exact, so /// existing numerics are bit-identical; recompute is mathematically exact, so
/// grads match the non-checkpointed path within fp tolerance. /// grads match the non-checkpointed path within fp tolerance.
recompute: bool, recompute: bool,
/// Fused flash-attention (Phase T14). When `true`, the SDPA core runs through
/// the hand-written single fused kernel ([`ops::flash_attention`]): online
/// softmax over KV tiles, the `[bh,seq,seq]` score matrix NEVER materialized,
/// backward caches only the O(N) logsumexp. Default `false` = the composed T10
/// path (`cublasSgemmStridedBatched`×2 + causal-softmax kernel, O(N²) probs),
/// so the default graph is unchanged. Mathematically the same SDPA → grads/loss
/// match the composed path within fp/bf16 tolerance. Opt-in via `--flash`.
use_flash: bool,
} }
impl TinyTransformer { impl TinyTransformer {
@@ -90,6 +98,7 @@ impl TinyTransformer {
lm_head, lm_head,
compute_dtype: DType::F32, compute_dtype: DType::F32,
recompute: false, recompute: false,
use_flash: false,
} }
} }
@@ -127,6 +136,19 @@ impl TinyTransformer {
self.recompute self.recompute
} }
/// Enable the fused flash-attention SDPA core (Phase T14). Builder-style and
/// opt-in; default off keeps the composed T10 path (so the default graph is
/// unchanged). On, the SDPA runs through [`ops::flash_attention`] — same SDPA
/// math, online softmax, no materialized `[bh,seq,seq]` scores.
pub fn with_flash(mut self, use_flash: bool) -> Self {
self.use_flash = use_flash;
self
}
pub fn use_flash(&self) -> bool {
self.use_flash
}
/// All learnable parameters, in a stable order. The optimizer (a hand-written /// All learnable parameters, in a stable order. The optimizer (a hand-written
/// GD step in T5, AdamW in T6) iterates this; each holds its `.grad()` after /// GD step in T5, AdamW in T6) iterates this; each holds its `.grad()` after
/// `backward()`. /// `backward()`.
@@ -189,13 +211,16 @@ impl TinyTransformer {
// the block forward is re-run in backward to recover the grads. The // the block forward is re-run in backward to recover the grads. The
// segment fn captures only `Copy` config (no borrow of `self`) and // segment fn captures only `Copy` config (no borrow of `self`) and
// receives the block's params via the slice, in `block_params` order. // receives the block's params via the slice, in `block_params` order.
let (cfg, cdt) = (self.cfg, self.compute_dtype); // `flash` is captured too → the recompute segment also runs flash.
let seg = move |x: &Var, p: &[Var]| block_forward(cfg, cdt, batch, seq, x, p); let (cfg, cdt, flash) = (self.cfg, self.compute_dtype, self.use_flash);
let seg =
move |x: &Var, p: &[Var]| block_forward(cfg, cdt, flash, batch, seq, x, p);
xtrain_autodiff::checkpoint::checkpoint(seg, &h, &b.block_params()) xtrain_autodiff::checkpoint::checkpoint(seg, &h, &b.block_params())
} else { } else {
block_forward( block_forward(
self.cfg, self.cfg,
self.compute_dtype, self.compute_dtype,
self.use_flash,
batch, batch,
seq, seq,
&h, &h,
@@ -279,7 +304,16 @@ fn norm_gamma(cdt: DType, gamma: &Var) -> Var {
/// seq, input, params)` (no `&self`) so it can be the segment fn of /// seq, input, params)` (no `&self`) so it can be the segment fn of
/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13). `params` is /// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13). `params` is
/// the block's leaves in [`Block::block_params`] order. /// the block's leaves in [`Block::block_params`] order.
fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p: &[Var]) -> Var { #[allow(clippy::too_many_arguments)]
fn block_forward(
cfg: Config,
cdt: DType,
flash: bool,
batch: usize,
seq: usize,
h: &Var,
p: &[Var],
) -> Var {
let (attn_norm, wq, wk, wv) = (&p[0], &p[1], &p[2], &p[3]); let (attn_norm, wq, wk, wv) = (&p[0], &p[1], &p[2], &p[3]);
let (q_norm, k_norm, wo) = (&p[4], &p[5], &p[6]); let (q_norm, k_norm, wo) = (&p[4], &p[5], &p[6]);
let (ffn_norm, w_gate, w_up, w_down) = (&p[7], &p[8], &p[9], &p[10]); let (ffn_norm, w_gate, w_up, w_down) = (&p[7], &p[8], &p[9], &p[10]);
@@ -287,7 +321,7 @@ fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p:
// --- Attention sub-block (pre-norm + residual) --- // --- Attention sub-block (pre-norm + residual) ---
let normed = ops::rms_norm(h, &norm_gamma(cdt, attn_norm), cfg.eps); let normed = ops::rms_norm(h, &norm_gamma(cdt, attn_norm), cfg.eps);
let attn = attention( let attn = attention(
cfg, cdt, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo, cfg, cdt, flash, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo,
); );
let h = ops::add(h, &attn); let h = ops::add(h, &attn);
@@ -308,6 +342,7 @@ fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p:
fn attention( fn attention(
cfg: Config, cfg: Config,
cdt: DType, cdt: DType,
flash: bool,
batch: usize, batch: usize,
seq: usize, seq: usize,
x: &Var, x: &Var,
@@ -351,9 +386,14 @@ fn attention(
let k = to_bh(linear(cdt, x, wk), Some(k_norm)); let k = to_bh(linear(cdt, x, wk), Some(k_norm));
let v = to_bh(linear(cdt, x, wv), None); let v = to_bh(linear(cdt, x, wv), None);
// Fused batched causal SDPA over all B*nh (sequence,head) blocks at once // Causal SDPA over all B*nh (sequence,head) blocks. `flash` (T14) picks the
// (2 batched GEMMs + 1 causal-softmax kernel; no per-head/per-seq loop). // single fused flash kernel (online softmax, no materialized [bh,S,S] scores);
let out = ops::attention(&q, &k, &v, scale); // [B*nh, S, hd] // otherwise the composed T10 path (2 batched GEMMs + 1 causal-softmax kernel).
let out = if flash {
ops::flash_attention(&q, &k, &v, scale) // [B*nh, S, hd]
} else {
ops::attention(&q, &k, &v, scale) // [B*nh, S, hd]
};
// Back to [B*S, dim]: [B*nh,S,hd] → [B,nh,S,hd] → transpose(1,2) → // Back to [B*S, dim]: [B*nh,S,hd] → [B,nh,S,hd] → transpose(1,2) →
// [B,S,nh,hd] → [B*S, dim]. // [B,S,nh,hd] → [B*S, dim].

View File

@@ -0,0 +1,209 @@
// T14 flash-attention correctness gate: the fused flash SDPA core must match the
// composed T10 path (cublasSgemmStridedBatched×2 + causal-softmax kernel) in
// forward logits, loss, AND every parameter gradient — flash is the SAME SDPA
// math (online softmax never materializes the [bh,S,S] scores), so it differs
// from composed only by reduction order (in-kernel fp32 FMA vs cuBLAS, and the
// dK/dV atomicAdd order in backward). This test makes that a closed on-GPU loop:
//
// build two identical models (same init), one with `--flash` on, one off, run
// the SAME batched loss + backward on both, and assert
// 1. the forward logits match within tolerance
// 2. the loss matches
// 3. EVERY parameter's grad matches within tolerance
//
// Parameterised over fp32 AND bf16 (T12). bf16 just adds the bf16 rounding band on
// top — flash's bf16 path upcasts Q/K/V to fp32 for the kernel exactly like the
// composed path's fp32 softmax, so the two are still the same softmax numerics.
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
use xtrain_tensor::{DType, Device};
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device, dtype: DType, flash: bool) -> TinyTransformer {
let mut seed = 1u64;
let m = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
});
m.with_compute_dtype(dtype).with_flash(flash)
}
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
t.to_dtype(DType::F32)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
}
// fp32: same SDPA math, differs only by reduction order → tight per-element check.
fn run_fp32(logit_tol: f32, grad_tol: f32) {
let (off_logits, off_loss, off_grads, on_logits, on_loss, on_grads) = run_both(DType::F32);
let logit_rel = off_logits
.iter()
.zip(&on_logits)
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
.fold(0.0f32, f32::max);
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
println!(
"[F32] flash on/off: loss {off_loss:.6}/{on_loss:.6} (rel {loss_rel:.2e}), \
logits max rel {logit_rel:.2e}"
);
assert!(
logit_rel < logit_tol,
"[F32] logits diverged: {logit_rel:.2e}"
);
assert!(loss_rel < logit_tol, "[F32] loss diverged: {loss_rel:.2e}");
let mut max_grad_rel = 0.0f32;
for (off_g, on_g) in off_grads.iter().zip(&on_grads) {
for (a, b) in off_g.iter().zip(on_g) {
max_grad_rel = max_grad_rel.max((a - b).abs() / a.abs().max(1e-3));
}
}
println!("[F32] flash on/off: grad max rel err = {max_grad_rel:.3e}");
assert!(
max_grad_rel < grad_tol,
"[F32] flash grads diverged from composed: {max_grad_rel:.3e}"
);
}
// bf16: ~2-3 decimal digits → robust comparison (mean + p99 with abs().max(1.0)
// for logits, per-tensor scale-relative mean for grads), the same convention as
// the repo's bf16.rs gate (per-element max-rel blows up on near-zero bf16 logits).
fn run_bf16() {
let (off_logits, off_loss, off_grads, on_logits, on_loss, on_grads) = run_both(DType::BF16);
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
println!("[BF16] flash on/off: loss {off_loss:.5}/{on_loss:.5} (rel {loss_rel:.3e})");
assert!(loss_rel < 2e-2, "[BF16] loss diverged: {loss_rel:.3e}");
let n = off_logits.len();
let mut rels: Vec<f32> = off_logits
.iter()
.zip(&on_logits)
.map(|(f, b)| (b - f).abs() / f.abs().max(1.0))
.collect();
rels.sort_by(|a, b| a.partial_cmp(b).unwrap());
let p99 = rels[(n as f32 * 0.99) as usize];
let mean: f32 = rels.iter().sum::<f32>() / n as f32;
println!("[BF16] flash on/off logits: mean rel {mean:.3e}, p99 rel {p99:.3e}");
assert!(mean < 1e-2, "[BF16] logits mean rel too high: {mean:.3e}");
assert!(p99 < 5e-2, "[BF16] logits p99 rel too high: {p99:.3e}");
let mut worst = 0.0f32;
for (off_g, on_g) in off_grads.iter().zip(&on_grads) {
let scale = off_g
.iter()
.map(|v| v.abs())
.fold(0.0f32, f32::max)
.max(1e-6);
let mean_err: f32 = off_g
.iter()
.zip(on_g)
.map(|(f, b)| (f - b).abs())
.sum::<f32>()
/ off_g.len() as f32
/ scale;
worst = worst.max(mean_err);
}
println!("[BF16] flash on/off grads: worst per-tensor scaled-mean err = {worst:.3e}");
assert!(worst < 3e-2, "[BF16] flash grads diverged: {worst:.3e}");
}
#[allow(clippy::type_complexity)]
fn run_both(dtype: DType) -> (Vec<f32>, f32, Vec<Vec<f32>>, Vec<f32>, f32, Vec<Vec<f32>>) {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
// seq=40 > FA_TILE=32 so the online-softmax tile-rescale path is exercised.
let mut cfg = Config::tiny();
cfg.vocab = 16;
cfg.n_layers = 4;
let batch = 3usize;
let seq = 40usize;
let seqs: Vec<Vec<i32>> = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
.collect()
})
.collect();
let tgts: Vec<Vec<i32>> = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
.collect()
})
.collect();
let ids = batched_ids_tensor(&seqs, device);
let tgt = batched_ids_tensor(&tgts, device);
// --- flash OFF (composed reference) ---
let off = build(cfg, device, dtype, false);
let off_logits = host(&off.forward_batched(&ids, batch).value());
let off_loss = off.loss_batched(&ids, &tgt, batch);
let off_loss_val = host(&off_loss.value())[0];
off_loss.backward();
let off_grads: Vec<Vec<f32>> = off
.params()
.iter()
.map(|p| host(&p.grad().expect("off grad")))
.collect();
// --- flash ON ---
let on = build(cfg, device, dtype, true);
let on_logits = host(&on.forward_batched(&ids, batch).value());
let on_loss = on.loss_batched(&ids, &tgt, batch);
let on_loss_val = host(&on_loss.value())[0];
on_loss.backward();
let on_grads: Vec<Vec<f32>> = on
.params()
.iter()
.map(|p| host(&p.grad().expect("on grad")))
.collect();
(
off_logits,
off_loss_val,
off_grads,
on_logits,
on_loss_val,
on_grads,
)
}
#[test]
fn flash_matches_composed_fp32() {
// fp32: same SDPA math, differs only by reduction order (in-kernel fp32 FMA vs
// cuBLAS, dK/dV atomicAdd order). Tight per-element check, not bit-exact.
run_fp32(1e-3, 2e-2);
}
#[test]
fn flash_matches_composed_bf16() {
// bf16 (T12 composition): bf16 rounding band on the fp32-softmax core; robust
// (mean/p99/scaled-mean) comparison per the repo's bf16 convention.
run_bf16();
}

View File

@@ -67,7 +67,7 @@ fn dump_for_parity() {
// Same deterministic init as the overfit test. // Same deterministic init as the overfit test.
let mut seed = 1u64; let mut seed = 1u64;
let model = TinyTransformer::new(cfg, device, |shape| { let mut model = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1); seed = seed.wrapping_add(1);
let n: usize = shape.iter().product(); let n: usize = shape.iter().product();
if shape.len() == 1 { if shape.len() == 1 {
@@ -76,6 +76,14 @@ fn dump_for_parity() {
fill(n, seed, 0.08) fill(n, seed, 0.08)
} }
}); });
// T14: with XTRAIN_PARITY_FLASH set, dump from the fused flash-attention path.
// flash is the SAME SDPA math, so the SAME parity.py PyTorch oracle is the
// reference for both paths — running this once per path checks flash against
// PyTorch at B>1 (forward logits + every parameter grad).
if std::env::var("XTRAIN_PARITY_FLASH").is_ok() {
model = model.with_flash(true);
println!("parity: FLASH attention path");
}
// config + ids // config + ids
{ {

View File

@@ -1092,6 +1092,119 @@ impl Tensor {
(dq, dk, dv) (dq, dk, dv)
} }
// --- Fused flash-attention (the T14 op) ---
/// Fused flash-attention forward (Phase T14). `self`=Q, `k`, `v` each
/// `[bh, seq, head_dim]`, contiguous on one GPU. Computes, per batch element,
/// `out = softmax(causal(Q·Kᵀ·scale))·V` in a SINGLE kernel that streams over
/// KV tiles with an online softmax — the `[bh,seq,seq]` score matrix is NEVER
/// materialized. Returns `(out, lse)` where `lse`:[bh,seq] (F32) is the per-row
/// logsumexp cached for backward (O(N), vs the composed path's O(N²) probs).
///
/// The fused kernel is fp32; for bf16 we upcast Q/K/V → f32 → kernel → downcast
/// `out` back to bf16 (same fp32-softmax policy as the composed [`attention`]),
/// so flash and composed produce the same softmax numerics. `lse` stays fp32.
#[cfg(not(no_cuda))]
pub fn flash_attention(&self, k: &Tensor, v: &Tensor, scale: f32) -> (Tensor, Tensor) {
assert_eq!(
self.ndim(),
3,
"flash_attention Q must be [bh,seq,head_dim]"
);
assert_eq!(self.shape(), k.shape(), "Q/K shape mismatch");
assert_eq!(self.shape(), v.shape(), "Q/V shape mismatch");
assert_eq!(self.dtype, k.dtype, "Q/K dtype mismatch");
assert_eq!(self.dtype, v.dtype, "Q/V dtype mismatch");
let (bh, seq, hd) = (self.shape[0], self.shape[1], self.shape[2]);
let dev = self.device();
let dt = self.dtype;
let qf = self.to_dtype(DType::F32);
let kf = k.to_dtype(DType::F32);
let vf = v.to_dtype(DType::F32);
let out_f32 = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
let lse = Tensor::zeros(&[bh, seq], DType::F32, dev);
unsafe {
xtrain_cuda::ffi::launch_flash_attention_fwd_f32(
qf.data_ptr() as *const f32,
kf.data_ptr() as *const f32,
vf.data_ptr() as *const f32,
out_f32.data_ptr() as *mut f32,
lse.data_ptr() as *mut f32,
bh as i32,
seq as i32,
hd as i32,
scale,
std::ptr::null_mut(),
);
}
(out_f32.to_dtype(dt), lse)
}
/// Backward of [`flash_attention`](Self::flash_attention). Inputs: forward
/// `q`,`k`,`v`, the forward output `out`, the cached `lse`:[bh,seq], the upstream
/// `dout`, and the same `scale`. Returns `(dq, dk, dv)`.
///
/// flash-style: NO cached probs. Recomputes scores from Q/K/V + `lse`, uses
/// `D[i]=Σ dOᵢ·Oᵢ` to collapse the softmax Jacobian, streams KV in tiles. dQ is
/// owned per query row; dK/dV are accumulated across rows (atomicAdd). Same
/// fp32 kernel; bf16 callers get fp32 grads which the autograd `cast` op casts.
#[cfg(not(no_cuda))]
pub fn flash_attention_backward(
q: &Tensor,
k: &Tensor,
v: &Tensor,
out: &Tensor,
lse: &Tensor,
dout: &Tensor,
scale: f32,
) -> (Tensor, Tensor, Tensor) {
let (bh, seq, hd) = (q.shape[0], q.shape[1], q.shape[2]);
let dev = q.device();
let dt = q.dtype;
let qf = q.to_dtype(DType::F32);
let kf = k.to_dtype(DType::F32);
let vf = v.to_dtype(DType::F32);
let of = out.to_dtype(DType::F32);
let dof = dout.to_dtype(DType::F32);
// D[i] = Σ_d dO[i,d]·O[i,d] (one scalar per query row, O(N)).
let d = Tensor::zeros(&[bh, seq], DType::F32, dev);
unsafe {
xtrain_cuda::ffi::launch_flash_attention_rowdot_f32(
dof.data_ptr() as *const f32,
of.data_ptr() as *const f32,
d.data_ptr() as *mut f32,
(bh * seq) as i32,
hd as i32,
std::ptr::null_mut(),
);
}
// dq/dk/dv pre-zeroed (Tensor::zeros memsets); dk/dv accumulate via atomicAdd.
let dq = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
let dk = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
let dv = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
unsafe {
xtrain_cuda::ffi::launch_flash_attention_bwd_f32(
qf.data_ptr() as *const f32,
kf.data_ptr() as *const f32,
vf.data_ptr() as *const f32,
dof.data_ptr() as *const f32,
lse.data_ptr() as *const f32,
d.data_ptr() as *mut f32,
dq.data_ptr() as *mut f32,
dk.data_ptr() as *mut f32,
dv.data_ptr() as *mut f32,
bh as i32,
seq as i32,
hd as i32,
scale,
std::ptr::null_mut(),
);
}
(dq.to_dtype(dt), dk.to_dtype(dt), dv.to_dtype(dt))
}
/// 4D axis-(1,2) transpose: `self`:[a,b,c,d] → [a,c,b,d], /// 4D axis-(1,2) transpose: `self`:[a,b,c,d] → [a,c,b,d],
/// `out[i,k,j,l]=self[i,j,k,l]`. Lays out batched multi-head attention /// `out[i,k,j,l]=self[i,j,k,l]`. Lays out batched multi-head attention
/// (`[B,S,nh,hd] <-> [B,nh,S,hd]`). Its own backward is the same op (swap b,c). /// (`[B,S,nh,hd] <-> [B,nh,S,hd]`). Its own backward is the same op (swap b,c).

View File

@@ -116,6 +116,9 @@ fn main() {
// exact grads, lower peak activation memory (lets dim1024 batch32 fit). Opt-in; // exact grads, lower peak activation memory (lets dim1024 batch32 fit). Opt-in;
// default off stores every activation (unchanged numerics). // default off stores every activation (unchanged numerics).
let recompute = args.iter().any(|a| a == "--recompute"); let recompute = args.iter().any(|a| a == "--recompute");
// Fused flash-attention (Phase T14): single fused SDPA kernel, online softmax,
// no materialized [bh,S,S] scores. Opt-in; default off keeps the composed path.
let flash = args.iter().any(|a| a == "--flash");
let ckpt: PathBuf = PathBuf::from( let ckpt: PathBuf = PathBuf::from(
args.iter() args.iter()
.position(|a| a == "--ckpt") .position(|a| a == "--ckpt")
@@ -183,6 +186,10 @@ fn main() {
model = model.with_recompute(true); model = model.with_recompute(true);
println!("activation recompute: ON (per-block gradient checkpointing)"); println!("activation recompute: ON (per-block gradient checkpointing)");
} }
if flash {
model = model.with_flash(true);
println!("flash-attention: ON (fused SDPA kernel, no materialized scores)");
}
// Eval-only mode: load a checkpoint and score it on the held-out val set, then // Eval-only mode: load a checkpoint and score it on the held-out val set, then
// exit. Used to put an EXISTING model (e.g. v0) and a new one on the same // exit. Used to put an EXISTING model (e.g. v0) and a new one on the same

281
csrc/ops/flash_attention.cu Normal file
View File

@@ -0,0 +1,281 @@
// Hand-written fused flash-attention (Phase T14).
//
// The T10 composed SDPA path is 3 launches that MATERIALIZE the [bh,S,S] score
// matrix: cublasSgemmStridedBatched (Q·Kᵀ) → causal-softmax kernel (writes the
// whole probs) → cublasSgemmStridedBatched (P·V), and backward caches that whole
// probs. flash-attention NEVER materializes N×N: a single fused kernel streams
// over KV tiles with an ONLINE softmax (running max/sum + rescaled V accumulator),
// so peak attention activation drops from O(S²) to O(S·hd) (= the output itself).
//
// Layout (matches the T10 op): Q/K/V/out are [bh, S, hd] row-major contiguous,
// bh = batch·n_heads. The query's position within its sequence is the row index
// within its [S,hd] block (so the flat row's qpos = (row % S) is automatic here —
// we index per (bh, row)). CAUSAL: a query at position i attends to keys j ≤ i.
// `scale` (= 1/sqrt(hd)) is folded into the logits before the max/exp.
//
// All F32, contiguous. (bf16 callers upcast Q/K/V → f32 on the Rust side and
// downcast the f32 out, mirroring the composed path's fp32 softmax policy, so the
// kernel only ever sees fp32.) Reduction helpers are inlined (self-contained file,
// matching the csrc/ layout).
//
// Parallelisation: grid = bh*S, one block per query row; blockDim.x threads
// cooperate. Forward keeps m (running max), l (running sum), acc[hd] (rescaled
// V accumulator) in shared memory, streams KV in tiles of BK. Backward recomputes
// scores from Q/K/V + the saved logsumexp L[bh,S] (NO cached probs), uses
// D[i]=Σ dOᵢ·Oᵢ to collapse the softmax Jacobian, and atomicAdds dK/dV (which are
// accumulated across query rows).
#include <math.h>
extern "C" {
__device__ __forceinline__ float fa_warp_sum(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v += __shfl_down_sync(0xffffffff, v, off);
return v;
}
__device__ __forceinline__ float fa_warp_max(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
return v;
}
__device__ __forceinline__ float fa_block_sum(float v) {
__shared__ float sh[32];
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = fa_warp_sum(v);
if (lane == 0) sh[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : 0.0f;
if (warp == 0) v = fa_warp_sum(v);
__shared__ float bc;
if (threadIdx.x == 0) bc = v;
__syncthreads();
return bc;
}
__device__ __forceinline__ float fa_block_max(float v) {
__shared__ float sh[32];
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = fa_warp_max(v);
if (lane == 0) sh[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : -INFINITY;
if (warp == 0) v = fa_warp_max(v);
__shared__ float bc;
if (threadIdx.x == 0) bc = v;
__syncthreads();
return bc;
}
#define FA_TILE 32 // KV tile width (columns streamed per step)
// One block per (bh-row, query-position). Computes out[bh, i, :] and L[bh, i] via
// an online softmax that streams the keys in tiles of FA_TILE — the [S,S] score
// row is never stored, only the per-tile partials flow through shared memory.
__global__ void flash_attn_fwd_k(const float* Q, const float* K, const float* V,
float* O, float* L, int seq, int hd, float scale) {
int row = blockIdx.x; // global query row over bh*S
int b = row / seq; // which (batch,head) block
int i = row % seq; // query position within the sequence (causal limit)
int t = threadIdx.x;
int nthreads = blockDim.x;
const float* q = Q + (size_t)row * hd;
const float* kb = K + (size_t)b * seq * hd; // this block's keys [seq,hd]
const float* vb = V + (size_t)b * seq * hd; // this block's values[seq,hd]
// Q row in shared memory (reused every tile); acc accumulator over hd.
extern __shared__ float smem[];
float* sq = smem; // [hd]
float* acc = smem + hd; // [hd]
for (int d = t; d < hd; d += nthreads) {
sq[d] = q[d];
acc[d] = 0.0f;
}
__shared__ float m_run, l_run;
if (t == 0) { m_run = -INFINITY; l_run = 0.0f; }
__syncthreads();
int valid = i + 1; // causal: attend to keys [0, i]
for (int j0 = 0; j0 < valid; j0 += FA_TILE) {
int tile = min(FA_TILE, valid - j0);
// Each thread computes whole logits for a strided subset of the tile's
// columns: s = scale * (q · k_j). hd is small (≤128) so the per-thread
// dot loop is cheap; this avoids a block-reduce per column.
__shared__ float s_tile[FA_TILE];
for (int c = t; c < tile; c += nthreads) {
const float* kj = kb + (size_t)(j0 + c) * hd;
float dot = 0.0f;
for (int d = 0; d < hd; ++d) dot += sq[d] * kj[d];
s_tile[c] = dot * scale;
}
__syncthreads();
// Tile max, then online rescale of (m, l, acc).
float tmax = -INFINITY;
for (int c = t; c < tile; c += nthreads) tmax = fmaxf(tmax, s_tile[c]);
tmax = fa_block_max(tmax);
__shared__ float m_new, corr;
if (t == 0) {
float mn = fmaxf(m_run, tmax);
corr = (m_run == -INFINITY) ? 0.0f : expf(m_run - mn); // rescale old state
m_new = mn;
}
__syncthreads();
// Overwrite s_tile with the softmax weights p = exp(s - m_new) ONCE per
// column (instead of recomputing expf inside the per-dim V loop, which
// would cost hd× the transcendentals). Sum them for l.
float lsum = 0.0f;
for (int c = t; c < tile; c += nthreads) {
float p = expf(s_tile[c] - m_new);
s_tile[c] = p;
lsum += p;
}
lsum = fa_block_sum(lsum);
// Rescale old accumulator + add this tile's p·V (p cached in s_tile).
// Each thread owns a strided subset of hd; loops over the tile columns.
for (int d = t; d < hd; d += nthreads) {
float a = acc[d] * corr;
for (int c = 0; c < tile; ++c)
a += s_tile[c] * vb[(size_t)(j0 + c) * hd + d];
acc[d] = a;
}
if (t == 0) {
l_run = l_run * corr + lsum;
m_run = m_new;
}
__syncthreads();
}
// out = acc / l ; L = m + log(l) (logsumexp, saved for backward).
float inv = 1.0f / l_run;
for (int d = t; d < hd; d += nthreads) O[(size_t)row * hd + d] = acc[d] * inv;
if (t == 0) L[row] = m_run + logf(l_run);
}
void launch_flash_attention_fwd_f32(const float* q, const float* k, const float* v,
float* o, float* l, int bh, int seq, int hd,
float scale, void* s) {
int blk = hd < 1024 ? hd : 1024;
if (blk < 32) blk = 32;
size_t shmem = (size_t)2 * hd * sizeof(float); // sq[hd] + acc[hd]
flash_attn_fwd_k<<<bh * seq, blk, shmem, (cudaStream_t)s>>>(q, k, v, o, l, seq, hd, scale);
}
// Per-row D[i] = Σ_d dO[i,d] · O[i,d]. One block per row (bh*S rows). Used to
// collapse the softmax Jacobian in backward (Σ_j P_ij dP_ij = dOᵢ·Oᵢ).
__global__ void flash_attn_rowdot_k(const float* dO, const float* O, float* D, int hd) {
int row = blockIdx.x;
int t = threadIdx.x;
const float* d = dO + (size_t)row * hd;
const float* o = O + (size_t)row * hd;
float v = 0.0f;
for (int c = t; c < hd; c += blockDim.x) v += d[c] * o[c];
v = fa_block_sum(v);
if (t == 0) D[row] = v;
}
// Backward: one block per query row i. Recomputes scores from Q/K/V + the saved
// logsumexp L (NO cached probs), streams KV in tiles. dQ accumulates locally (this
// row owns it). dK/dV are accumulated ACROSS query rows so they atomicAdd into the
// shared global buffers (pre-zeroed by the caller).
// p_ij = exp(Qᵢ·Kⱼ·scale - L[i]) ; dp_ij = dOᵢ·Vⱼ ;
// ds_ij = p_ij·(dp_ij - D[i])·scale
// dQᵢ += Σ_j ds_ij·Kⱼ ; dKⱼ += ds_ij·Qᵢ ; dVⱼ += p_ij·dOᵢ
__global__ void flash_attn_bwd_k(const float* Q, const float* K, const float* V,
const float* dO, const float* L, const float* D,
float* dQ, float* dK, float* dV,
int seq, int hd, float scale) {
int row = blockIdx.x;
int b = row / seq;
int i = row % seq;
int t = threadIdx.x;
int nthreads = blockDim.x;
const float* q = Q + (size_t)row * hd;
const float* doi = dO + (size_t)row * hd;
const float* kb = K + (size_t)b * seq * hd;
const float* vb = V + (size_t)b * seq * hd;
float* dkb = dK + (size_t)b * seq * hd;
float* dvb = dV + (size_t)b * seq * hd;
extern __shared__ float smem[];
float* sq = smem; // [hd] Qᵢ
float* sdo = smem + hd; // [hd] dOᵢ
float* dqa = smem + 2*hd; // [hd] dQᵢ accumulator
for (int d = t; d < hd; d += nthreads) {
sq[d] = q[d];
sdo[d] = doi[d];
dqa[d] = 0.0f;
}
__shared__ float Li, Di;
if (t == 0) { Li = L[row]; Di = D[row]; }
__syncthreads();
int valid = i + 1;
for (int j0 = 0; j0 < valid; j0 += FA_TILE) {
int tile = min(FA_TILE, valid - j0);
// Phase 1: per-column ds[c] and p[c] (the column owner does the dots).
__shared__ float s_ds[FA_TILE];
__shared__ float s_p[FA_TILE];
for (int c = t; c < tile; c += nthreads) {
const float* kj = kb + (size_t)(j0 + c) * hd;
const float* vj = vb + (size_t)(j0 + c) * hd;
float sdot = 0.0f, dpdot = 0.0f;
for (int d = 0; d < hd; ++d) {
sdot += sq[d] * kj[d];
dpdot += sdo[d] * vj[d];
}
float p = expf(sdot * scale - Li);
s_p[c] = p;
s_ds[c] = p * (dpdot - Di) * scale;
}
__syncthreads();
// Phase 2: dV_j += p·dOᵢ ; dK_j += ds·Qᵢ (accumulated across rows → atomic).
// Spread the tile×hd atomics over ALL threads (was serial in the column
// owner) — flatten (c,d) so every thread issues a balanced share.
for (int idx = t; idx < tile * hd; idx += nthreads) {
int c = idx / hd, d = idx % hd;
size_t off = (size_t)(j0 + c) * hd + d;
atomicAdd(&dvb[off], s_p[c] * sdo[d]);
atomicAdd(&dkb[off], s_ds[c] * sq[d]);
}
// dQᵢ += Σ_c ds[c] · K_{j0+c} (this row owns dQ — no atomic).
for (int d = t; d < hd; d += nthreads) {
float a = 0.0f;
for (int c = 0; c < tile; ++c)
a += s_ds[c] * kb[(size_t)(j0 + c) * hd + d];
dqa[d] += a;
}
__syncthreads();
}
for (int d = t; d < hd; d += nthreads) dQ[(size_t)row * hd + d] = dqa[d];
}
void launch_flash_attention_bwd_f32(const float* q, const float* k, const float* v,
const float* d_o, const float* l, float* d_d,
float* dq, float* dk, float* dv,
int bh, int seq, int hd, float scale, void* s) {
int blk = hd < 1024 ? hd : 1024;
if (blk < 32) blk = 32;
// d_d is the pre-computed D[i]=Σ dOᵢ·Oᵢ (the Rust wrapper runs rowdot first,
// since it holds the forward O). dq/dk/dv are pre-zeroed by the caller.
flash_attn_bwd_k<<<bh * seq, blk, (size_t)3 * hd * sizeof(float), (cudaStream_t)s>>>(
q, k, v, d_o, l, d_d, dq, dk, dv, seq, hd, scale);
}
// Standalone D = rowdot(dO, O) launcher (the Rust wrapper calls this before bwd).
void launch_flash_attention_rowdot_f32(const float* d_o, const float* o, float* d_d,
int rows, int hd, void* s) {
int blk = hd < 1024 ? hd : 1024;
if (blk < 32) blk = 32;
flash_attn_rowdot_k<<<rows, blk, 0, (cudaStream_t)s>>>(d_o, o, d_d, hd);
}
} // extern "C"

183
docs/13-flash-attention.md Normal file
View File

@@ -0,0 +1,183 @@
# Phase T14: 融合 Flash-Attention Kernel — Design Document
## Goal
T10 把 attention 批量化了,但它的 SDPA 走的是 **「物化 N×N scores」** 的组合路径:
`cublasSgemmStridedBatched`Q·Kᵀ→ 一个 causal-softmax kernel写出整张 probs
`cublasSgemmStridedBatched`P·V**3 次 launch + 一张 `[bh, S, S]` 的 scores/probs 张量**
常驻显存(反向还要缓存这张 probs。S 一大,这张 N×N 就成了激活显存与带宽的主导项。
T14 的目标:手写一个**单 kernel 的 fused flash-attention**——streaming / online softmax、**tiled
over KV**、**绝不物化 N×N**。前向一发 kernel 直接吐出 `out[bh,S,hd]`(外加 `O(N)` 的 logsumexp
反向一发 kernelflash 式:重算 scores + dQ/dK/dV同样不物化 N×N。接进 model + autograd 作
**opt-in `--flash`**,默认保留 T10 的 composed 路径以便 A/B。
**硬闸门是诚实正确性**:新 kernel 的 dQ/dK/dV finite-diff grad-check 过fwd/bwd 对现有 composed-SDPA
路径数值贴合(进 bf16 容差PyTorch SDPA 对拍 B>1峰值显存↓不物化 scores+ tok/s before/after 实测;
全回归套(含 xserv 闭环 md5开/关 flag 都绿——默认flag off图不变 → 不回归。
## 什么是 flash-attention
标准 attention 是 `O = softmax(causal(Q·Kᵀ/√d)) · V`,朴素实现把 `S[i,j] = Qᵢ·Kⱼ/√d` 整张
`[S,S]` 算出来、softmax、再乘 V——显存 `O(S²)`、HBM 读写 `O(S²)`
**flash-attention** 的洞察softmax 可以 **onlinestreaming** 地算。把 K/V 切成若干 **tile**,对一个
query 行 `i`,依次扫过 KV tile**running max `m` + running sum `l`** 维护 softmax 的归一化,并把
部分加权的 `V` 累加进一个 `[hd]` 的 accumulator `acc`,每来一个新 tile 就用「新旧 max 的差」对旧 `acc`/`l`
做 rescale。扫完所有 tile`out = acc / l`。**整张 `[S,S]` 从不落地**——只有 `[hd]` 的 acc 和两个标量
在寄存器/共享内存里流动。峰值激活从 `O(S²)` 降到 `O(S·hd)`(就是 O 本身)。
online softmax 的核心递推block `j` 的部分 logits 行 `s_j`,旧状态 `m, l, acc`
```text
m_new = max(m, max_k s_j[k])
p = exp(s_j - m_new) # 本 tile 的未归一化权重
l = l * exp(m - m_new) + sum(p) # 旧 sum 先 rescale再加本 tile
acc = acc * exp(m - m_new) + p · V_tile # 旧 acc 同样 rescale再加本 tile 贡献
m = m_new
# 扫完所有 tile
out = acc / l
L = m + log(l) # logsumefpO(N) 存给反向
```
**因果 mask 内联**query 全局位置 = `i % S`(沿用 T10 的 per-seq 复位约定KV 位置 `j` 满足
`j > i%S` 的列直接当 `-inf``p=0`。tile 整块在对角线之上可**直接 skip**causal 的天然稀疏,省一半算力)。
**反向flash 式,[Dao 2022] 的标准做法)**:不缓存 probs从 Q/K/V + 前向存的 `L[bh,S]` **重算** scores。
关键预计算 `D[i] = Σ_d dOᵢ[d]·Oᵢ[d]`(每 query 一个标量,`O(N)`),则对每个 `(i,j)`
```text
s_ij = Qᵢ·Kⱼ * scale # 重算 logit
p_ij = exp(s_ij - L[i]) # 重算 softmax 权重L 是前向存的 logsumexp
dp_ij = dOᵢ · Vⱼ # 对 P 的梯度
ds_ij = p_ij * (dp_ij - D[i]) * scale # softmax 雅可比,化简掉了显式 N×N
dQᵢ += ds_ij * Kⱼ ; dKⱼ += ds_ij * Qᵢ ; dVⱼ += p_ij * dOᵢ
```
`ds = P ∘ (dP - D)` 是 softmax 反向用 `Σⱼ Pⱼ·dPⱼ = D`(因为 `D[i]=Σ dOᵢ·Oᵢ = Σⱼ Pᵢⱼ dPᵢⱼ`)化简的结果,
**不需要 N×N 的 softmax 雅可比矩阵**。同样 tiled、同样不物化 N×N。
## Module Layoutsurgicalcomposed 路径逐字节不动flash 全程新增并行路径)
```
csrc/ops/flash_attention.cu # 新fwd kernelonline softmaxtiled KV+ bwd kernel重算 + dQ/dK/dV
crates/xtrain-cuda/
├── src/ffi.rs # +launch_flash_attention_fwd_f32 / _bwd_f32 声明
└── build.rs # +flash_attention.cu
crates/xtrain-tensor/src/tensor.rs # +Tensor::flash_attention / flash_attention_backwardfwd 存 logsumexp Lbf16 upcast→f32 kernel→downcast
crates/xtrain-autodiff/
├── src/ops.rs # +ops::flash_attention 节点(前向调 fwd缓存 L反向调 bwd
└── tests/autograd.rs # +flash_attention(batched) dQ/dK/dV grad-check
crates/xtrain-model/
├── src/model.rs # attention() 按 use_flash 选 ops::attention | ops::flash_attention+with_flash(bool) builderflash 标志透传 block_forwardrecompute 段内也走 flash
└── tests/flash.rs # 新flash == composedfwd logits + 每参数梯度),参数化 fp32/bf16
crates/xtrain-train/src/bin/train.rs # +--flash flag → model.with_flash(true)
crates/xtrain-distributed/src/bin/train_ddp.rs # +--flash flagDDP 路径)
crates/xtrain-model/tests/parity_dump.rs # PyTorch B>1 对拍跑两遍composed 与 flash共用 PyTorch oracle
```
## Key Design Decisions
### ① 一个 block 负责一行 query先做对再谈快
最直接、最易验证正确的并行划分:**`grid = bh * S`,每个 block 算一整行 query 的 `out[bh, i, :]`**。
block 内 `hd` 个线程hd ≤ 128正好一个 warp 多一点),共享 `m/l` 标量 + `acc[hd]`。block 顺序扫
KV tiletile 宽 `BK`,沿 `j` 维),每个 tile线程并行算 `BK` 个 logit点积 over hd 用 block-reduce
求 tile max、online-rescale `m/l/acc`、累加 `p·V`。扫完写 `out = acc/l``L[i] = m + log(l)`
**为什么先这样而不是 FA2 的 query-tile 划分**:本项目的硬闸门是**正确性 + 不物化 N×N + 显存↓**,不是
打榜峰值 FLOPs。一行一 block 的版本:(a) online softmax 与 N×N skip 已经完全落地(显存与带宽收益拿到),
(b) 代码直白、逐 query 行可对拍,正确性风险最低。它**不会**比 cuBLAS 两发 GEMM 更快cuBLAS tensor-core
吃满),所以 tok/s 上 flash 在我们这种 `hd=32` 小头维下大概率**持平或略慢**——这正是 flash 的已知权衡
flash 的胜场是**显存**,不是小模型的 wall-clock。把这点诚实写进 perf 表,不掩饰。
### ② 前向只存 `L[bh,S]`logsumefp不存 probs
composed 路径反向要缓存整张 `probs[bh,S,S]``O(N²)`。flash 反向**只需要前向的 logsumexp
`L[i]=m_i+log(l_i)`**(每 query 一个 fp32`O(N)`)即可重算任意 `p_ij = exp(Qᵢ·Kⱼ·scale - L[i])`
所以 fwd kernel 顺手把 `L` 写出来autograd 节点缓存它(外加 Q/K/V/O parents 本就在)。**这就是显存闸门的来源**
attention 的反向缓存从 `[bh,S,S]` 砍到 `[bh,S]`
### ③ 反向用 `D[i]=Σ dOᵢ·Oᵢ` 化简 softmax 雅可比
softmax 反向通项 `ds_ij = p_ij·(dp_ij - Σ_k p_ik·dp_ik)`。注意 `Σ_k p_ik·dp_ik = Σ_k p_ik (dOᵢ·V_k)
= dOᵢ·(Σ_k p_ik V_k) = dOᵢ·Oᵢ = D[i]`。所以一趟先算 `D[bh,S]`(每行 `dO·O` 的点积,`O(N)`),反向
扫 KV tile 时直接 `ds = p·(dp - D)·scale`**不需要再算或物化整行的 `Σ p·dp`**。
dQ/dK/dV 三者dQ 由「该 query 行」累加block 私有无竞争dK/dV 跨 query 行累加同一个 `(j)`
→ 用 `atomicAdd` 到全局 dK/dVfp32 原子加,确定 race-free
### ④ bf16kernel 内 fp32边界 cast与 composed 路径一致的数值策略)
T10/T12 的 composed attention 对 bf16 也是 **softmax 用 fp32**scores 升 f32 → kernel → probs 降回 bf16
flash 沿用同策略最省心且数值最稳bf16 模式下 `flash_attention` 把 Q/K/V `to_dtype(F32)` 喂给 fp32 kernel
`out``to_dtype(BF16)`反向同理。kernel 本身只有一份 fp32 实现。这样 flash 的 bf16 数值与 composed 的
bf16 数值是**同一套 fp32 softmax 算的**,只差 GEMM roundingcuBLAS tensor-core vs kernel 内 fp32 FMA→ 落在
既有 bf16 容差内。`L` 始终 fp32。
> 备选不采纳bf16 全程 in-kernel half。收益是少两次 cast但 (a) 引入与 composed 不同的 softmax 累加路径,
> 威胁 on-vs-off 贴合闸门;(b) 本规模 attention 非瓶颈。escape hatch先 fp32-core 把正确性钉死,纯 half flash 留 follow-up。
### ⑤ opt-in 透传:`use_flash` 是运行时旗标,不是架构
`use_flash` 不进 `Config`(它不改模型尺寸、不改导出、不该污染 `num_params`),而是 `TinyTransformer` 的一个
`bool` 字段 + `with_flash(bool)` builder对齐 `with_recompute` / `with_compute_dtype`)。`block_forward` 已经
`(cfg, cdt, …)` 的自由函数T13 为 recompute 抽的),给它加一个 `flash: bool` 形参model 的 `attention()`
据此选 `ops::attention`composed`ops::flash_attention`。recompute 闭包捕获 `flash``Copy`)→ **重算段内也走
flash**flash×recompute 组合天然成立。默认 `false` = composed 路径**逐字节不变**(硬闸门:默认图不变 → 不回归)。
## 验证方法
**硬闸门全绿dash5 实跑 capture**
### 1. 正确性
- **新 kernel dQ/dK/dV finite-diff grad-check**`xtrain-autodiff/tests/autograd.rs::flash_attention_batched_bwd`
与既有 `attention_batched_bwd` 同构(`L = sum(W∘out)`,中心差分),断 dQ/dK/dV 在 `cfg_nonlinear`/`cfg_linear` 容差内。
- **flash == composed**`xtrain-model/tests/flash.rs`):同 init 两个模型flash on/off同一 batched
loss + backward断**前向 logits / loss / 每参数梯度**在紧容差内一致;参数化 fp32近逐位与 bf16bf16 舍入级)。
- **PyTorch SDPA 对拍 B>1**`parity_dump.rs` + `parity.py`):等价 PyTorch 模型per-seq RoPE、per-seq causal、
QK-norm、SwiGLU对拍 forward logits + 全部参数梯度——**composed 与 flash 两条都跑**,共用同一 PyTorch oracle。
- **全回归套开/关 `--flash`**autograd 15、structural、batched==looped、bf16、recompute逐位、overfit 27/27、
AdamWGPU bit-exact + host 对 torch、DDP loss-match + 跨 rank、**xserv 闭环(导出 safetensors → md5 对 registry →
xserv 贪心逐 token 一致)**。flag off 默认图不变 → composed 数值不回归。
### 2. 显存payoff—— 不物化 N×N 的直接收益
dash5 1× RTX 5090同 confignvidia-smi 峰值flash off vs onattention 反向缓存 `[bh,S,S]→[bh,S]`
峰值显存应↓(尤其 seq 大时。capture 实际数字进表。
### 3. 吞吐
同 config steady-state tok/s flash off vs on。预期本规模 `hd=32` 下 flash kernel **持平或略慢于** cuBLAS 双
GEMM小头维喂不满 tensor-core 是 flash 的已知权衡,胜场在显存)——诚实报告,不为绿而调。
## 实测结果dash5 1× RTX 5090
**正确性(硬闸门全绿):**
| 闸门 | 结果 |
|---|---|
| ① 新 kernel dQ/dK/dV finite-diff grad-check | **过** — dQ 9.3e-3 / dK 1.7e-2 / dV 5.6e-4单 tile 干净区;多 tile 由②兜) |
| flash fwd 对 composed | max rel **6.7e-5** |
| flash bwd 对(已 grad-check 的composed bwd | dQ **1.7e-5** / dK 1.2e-5 / dV 4.3e-5 |
| ② flash==composedmodel 级logits/loss/每参数梯度) | fp32: loss rel **0.0**、logits 1.7e-4、grad 4.4e-5bf16: loss 1.5e-4、logits mean 1.6e-3/p99 5.9e-3、grad scaled-mean 1.2e-2 |
| ③ PyTorch SDPA 对拍 B>1flash 路径,共用 composed oracle | loss relerr **4.98e-8**、logits **7.92e-6**、25 参数 grad 全进 rtol 0.02 |
| ⑤ 回归套flag off 默认 + flash 路径都测autograd 18 / structural 5 / batched / bf16 / **flash 3** / overfit 27/27 / recompute 2 / AdamW(GPU+host) / GEMM / DDP 2 / checkpoint-roundtrip | **全绿** |
| ⑤ xserv 闭环 md5v3 ckpt 用 T14 代码重导 safetensors | **逐位一致** `b04fc9f9a0c9af04c47d9ca649aea12e`(与 registry 同)→ 默认 export 零漂移 |
| ⑤ xserv 闭环flash 训练 → 导出 → xserv 服务贪心) | flash-训出 coherent TinyStoriesxserv(BF16) 对 xtrain(F32) 贪心3 prompt 中 "One day" 逐 token 一致,其余在 ~0.5% BF16 漂移处晚分叉(与 v1/v2/v3 同款) |
> **finite-diff 的诚实记录**:长 softmaxseq>tile会产生大量近零梯度元素中心差分在那些元素上不可靠出现伪 0.0 / 符号翻转——不是 backward bug。故 ① 的 finite-diff 跑**单 tile 干净区**seq=5对齐既有 composed grad-check 的良态区),**多 tile 的 streaming/online 路径**用「flash bwd 对已 grad-check 的 composed bwd」seq=40dQ 1.7e-5兜——比 finite-diff 更利。dQ/dK 用 eps=2e-3 压低 f32 舍入项(~4e-4 小梯度上舍入项压过截断项)。**没有为凑绿放宽容差**。
**④ 显存 + 吞吐payoff vs tradeoffdim768=8L/12h×64/ffn3072, bf16, steady-state**
| config | path | 峰值显存 | tok/s |
|---|---|---|---|
| batch8 seq1024 | composed (off) | 24670 MiB | **58.6K** |
| batch8 seq1024 | **flash (on)** | **20736 MiB16%** | 25.0K57%, ~2.3× 慢) |
| batch2 seq2048 | composed (off) | 17264 MiB | 36.7K |
| batch2 seq2048 | **flash (on)** | **13246 MiB23%** | 13.2K64% |
**显存按预期降**(不物化 `[bh,S,S]`),且**收益随 seq 增长**seq1024 16% → seq2048 23%O(S²) 砍掉)。
**tok/s 如设计 ① 预测的「持平或略慢」实为 ~2.32.8× 慢**hd=64 的小头维下,手写「一行一 block + 串行扫 KV」kernel 喂不满 SM干不过 cuBLAS tensor-core 的两发批量 GEMM——这正是 flash 的已知权衡(**胜场在显存,不是小模型 wall-clock**诚实报告不掩饰。两个落地的优化softmax 权重缓存进 shared 省 hd× 的 expfdK/dV 原子加摊到全 block 而非串行在列 owner 内)把 backward 从 6.8× 慢拉到 2.3× 慢——主瓶颈是 backward 的跨行原子累加FA2 用 K-block 拥有 dK/dV 的独立 pass 解,本版未做,留 follow-up
> **escape hatchfollow-up未做记给后续**:① FA2 式 query-tile 划分(一 block 多 query 行K/V 进 shared 复用)提 SM 占用;② backward 的 dK/dV 改 K-block-owned 独立 pass 消跨行原子;③ 纯 bf16 in-kernel省两次 cast。本规模 attention 非训练瓶颈、且会动数值贴合闸门,按 escape hatch 推迟——T14 先把**正确性 + 不物化 N×N + 显存↓**钉死。

View File

@@ -24,6 +24,7 @@
| T11 | Infra | **device caching/pool allocator**(复用 op 输出显存,消 per-step cudaMalloc | 单卡 2.3×**8卡 461K tok/s** 近线性(修 KI-5 | | T11 | Infra | **device caching/pool allocator**(复用 op 输出显存,消 per-step cudaMalloc | 单卡 2.3×**8卡 461K tok/s** 近线性(修 KI-5 |
| T12 | 算法/Infra | **bf16 混合精度**fp32 mastercuBLAS GemmExnorm/softmax/CE 保 fp32 | dim768 OOM 解除29% 显存/+13% tok/s修 KI-2 | | T12 | 算法/Infra | **bf16 混合精度**fp32 mastercuBLAS GemmExnorm/softmax/CE 保 fp32 | dim768 OOM 解除29% 显存/+13% tok/s修 KI-2 |
| T13 | 算法/Infra | **激活重计算**per-block gradient checkpointing前向 no-tape + 反向重算,`backward_seeded` | 梯度对非重计算版**逐位一致**(0.00)dim768 31.1→14.6GB**dim1024 batch32 OOM→16.6GB 装下**(修 KI-3解锁 v8 | | T13 | 算法/Infra | **激活重计算**per-block gradient checkpointing前向 no-tape + 反向重算,`backward_seeded` | 梯度对非重计算版**逐位一致**(0.00)dim768 31.1→14.6GB**dim1024 batch32 OOM→16.6GB 装下**(修 KI-3解锁 v8 |
| T14 | 算法/Infra | **融合 flash-attention kernel**(手写单 kernelonline softmax、tiled over KV、**不物化 N×N scores**flash 式 bwd重算 scores + `D=ΣdO·O` 化简雅可比 + dQ/dK/dVopt-in `--flash`,默认保 composedPhase 2 | fwd 对 composed 6.7e-5、bwd 对 composed dQ 1.7e-5、PyTorch B>1 7.9e-6、flash==composed loss rel 0.0**峰值显存 16%@seq1024 / 23%@seq2048**(不物化 N×N收益随 seq 增长tok/s ~2.32.8×hd=64 小头维干不过 cuBLAS tensor-coreflash 已知权衡=胜场在显存md5 闭环逐位一致 |
--- ---
@@ -49,9 +50,9 @@
## 三、各维度的累积演进(轴向看一条线怎么走的) ## 三、各维度的累积演进(轴向看一条线怎么走的)
- **算法**:手写 autograd(tape)+扇出累加 → AdamW/LR-sched/grad-clip → +QK-norm(Qwen3) → batched forward → bf16 混合精度(fp32 master) → 激活重计算(T13)。 - **算法**:手写 autograd(tape)+扇出累加 → AdamW/LR-sched/grad-clip → +QK-norm(Qwen3) → batched forward → bf16 混合精度(fp32 master) → 激活重计算(T13) → 融合 flash-attention(T14online softmax + flash 式 bwd)
- **模型架构**:固定 Qwen3-styledim **32→256→384→512→768→1024**v8 首拨容量轴,头数 24→32核心参数 **41K→226M**(总 3.26M→329M - **模型架构**:固定 Qwen3-styledim **32→256→384→512→768→1024**v8 首拨容量轴,头数 24→32核心参数 **41K→226M**(总 3.26M→329M
- **Infra**:单卡 fp32 → cuBLAS/GPU-optim(T7) → NCCL DDP(T8) → batched forward(T10) → caching allocator(T11) → bf16(T12) → 激活重计算(T13解锁 dim1024)。吞吐 **3.3K→217K tok/s**dim768 bf16dim1024+重算 ~129K重算税MFU **0.4%→17%**(每次提升都对应一块 perf 基建,详见 known-issues + MFU 分析)。 - **Infra**:单卡 fp32 → cuBLAS/GPU-optim(T7) → NCCL DDP(T8) → batched forward(T10) → caching allocator(T11) → bf16(T12) → 激活重计算(T13解锁 dim1024) → flash-attention(T14不物化 N×Nattention 显存收益随 seq 增长)。吞吐 **3.3K→217K tok/s**dim768 bf16dim1024+重算 ~129K重算税MFU **0.4%→17%**(每次提升都对应一块 perf 基建,详见 known-issues + MFU 分析)。
- **数据集**TinyStories 3MB 切片 → 全量 TinyStoriesepoch 0.01→5.33**至饱和**)→ **v6 毕业到 FineWeb-edu 真实网页**2.255B 语料1.02ep)→ **v7 同子集多 epoch1.45ep,近顶)→ v8 同子集换大模型**dim10241.05ep。tokenizer 全程 gpt2 BPE复用 xserv-tokenizerv6 刻意不换 tokenizer 以隔离「数据来源」变量KI-4 留后续版本)。 - **数据集**TinyStories 3MB 切片 → 全量 TinyStoriesepoch 0.01→5.33**至饱和**)→ **v6 毕业到 FineWeb-edu 真实网页**2.255B 语料1.02ep)→ **v7 同子集多 epoch1.45ep,近顶)→ v8 同子集换大模型**dim10241.05ep。tokenizer 全程 gpt2 BPE复用 xserv-tokenizerv6 刻意不换 tokenizer 以隔离「数据来源」变量KI-4 留后续版本)。
- **v5→v6 数据轴的质变**v0v5 都吃合成幼儿故事TinyStories低熵、词汇受控v5 证明同尺寸模型在它上面已饱和v6 第一版换成**真实教育类网页文本**FineWeb-edu语言种类发生质变——采样从「只会写小故事」变成「能写历史/科学/说明文」。 - **v5→v6 数据轴的质变**v0v5 都吃合成幼儿故事TinyStories低熵、词汇受控v5 证明同尺寸模型在它上面已饱和v6 第一版换成**真实教育类网页文本**FineWeb-edu语言种类发生质变——采样从「只会写小故事」变成「能写历史/科学/说明文」。
- ⚠️ **同子集多 epoch 也有天花板v6→v7**v6 的 FineWeb val 才训 1.02ep、末步仍单调降曾被读作「还没喂够」v7 把**同一 2.255B 子集**喂到 1.45ep(多 ~1B tokenFineWeb val 仅 ↓0.053.07→3.01)且 ~step44000 后走平、采样无质变 ⇒ **该子集在 dim768 已近天花板**。这与 v5 的 TinyStories 数据量饱和是**同一类现象****「重复喂老数据」边际都薄,无论是 v5 的同语料多 epoch 还是 v7 的同子集多 epoch**。真正抬天花板的是 v6「换更广的新语料」那一步——**杠杆在「更多样的新 token」不在「同数据多读几遍」**。后续要继续降 val必须补**新 FineWeb shards**(更多样、不重复),不是同子集加 epoch。 - ⚠️ **同子集多 epoch 也有天花板v6→v7**v6 的 FineWeb val 才训 1.02ep、末步仍单调降曾被读作「还没喂够」v7 把**同一 2.255B 子集**喂到 1.45ep(多 ~1B tokenFineWeb val 仅 ↓0.053.07→3.01)且 ~step44000 后走平、采样无质变 ⇒ **该子集在 dim768 已近天花板**。这与 v5 的 TinyStories 数据量饱和是**同一类现象****「重复喂老数据」边际都薄,无论是 v5 的同语料多 epoch 还是 v7 的同子集多 epoch**。真正抬天花板的是 v6「换更广的新语料」那一步——**杠杆在「更多样的新 token」不在「同数据多读几遍」**。后续要继续降 val必须补**新 FineWeb shards**(更多样、不重复),不是同子集加 epoch。