Files
xtrain/csrc/ops
Gahow Wang 7821bd9c34 autograd: batch dim for ops (flatten linears, batched attention)
Add the batched-forward primitives. Linears/norms/elementwise/embedding/CE
already act on flat [rows,dim], so they work unchanged on [B*S,dim]; only
attention + RoPE need sequence awareness:

- RoPE: kernel takes a `period` (= seq len) so position = row % period, i.e.
  per-sequence position on a flattened batch (period == tokens = single seq).
- Fused batched causal attention: new `Tensor::attention`/`attention_backward`
  + ops node, running QKᵀ and PV as cublasSgemmStridedBatched over the B*nh
  (sequence,head) blocks (new sgemm_strided_batched binding) and a causal
  softmax kernel (scale + per-row causal mask inline) — the whole attention is
  3 launches regardless of B*nh, no per-head/per-seq loop, no host round-trip.
- transpose_4d12 ([B,S,nh,hd] <-> [B,nh,S,hd]) to lay out the batched heads.

grad-checks: new batched-rope, transpose_4d12, batched-attention dQ/dK/dV all
pass finite-diff (attn dK 1.5e-2, dQ 7.5e-3, dV 2.9e-4; rest tighter) alongside
the existing 12.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 00:44:15 +08:00
..
2026-06-15 16:53:09 +08:00