Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> # Conflicts: # README.md # crates/xtrain-autodiff/tests/autograd.rs # crates/xtrain-model/src/model.rs # crates/xtrain-train/src/bin/train.rs # crates/xtrain-train/src/train_loop.rs # docs/evolution.md
519 lines
22 KiB
Rust
519 lines
22 KiB
Rust
//! The tiny transformer forward graph + parameter container (Phase T5).
|
||
|
||
#![cfg(not(no_cuda))]
|
||
|
||
use std::cell::Cell;
|
||
|
||
use crate::config::Config;
|
||
use xtrain_autodiff::ops;
|
||
use xtrain_autodiff::tape::Var;
|
||
use xtrain_tensor::{DType, Device, Tensor};
|
||
|
||
/// One decoder block's learnable tensors.
|
||
struct Block {
|
||
attn_norm: Var, // [dim]
|
||
wq: Var, // [dim, dim]
|
||
wk: Var, // [dim, dim]
|
||
wv: Var, // [dim, dim]
|
||
q_norm: Var, // [head_dim] — per-head QK-norm (Qwen3-style)
|
||
k_norm: Var, // [head_dim]
|
||
wo: Var, // [dim, dim]
|
||
ffn_norm: Var, // [dim]
|
||
w_gate: Var, // [dim, ffn_hidden]
|
||
w_up: Var, // [dim, ffn_hidden]
|
||
w_down: Var, // [ffn_hidden, dim]
|
||
}
|
||
|
||
/// A tiny RoPE+RMSNorm+SwiGLU decoder. Holds every parameter as a leaf [`Var`];
|
||
/// `forward` builds an autograd graph over them.
|
||
pub struct TinyTransformer {
|
||
cfg: Config,
|
||
embed: Var, // [vocab, dim]
|
||
blocks: Vec<Block>,
|
||
final_norm: Var, // [dim]
|
||
lm_head: Var, // [dim, vocab]
|
||
/// Compute dtype for the forward graph (Phase T12). `F32` (default) = the
|
||
/// original path, bit-identical to T10/T11. `BF16` = mixed precision: the
|
||
/// parameter leaves stay fp32 (master), but each linear's weight is cast to
|
||
/// bf16 on the fly and the activation stream flows bf16 (see
|
||
/// `docs/11-bf16-mixed-precision.md`). The cast op's backward upcasts the bf16
|
||
/// weight grad back to fp32, so AdamW/clip/DDP stay fp32 and unchanged.
|
||
compute_dtype: DType,
|
||
/// Activation recomputation / gradient checkpointing (Phase T13, KI-3). When
|
||
/// `true`, each transformer block's forward runs through
|
||
/// [`xtrain_autodiff::checkpoint`]: the block's internal activations are NOT
|
||
/// kept on the tape during forward (only the block input is), and the block
|
||
/// forward is re-run during backward to recover them. Trades ~one extra forward
|
||
/// per block for a large drop in peak activation memory → lets dim1024 batch32
|
||
/// fit. Default `false` = the unchanged path (every activation stored), so
|
||
/// existing numerics are bit-identical; recompute is mathematically exact, so
|
||
/// grads match the non-checkpointed path within fp tolerance.
|
||
recompute: bool,
|
||
/// Fused flash-attention (Phase T14). When `true`, the SDPA core runs through
|
||
/// the hand-written single fused kernel ([`ops::flash_attention`]): online
|
||
/// softmax over KV tiles, the `[bh,seq,seq]` score matrix NEVER materialized,
|
||
/// backward caches only the O(N) logsumexp. Default `false` = the composed T10
|
||
/// path (`cublasSgemmStridedBatched`×2 + causal-softmax kernel, O(N²) probs),
|
||
/// so the default graph is unchanged. Mathematically the same SDPA → grads/loss
|
||
/// match the composed path within fp/bf16 tolerance. Opt-in via `--flash`.
|
||
use_flash: bool,
|
||
/// Training mode for dropout (Phase T18). `true` → the attn/MLP sub-block
|
||
/// outputs pass through `ops::dropout` (with `cfg.dropout` and a per-step,
|
||
/// per-site seed); `false` (default) → dropout is identity (eval/sampling/
|
||
/// export). `Cell` so `train()`/`eval()` flip it through `&self` (the forward
|
||
/// takes `&self`). When `cfg.dropout == 0` this flag is irrelevant — the graph
|
||
/// is bit-identical to the no-dropout path either way.
|
||
training: Cell<bool>,
|
||
/// Per-step dropout RNG seed (Phase T18). Bumped once at the start of each
|
||
/// TRAINING forward so every step draws fresh masks; combined with the layer
|
||
/// index + a per-site constant to give each dropout site its own seed. The RNG
|
||
/// is counter-based, so re-running a checkpointed block's forward in backward
|
||
/// (T13) reproduces the same seed → the same mask (recompute stays exact).
|
||
step_seed: Cell<u64>,
|
||
}
|
||
|
||
impl TinyTransformer {
|
||
/// Build a model with parameters initialised from `init(shape) -> host data`.
|
||
/// The caller controls initialisation (deterministic for tests / PyTorch
|
||
/// parity). `init` receives the logical shape and returns row-major data.
|
||
pub fn new(cfg: Config, device: Device, mut init: impl FnMut(&[usize]) -> Vec<f32>) -> Self {
|
||
let leaf = |data: Vec<f32>, shape: &[usize]| -> Var {
|
||
Var::leaf(Tensor::from_slice(&data, shape).to_device(device))
|
||
};
|
||
let mut mk = |shape: &[usize]| -> Var {
|
||
let data = init(shape);
|
||
assert_eq!(data.len(), shape.iter().product::<usize>(), "init size");
|
||
leaf(data, shape)
|
||
};
|
||
|
||
let embed = mk(&[cfg.vocab, cfg.dim]);
|
||
let blocks = (0..cfg.n_layers)
|
||
.map(|_| Block {
|
||
attn_norm: mk(&[cfg.dim]),
|
||
wq: mk(&[cfg.dim, cfg.dim]),
|
||
wk: mk(&[cfg.dim, cfg.dim]),
|
||
wv: mk(&[cfg.dim, cfg.dim]),
|
||
q_norm: mk(&[cfg.head_dim]),
|
||
k_norm: mk(&[cfg.head_dim]),
|
||
wo: mk(&[cfg.dim, cfg.dim]),
|
||
ffn_norm: mk(&[cfg.dim]),
|
||
w_gate: mk(&[cfg.dim, cfg.ffn_hidden]),
|
||
w_up: mk(&[cfg.dim, cfg.ffn_hidden]),
|
||
w_down: mk(&[cfg.ffn_hidden, cfg.dim]),
|
||
})
|
||
.collect();
|
||
let final_norm = mk(&[cfg.dim]);
|
||
let lm_head = mk(&[cfg.dim, cfg.vocab]);
|
||
|
||
Self {
|
||
cfg,
|
||
embed,
|
||
blocks,
|
||
final_norm,
|
||
lm_head,
|
||
compute_dtype: DType::F32,
|
||
recompute: false,
|
||
use_flash: false,
|
||
training: Cell::new(false),
|
||
step_seed: Cell::new(0),
|
||
}
|
||
}
|
||
|
||
pub fn config(&self) -> &Config {
|
||
&self.cfg
|
||
}
|
||
|
||
/// Set the forward compute dtype (Phase T12). `BF16` enables mixed precision
|
||
/// (fp32 master weights, bf16 linears + activations); `F32` (the default) is
|
||
/// the unchanged full-precision path. Builder-style so existing call sites
|
||
/// that don't opt in keep the fp32 numerics bit-for-bit.
|
||
pub fn with_compute_dtype(mut self, dtype: DType) -> Self {
|
||
assert!(
|
||
matches!(dtype, DType::F32 | DType::BF16),
|
||
"compute_dtype must be F32 or BF16"
|
||
);
|
||
self.compute_dtype = dtype;
|
||
self
|
||
}
|
||
|
||
pub fn compute_dtype(&self) -> DType {
|
||
self.compute_dtype
|
||
}
|
||
|
||
/// Enable per-block activation recomputation / gradient checkpointing (Phase
|
||
/// T13). Builder-style and opt-in; default off keeps the unchanged tape (every
|
||
/// activation stored). On, each block's forward is wrapped in
|
||
/// [`xtrain_autodiff::checkpoint`] — exact grads, lower peak activation memory.
|
||
pub fn with_recompute(mut self, recompute: bool) -> Self {
|
||
self.recompute = recompute;
|
||
self
|
||
}
|
||
|
||
pub fn recompute(&self) -> bool {
|
||
self.recompute
|
||
}
|
||
|
||
/// Enable the fused flash-attention SDPA core (Phase T14). Builder-style and
|
||
/// opt-in; default off keeps the composed T10 path (so the default graph is
|
||
/// unchanged). On, the SDPA runs through [`ops::flash_attention`] — same SDPA
|
||
/// math, online softmax, no materialized `[bh,seq,seq]` scores.
|
||
pub fn with_flash(mut self, use_flash: bool) -> Self {
|
||
self.use_flash = use_flash;
|
||
self
|
||
}
|
||
|
||
pub fn use_flash(&self) -> bool {
|
||
self.use_flash
|
||
}
|
||
|
||
/// Switch to training mode (Phase T18): dropout (if `cfg.dropout > 0`) is
|
||
/// active in subsequent forwards. The training loop calls this before stepping.
|
||
pub fn train(&self) {
|
||
self.training.set(true);
|
||
}
|
||
|
||
/// Switch to eval mode (Phase T18): dropout is identity. Held-out eval,
|
||
/// autoregressive sampling, and weight export all run in this mode (default).
|
||
pub fn eval(&self) {
|
||
self.training.set(false);
|
||
}
|
||
|
||
pub fn is_training(&self) -> bool {
|
||
self.training.get()
|
||
}
|
||
|
||
/// Builder-style train/eval toggle (Phase T18) — handy for tests that want a
|
||
/// model fixed in one mode. Equivalent to [`train`](Self::train) /
|
||
/// [`eval`](Self::eval) but chains off `new(..)`.
|
||
pub fn with_training(self, training: bool) -> Self {
|
||
self.training.set(training);
|
||
self
|
||
}
|
||
|
||
/// All learnable parameters, in a stable order. The optimizer (a hand-written
|
||
/// GD step in T5, AdamW in T6) iterates this; each holds its `.grad()` after
|
||
/// `backward()`.
|
||
pub fn params(&self) -> Vec<Var> {
|
||
let mut ps = vec![self.embed.clone()];
|
||
for b in &self.blocks {
|
||
ps.extend([
|
||
b.attn_norm.clone(),
|
||
b.wq.clone(),
|
||
b.wk.clone(),
|
||
b.wv.clone(),
|
||
b.q_norm.clone(),
|
||
b.k_norm.clone(),
|
||
b.wo.clone(),
|
||
b.ffn_norm.clone(),
|
||
b.w_gate.clone(),
|
||
b.w_up.clone(),
|
||
b.w_down.clone(),
|
||
]);
|
||
}
|
||
ps.push(self.final_norm.clone());
|
||
ps.push(self.lm_head.clone());
|
||
ps
|
||
}
|
||
|
||
/// Forward over a single sequence of token `ids` (`[seq]` I32 on this
|
||
/// model's device). Returns the logits [`Var`] of shape `[seq, vocab]`. This
|
||
/// is the batch-1 special case of [`forward_batched`](Self::forward_batched)
|
||
/// (used by the autoregressive sampler / inference path).
|
||
pub fn forward(&self, ids: &Tensor) -> Var {
|
||
self.forward_batched(ids, 1)
|
||
}
|
||
|
||
/// Batched forward over `batch` sequences of equal length `seq`, flattened to
|
||
/// `[batch*seq]` I32 ids in sequence-major order (sequence 0's `seq` tokens,
|
||
/// then sequence 1's, …). Returns logits `[batch*seq, vocab]` in the SAME flat
|
||
/// layout. The whole graph runs on the flattened tokens so every linear
|
||
/// projection is ONE big `[batch*seq, dim] × [dim, out]` GEMM (the
|
||
/// GPU-filling win); only attention is sequence-aware (per-sequence causal
|
||
/// mask + RoPE position, NO cross-sequence attention).
|
||
pub fn forward_batched(&self, ids: &Tensor, batch: usize) -> Var {
|
||
let total = ids.shape()[0];
|
||
assert_eq!(
|
||
total % batch,
|
||
0,
|
||
"ids len {total} not divisible by batch {batch}"
|
||
);
|
||
let seq = total / batch;
|
||
|
||
// Dropout (T18) is active only in training mode with p>0; otherwise it is
|
||
// identity (`ops::dropout` no-ops at p==0). Bump the per-step seed ONCE per
|
||
// training forward so each step draws fresh masks (counter-based RNG, so a
|
||
// checkpointed block's recompute reproduces the same seed → same mask).
|
||
let dropout_p = if self.training.get() {
|
||
self.cfg.dropout
|
||
} else {
|
||
0.0
|
||
};
|
||
if dropout_p > 0.0 {
|
||
self.step_seed.set(self.step_seed.get().wrapping_add(1));
|
||
}
|
||
let base_seed = self.step_seed.get();
|
||
|
||
// Embedding gathers from the fp32 master table; in bf16 mode cast the
|
||
// activation stream to bf16 here (norms are cast to bf16 gammas too).
|
||
let mut h = ops::embedding(&self.embed, ids); // [batch*seq, dim], fp32
|
||
if self.compute_dtype == DType::BF16 {
|
||
h = ops::cast(&h, DType::BF16);
|
||
}
|
||
for (li, b) in self.blocks.iter().enumerate() {
|
||
// Per-layer dropout seed: a deterministic function of (base_seed,
|
||
// layer index) — NOT a mutable counter — so the checkpoint recompute
|
||
// (which re-derives it from the captured base_seed/li) gets the same
|
||
// masks. The block derives its two per-site seeds from this.
|
||
let block_seed = base_seed
|
||
.wrapping_mul(0x100000001B3)
|
||
.wrapping_add(li as u64);
|
||
h = if self.recompute {
|
||
// Activation recomputation (T13): run the whole block forward inside
|
||
// `checkpoint` so its internal activations aren't kept on the tape;
|
||
// the block forward is re-run in backward to recover the grads. The
|
||
// segment fn captures only `Copy` config (no borrow of `self`) and
|
||
// receives the block's params via the slice, in `block_params` order.
|
||
// `flash` is captured too → the recompute segment also runs flash;
|
||
// `dropout_p`/`block_seed` are captured so the recompute re-derives
|
||
// the same per-site dropout masks (counter-based RNG, exact).
|
||
let (cfg, cdt, flash) = (self.cfg, self.compute_dtype, self.use_flash);
|
||
let seg = move |x: &Var, p: &[Var]| {
|
||
block_forward(cfg, cdt, flash, batch, seq, dropout_p, block_seed, x, p)
|
||
};
|
||
xtrain_autodiff::checkpoint::checkpoint(seg, &h, &b.block_params())
|
||
} else {
|
||
block_forward(
|
||
self.cfg,
|
||
self.compute_dtype,
|
||
self.use_flash,
|
||
batch,
|
||
seq,
|
||
dropout_p,
|
||
block_seed,
|
||
&h,
|
||
&b.block_params(),
|
||
)
|
||
};
|
||
}
|
||
|
||
let h = ops::rms_norm(
|
||
&h,
|
||
&norm_gamma(self.compute_dtype, &self.final_norm),
|
||
self.cfg.eps,
|
||
);
|
||
// lm_head matmul in compute dtype. Logits stay bf16 in bf16 mode — the
|
||
// cross_entropy op upcasts to fp32 internally (no persistent fp32 logits
|
||
// buffer, a real saving at vocab 50257), and its backward casts dx back.
|
||
linear(self.compute_dtype, &h, &self.lm_head) // [batch*seq, vocab]
|
||
}
|
||
|
||
/// Cross-entropy mean loss of `forward(ids)` against `targets` (`[seq]` I32).
|
||
pub fn loss(&self, ids: &Tensor, targets: &Tensor) -> Var {
|
||
let logits = self.forward(ids);
|
||
ops::cross_entropy(&logits, targets)
|
||
}
|
||
|
||
/// Batched cross-entropy mean loss: `forward_batched(ids, batch)` against
|
||
/// flat `targets` (`[batch*seq]` I32, same sequence-major layout). The CE mean
|
||
/// is over all `batch*seq` rows — identical to averaging the per-sequence
|
||
/// losses, so the loss value matches the looped single-sequence path.
|
||
pub fn loss_batched(&self, ids: &Tensor, targets: &Tensor, batch: usize) -> Var {
|
||
let logits = self.forward_batched(ids, batch);
|
||
ops::cross_entropy(&logits, targets)
|
||
}
|
||
}
|
||
|
||
impl Block {
|
||
/// The block's learnable leaves, in the fixed order the segment forward
|
||
/// (`block_forward`) indexes them — matches the per-block slice in
|
||
/// [`TinyTransformer::params`]. This is the param order `checkpoint` passes to
|
||
/// the recompute closure.
|
||
fn block_params(&self) -> Vec<Var> {
|
||
vec![
|
||
self.attn_norm.clone(),
|
||
self.wq.clone(),
|
||
self.wk.clone(),
|
||
self.wv.clone(),
|
||
self.q_norm.clone(),
|
||
self.k_norm.clone(),
|
||
self.wo.clone(),
|
||
self.ffn_norm.clone(),
|
||
self.w_gate.clone(),
|
||
self.w_up.clone(),
|
||
self.w_down.clone(),
|
||
]
|
||
}
|
||
}
|
||
|
||
/// Project `x` (activation, in the compute dtype) by weight `w` (an fp32 master
|
||
/// leaf). In bf16 mode the weight is cast to bf16 via the autograd `cast` op (whose
|
||
/// backward upcasts the grad to fp32); in fp32 mode this is just `matmul(x, w)`.
|
||
fn linear(cdt: DType, x: &Var, w: &Var) -> Var {
|
||
match cdt {
|
||
DType::F32 => ops::matmul(x, w),
|
||
DType::BF16 => ops::matmul(x, &ops::cast(w, DType::BF16)),
|
||
_ => unreachable!(),
|
||
}
|
||
}
|
||
|
||
/// A norm/QK-norm gamma in the compute dtype. fp32 master leaf → bf16 (cast op,
|
||
/// grad upcast) in bf16 mode; identity in fp32 mode.
|
||
fn norm_gamma(cdt: DType, gamma: &Var) -> Var {
|
||
match cdt {
|
||
DType::F32 => gamma.clone(),
|
||
DType::BF16 => ops::cast(gamma, DType::BF16),
|
||
_ => unreachable!(),
|
||
}
|
||
}
|
||
|
||
/// One transformer block's forward: pre-norm + multi-head causal attention +
|
||
/// (T18) dropout + residual, then pre-norm + SwiGLU MLP + dropout + residual.
|
||
/// Attention runs the composed or fused-flash (T14) SDPA per `flash`. Pure in
|
||
/// `(cfg, cdt, flash, batch, seq, dropout_p, block_seed, input, params)` (no
|
||
/// `&self`, all `Copy`) so it can be the segment fn of
|
||
/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13) — the
|
||
/// recompute re-derives the same per-site seeds, so the dropout masks are
|
||
/// reproduced bit-for-bit. `dropout_p == 0` makes `ops::dropout` a no-op (the
|
||
/// graph is then identical to the pre-T18 path). `params` is the block's leaves in
|
||
/// [`Block::block_params`] order.
|
||
#[allow(clippy::too_many_arguments)]
|
||
fn block_forward(
|
||
cfg: Config,
|
||
cdt: DType,
|
||
flash: bool,
|
||
batch: usize,
|
||
seq: usize,
|
||
dropout_p: f32,
|
||
block_seed: u64,
|
||
h: &Var,
|
||
p: &[Var],
|
||
) -> Var {
|
||
let (attn_norm, wq, wk, wv) = (&p[0], &p[1], &p[2], &p[3]);
|
||
let (q_norm, k_norm, wo) = (&p[4], &p[5], &p[6]);
|
||
let (ffn_norm, w_gate, w_up, w_down) = (&p[7], &p[8], &p[9], &p[10]);
|
||
|
||
// Per-site dropout seeds (XOR a site constant into the block seed) so the two
|
||
// residual-path dropouts draw independent masks within the same step/layer.
|
||
let attn_seed = block_seed ^ 0x0A7700;
|
||
let ffn_seed = block_seed ^ 0x0FF700;
|
||
|
||
// --- Attention sub-block (pre-norm + dropout + residual) ---
|
||
let normed = ops::rms_norm(h, &norm_gamma(cdt, attn_norm), cfg.eps);
|
||
let attn = attention(
|
||
cfg, cdt, flash, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo,
|
||
);
|
||
let attn = ops::dropout(&attn, dropout_p, attn_seed);
|
||
let h = ops::add(h, &attn);
|
||
|
||
// --- MLP sub-block (pre-norm + dropout + residual) ---
|
||
let normed = ops::rms_norm(&h, &norm_gamma(cdt, ffn_norm), cfg.eps);
|
||
let mlp = swiglu_mlp(cdt, &normed, w_gate, w_up, w_down);
|
||
let mlp = ops::dropout(&mlp, dropout_p, ffn_seed);
|
||
ops::add(&h, &mlp)
|
||
}
|
||
|
||
/// Multi-head causal self-attention over a flattened batch. `x`:[batch*seq,dim]
|
||
/// (already normed), laid out sequence-major. The Q/K/V/O projections are big
|
||
/// `[batch*seq, dim]` GEMMs; the scaled-dot-product attention itself runs as a
|
||
/// fused BATCHED op over the `batch·n_heads` (sequence,head) blocks — each attends
|
||
/// within its own `[seq,seq]` causal window (NO cross-sequence attention), with
|
||
/// RoPE positions reset per sequence (`period = seq`). Causal masking is applied
|
||
/// inside the fused op's softmax kernel (no additive `[seq,seq]` mask tensor).
|
||
#[allow(clippy::too_many_arguments)]
|
||
fn attention(
|
||
cfg: Config,
|
||
cdt: DType,
|
||
flash: bool,
|
||
batch: usize,
|
||
seq: usize,
|
||
x: &Var,
|
||
wq: &Var,
|
||
wk: &Var,
|
||
wv: &Var,
|
||
q_norm: &Var,
|
||
k_norm: &Var,
|
||
wo: &Var,
|
||
) -> Var {
|
||
let (nh, hd) = (cfg.n_heads, cfg.head_dim);
|
||
let total = batch * seq;
|
||
let bh = batch * nh;
|
||
let scale = 1.0 / (hd as f32).sqrt();
|
||
|
||
// Project, qk-norm + RoPE, then lay out as a batched [B*nh, seq, hd] tensor.
|
||
// [B*S,dim] @ [dim,dim] = [B*S,dim]
|
||
// reshape [B*S, nh, hd]
|
||
// qk-norm per-head RMSNorm over hd (Qwen3-style; Q/K only, before RoPE)
|
||
// rope [B*S, nh, hd] with per-sequence position (period = seq)
|
||
// reshape [B, S, nh, hd] → transpose(1,2) → [B, nh, S, hd] → [B*nh, S, hd]
|
||
let to_bh = |proj: Var, norm: Option<&Var>| -> Var {
|
||
let r = ops::reshape(&proj, &[total, nh, hd]);
|
||
let r = match norm {
|
||
// Per-head RMSNorm: flatten the (B*S,nh) head rows, norm over hd,
|
||
// restore. RoPE follows on the normed Q/K (mirrors xserv qwen3.rs).
|
||
Some(gamma) => {
|
||
let flat = ops::reshape(&r, &[total * nh, hd]);
|
||
let normed = ops::rms_norm(&flat, &norm_gamma(cdt, gamma), cfg.eps);
|
||
let r = ops::reshape(&normed, &[total, nh, hd]);
|
||
ops::rope(&r, cfg.rope_theta, seq)
|
||
}
|
||
None => r,
|
||
};
|
||
let r = ops::reshape(&r, &[batch, seq, nh, hd]);
|
||
let t = ops::transpose_4d12(&r); // [B, nh, S, hd]
|
||
ops::reshape(&t, &[bh, seq, hd]) // [B*nh, S, hd]
|
||
};
|
||
|
||
let q = to_bh(linear(cdt, x, wq), Some(q_norm));
|
||
let k = to_bh(linear(cdt, x, wk), Some(k_norm));
|
||
let v = to_bh(linear(cdt, x, wv), None);
|
||
|
||
// Causal SDPA over all B*nh (sequence,head) blocks. `flash` (T14) picks the
|
||
// single fused flash kernel (online softmax, no materialized [bh,S,S] scores);
|
||
// otherwise the composed T10 path (2 batched GEMMs + 1 causal-softmax kernel).
|
||
let out = if flash {
|
||
ops::flash_attention(&q, &k, &v, scale) // [B*nh, S, hd]
|
||
} else {
|
||
ops::attention(&q, &k, &v, scale) // [B*nh, S, hd]
|
||
};
|
||
|
||
// Back to [B*S, dim]: [B*nh,S,hd] → [B,nh,S,hd] → transpose(1,2) →
|
||
// [B,S,nh,hd] → [B*S, dim].
|
||
let out = ops::reshape(&out, &[batch, nh, seq, hd]);
|
||
let out = ops::transpose_4d12(&out); // [B, S, nh, hd]
|
||
let concat = ops::reshape(&out, &[total, nh * hd]); // [B*S, dim]
|
||
linear(cdt, &concat, wo) // out projection
|
||
}
|
||
|
||
/// SwiGLU MLP: `down( silu(gate(x)) ∘ up(x) )`. `x`:[batch*seq,dim].
|
||
fn swiglu_mlp(cdt: DType, x: &Var, w_gate: &Var, w_up: &Var, w_down: &Var) -> Var {
|
||
let gate = linear(cdt, x, w_gate); // [seq, ffn_hidden]
|
||
let up = linear(cdt, x, w_up); // [seq, ffn_hidden]
|
||
let act = ops::swiglu(&gate, &up); // silu(gate) ∘ up
|
||
linear(cdt, &act, w_down) // [seq, dim]
|
||
}
|
||
|
||
/// Materialise a parameter's value back to a host `Vec<f32>` (for the GD step
|
||
/// and PyTorch parity export).
|
||
pub fn param_to_host(v: &Var) -> Vec<f32> {
|
||
v.value().to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||
}
|
||
|
||
/// Build an I32 id tensor on `device` from token ids.
|
||
pub fn ids_tensor(ids: &[i32], device: Device) -> Tensor {
|
||
Tensor::from_slice(ids, &[ids.len()]).to_device(device)
|
||
}
|
||
|
||
/// Flatten `batch` equal-length sequences into one `[batch*seq]` I32 tensor in
|
||
/// sequence-major order (the layout `forward_batched` expects). Each row of
|
||
/// `seqs` is one sequence; all must have the same length.
|
||
pub fn batched_ids_tensor(seqs: &[Vec<i32>], device: Device) -> Tensor {
|
||
assert!(!seqs.is_empty(), "empty batch");
|
||
let seq = seqs[0].len();
|
||
let mut flat = Vec::with_capacity(seqs.len() * seq);
|
||
for s in seqs {
|
||
assert_eq!(s.len(), seq, "ragged batch: sequences must be equal length");
|
||
flat.extend_from_slice(s);
|
||
}
|
||
Tensor::from_slice(&flat, &[flat.len()]).to_device(device)
|
||
}
|