Commit Graph

4 Commits

Author SHA1 Message Date
5f3b81ac96 test+bins: flash grad-check, flash==composed, PyTorch parity, --flash flag
autograd: flash_attention_batched_bwd (dQ/dK/dV finite-diff, seq>tile)
+ flash_matches_composed_fwd. model/tests/flash.rs: flash==composed
on-vs-off (logits/loss/every param grad), fp32 + bf16. parity_dump:
XTRAIN_PARITY_FLASH dumps the flash path for the same parity.py oracle
(PyTorch SDPA parity at B>1). train + train_ddp get the --flash flag.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 23:10:39 +08:00
5353b38402 model: batched forward [B,S]
forward_batched(ids[B*S], batch)/loss_batched: run B equal-length sequences as
ONE forward over flattened [B*S] ids, so every linear is one big [B*S,dim] GEMM.
Attention reshapes to [B*nh,S,hd], runs the fused batched causal SDPA (per-seq
mask + RoPE period=S, no cross-sequence attention), writes back [B*S,dim]. The
old per-(batch,head) loop + host-round-tripping split/merge_heads + the additive
causal_mask leaf are gone. forward(ids[seq]) is now forward_batched(ids,1), so
the sampler / inference path (batch=1) is unchanged.

+batched_ids_tensor helper. New batched.rs test: batched forward == looped
single-sequence (logits identical 0.0, grads 6.4e-4, loss identical). PyTorch
parity now exercises B>1 (B=2,S=4): loss 5e-8, logits 6.9e-6, all 25 param
grads within rtol — verifying per-seq RoPE position + per-seq causal masking.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 00:44:25 +08:00
7a4f69e430 model: add per-head QK-norm (Qwen3-compat) for xserv export
xserv's Qwen3 forward unconditionally applies per-head RMSNorm to Q and K
(q_norm/k_norm, shape [head_dim]) before RoPE — even gamma=1 is a real RMS
divide, not identity. xtrain never had this, so an exact xserv<->xtrain loop
was structurally impossible. Add it (reusing the 2D rms_norm op on the
[seq*nh, hd] head rows, inserted between reshape and rope to mirror
qwen3.rs's order) so the trained model is genuinely Qwen3-compatible.

params() inserts q_norm,k_norm after wv; num_params() counts them; the
PyTorch parity refs (parity.py / adamw_parity.py) + their name lists add the
same step so the dumps stay self-consistent.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 17:33:19 +08:00
3366f30c4d model: PyTorch parity harness (weight dump + equivalent torch model)
parity_dump.rs (#[ignore] fixture generator) dumps the model's exact
weights, ids, forward logits, loss, and per-param grads after one
backward. parity.py rebuilds the IDENTICAL model in PyTorch (same x@W
convention, RoPE rotate_half pos=row, RMSNorm, SwiGLU, causal SDPA),
runs fwd+bwd, and compares logits + every grad within rtol.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 16:07:30 +08:00