Compare commits
4 Commits
4b6d3e0a79
...
2ff4573a31
| Author | SHA1 | Date | |
|---|---|---|---|
| 2ff4573a31 | |||
| 39df0b40c1 | |||
| 830d06ad01 | |||
| 62b1cb5dc7 |
@@ -51,6 +51,7 @@ Each phase: design doc + implementation + tests + a scoped commit (see [`docs/`]
|
||||
| **T12** | **bf16 mixed precision** (fp32 master, fixes KI-2) | dim768 OOM solved; −29% mem |
|
||||
| **T13** | **activation recompute** / checkpointing (fixes KI-3) | dim1024 fits; grads bit-identical |
|
||||
| **T14** | **fused flash-attention** kernel (online softmax, no materialized N×N; opt-in `--flash`) | peak mem −16%@1k / −23%@2k seq; flash==composed (grads/PyTorch) |
|
||||
| **T15** | **grouped-query attention** (`num_kv_heads<num_heads`; `repeat_kv` broadcast feeds both SDPA paths; backward sums each kv head's group; `--kv-heads`) | repeat_kv grad-check + **group=1 bit-identical to MHA**; GQA flash==composed; PyTorch GQA B>1; **xserv closed loop with real `num_key_value_heads`** token-identical |
|
||||
| **T16** | **gradient accumulation** (`--accum-steps`; DDP all-reduces only at the boundary) | equiv to N× big batch (grad 3.8e-5); same effective-64 batch 27.7GB→7.2GB (−74%) |
|
||||
| **T18** | **dropout** (hand counter-based device RNG + mask, inverted scaling, train/eval switch) | fixed-seed grad-check; **p=0 bit-identical**; recompute-safe |
|
||||
|
||||
@@ -58,6 +59,9 @@ The four performance fixes (T10–T13) each removed a real bottleneck — see
|
||||
[`docs/known-issues.md`](docs/known-issues.md). **Phase 2 (systems-stack depth, T14–)**
|
||||
revisits hand-writing deferred training-stack features: T14 = the fused
|
||||
flash-attention kernel ([`docs/13-flash-attention.md`](docs/13-flash-attention.md));
|
||||
T15 = real grouped-query attention ([`docs/14-gqa.md`](docs/14-gqa.md), `num_kv_heads <
|
||||
num_heads` via a `repeat_kv` broadcast op whose backward sums each kv head's query-head
|
||||
group — feeding both SDPA paths unchanged, default MHA bit-identical);
|
||||
T16 = micro-batch gradient accumulation ([`docs/15-grad-accum.md`](docs/15-grad-accum.md)),
|
||||
which decouples the effective batch from activation memory (memory tracks the micro-batch,
|
||||
not N×); T18 = dropout ([`docs/17-dropout.md`](docs/17-dropout.md), hand counter-based
|
||||
@@ -145,5 +149,5 @@ cargo test --workspace # autograd grad-checks, PyTorch parity, DDP, e
|
||||
|
||||
- [`docs/evolution.md`](docs/evolution.md) — per-milestone changes across algorithm / architecture / infra / dataset.
|
||||
- [`docs/runs/README.md`](docs/runs/README.md) — the v0–v8 comparison; [`docs/runs/0N-*.md`](docs/runs/) — per-run detail.
|
||||
- [`docs/00-*` … `12-*`](docs/) — per-phase design docs (build chain → tensor → autograd → transformer → training → perf → distributed → export → batched → allocator → bf16 → recompute).
|
||||
- [`docs/00-*` … `14-*`](docs/) — per-phase design docs (build chain → tensor → autograd → transformer → training → perf → distributed → export → batched → allocator → bf16 → recompute → flash-attention → GQA).
|
||||
- [`docs/known-issues.md`](docs/known-issues.md) — perf backlog (KI-1/2/3/5 fixed; KI-4 + process-per-GPU open).
|
||||
|
||||
@@ -376,6 +376,27 @@ pub fn flash_attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
|
||||
)
|
||||
}
|
||||
|
||||
/// GQA repeat_kv head broadcast (Phase T15). `kv`:[batch·num_kv, seq, head_dim]
|
||||
/// (a K or V tensor) → `[batch·nh, seq, head_dim]`, each KV head broadcast to its
|
||||
/// `group = nh/num_kv` query heads (qh ← kv head qh/group, contiguous groups —
|
||||
/// matches xserv's repeat_kv). Feeds the unchanged composed/flash SDPA so GQA is
|
||||
/// "free" for both. Backward SUMS the `group` query heads sharing each KV head back
|
||||
/// onto it (the multi-group grad accumulation). `nh == num_kv` (group 1) is identity
|
||||
/// → bit-identical to the MHA path. `batch` lets the op recover num_kv from kv's bh.
|
||||
pub fn repeat_kv(kv: &Var, nh: usize, batch: usize) -> Var {
|
||||
let bh_kv = kv.value().shape()[0];
|
||||
let num_kv = bh_kv / batch;
|
||||
let out = kv.value().repeat_kv(nh, batch);
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![kv.clone()],
|
||||
Box::new(move |dout, parents| {
|
||||
let din = Tensor::repeat_kv_backward(dout, num_kv, batch);
|
||||
Var::push_grad(&parents[0], din);
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
|
||||
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
|
||||
/// scaled by the upstream scalar grad.
|
||||
|
||||
@@ -776,6 +776,94 @@ fn flash_bwd_matches_composed_bwd() {
|
||||
assert!(rv < 2e-2, "dV diverges: {rv:.3e}");
|
||||
}
|
||||
|
||||
// ---- GQA repeat_kv head broadcast (Phase T15) ----
|
||||
//
|
||||
// repeat_kv expands K/V from [batch·num_kv, seq, hd] to [batch·nh, seq, hd]; each
|
||||
// kv head is broadcast to its `group = nh/num_kv` query heads. The forward is a
|
||||
// gather (a linear map), so finite-diff is clean. The CRITICAL gate is the
|
||||
// BACKWARD: a kv head receives the SUM of the `group` query heads sharing it —
|
||||
// the multi-group-to-one grad accumulation GQA correctness hinges on. We grad-check
|
||||
// din against finite-diff of L = sum(W∘out) with group>1, plus assert the forward
|
||||
// actually broadcasts and that group==1 is exact identity.
|
||||
#[test]
|
||||
fn repeat_kv_grad() {
|
||||
require_gpu();
|
||||
// batch 2, num_kv 2 → bh_kv 4 input rows; nh 6 → group 3, bh_q 12 output rows.
|
||||
let (batch, num_kv, nh, seq, hd) = (2usize, 2usize, 6usize, 4usize, 5usize);
|
||||
let n_in = batch * num_kv * seq * hd;
|
||||
let n_out = batch * nh * seq * hd;
|
||||
let x_h = fill(n_in, 711);
|
||||
let w = fill(n_out, 712);
|
||||
|
||||
let kv = Var::leaf(cuda(&x_h, &[batch * num_kv, seq, hd]));
|
||||
let out = ops::repeat_kv(&kv, nh, batch);
|
||||
assert_eq!(out.value().shape(), &[batch * nh, seq, hd]);
|
||||
|
||||
// Forward sanity: out head (b·nh + qh) must equal in head (b·num_kv + qh/group).
|
||||
let group = nh / num_kv;
|
||||
let out_h = out
|
||||
.value()
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec();
|
||||
let row = seq * hd;
|
||||
for b in 0..batch {
|
||||
for qh in 0..nh {
|
||||
let kvh = qh / group;
|
||||
let o0 = (b * nh + qh) * row;
|
||||
let i0 = (b * num_kv + kvh) * row;
|
||||
for e in 0..row {
|
||||
assert_eq!(out_h[o0 + e], x_h[i0 + e], "repeat_kv fwd mismatch");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scalar_loss(&out, &w).backward();
|
||||
let din = kv.grad().unwrap().to_device(Device::Cpu);
|
||||
|
||||
let fwd = move |xh: &[f32], _s: &[usize]| -> f32 {
|
||||
let kv = cuda(xh, &[batch * num_kv, seq, hd]);
|
||||
let o = kv.repeat_kv(nh, batch);
|
||||
weighted_sum(&o, &w)
|
||||
};
|
||||
// repeat_kv is exactly linear (gather/sum), so the linear-op tolerances apply.
|
||||
report(
|
||||
"repeat_kv din",
|
||||
&grad_check(
|
||||
&x_h,
|
||||
&[batch * num_kv, seq, hd],
|
||||
&fwd,
|
||||
din.as_slice::<f32>(),
|
||||
cfg_linear(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// group==1 (num_kv == nh) must be a bit-exact identity in BOTH directions — this is
|
||||
// the regression guard that makes the MHA path (kv_heads == n_heads) unchanged.
|
||||
#[test]
|
||||
fn repeat_kv_identity_group1() {
|
||||
require_gpu();
|
||||
let (batch, nh, seq, hd) = (2usize, 3usize, 4usize, 5usize);
|
||||
let n = batch * nh * seq * hd;
|
||||
let x_h = fill(n, 721);
|
||||
let w = fill(n, 722);
|
||||
let kv = Var::leaf(cuda(&x_h, &[batch * nh, seq, hd]));
|
||||
let out = ops::repeat_kv(&kv, nh, batch); // group 1
|
||||
let out_h = out
|
||||
.value()
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec();
|
||||
assert_eq!(out_h, x_h, "group-1 repeat_kv fwd must be identity");
|
||||
scalar_loss(&out, &w).backward();
|
||||
let din = kv.grad().unwrap().to_device(Device::Cpu);
|
||||
// dL/din = w exactly (identity forward → grad passes through unchanged).
|
||||
for (g, expect) in din.as_slice::<f32>().iter().zip(&w) {
|
||||
assert_eq!(*g, *expect, "group-1 repeat_kv bwd must be identity");
|
||||
}
|
||||
}
|
||||
|
||||
// ---- dropout (Phase T18) ----
|
||||
//
|
||||
// Fixed-seed finite-diff grad-check. Under a fixed `seed` the mask is constant
|
||||
@@ -827,9 +915,17 @@ fn dropout_expectation_and_keep_rate() {
|
||||
let (out, mask) = x.dropout(p, 0x5EED_0000 + t as u64);
|
||||
let out_h = out.to_device(Device::Cpu);
|
||||
let mask_h = mask.to_device(Device::Cpu);
|
||||
let mean_out: f64 =
|
||||
out_h.as_slice::<f32>().iter().map(|&v| v as f64).sum::<f64>() / n as f64;
|
||||
let kept = mask_h.as_slice::<f32>().iter().filter(|&&m| m != 0.0).count();
|
||||
let mean_out: f64 = out_h
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.map(|&v| v as f64)
|
||||
.sum::<f64>()
|
||||
/ n as f64;
|
||||
let kept = mask_h
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.filter(|&&m| m != 0.0)
|
||||
.count();
|
||||
mean_out_acc += mean_out;
|
||||
keep_acc += kept as f64 / n as f64;
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ fn main() {
|
||||
.file("../../csrc/ops/optim.cu")
|
||||
.file("../../csrc/ops/attention.cu")
|
||||
.file("../../csrc/ops/flash_attention.cu")
|
||||
.file("../../csrc/ops/repeat_kv.cu")
|
||||
.file("../../csrc/ops/cast.cu")
|
||||
.file("../../csrc/ops/dropout.cu")
|
||||
.compile("xtrain_cuda_kernels");
|
||||
|
||||
@@ -296,6 +296,37 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
// GQA repeat_kv head broadcast (csrc/ops/repeat_kv.cu, Phase T15). Expands a K/V
|
||||
// tensor from [batch·num_kv, S, hd] to the full [batch·nh, S, hd] so the SDPA
|
||||
// (composed or flash, both untouched) sees a full set of heads. Forward gathers
|
||||
// (out head qh ← kv head qh/group, group = nh/num_kv); backward sums the `group`
|
||||
// query heads sharing each kv head (deterministic, no atomics). All F32.
|
||||
#[cfg(not(no_cuda))]
|
||||
unsafe extern "C" {
|
||||
// Forward: out[b·nh+qh] = in[b·num_kv + qh/group], per [S,hd] head block.
|
||||
pub fn launch_repeat_kv_fwd_f32(
|
||||
input: *const f32,
|
||||
out: *mut f32,
|
||||
batch: i32,
|
||||
nh: i32,
|
||||
num_kv: i32,
|
||||
seq: i32,
|
||||
hd: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// Backward: din[b·num_kv+kvh] = Σ_{r<group} dout[b·nh + kvh·group + r].
|
||||
pub fn launch_repeat_kv_bwd_f32(
|
||||
dout: *const f32,
|
||||
din: *mut f32,
|
||||
batch: i32,
|
||||
nh: i32,
|
||||
num_kv: i32,
|
||||
seq: i32,
|
||||
hd: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
}
|
||||
|
||||
// GPU-side optimizer kernels (csrc/ops/optim.cu): AdamW step (m/v on device) and
|
||||
// the global grad-norm reduction + in-place rescale (Phase T7).
|
||||
#[cfg(not(no_cuda))]
|
||||
|
||||
@@ -61,6 +61,8 @@ fn main() {
|
||||
let head_dim = flag(&args, "--head-dim", 16usize);
|
||||
let n_layers = flag(&args, "--layers", 4usize);
|
||||
let ffn = flag(&args, "--ffn", 64usize);
|
||||
// GQA (Phase T15): num K/V heads (must divide --heads). Default = --heads (MHA).
|
||||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||
// `--dim` is informational; dim is always n_heads*head_dim. Warn on mismatch.
|
||||
let dim_flag = flag(&args, "--dim", 0usize);
|
||||
if dim_flag != 0 && dim_flag != n_heads * head_dim {
|
||||
@@ -137,13 +139,14 @@ fn main() {
|
||||
(corpus, None)
|
||||
};
|
||||
|
||||
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
|
||||
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
|
||||
println!(
|
||||
"model: dim {} layers {} heads {} head_dim {} ffn {} → core {:.3}M params \
|
||||
"model: dim {} layers {} heads {} kv_heads {} head_dim {} ffn {} → core {:.3}M params \
|
||||
(+ embed/lm {:.2}M = {:.2}M total)",
|
||||
cfg.dim,
|
||||
cfg.n_layers,
|
||||
cfg.n_heads,
|
||||
cfg.num_kv_heads,
|
||||
cfg.head_dim,
|
||||
cfg.ffn_hidden,
|
||||
cfg.core_params() as f32 / 1e6,
|
||||
|
||||
@@ -10,8 +10,15 @@ pub struct Config {
|
||||
pub dim: usize,
|
||||
/// Number of decoder blocks.
|
||||
pub n_layers: usize,
|
||||
/// Number of attention heads.
|
||||
/// Number of attention (query) heads.
|
||||
pub n_heads: usize,
|
||||
/// Number of key/value heads (Phase T15, GQA). Each KV head is shared by a
|
||||
/// group of `n_heads / num_kv_heads` query heads (repeat_kv). Must divide
|
||||
/// `n_heads`. `num_kv_heads == n_heads` (the default) = MHA, bit-identical to
|
||||
/// the pre-T15 path; `num_kv_heads < n_heads` = real grouped-query attention,
|
||||
/// shrinking the K/V projections to `num_kv_heads * head_dim` and exported as a
|
||||
/// real `num_key_value_heads`.
|
||||
pub num_kv_heads: usize,
|
||||
/// Per-head dimension (`dim / n_heads`).
|
||||
pub head_dim: usize,
|
||||
/// SwiGLU hidden width (gate/up project to this, down projects back).
|
||||
@@ -37,6 +44,7 @@ impl Config {
|
||||
dim: n_heads * head_dim,
|
||||
n_layers: 2,
|
||||
n_heads,
|
||||
num_kv_heads: n_heads, // default = MHA
|
||||
head_dim,
|
||||
ffn_hidden: 64,
|
||||
eps: 1e-5,
|
||||
@@ -62,6 +70,7 @@ impl Config {
|
||||
dim: n_heads * head_dim,
|
||||
n_layers,
|
||||
n_heads,
|
||||
num_kv_heads: n_heads, // default = MHA; set via with_kv_heads for GQA
|
||||
head_dim,
|
||||
ffn_hidden,
|
||||
eps: 1e-5,
|
||||
@@ -70,6 +79,27 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the number of K/V heads (Phase T15, GQA). Builder-style so existing
|
||||
/// `from_arch` call sites stay MHA unless they opt in. Asserts `num_kv_heads`
|
||||
/// divides `n_heads`.
|
||||
pub fn with_kv_heads(mut self, num_kv_heads: usize) -> Self {
|
||||
assert!(num_kv_heads > 0, "num_kv_heads must be > 0");
|
||||
assert_eq!(
|
||||
self.n_heads % num_kv_heads,
|
||||
0,
|
||||
"n_heads {} not divisible by num_kv_heads {num_kv_heads}",
|
||||
self.n_heads
|
||||
);
|
||||
self.num_kv_heads = num_kv_heads;
|
||||
self
|
||||
}
|
||||
|
||||
/// KV projection width (`num_kv_heads * head_dim`). For GQA this is smaller than
|
||||
/// `dim`; for MHA it equals `dim`.
|
||||
pub fn kv_dim(&self) -> usize {
|
||||
self.num_kv_heads * self.head_dim
|
||||
}
|
||||
|
||||
/// Transformer-core parameter count: everything except the token embedding and
|
||||
/// the LM head (the two `vocab × dim` tables). This is the figure the scaling
|
||||
/// ladder is sized against — the 50257-vocab embed+lm_head adds a fixed ~25M on
|
||||
@@ -82,7 +112,8 @@ impl Config {
|
||||
pub fn num_params(&self) -> usize {
|
||||
let per_layer = 2 * self.dim // 2 rmsnorm gammas
|
||||
+ 2 * self.head_dim // q/k per-head norm gammas
|
||||
+ 3 * self.dim * self.dim // q/k/v proj
|
||||
+ self.dim * self.dim // q proj [dim,dim]
|
||||
+ 2 * self.dim * self.kv_dim() // k/v proj [dim,kv_dim] (GQA: smaller)
|
||||
+ self.dim * self.dim // out proj
|
||||
+ 2 * self.dim * self.ffn_hidden // gate/up proj
|
||||
+ self.ffn_hidden * self.dim; // down proj
|
||||
|
||||
@@ -13,8 +13,8 @@ use xtrain_tensor::{DType, Device, Tensor};
|
||||
struct Block {
|
||||
attn_norm: Var, // [dim]
|
||||
wq: Var, // [dim, dim]
|
||||
wk: Var, // [dim, dim]
|
||||
wv: Var, // [dim, dim]
|
||||
wk: Var, // [dim, kv_dim] — kv_dim = num_kv_heads·head_dim (GQA; = dim for MHA)
|
||||
wv: Var, // [dim, kv_dim]
|
||||
q_norm: Var, // [head_dim] — per-head QK-norm (Qwen3-style)
|
||||
k_norm: Var, // [head_dim]
|
||||
wo: Var, // [dim, dim]
|
||||
@@ -91,8 +91,9 @@ impl TinyTransformer {
|
||||
.map(|_| Block {
|
||||
attn_norm: mk(&[cfg.dim]),
|
||||
wq: mk(&[cfg.dim, cfg.dim]),
|
||||
wk: mk(&[cfg.dim, cfg.dim]),
|
||||
wv: mk(&[cfg.dim, cfg.dim]),
|
||||
// GQA (T15): K/V project to num_kv_heads·head_dim (= dim when MHA).
|
||||
wk: mk(&[cfg.dim, cfg.kv_dim()]),
|
||||
wv: mk(&[cfg.dim, cfg.kv_dim()]),
|
||||
q_norm: mk(&[cfg.head_dim]),
|
||||
k_norm: mk(&[cfg.head_dim]),
|
||||
wo: mk(&[cfg.dim, cfg.dim]),
|
||||
@@ -435,37 +436,47 @@ fn attention(
|
||||
wo: &Var,
|
||||
) -> Var {
|
||||
let (nh, hd) = (cfg.n_heads, cfg.head_dim);
|
||||
let num_kv = cfg.num_kv_heads; // GQA (T15): K/V have fewer heads than Q
|
||||
let total = batch * seq;
|
||||
let bh = batch * nh;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
|
||||
// Project, qk-norm + RoPE, then lay out as a batched [B*nh, seq, hd] tensor.
|
||||
// [B*S,dim] @ [dim,dim] = [B*S,dim]
|
||||
// reshape [B*S, nh, hd]
|
||||
// Project, qk-norm + RoPE, then lay out as a batched [B*heads, seq, hd] tensor.
|
||||
// `heads` = nh for Q, num_kv for K/V (GQA; equal for MHA).
|
||||
// [B*S,dim] @ [dim,heads*hd] = [B*S, heads*hd]
|
||||
// reshape [B*S, heads, hd]
|
||||
// qk-norm per-head RMSNorm over hd (Qwen3-style; Q/K only, before RoPE)
|
||||
// rope [B*S, nh, hd] with per-sequence position (period = seq)
|
||||
// reshape [B, S, nh, hd] → transpose(1,2) → [B, nh, S, hd] → [B*nh, S, hd]
|
||||
let to_bh = |proj: Var, norm: Option<&Var>| -> Var {
|
||||
let r = ops::reshape(&proj, &[total, nh, hd]);
|
||||
// rope [B*S, heads, hd] with per-sequence position (period = seq)
|
||||
// reshape [B, S, heads, hd] → transpose(1,2) → [B, heads, S, hd] → [B*heads, S, hd]
|
||||
let to_bh = |proj: Var, heads: usize, norm: Option<&Var>| -> Var {
|
||||
let r = ops::reshape(&proj, &[total, heads, hd]);
|
||||
let r = match norm {
|
||||
// Per-head RMSNorm: flatten the (B*S,nh) head rows, norm over hd,
|
||||
// Per-head RMSNorm: flatten the (B*S,heads) head rows, norm over hd,
|
||||
// restore. RoPE follows on the normed Q/K (mirrors xserv qwen3.rs).
|
||||
Some(gamma) => {
|
||||
let flat = ops::reshape(&r, &[total * nh, hd]);
|
||||
let flat = ops::reshape(&r, &[total * heads, hd]);
|
||||
let normed = ops::rms_norm(&flat, &norm_gamma(cdt, gamma), cfg.eps);
|
||||
let r = ops::reshape(&normed, &[total, nh, hd]);
|
||||
let r = ops::reshape(&normed, &[total, heads, hd]);
|
||||
ops::rope(&r, cfg.rope_theta, seq)
|
||||
}
|
||||
None => r,
|
||||
};
|
||||
let r = ops::reshape(&r, &[batch, seq, nh, hd]);
|
||||
let t = ops::transpose_4d12(&r); // [B, nh, S, hd]
|
||||
ops::reshape(&t, &[bh, seq, hd]) // [B*nh, S, hd]
|
||||
let r = ops::reshape(&r, &[batch, seq, heads, hd]);
|
||||
let t = ops::transpose_4d12(&r); // [B, heads, S, hd]
|
||||
ops::reshape(&t, &[batch * heads, seq, hd]) // [B*heads, S, hd]
|
||||
};
|
||||
|
||||
let q = to_bh(linear(cdt, x, wq), Some(q_norm));
|
||||
let k = to_bh(linear(cdt, x, wk), Some(k_norm));
|
||||
let v = to_bh(linear(cdt, x, wv), None);
|
||||
let q = to_bh(linear(cdt, x, wq), nh, Some(q_norm));
|
||||
// K/V are laid out with num_kv heads, then repeat_kv-broadcast to nh heads so
|
||||
// the SDPA below (composed or flash, both unchanged) sees a full head set. The
|
||||
// broadcast's backward sums each KV head's group of query-head grads (GQA). For
|
||||
// MHA (num_kv == nh) repeat_kv is identity → bit-identical to the pre-T15 path.
|
||||
let k = to_bh(linear(cdt, x, wk), num_kv, Some(k_norm));
|
||||
let v = to_bh(linear(cdt, x, wv), num_kv, None);
|
||||
let (k, v) = if num_kv == nh {
|
||||
(k, v)
|
||||
} else {
|
||||
(ops::repeat_kv(&k, nh, batch), ops::repeat_kv(&v, nh, batch))
|
||||
};
|
||||
|
||||
// Causal SDPA over all B*nh (sequence,head) blocks. `flash` (T14) picks the
|
||||
// single fused flash kernel (online softmax, no materialized [bh,S,S] scores);
|
||||
|
||||
269
crates/xtrain-model/tests/gqa.rs
Normal file
269
crates/xtrain-model/tests/gqa.rs
Normal file
@@ -0,0 +1,269 @@
|
||||
// T15 GQA correctness gate. Real grouped-query attention (num_kv_heads <
|
||||
// num_heads): K/V project to num_kv_heads·head_dim and are repeat_kv-broadcast to
|
||||
// the full head set before the SDPA. This test pins three things:
|
||||
//
|
||||
// 1. GQA flash == GQA composed (forward logits + loss + EVERY param grad) — the
|
||||
// repeat_kv broadcast feeds both SDPA paths unchanged, so they must agree; in
|
||||
// particular the wk/wv grads (which flow back through repeat_kv's group-sum)
|
||||
// must match. Parameterised over fp32 (tight) and bf16 (rounding band).
|
||||
// 2. group==1 (num_kv_heads == n_heads) is BIT-IDENTICAL to the pre-T15 MHA path
|
||||
// (a model with num_kv_heads explicitly == n_heads vs the default config):
|
||||
// forward logits + every grad |Δ|=0. The regression guard.
|
||||
// 3. wk/wv really shrank to [dim, kv_dim] under GQA (shape check).
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
||||
use xtrain_tensor::{DType, Device};
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build(cfg: Config, device: Device, dtype: DType, flash: bool) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
let m = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
});
|
||||
m.with_compute_dtype(dtype).with_flash(flash)
|
||||
}
|
||||
|
||||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||||
t.to_dtype(DType::F32)
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec()
|
||||
}
|
||||
|
||||
// A real GQA config: 8 query heads, 2 kv heads → group 4. seq=40 > FA_TILE=32 so
|
||||
// the flash online-softmax tile path is exercised too.
|
||||
fn gqa_cfg() -> Config {
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 16;
|
||||
cfg.n_layers = 3;
|
||||
// tiny() is 2 heads; rebuild with 8 query / 2 kv heads keeping head_dim=16.
|
||||
Config::from_arch(cfg.vocab, 8, cfg.head_dim, cfg.n_layers, cfg.ffn_hidden).with_kv_heads(2)
|
||||
}
|
||||
|
||||
fn ids_targets(cfg: &Config, batch: usize, seq: usize) -> (Vec<Vec<i32>>, Vec<Vec<i32>>) {
|
||||
let seqs = (0..batch)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let tgts = (0..batch)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
(seqs, tgts)
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn run_both(
|
||||
cfg: Config,
|
||||
dtype: DType,
|
||||
) -> (Vec<f32>, f32, Vec<Vec<f32>>, Vec<f32>, f32, Vec<Vec<f32>>) {
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
let (batch, seq) = (3usize, 40usize);
|
||||
let (seqs, tgts) = ids_targets(&cfg, batch, seq);
|
||||
let ids = batched_ids_tensor(&seqs, device);
|
||||
let tgt = batched_ids_tensor(&tgts, device);
|
||||
|
||||
let off = build(cfg, device, dtype, false);
|
||||
let off_logits = host(&off.forward_batched(&ids, batch).value());
|
||||
let off_loss = off.loss_batched(&ids, &tgt, batch);
|
||||
let off_loss_val = host(&off_loss.value())[0];
|
||||
off_loss.backward();
|
||||
let off_grads: Vec<Vec<f32>> = off
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().expect("off grad")))
|
||||
.collect();
|
||||
|
||||
let on = build(cfg, device, dtype, true);
|
||||
let on_logits = host(&on.forward_batched(&ids, batch).value());
|
||||
let on_loss = on.loss_batched(&ids, &tgt, batch);
|
||||
let on_loss_val = host(&on_loss.value())[0];
|
||||
on_loss.backward();
|
||||
let on_grads: Vec<Vec<f32>> = on
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().expect("on grad")))
|
||||
.collect();
|
||||
|
||||
(
|
||||
off_logits,
|
||||
off_loss_val,
|
||||
off_grads,
|
||||
on_logits,
|
||||
on_loss_val,
|
||||
on_grads,
|
||||
)
|
||||
}
|
||||
|
||||
// GQA flash vs composed: same SDPA math on the same repeat_kv-broadcast K/V → fp32
|
||||
// agrees to reduction-order, bf16 to its rounding band.
|
||||
#[test]
|
||||
fn gqa_flash_matches_composed_fp32() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
let cfg = gqa_cfg();
|
||||
assert!(cfg.num_kv_heads < cfg.n_heads, "test must be real GQA");
|
||||
let (off_l, off_loss, off_g, on_l, on_loss, on_g) = run_both(cfg, DType::F32);
|
||||
|
||||
let logit_rel = off_l
|
||||
.iter()
|
||||
.zip(&on_l)
|
||||
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
|
||||
.fold(0.0f32, f32::max);
|
||||
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
|
||||
println!(
|
||||
"[GQA F32] flash on/off: loss {off_loss:.6}/{on_loss:.6} (rel {loss_rel:.2e}), \
|
||||
logits max rel {logit_rel:.2e}"
|
||||
);
|
||||
assert!(
|
||||
logit_rel < 1e-3,
|
||||
"[GQA F32] logits diverged: {logit_rel:.2e}"
|
||||
);
|
||||
assert!(loss_rel < 1e-3, "[GQA F32] loss diverged: {loss_rel:.2e}");
|
||||
|
||||
let mut worst = 0.0f32;
|
||||
for (a_g, b_g) in off_g.iter().zip(&on_g) {
|
||||
for (a, b) in a_g.iter().zip(b_g) {
|
||||
worst = worst.max((a - b).abs() / a.abs().max(1e-3));
|
||||
}
|
||||
}
|
||||
println!("[GQA F32] flash on/off grad max rel = {worst:.3e}");
|
||||
assert!(worst < 2e-2, "[GQA F32] grads diverged: {worst:.3e}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gqa_flash_matches_composed_bf16() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
let (off_l, off_loss, off_g, on_l, on_loss, on_g) = run_both(gqa_cfg(), DType::BF16);
|
||||
|
||||
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
|
||||
println!("[GQA BF16] flash on/off: loss {off_loss:.5}/{on_loss:.5} (rel {loss_rel:.3e})");
|
||||
assert!(loss_rel < 2e-2, "[GQA BF16] loss diverged: {loss_rel:.3e}");
|
||||
|
||||
let n = off_l.len();
|
||||
let mut rels: Vec<f32> = off_l
|
||||
.iter()
|
||||
.zip(&on_l)
|
||||
.map(|(f, b)| (b - f).abs() / f.abs().max(1.0))
|
||||
.collect();
|
||||
rels.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let mean: f32 = rels.iter().sum::<f32>() / n as f32;
|
||||
let p99 = rels[(n as f32 * 0.99) as usize];
|
||||
println!("[GQA BF16] logits: mean rel {mean:.3e}, p99 rel {p99:.3e}");
|
||||
assert!(
|
||||
mean < 1e-2,
|
||||
"[GQA BF16] logits mean rel too high: {mean:.3e}"
|
||||
);
|
||||
assert!(p99 < 5e-2, "[GQA BF16] logits p99 rel too high: {p99:.3e}");
|
||||
|
||||
let mut worst = 0.0f32;
|
||||
for (a_g, b_g) in off_g.iter().zip(&on_g) {
|
||||
let scale = a_g.iter().map(|v| v.abs()).fold(0.0f32, f32::max).max(1e-6);
|
||||
let mean_err: f32 =
|
||||
a_g.iter().zip(b_g).map(|(f, b)| (f - b).abs()).sum::<f32>() / a_g.len() as f32 / scale;
|
||||
worst = worst.max(mean_err);
|
||||
}
|
||||
println!("[GQA BF16] grads: worst per-tensor scaled-mean err = {worst:.3e}");
|
||||
assert!(worst < 3e-2, "[GQA BF16] grads diverged: {worst:.3e}");
|
||||
}
|
||||
|
||||
// REGRESSION GUARD: num_kv_heads == n_heads (group 1) must be BIT-IDENTICAL to the
|
||||
// pre-T15 MHA path. Build one model with the default config (num_kv_heads ==
|
||||
// n_heads, the untouched path: repeat_kv not even invoked) and one that explicitly
|
||||
// sets num_kv_heads = n_heads, then assert forward logits + every grad match to the
|
||||
// bit. (Same composed path, so this is exact equality, not a tolerance.)
|
||||
#[test]
|
||||
fn gqa_group1_bit_identical_to_mha() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let mut base = Config::tiny();
|
||||
base.vocab = 16;
|
||||
base.n_layers = 3;
|
||||
let base = Config::from_arch(base.vocab, 4, base.head_dim, base.n_layers, base.ffn_hidden);
|
||||
// `explicit` sets num_kv_heads = n_heads (already the default, but exercises the
|
||||
// with_kv_heads path); they are the same config → must produce identical output.
|
||||
let explicit = base.with_kv_heads(base.n_heads);
|
||||
assert_eq!(base.num_kv_heads, explicit.num_kv_heads);
|
||||
|
||||
let (batch, seq) = (2usize, 8usize);
|
||||
let (seqs, tgts) = ids_targets(&base, batch, seq);
|
||||
let ids = batched_ids_tensor(&seqs, device);
|
||||
let tgt = batched_ids_tensor(&tgts, device);
|
||||
|
||||
let run = |cfg: Config| -> (Vec<f32>, f32, Vec<Vec<f32>>) {
|
||||
let m = build(cfg, device, DType::F32, false);
|
||||
let logits = host(&m.forward_batched(&ids, batch).value());
|
||||
let loss = m.loss_batched(&ids, &tgt, batch);
|
||||
let loss_v = host(&loss.value())[0];
|
||||
loss.backward();
|
||||
let grads = m
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().unwrap()))
|
||||
.collect();
|
||||
(logits, loss_v, grads)
|
||||
};
|
||||
let (la, sa, ga) = run(base);
|
||||
let (lb, sb, gb) = run(explicit);
|
||||
assert_eq!(la, lb, "group-1 logits must be bit-identical to MHA");
|
||||
assert_eq!(sa, sb, "group-1 loss must be bit-identical to MHA");
|
||||
for (a, b) in ga.iter().zip(&gb) {
|
||||
assert_eq!(a, b, "group-1 grad must be bit-identical to MHA");
|
||||
}
|
||||
println!("[GQA group1] bit-identical to MHA: logits + loss + all grads |Δ|=0");
|
||||
}
|
||||
|
||||
// Under GQA, wk/wv must be [dim, kv_dim] (= num_kv_heads·head_dim), wq stays [dim,dim].
|
||||
#[test]
|
||||
fn gqa_kv_proj_shape() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
let cfg = gqa_cfg();
|
||||
let m = build(cfg, device, DType::F32, false);
|
||||
let p = m.params();
|
||||
// params order: embed[0], then block 0 = [attn_norm[1], wq[2], wk[3], wv[4],
|
||||
// q_norm[5], k_norm[6], wo[7], ...]
|
||||
let wq = p[2].value().shape().to_vec();
|
||||
let wk = p[3].value().shape().to_vec();
|
||||
let wv = p[4].value().shape().to_vec();
|
||||
assert_eq!(wq, vec![cfg.dim, cfg.dim], "wq must be [dim,dim]");
|
||||
assert_eq!(wk, vec![cfg.dim, cfg.kv_dim()], "wk must be [dim,kv_dim]");
|
||||
assert_eq!(wv, vec![cfg.dim, cfg.kv_dim()], "wv must be [dim,kv_dim]");
|
||||
println!(
|
||||
"[GQA shapes] wq {:?} wk {:?} wv {:?} (kv_dim {})",
|
||||
wq,
|
||||
wk,
|
||||
wv,
|
||||
cfg.kv_dim()
|
||||
);
|
||||
}
|
||||
@@ -52,6 +52,10 @@ cfg = read_cfg()
|
||||
DIM = int(cfg["dim"])
|
||||
NL = int(cfg["n_layers"])
|
||||
NH = int(cfg["n_heads"])
|
||||
# GQA (T15): num_kv_heads <= n_heads; each kv head shared by group query heads.
|
||||
# Default to NH (MHA) for fixtures dumped before the field existed.
|
||||
NKV = int(cfg.get("num_kv_heads", str(NH)))
|
||||
GROUP = NH // NKV
|
||||
HD = int(cfg["head_dim"])
|
||||
EPS = float(cfg["eps"])
|
||||
THETA = float(cfg["rope_theta"])
|
||||
@@ -114,17 +118,23 @@ for L in layers:
|
||||
# Attention
|
||||
x = rms_norm(h, L["attn_norm"])
|
||||
q = (x @ L["wq"]).reshape(B * SEQ, NH, HD)
|
||||
k = (x @ L["wk"]).reshape(B * SEQ, NH, HD)
|
||||
v = (x @ L["wv"]).reshape(B * SEQ, NH, HD)
|
||||
# GQA: K/V project to NKV heads, then repeat each kv head GROUP times to NH.
|
||||
k = (x @ L["wk"]).reshape(B * SEQ, NKV, HD)
|
||||
v = (x @ L["wv"]).reshape(B * SEQ, NKV, HD)
|
||||
# Per-head QK-norm (Qwen3-style), before RoPE.
|
||||
q = rms_norm(q, L["q_norm"])
|
||||
k = rms_norm(k, L["k_norm"])
|
||||
q = rope(q) # [B*SEQ, nh, hd]
|
||||
k = rope(k)
|
||||
# Reshape to [B, NH, SEQ, HD] so attention runs within each sequence.
|
||||
k = rope(k) # [B*SEQ, nkv, hd]
|
||||
# Reshape to [B, *, SEQ, HD]; broadcast kv heads to NH (repeat_interleave along
|
||||
# the head axis: kv head kvh → query heads [kvh*GROUP, (kvh+1)*GROUP), matching
|
||||
# xtrain repeat_kv + xserv repeat_kv).
|
||||
q = q.reshape(B, SEQ, NH, HD).transpose(1, 2) # [B, nh, seq, hd]
|
||||
k = k.reshape(B, SEQ, NH, HD).transpose(1, 2)
|
||||
v = v.reshape(B, SEQ, NH, HD).transpose(1, 2)
|
||||
k = k.reshape(B, SEQ, NKV, HD).transpose(1, 2) # [B, nkv, seq, hd]
|
||||
v = v.reshape(B, SEQ, NKV, HD).transpose(1, 2)
|
||||
if GROUP > 1:
|
||||
k = k.repeat_interleave(GROUP, dim=1) # [B, nh, seq, hd]
|
||||
v = v.repeat_interleave(GROUP, dim=1)
|
||||
scale = 1.0 / math.sqrt(HD)
|
||||
scores = (q @ k.transpose(-1, -2)) * scale + mask # [B, nh, seq, seq]
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
|
||||
@@ -58,8 +58,20 @@ fn dump_for_parity() {
|
||||
// sequence-major to [B*S]=8 ids. Per-sequence RoPE position (resets at the
|
||||
// sequence boundary) + per-sequence causal masking (no cross-sequence
|
||||
// attention) are both checked against PyTorch.
|
||||
// Default: tiny MHA (2 heads). With XTRAIN_PARITY_KV_HEADS=k set, dump a real
|
||||
// GQA config (8 query heads / k kv heads) so parity.py checks GQA at B>1 — the
|
||||
// kv-projection shapes + the repeat_kv group-sum backward against PyTorch.
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 12;
|
||||
if let Ok(kv) = std::env::var("XTRAIN_PARITY_KV_HEADS") {
|
||||
let kv: usize = kv.parse().expect("XTRAIN_PARITY_KV_HEADS");
|
||||
cfg = Config::from_arch(cfg.vocab, 8, cfg.head_dim, cfg.n_layers, cfg.ffn_hidden)
|
||||
.with_kv_heads(kv);
|
||||
println!(
|
||||
"parity: GQA config (n_heads {} kv_heads {})",
|
||||
cfg.n_heads, cfg.num_kv_heads
|
||||
);
|
||||
}
|
||||
let batch = 2usize;
|
||||
let seq = 4usize;
|
||||
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6]; // [B*S], sequence-major
|
||||
@@ -92,6 +104,7 @@ fn dump_for_parity() {
|
||||
writeln!(f, "dim {}", cfg.dim).unwrap();
|
||||
writeln!(f, "n_layers {}", cfg.n_layers).unwrap();
|
||||
writeln!(f, "n_heads {}", cfg.n_heads).unwrap();
|
||||
writeln!(f, "num_kv_heads {}", cfg.num_kv_heads).unwrap();
|
||||
writeln!(f, "head_dim {}", cfg.head_dim).unwrap();
|
||||
writeln!(f, "ffn_hidden {}", cfg.ffn_hidden).unwrap();
|
||||
writeln!(f, "eps {:e}", cfg.eps).unwrap();
|
||||
|
||||
@@ -1291,6 +1291,97 @@ impl Tensor {
|
||||
(dq.to_dtype(dt), dk.to_dtype(dt), dv.to_dtype(dt))
|
||||
}
|
||||
|
||||
// --- GQA repeat_kv head broadcast (the T15 op) ---
|
||||
|
||||
/// GQA repeat_kv (Phase T15). `self` is a K or V tensor `[batch·num_kv, seq,
|
||||
/// head_dim]`; returns `[batch·nh, seq, head_dim]` with each KV head broadcast
|
||||
/// to its `group = nh/num_kv` query heads (query head `qh` ← kv head `qh/group`,
|
||||
/// query heads contiguous within a group — matching xserv's repeat_kv). When
|
||||
/// `nh == num_kv` (group 1) this is a plain copy (identity → MHA bit-identical).
|
||||
/// bf16 callers are upcast→f32→kernel→downcast (the attention stack's fp32 policy).
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn repeat_kv(&self, nh: usize, batch: usize) -> Self {
|
||||
assert_eq!(
|
||||
self.ndim(),
|
||||
3,
|
||||
"repeat_kv input must be [batch·num_kv,seq,hd]"
|
||||
);
|
||||
let (bh_kv, seq, hd) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||
assert_eq!(
|
||||
bh_kv % batch,
|
||||
0,
|
||||
"repeat_kv: bh_kv {bh_kv} not div by batch {batch}"
|
||||
);
|
||||
let num_kv = bh_kv / batch;
|
||||
assert_eq!(
|
||||
nh % num_kv,
|
||||
0,
|
||||
"repeat_kv: nh {nh} not divisible by num_kv {num_kv}"
|
||||
);
|
||||
if self.dtype == DType::BF16 {
|
||||
return self
|
||||
.to_dtype(DType::F32)
|
||||
.repeat_kv(nh, batch)
|
||||
.to_dtype(DType::BF16);
|
||||
}
|
||||
assert_eq!(self.dtype, DType::F32, "repeat_kv supports F32/BF16");
|
||||
let out = Tensor::zeros(&[batch * nh, seq, hd], DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_repeat_kv_fwd_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
batch as i32,
|
||||
nh as i32,
|
||||
num_kv as i32,
|
||||
seq as i32,
|
||||
hd as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Backward of [`repeat_kv`](Self::repeat_kv). `dout`:[batch·nh,seq,hd] →
|
||||
/// `din`:[batch·num_kv,seq,hd], summing the `group` query heads that share each
|
||||
/// KV head (the multi-group grad accumulation GQA correctness hinges on).
|
||||
/// Deterministic (no atomics: each kv element serially sums its group rows).
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn repeat_kv_backward(dout: &Tensor, num_kv: usize, batch: usize) -> Self {
|
||||
assert_eq!(
|
||||
dout.ndim(),
|
||||
3,
|
||||
"repeat_kv_backward dout must be [batch·nh,seq,hd]"
|
||||
);
|
||||
let (bh_q, seq, hd) = (dout.shape[0], dout.shape[1], dout.shape[2]);
|
||||
assert_eq!(bh_q % batch, 0, "repeat_kv_backward: bh_q not div by batch");
|
||||
let nh = bh_q / batch;
|
||||
assert_eq!(
|
||||
nh % num_kv,
|
||||
0,
|
||||
"repeat_kv_backward: nh not divisible by num_kv"
|
||||
);
|
||||
let dt = dout.dtype;
|
||||
if dt == DType::BF16 {
|
||||
return Tensor::repeat_kv_backward(&dout.to_dtype(DType::F32), num_kv, batch)
|
||||
.to_dtype(DType::BF16);
|
||||
}
|
||||
assert_eq!(dt, DType::F32, "repeat_kv_backward supports F32/BF16");
|
||||
let din = Tensor::zeros(&[batch * num_kv, seq, hd], DType::F32, dout.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_repeat_kv_bwd_f32(
|
||||
dout.data_ptr() as *const f32,
|
||||
din.data_ptr() as *mut f32,
|
||||
batch as i32,
|
||||
nh as i32,
|
||||
num_kv as i32,
|
||||
seq as i32,
|
||||
hd as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
din
|
||||
}
|
||||
|
||||
/// 4D axis-(1,2) transpose: `self`:[a,b,c,d] → [a,c,b,d],
|
||||
/// `out[i,k,j,l]=self[i,j,k,l]`. Lays out batched multi-head attention
|
||||
/// (`[B,S,nh,hd] <-> [B,nh,S,hd]`). Its own backward is the same op (swap b,c).
|
||||
|
||||
@@ -174,7 +174,7 @@ fn config_json(cfg: &Config) -> String {
|
||||
ffn = cfg.ffn_hidden,
|
||||
layers = cfg.n_layers,
|
||||
heads = cfg.n_heads,
|
||||
kv_heads = cfg.n_heads, // xtrain is MHA → kv heads == query heads
|
||||
kv_heads = cfg.num_kv_heads, // GQA (T15): real num_key_value_heads (= n_heads for MHA)
|
||||
head_dim = cfg.head_dim,
|
||||
eps = cfg.eps,
|
||||
theta = cfg.rope_theta,
|
||||
@@ -206,6 +206,8 @@ fn main() {
|
||||
let head_dim = flag(&args, "--head-dim", 16usize);
|
||||
let n_layers = flag(&args, "--layers", 4usize);
|
||||
let ffn = flag(&args, "--ffn", 64usize);
|
||||
// GQA (Phase T15): num K/V heads (must match the trained ckpt; default = --heads).
|
||||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
@@ -214,15 +216,16 @@ fn main() {
|
||||
// Size the model from the arch flags + gpt2 vocab; must match the checkpoint.
|
||||
let tok = Tokenizer::from_file(&tok_path);
|
||||
let vocab = tok.vocab_size();
|
||||
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
|
||||
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
|
||||
println!(
|
||||
"export: ckpt {} → {} (vocab {}, dim {}, layers {}, heads {}, head_dim {})",
|
||||
"export: ckpt {} → {} (vocab {}, dim {}, layers {}, heads {}, kv_heads {}, head_dim {})",
|
||||
ckpt.display(),
|
||||
out_dir.display(),
|
||||
cfg.vocab,
|
||||
cfg.dim,
|
||||
cfg.n_layers,
|
||||
cfg.n_heads,
|
||||
cfg.num_kv_heads,
|
||||
cfg.head_dim,
|
||||
);
|
||||
|
||||
|
||||
@@ -72,6 +72,8 @@ fn main() {
|
||||
let head_dim = flag(&args, "--head-dim", 16usize);
|
||||
let n_layers = flag(&args, "--layers", 4usize);
|
||||
let ffn = flag(&args, "--ffn", 64usize);
|
||||
// GQA (Phase T15): num K/V heads (must match the ckpt; default = --heads).
|
||||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||
let max_new = flag(&args, "--max-tokens", 40usize);
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
@@ -79,14 +81,16 @@ fn main() {
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let tok = Tokenizer::from_file(&tok_path);
|
||||
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn);
|
||||
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn)
|
||||
.with_kv_heads(kv_heads);
|
||||
println!(
|
||||
"greedy_sample: ckpt {} (vocab {}, dim {}, layers {}, heads {}, head_dim {})",
|
||||
"greedy_sample: ckpt {} (vocab {}, dim {}, layers {}, heads {}, kv_heads {}, head_dim {})",
|
||||
ckpt.display(),
|
||||
cfg.vocab,
|
||||
cfg.dim,
|
||||
cfg.n_layers,
|
||||
cfg.n_heads,
|
||||
cfg.num_kv_heads,
|
||||
cfg.head_dim,
|
||||
);
|
||||
|
||||
|
||||
@@ -88,6 +88,8 @@ fn main() {
|
||||
let head_dim = flag(&args, "--head-dim", 16usize);
|
||||
let n_layers = flag(&args, "--layers", 4usize);
|
||||
let ffn = flag(&args, "--ffn", 64usize);
|
||||
// GQA (Phase T15): num K/V heads (must divide --heads). Default = --heads (MHA).
|
||||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||
// `--dim` is informational; dim is always n_heads*head_dim. Warn on mismatch.
|
||||
let dim_flag = flag(&args, "--dim", 0usize);
|
||||
if dim_flag != 0 && dim_flag != n_heads * head_dim {
|
||||
@@ -160,14 +162,16 @@ fn main() {
|
||||
(corpus, None)
|
||||
};
|
||||
|
||||
let mut cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
|
||||
let mut cfg =
|
||||
Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
|
||||
cfg.dropout = dropout;
|
||||
println!(
|
||||
"model: dim {} layers {} heads {} head_dim {} ffn {} → core {:.3}M params \
|
||||
"model: dim {} layers {} heads {} kv_heads {} head_dim {} ffn {} → core {:.3}M params \
|
||||
(+ embed/lm {:.2}M = {:.2}M total)",
|
||||
cfg.dim,
|
||||
cfg.n_layers,
|
||||
cfg.n_heads,
|
||||
cfg.num_kv_heads,
|
||||
cfg.head_dim,
|
||||
cfg.ffn_hidden,
|
||||
cfg.core_params() as f32 / 1e6,
|
||||
|
||||
84
csrc/ops/repeat_kv.cu
Normal file
84
csrc/ops/repeat_kv.cu
Normal file
@@ -0,0 +1,84 @@
|
||||
// repeat_kv: the grouped-query-attention (GQA) head broadcast (Phase T15).
|
||||
//
|
||||
// GQA projects K/V to fewer heads than Q (num_kv_heads < num_heads); each KV head
|
||||
// is SHARED by a group of `group = num_heads / num_kv_heads` query heads. Before
|
||||
// the SDPA (composed or fused-flash, both untouched), we expand the KV tensor from
|
||||
// [B·num_kv, S, hd] to the full [B·nh, S, hd] so the existing per-(batch,head)
|
||||
// attention sees a full set of heads. GQA is then "free" for both SDPA paths.
|
||||
//
|
||||
// Layout: K/V are [bh_kv, S, hd] = [B·num_kv, S, hd] row-major, contiguous; the
|
||||
// output is [bh_q, S, hd] = [B·nh, S, hd]. The head ordering matches xserv's
|
||||
// repeat_kv (crates/xserv-model/src/qwen3.rs): query head qh reads kv head
|
||||
// qh/group, query heads CONTIGUOUS within a group (dst = kvh*group + r). So:
|
||||
//
|
||||
// out[b·nh + qh, :, :] = in[b·num_kv + qh/group, :, :]
|
||||
//
|
||||
// Forward is a gather (each output row copies one input row). Backward is its
|
||||
// transpose: a kv head receives the SUM of the `group` query heads that share it
|
||||
// din[b·num_kv + kvh, e] = Σ_{r∈[0,group)} dout[b·nh + kvh·group + r, e]
|
||||
// — the multi-group-to-one grad accumulation GQA's correctness hinges on. Each
|
||||
// input element is owned by exactly one thread that serially sums its `group`
|
||||
// source rows: race-free, NO atomics, run-to-run DETERMINISTIC. group==1 makes
|
||||
// both directions a plain copy (identity → bit-identical to the MHA path).
|
||||
//
|
||||
// All F32, contiguous. (bf16 callers upcast → f32 on the Rust side and downcast
|
||||
// the f32 result, mirroring the rest of the attention stack's fp32 policy.)
|
||||
|
||||
#include <math.h>
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Forward gather. grid-stride over the bh_q·S·hd output elements; each output
|
||||
// element copies from its kv-head source row. b = (out_bh / nh), qh = out_bh % nh,
|
||||
// kv source bh = b·num_kv + qh/group.
|
||||
__global__ void repeat_kv_fwd_k(const float* in, float* out, int nh, int num_kv,
|
||||
int group, int seq, int hd) {
|
||||
long row_elems = (long)seq * hd; // S·hd per head block
|
||||
// One block per (batch, query-head); threads cover the S·hd block.
|
||||
int out_bh = blockIdx.x; // over B·nh
|
||||
int b = out_bh / nh;
|
||||
int qh = out_bh % nh;
|
||||
int kvh = qh / group;
|
||||
const float* src = in + ((long)b * num_kv + kvh) * row_elems;
|
||||
float* dst = out + (long)out_bh * row_elems;
|
||||
for (long e = threadIdx.x; e < row_elems; e += blockDim.x) dst[e] = src[e];
|
||||
}
|
||||
|
||||
void launch_repeat_kv_fwd_f32(const float* in, float* out, int batch, int nh,
|
||||
int num_kv, int seq, int hd, void* s) {
|
||||
int group = nh / num_kv;
|
||||
int blk = (seq * hd) < 256 ? (seq * hd) : 256;
|
||||
if (blk < 32) blk = 32;
|
||||
repeat_kv_fwd_k<<<batch * nh, blk, 0, (cudaStream_t)s>>>(in, out, nh, num_kv,
|
||||
group, seq, hd);
|
||||
}
|
||||
|
||||
// Backward sum. One block per (batch, kv-head); threads cover the S·hd block.
|
||||
// Each owned input element sums the `group` contiguous query-head source rows.
|
||||
__global__ void repeat_kv_bwd_k(const float* dout, float* din, int nh, int num_kv,
|
||||
int group, int seq, int hd) {
|
||||
long row_elems = (long)seq * hd;
|
||||
int in_bh = blockIdx.x; // over B·num_kv
|
||||
int b = in_bh / num_kv;
|
||||
int kvh = in_bh % num_kv;
|
||||
int qh0 = kvh * group; // first query head sharing this kv head
|
||||
float* dst = din + (long)in_bh * row_elems;
|
||||
const float* base = dout + ((long)b * nh + qh0) * row_elems;
|
||||
for (long e = threadIdx.x; e < row_elems; e += blockDim.x) {
|
||||
float acc = 0.0f;
|
||||
for (int r = 0; r < group; ++r) acc += base[(long)r * row_elems + e];
|
||||
dst[e] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
void launch_repeat_kv_bwd_f32(const float* dout, float* din, int batch, int nh,
|
||||
int num_kv, int seq, int hd, void* s) {
|
||||
int group = nh / num_kv;
|
||||
int blk = (seq * hd) < 256 ? (seq * hd) : 256;
|
||||
if (blk < 32) blk = 32;
|
||||
repeat_kv_bwd_k<<<batch * num_kv, blk, 0, (cudaStream_t)s>>>(dout, din, nh,
|
||||
num_kv, group,
|
||||
seq, hd);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
180
docs/14-gqa.md
Normal file
180
docs/14-gqa.md
Normal file
@@ -0,0 +1,180 @@
|
||||
# Phase T15: Grouped-Query Attention (GQA) — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
到 T14 为止,xtrain 的 attention 都是 **MHA**(`num_kv_heads = num_heads`)——每个
|
||||
query 头有自己独立的 K/V 头。导出 xserv 时 `num_key_value_heads = num_attention_heads`
|
||||
(退化 GQA,docs/08)。
|
||||
|
||||
T15 做**真正的 grouped-query attention**:`num_kv_heads < num_heads`,K/V 只投影到
|
||||
`num_kv_heads · head_dim`,每个 KV 头被一组 `group = num_heads / num_kv_heads` 个 query
|
||||
头**共享**(repeat_kv / broadcast)。GQA 是现代 LLM(Llama-2-70B、Qwen2/3、Mistral)的标配
|
||||
——它把 KV cache 显存(推理)与 K/V 投影参数(训练)按 `group` 倍压缩,几乎不掉质量。
|
||||
|
||||
**硬闸门是诚实正确性**,重点在 **repeat_kv 的反向梯度累加**:一个 KV 头被 `group` 个 query 头
|
||||
共享,反向时这 `group` 个 query 头各自对该 KV 头的梯度必须**正确求和**回到那一个共享 KV 头上。
|
||||
这条「多组 q 头梯度汇到一个 kv 头」的累加路径是本任务最易错处,单列为首要 grad-check 闸门。
|
||||
|
||||
GQA 必须**同时**接进 T14 的 fused flash kernel(优先)与旧 composed/batched SDPA 路径,且
|
||||
`num_kv_heads == num_heads`(`group = 1`)时与现有 MHA 路径**逐位一致**(回归保护)。
|
||||
|
||||
## 什么是 GQA
|
||||
|
||||
MHA:`num_heads` 个 query 头,每个配一个独立 K/V 头。
|
||||
MQA(multi-query):所有 query 头共享**一个** K/V 头(极端)。
|
||||
GQA:折中——`num_kv_heads` 个 K/V 头,每个被 `group = num_heads/num_kv_heads` 个相邻 query
|
||||
头共享。`num_kv_heads = num_heads` 退化为 MHA,`num_kv_heads = 1` 退化为 MQA。
|
||||
|
||||
```
|
||||
num_heads = 8, num_kv_heads = 2 → group = 4
|
||||
q heads: 0 1 2 3 4 5 6 7
|
||||
kv heads: 0 0 0 0 1 1 1 1 # q head qh 用 kv head qh/group(相邻分组,连续)
|
||||
```
|
||||
|
||||
**分组约定必须与 xserv repeat_kv 一致**(闭环命门):xserv 的 `repeat_kv`
|
||||
(`crates/xserv-model/src/qwen3.rs`)把 kv 头 `kvh` 复制到目标头 `dst = kvh*group + r`
|
||||
(`r∈[0,group)`),即**query 头 `qh` 读 kv 头 `qh/group`,组内 query 头连续**。xtrain 的
|
||||
repeat_kv 用同一映射,否则导出的 `q_proj` 行块与 kv 头对不上 → 闭环必崩。
|
||||
|
||||
## Module Layout(surgical:复用已验证的两条 SDPA,GQA = 头维 broadcast op)
|
||||
|
||||
```
|
||||
csrc/ops/repeat_kv.cu # 新:repeat_kv fwd(头块 gather)+ bwd(组内 group 行求和,无 atomic,确定性)
|
||||
crates/xtrain-cuda/
|
||||
├── src/ffi.rs # +launch_repeat_kv_fwd_f32 / _bwd_f32 声明(no_cuda 门控)
|
||||
└── build.rs # +repeat_kv.cu
|
||||
crates/xtrain-tensor/src/tensor.rs # +Tensor::repeat_kv / repeat_kv_backward([B*kvh,S,hd]→[B*nh,S,hd];bf16 upcast→f32→downcast)
|
||||
crates/xtrain-autodiff/
|
||||
├── src/ops.rs # +ops::repeat_kv 节点(fwd broadcast,bwd 组内求和)
|
||||
└── tests/autograd.rs # +repeat_kv grad-check(含 group>1 的多组梯度累加)
|
||||
crates/xtrain-model/
|
||||
├── src/config.rs # +num_kv_heads 字段(默认 = n_heads → MHA);from_arch 加形参;num_params 计 K/V 投影按 kv_dim
|
||||
├── src/model.rs # wk/wv 投影到 kv_dim;attention() 在 SDPA 前对 K/V 做 ops::repeat_kv;两条路径都吃到 GQA
|
||||
└── tests/gqa.rs # 新:GQA(group>1) flash==composed + group=1 与 MHA 逐位一致
|
||||
crates/xtrain-train/src/bin/train.rs # +--kv-heads flag
|
||||
crates/xtrain-distributed/src/bin/train_ddp.rs # +--kv-heads flag(DDP 路径)
|
||||
crates/xtrain-train/src/bin/export_safetensors.rs # +--kv-heads;config.json 写真 num_key_value_heads
|
||||
crates/xtrain-model/tests/parity{.py,_dump.rs} # PyTorch 对拍加 GQA(kv 投影 + repeat_kv)
|
||||
```
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### ① GQA = K/V 头维 broadcast op,喂给**未改动**的两条 SDPA(不写第三套 attention)
|
||||
|
||||
T14 已经有两条**逐位/数值都验证过**的 SDPA:composed(`ops::attention`)与 fused flash
|
||||
(`ops::flash_attention`),二者都吃 `[bh, S, hd]`(`bh = batch·heads`)。GQA 的本质只是「K/V
|
||||
比 Q 少 `group` 倍头,用前把每个 kv 头复制 `group` 份」。所以**最外科**的做法:
|
||||
|
||||
- wk/wv 投影到 `kv_dim = num_kv_heads · head_dim`,按 `[B, num_kv, S, hd] → [B·num_kv, S, hd]`
|
||||
排好(和 Q 的 `[B·nh, S, hd]` 同流水线,只是头数不同)。
|
||||
- 在调 SDPA 之前,对 K、V 各做一个新 autograd op `ops::repeat_kv`,把 `[B·num_kv, S, hd]`
|
||||
**broadcast** 成 `[B·nh, S, hd]`(输出行 `b·nh + qh` = 输入行 `b·num_kv + qh/group` 的拷贝)。
|
||||
- 之后 `ops::attention` / `ops::flash_attention` **一字不改**——它们看到的就是满头的
|
||||
`[B·nh, S, hd]`,GQA 对两条路径**同时、免费**生效。flash kernel / composed kernel 都不用碰。
|
||||
|
||||
**为什么不在 kernel 内做 GQA**:那要给 flash fwd/bwd 两个 kernel 各加 kv-head 索引、给 composed
|
||||
的两次 strided GEMM 各算 GQA stride,且两套都要重测——是「第三套 attention 改动」。而 broadcast-op
|
||||
方案:(a) 两条 SDPA 路径零改动、其 T14 闸门不回归;(b) repeat_kv 的 fwd/bwd 是独立可 grad-check
|
||||
的小 op,正确性风险隔离在一处;(c) 关键的「多组 q 头梯度汇到一个 kv 头」就是 repeat_kv 的**反向**,
|
||||
干净地落在一个 op 上单测。代价是 K/V 在显存里被物化成满头 `[B·nh,S,hd]`(多 `group` 倍)——本规模
|
||||
(训练、seq 不极端)可接受;真要省这份显存是 follow-up(kernel 内 GQA 读取),记进逃生舱不在 T15 做。
|
||||
|
||||
> 备选(不采纳):flash/composed kernel 内直接按 `kv_head = q_head/group` 索引 K/V。省 broadcast
|
||||
> 物化,但动两套已验证 kernel + 重写两套 backward 的 kv 累加,违反「不写第三套 attention」与回归保护。
|
||||
> escape hatch:先 broadcast-op 把正确性 + 闭环钉死,kernel-内 GQA(省显存)留 follow-up。
|
||||
|
||||
### ② repeat_kv 的反向 = 组内求和(确定性,无 atomic)
|
||||
|
||||
`repeat_kv` 前向:`out[b·nh + qh] = in[b·num_kv + qh/group]`(按 `S·hd` 整行拷贝)。
|
||||
|
||||
反向是它的**转置**:一个 kv 头收到它那 `group` 个 query 头的梯度之**和**:
|
||||
```
|
||||
din[b·num_kv + kvh] = Σ_{r=0}^{group-1} dout[b·nh + kvh·group + r]
|
||||
```
|
||||
这正是闸门要求的「多组 q 头梯度累加到一个 kv 头」。实现上**不用 atomicAdd**:每个输入
|
||||
(kv-head, 元素)由唯一一个 block 负责,它**串行累加自己那 group 个连续源行**——天然 race-free
|
||||
且**run-to-run 确定**(不像 flash bwd 的跨行 atomic 反向有归约序不确定问题)。`group=1` 时
|
||||
反向退化为单行拷贝(identity)。
|
||||
|
||||
autograd 层面其实也可以靠引擎的扇出 SUM(把一个 kv Var 喂给 group 个下游),但那样图里要
|
||||
显式建 group 份 view、且 flash/composed 的 batched 布局不是按头切的——做成一个专门的
|
||||
broadcast op,fwd/bwd 各一发 kernel,最简且能单独 grad-check。
|
||||
|
||||
### ③ `num_kv_heads` 进 Config(它改模型尺寸/导出),默认 = n_heads → 退化 MHA
|
||||
|
||||
不同于 T14 的 `use_flash`(运行时旗标,不进 Config),`num_kv_heads` **改 K/V 投影的形状、改参数量、
|
||||
改导出的 `num_key_value_heads`**——它是**架构**的一部分,必须进 `Config` 并落进 checkpoint/导出。
|
||||
|
||||
- `Config` 加 `num_kv_heads: usize`;`from_arch` 加该形参;`Config::tiny()` 默认 `num_kv_heads =
|
||||
n_heads`(MHA)。约束:`num_heads % num_kv_heads == 0`(断言)。
|
||||
- `num_params()`:K/V 投影从 `2·dim·dim` 改成 `2·dim·(num_kv_heads·head_dim)`;QK-norm 的
|
||||
`k_norm` 仍是 `[head_dim]`(per-head,作用在单个 head 向量上,与头数无关)→ 不变。
|
||||
- **`num_kv_heads == n_heads` 时 `group=1`**:`ops::repeat_kv` 是 identity(fwd 单行拷贝、bwd 单行
|
||||
拷贝),wk/wv 形状回到 `[dim,dim]` → 整条图与 T14 的 MHA 路径**逐位一致**(回归保护闸门)。
|
||||
|
||||
> wk/wv 形状从 `[dim,dim]` 变成 `[dim, kv_dim]`:`Block` 里 wk/wv 的 `mk(&[dim, kv_dim])`,
|
||||
> `params()`/`block_params()` 顺序不变(还是 attn_norm,wq,wk,wv,q_norm,k_norm,wo,...),只是
|
||||
> wk/wv 的 shape 跟着 Config。导出转置照旧按各自 shape 走(`transpose` 读 `v.value().shape()`)。
|
||||
|
||||
### ④ bf16 / recompute / dropout / DDP 全部自动兼容
|
||||
|
||||
- **bf16**:`Tensor::repeat_kv` 沿用全 repo 一致的 cast 策略——bf16 入则 upcast f32 → kernel →
|
||||
downcast;kernel 只一份 f32。`ops::repeat_kv` 的 fwd/bwd 都在 SDPA 之前/之后,dtype 与 K/V 流一致。
|
||||
- **recompute(T13)**:repeat_kv 在 `block_forward` 内、`attention()` 里,重算段重跑 `attention()`
|
||||
自然重跑 repeat_kv(无额外状态,确定性)→ 梯度仍逐位一致。
|
||||
- **dropout(T18)**:dropout 接在 attn/mlp 子块**输出**,与 attention 内部的 repeat_kv 正交,不交互。
|
||||
- **DDP**:repeat_kv 不引入新参数;wk/wv 变小(kv_dim)只是参数张量小一圈,`params()` 泛化迭代
|
||||
+ all-reduce 照旧;跨 rank 一致性不受影响。
|
||||
|
||||
### ⑤ 导出 xserv:写真 `num_key_value_heads`,分组约定对齐 repeat_kv
|
||||
|
||||
`export_safetensors.rs` 的 `config.json` 把 `num_key_value_heads` 从「= num_attention_heads」改成
|
||||
**真 `cfg.num_kv_heads`**;`--kv-heads` flag 传入(须与训练 ckpt 一致)。q/k/v_proj 各自按其 shape
|
||||
转置导出(k/v_proj 现在是 `[kv_dim, dim]`,xserv loader 期望的 GQA 形状)。xserv 的 `repeat_kv`
|
||||
用 `dst = kvh·group + r` 分组,与 ① 的 xtrain 约定**逐头对齐** → 同一份权重在两侧前向数学一致,
|
||||
闭环(贪心逐 token 一致)成立。
|
||||
|
||||
## 验证方法
|
||||
|
||||
全部 `#![cfg(not(no_cuda))]` 门控,本地 `cargo check`/`fmt`,构建+实跑在 dash5(8× RTX 5090)。
|
||||
|
||||
### 1. 正确性(硬闸门全绿,dash5 实跑 capture)
|
||||
|
||||
- **repeat_kv finite-diff grad-check**(`autograd.rs::repeat_kv_grad`):**核心闸门**——`group>1`
|
||||
(如 bh: 2 kv 头 → 6 q 头)下 grad-check `din`,验证「多组 q 头梯度求和到一个 kv 头」的反向。
|
||||
外加 `group=1` identity 自检。
|
||||
- **GQA flash==composed**(`gqa.rs`):真 GQA 配置(`num_kv_heads < n_heads`,如 8 头/2 kv 头)下,
|
||||
flash on/off 两个同 init 模型,断 forward logits / loss / **每参数梯度**一致(fp32 紧容差 + bf16
|
||||
舍入带)——尤其 wk/wv 的梯度(它们经过 repeat_kv 反向的组内求和)。
|
||||
- **group=1 与 MHA 逐位一致**(`gqa.rs`):`num_kv_heads = n_heads` 的模型对 T14 的 MHA 模型,
|
||||
forward + 每参数梯度 `|Δ|=0`(回归保护)。
|
||||
- **PyTorch GQA 对拍 B>1**(`parity_dump.rs` + `parity.py`):等价 PyTorch 模型加 GQA(k/v 投影到
|
||||
kv_dim + `repeat_interleave(group)` 分组,与 xserv/xtrain 约定一致),对拍 forward logits + 全部
|
||||
参数梯度(composed 与 flash 两条都跑,共用同一 oracle)。
|
||||
- **小 GQA 配置短训收敛**:一个真 GQA 小模型短训,loss 单调降、无 NaN、采样连贯。
|
||||
- **全回归套开/关**:autograd / structural / batched==looped / bf16 / recompute(逐位)/ overfit 27/27 /
|
||||
AdamW(GPU bit-exact + host 对 torch)/ DDP loss-match + 跨 rank(`--test-threads=1`)/ flash /
|
||||
grad_accum / dropout / **xserv 闭环 md5**。MHA 默认(kv=heads)图不变 → 不回归。
|
||||
|
||||
### 2. 闭环(payoff)—— 真 GQA 端到端
|
||||
|
||||
导出一个 `num_key_value_heads < num_attention_heads` 的 GQA checkpoint → xserv 加载 → 贪心生成
|
||||
**对 xtrain 自身逐 token 一致**(BF16 推理 vs f32 训练,与 v1–v8 同款判据)。这是 GQA 真正落地的证明:
|
||||
训练侧的分组、导出的分组、推理侧 xserv 的 repeat_kv 分组三方对齐。
|
||||
|
||||
## 实测结果(dash5 1× / 2× RTX 5090)
|
||||
|
||||
**硬闸门全绿:**
|
||||
|
||||
| 闸门 | 结果 |
|
||||
|---|---|
|
||||
| ① repeat_kv grad-check(**多组 q 头梯度求和到一个 kv 头**,group=3) | **过** — din max_rel **2.05e-4**;group=1 identity 双向**逐位**(fwd/bwd |Δ|=0) |
|
||||
| GQA flash==composed(model 级 8h/2kv,logits/loss/每参数梯度) | fp32: loss rel **0.0**、logits 3.0e-4、grad **4.1e-5**;bf16: loss 9.0e-5、logits mean 2.9e-3/p99 1.0e-2、grad scaled-mean 8.9e-3 |
|
||||
| group=1 对 MHA**逐位一致**(回归保护) | **过** — logits + loss + 全部梯度 |Δ|=0 |
|
||||
| ② PyTorch GQA 对拍 B>1(composed & flash,repeat_interleave 分组对齐) | composed: loss **1.74e-8**/logits 2.04e-5/25 grad 进 rtol;flash: loss 1.74e-8/logits 2.28e-5/25 grad 进 rtol |
|
||||
| ③ 小 GQA 配置短训收敛(8h/2kv/hd32/4L/ffn1024,600 步) | train **10.90→3.15** 无 NaN、gnorm 稳 ~1.2、采样连贯英文(~200K tok/s) |
|
||||
| ④ **xserv 闭环真 GQA**(导出 `num_key_value_heads=2 < num_attention_heads=8`,xserv 加载 `heads=8/2 kv`,贪心) | "One day"/"The little" 两 prompt **逐 token 一致**;"Once upon a time" 在 `...Lily's mommy ` 处 BF16 漂移晚分叉(said vs came)——与 v1/v2/v3/T14 同款判据 |
|
||||
| ⑤ 回归套:autograd 23(含 repeat_kv 2)/ structural 5 / batched / bf16 / flash 2 / **gqa 4** / overfit 27/27 / recompute 2 / dropout 6 / grad_accum 3 / checkpoint-roundtrip / AdamW(host 对 torch 4.8e-6) / DDP 3(`--test-threads=1`, loss 5.67e-7+跨 rank 一致) / GEMM / tensor | **全绿** |
|
||||
| ⑤ MHA 默认 export md5(v3 ckpt 用 T15 代码重导 safetensors) | **逐位一致** `b04fc9f9a0c9af04c47d9ca649aea12e`(与 registry/T14 同)→ 默认(kv=heads)export 零漂移 |
|
||||
|
||||
> **诚实记录**:闭环 2/3 prompt 完全 token-identical、1/3 在 BF16 漂移点晚分叉——这恰证明 GQA 分组**正确**:若 kv→q 头映射错,attention 会从第一个生成 token 起就崩(不会是深处近-tie 的晚分叉)。GQA 把 K/V 在显存里物化成满头 `[B·nh,S,hd]`(broadcast-op 方案的代价)——本规模可接受,kernel-内 GQA(省这份显存)留 follow-up。未为凑绿放宽任何容差。
|
||||
@@ -25,6 +25,7 @@
|
||||
| T12 | 算法/Infra | **bf16 混合精度**(fp32 master,cuBLAS GemmEx,norm/softmax/CE 保 fp32) | dim768 OOM 解除,−29% 显存/+13% tok/s(修 KI-2) |
|
||||
| T13 | 算法/Infra | **激活重计算**(per-block gradient checkpointing:前向 no-tape + 反向重算,`backward_seeded`) | 梯度对非重计算版**逐位一致**(0.00);dim768 31.1→14.6GB;**dim1024 batch32 OOM→16.6GB 装下**(修 KI-3,解锁 v8) |
|
||||
| T14 | 算法/Infra | **融合 flash-attention kernel**(手写单 kernel:online softmax、tiled over KV、**不物化 N×N scores**;flash 式 bwd:重算 scores + `D=ΣdO·O` 化简雅可比 + dQ/dK/dV);opt-in `--flash`,默认保 composed(Phase 2) | fwd 对 composed 6.7e-5、bwd 对 composed dQ 1.7e-5、PyTorch B>1 7.9e-6、flash==composed loss rel 0.0;**峰值显存 −16%@seq1024 / −23%@seq2048**(不物化 N×N,收益随 seq 增长);tok/s ~2.3–2.8× 慢(hd=64 小头维干不过 cuBLAS tensor-core,flash 已知权衡=胜场在显存);md5 闭环逐位一致 |
|
||||
| T15 | 模型架构 | **真 GQA**(`num_kv_heads<num_heads`:wk/wv 投影到 `kv_dim`,新 `repeat_kv` broadcast 算子把 K/V 复制 `group=nh/num_kv` 份喂给**未改动**的 composed/flash 两条 SDPA;分组约定对齐 xserv repeat_kv `dst=kvh·group+r`);`repeat_kv` 反向=组内 group 行**确定性求和**(无 atomic)→ 多组 q 头梯度汇一个 kv 头;`num_kv_heads` 进 Config(默认=nh→MHA)、`--kv-heads` flag、导出写真 `num_key_value_heads`(Phase 2) | repeat_kv grad-check 2.1e-4(group3)+group1 identity 逐位;GQA flash==composed fp32 grad 4.1e-5/bf16 在带;**group1 对 MHA 逐位一致**(回归保护);PyTorch GQA B>1 对拍 composed/flash 各 loss 1.7e-8/logits 2.3e-5/25 grad 进 rtol;小 GQA(8h/2kv) 训 600 步 10.9→3.15 连贯;**xserv 闭环真 GQA**(num_kv 2<8):2/3 prompt token-identical、1 在 BF16 漂移处晚分叉;MHA 默认 export md5 逐位一致(b04fc9f9) |
|
||||
| T16 | 算法/Infra | **梯度累积**(N 个 micro-step:每个 micro-loss `×1/N` 再 backward,tape SUM 累加 → 一次 AdamW step+zero;`--accum-steps`);**DDP 只在累积边界 all-reduce**(中间 micro-step 不发 NCCL,`/world` 与 `1/N` 正交);显存随 micro 不随有效 batch | 等效大 batch**逐位贴合**(loss rel 8.5e-8、grad rel 3.8e-5);`accum=1` 逐位回归(0.00);DDP+accum 对单卡 loss 5.7e-7/跨 rank 一致;**显存平**:同有效 batch 64,big-batch 27.7GB→accum(4×16) **7.2GB(−74%)**(big-batch OOM 而 accum 装下);全回归+xserv 闭环 md5 一致 |
|
||||
| T18 | 算法 | **dropout**(手写 counter-based 设备 RNG → Bernoulli mask,训练 inverted 1/(1-p) scaling、eval 恒等);新 autodiff `dropout` 算子(fwd 生成+施加 mask,bwd 用同 mask),接 residual/ffn 两处;`--dropout` flag 默认 0 | 固定 seed grad-check 过;E[out]≈input + keep≈1-p;**p=0 与无 dropout 逐位一致**;recompute(T13) 组合下梯度仍逐位一致(counter-based seed 重算复现同 mask);全回归 + xserv 闭环绿(导出/推理 dropout 关) |
|
||||
|
||||
@@ -53,7 +54,7 @@
|
||||
## 三、各维度的累积演进(轴向看一条线怎么走的)
|
||||
|
||||
- **算法**:手写 autograd(tape)+扇出累加 → AdamW/LR-sched/grad-clip → +QK-norm(Qwen3) → batched forward → bf16 混合精度(fp32 master) → 激活重计算(T13) → 融合 flash-attention(T14,online softmax + flash 式 bwd) → 梯度累积(T16,复用 tape SUM,等效大 batch 而显存随 micro) → dropout(T18,counter-based 设备 RNG + inverted scaling,train/eval 切换)。
|
||||
- **模型架构**:固定 Qwen3-style;dim **32→256→384→512→768→1024**(v8 首拨容量轴,头数 24→32);核心参数 **41K→226M**(总 3.26M→329M)。
|
||||
- **模型架构**:固定 Qwen3-style;dim **32→256→384→512→768→1024**(v8 首拨容量轴,头数 24→32);核心参数 **41K→226M**(总 3.26M→329M)。+QK-norm(T9,Qwen3 兼容) → **真 GQA(T15,`num_kv_heads<num_heads`,repeat_kv broadcast + 组内梯度求和;默认=nh→MHA 逐位回归)**——架构补齐到现代 LLM 标配(MHA/GQA/MQA 一条 `num_kv_heads` 轴),两条 SDPA(composed/flash) 共用同一 broadcast,导出真 `num_key_value_heads` 且 xserv 闭环。
|
||||
- **Infra**:单卡 fp32 → cuBLAS/GPU-optim(T7) → NCCL DDP(T8) → batched forward(T10) → caching allocator(T11) → bf16(T12) → 激活重计算(T13,解锁 dim1024) → flash-attention(T14,不物化 N×N,attention 显存收益随 seq 增长) → 梯度累积(T16,DDP 只在累积边界通信,显存随 micro 不随有效 batch)。吞吐 **3.3K→217K tok/s**(dim768 bf16),dim1024+重算 ~129K(重算税);MFU **0.4%→17%**(每次提升都对应一块 perf 基建,详见 known-issues + MFU 分析)。T13/T14/T16 是三条**显存杠杆**(重计算压激活峰值、flash 不物化 N×N attention scores、梯度累积解耦有效 batch 与激活显存),可叠加放大有效 batch。
|
||||
- **数据集**:TinyStories 3MB 切片 → 全量 TinyStories(epoch 0.01→5.33,**至饱和**)→ **v6 毕业到 FineWeb-edu 真实网页**(2.255B 语料,1.02ep)→ **v7 同子集多 epoch(1.45ep,近顶)→ v8 同子集换大模型**(dim1024,1.05ep)。tokenizer 全程 gpt2 BPE(复用 xserv-tokenizer;v6 刻意不换 tokenizer 以隔离「数据来源」变量,KI-4 留后续版本)。
|
||||
- **v5→v6 数据轴的质变**:v0–v5 都吃合成幼儿故事(TinyStories,低熵、词汇受控),v5 证明同尺寸模型在它上面已饱和;v6 第一版换成**真实教育类网页文本**(FineWeb-edu),语言种类发生质变——采样从「只会写小故事」变成「能写历史/科学/说明文」。
|
||||
|
||||
Reference in New Issue
Block a user