The rollout long-pole fix deferred from M2a: decode the G samples of one prompt in lockstep (one forward per step over the group → G× fewer kernel launches). - rope_pos(x, positions[]): RoPE with a per-row absolute position (new forward- only kernel) — G rows share one decode position. Gate: == full rope for [0..n], == rope_at(P) per row for uniform P (bit-identical). - generate_cached_batch: BatchKVCache [T, G·num_kv, hd] + batched decode_step. decode_attention is already batch-agnostic (bh = G·nh); repeat_kv(nh, batch=G) broadcasts per group. No finished-mask / ragged prompts yet (perf-only / next). - Gate (tests/decode_batch.rs): all G greedy rows token-identical to the single- sequence decode (8 query / 2 kv heads → exercises repeat_kv batching). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
442 lines
18 KiB
Rust
442 lines
18 KiB
Rust
//! KV-cache incremental-decode engine (post-training M2a, single sequence).
|
||
//!
|
||
//! The naive sampler ([`crate::TinyTransformer`] via `train::sample::generate`)
|
||
//! re-runs the full forward over the whole growing prefix every step — O(t²) and
|
||
//! a fresh autograd graph per token. This is the inference engine that replaces it:
|
||
//! a per-layer **K/V cache** + a **single-token incremental forward** that processes
|
||
//! one new token at a time, attending to the cached keys/values.
|
||
//!
|
||
//! Built on three primitives, all gated by their own correctness tests:
|
||
//! - [`Tensor::rope_at`](xtrain_tensor::Tensor::rope_at): RoPE at the token's
|
||
//! absolute position (not row-in-tile), so cached post-RoPE K matches the full
|
||
//! forward (bit-identical, `integration::rope_at_matches_full_rope_row`).
|
||
//! - [`Tensor::decode_attention`](xtrain_tensor::Tensor::decode_attention): the
|
||
//! single-query × cached-K/V SDPA, equal to the full causal attention's last row
|
||
//! (`integration::decode_attention_matches_full_attention_last_row`).
|
||
//! - this module's per-token block forward, mirroring `model::block_forward` at the
|
||
//! raw-Tensor level (no autograd tape — inference needs no gradients).
|
||
//!
|
||
//! Correctness gate (the M2 centerpiece): KV-cache greedy decode is **token-
|
||
//! identical** to the naive full-recompute greedy (`tests/decode_kv.rs`).
|
||
//!
|
||
//! Prefill is just the first `prompt.len()` decode steps (one token at a time) —
|
||
//! one code path, at the cost of a non-batched prefill (M2b adds batched prefill +
|
||
//! ragged batch decode). The cache is host-accumulated (token-major f32) and the
|
||
//! K/V tensor is rebuilt per step; the host round-trip is small (`num_kv·head_dim`
|
||
//! floats/token/layer) and is the honest M2a baseline — M2b moves it device-side.
|
||
|
||
#![cfg(not(no_cuda))]
|
||
|
||
use crate::TinyTransformer;
|
||
use xtrain_tensor::{DType, Device, Tensor};
|
||
|
||
/// Per-layer K/V cache: token-major host accumulation. For each layer, `k[li]` and
|
||
/// `v[li]` hold `[T, num_kv, head_dim]` (f32, flattened), grown by one token's
|
||
/// `num_kv·head_dim` values per decode step. Stored f32 (an exact upcast of the
|
||
/// bf16 projection output); rebuilt to the compute dtype when forming the K/V
|
||
/// tensor, so bf16 values round-trip bit-for-bit.
|
||
struct KVCache {
|
||
k: Vec<Vec<f32>>,
|
||
v: Vec<Vec<f32>>,
|
||
}
|
||
|
||
impl KVCache {
|
||
fn new(n_layers: usize) -> Self {
|
||
Self {
|
||
k: vec![Vec::new(); n_layers],
|
||
v: vec![Vec::new(); n_layers],
|
||
}
|
||
}
|
||
|
||
/// Append one token's K/V slab (each `num_kv·head_dim` f32) to layer `li`.
|
||
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
|
||
self.k[li].extend_from_slice(k_tok);
|
||
self.v[li].extend_from_slice(v_tok);
|
||
}
|
||
}
|
||
|
||
/// Linear `x @ W` in the compute dtype — mirrors `model::linear` (bf16 casts the
|
||
/// fp32-master weight to bf16 on the fly; the activation stream is already bf16).
|
||
fn linear_t(cdt: DType, x: &Tensor, w: &Tensor) -> Tensor {
|
||
match cdt {
|
||
DType::F32 => x.matmul(w),
|
||
DType::BF16 => x.matmul(&w.to_dtype(DType::BF16)),
|
||
_ => unreachable!("compute dtype must be F32/BF16"),
|
||
}
|
||
}
|
||
|
||
/// A norm/QK-norm gamma in the compute dtype — mirrors `model::norm_gamma`.
|
||
fn gamma_t(cdt: DType, g: &Tensor) -> Tensor {
|
||
match cdt {
|
||
DType::F32 => g.clone(),
|
||
DType::BF16 => g.to_dtype(DType::BF16),
|
||
_ => unreachable!("compute dtype must be F32/BF16"),
|
||
}
|
||
}
|
||
|
||
/// Greedy KV-cache decode: continue `prompt` by `max_new` tokens, argmax each step.
|
||
/// Returns the full token sequence (prompt + generated), matching the naive
|
||
/// `sample::generate` interface for `temperature == 0`. Token-identical to the
|
||
/// naive full-recompute greedy (gated by `tests/decode_kv.rs`).
|
||
pub fn generate_greedy_cached(
|
||
model: &TinyTransformer,
|
||
device: Device,
|
||
prompt: &[i32],
|
||
max_new: usize,
|
||
) -> Vec<i32> {
|
||
let mut rng = 0u64;
|
||
generate_cached(model, device, prompt, max_new, 0.0, &mut rng)
|
||
}
|
||
|
||
/// KV-cache decode with temperature sampling (`temperature == 0` → greedy argmax,
|
||
/// matching [`generate_greedy_cached`]; otherwise sample from `softmax(logits/T)`).
|
||
/// The KV-cache rollout the GRPO loop uses: each step allocates only a single-row
|
||
/// `[1, vocab]` logits buffer (vs the naive sampler's `[seq, vocab]`), so it is far
|
||
/// lighter on memory + the allocator — the naive sampler fragments the caching
|
||
/// allocator over a long rollout, which is the M4 "rollout is the long pole" wall.
|
||
pub fn generate_cached(
|
||
model: &TinyTransformer,
|
||
device: Device,
|
||
prompt: &[i32],
|
||
max_new: usize,
|
||
temperature: f32,
|
||
rng_state: &mut u64,
|
||
) -> Vec<i32> {
|
||
assert!(!prompt.is_empty(), "prompt must be non-empty");
|
||
let cfg = model.config();
|
||
let cdt = model.compute_dtype();
|
||
let n_layers = cfg.n_layers;
|
||
|
||
// params() is a stable, documented order (see TinyTransformer::params):
|
||
// [0] = embed [vocab, dim]
|
||
// [1 + li*11 .. +11] = layer li's 11 leaves, in block_params order:
|
||
// attn_norm, wq, wk, wv, q_norm, k_norm, wo, ffn_norm, w_gate, w_up, w_down
|
||
// [1 + n_layers*11] = final_norm [dim]
|
||
// [1 + n_layers*11 + 1] = lm_head [dim, vocab]
|
||
let params: Vec<Tensor> = model.params().iter().map(|p| p.value()).collect();
|
||
assert_eq!(
|
||
params.len(),
|
||
1 + n_layers * 11 + 2,
|
||
"unexpected param layout for decode"
|
||
);
|
||
let embed = ¶ms[0];
|
||
let final_norm = ¶ms[1 + n_layers * 11];
|
||
let lm_head = ¶ms[1 + n_layers * 11 + 1];
|
||
|
||
let mut cache = KVCache::new(n_layers);
|
||
let mut tokens = prompt.to_vec();
|
||
|
||
// Prefill: feed each prompt token in order; the last step's logits are the
|
||
// distribution for the first generated token.
|
||
let mut logits = Vec::new();
|
||
for (pos, &tok) in prompt.iter().enumerate() {
|
||
logits = decode_step(¶ms, cfg, cdt, device, &mut cache, tok, pos, embed, final_norm, lm_head);
|
||
}
|
||
|
||
for _ in 0..max_new {
|
||
let next = if temperature <= 0.0 {
|
||
argmax(&logits) as i32
|
||
} else {
|
||
sample_temperature(&logits, temperature, rng_state) as i32
|
||
};
|
||
tokens.push(next);
|
||
let pos = tokens.len() - 1; // absolute position of the token just appended
|
||
logits = decode_step(¶ms, cfg, cdt, device, &mut cache, next, pos, embed, final_norm, lm_head);
|
||
}
|
||
tokens
|
||
}
|
||
|
||
/// Sample a token from `softmax(logits / temperature)` (numerically stable). Same
|
||
/// LCG + inverse-CDF scheme as the naive `sample::sample_temperature`.
|
||
fn sample_temperature(row: &[f32], temperature: f32, rng_state: &mut u64) -> usize {
|
||
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||
let exps: Vec<f32> = row.iter().map(|&x| ((x - max) / temperature).exp()).collect();
|
||
let sum: f32 = exps.iter().sum();
|
||
*rng_state = rng_state
|
||
.wrapping_mul(6364136223846793005)
|
||
.wrapping_add(1442695040888963407);
|
||
let r = ((*rng_state >> 32) as f32 / u32::MAX as f32) * sum;
|
||
let mut acc = 0.0;
|
||
for (i, &e) in exps.iter().enumerate() {
|
||
acc += e;
|
||
if acc >= r {
|
||
return i;
|
||
}
|
||
}
|
||
exps.len() - 1
|
||
}
|
||
|
||
/// One incremental decode step for token `tok` at absolute position `pos`: append
|
||
/// its K/V to the cache and return the next-token logits as host f32 `[vocab]`.
|
||
#[allow(clippy::too_many_arguments)]
|
||
fn decode_step(
|
||
params: &[Tensor],
|
||
cfg: &crate::Config,
|
||
cdt: DType,
|
||
device: Device,
|
||
cache: &mut KVCache,
|
||
tok: i32,
|
||
pos: usize,
|
||
embed: &Tensor,
|
||
final_norm: &Tensor,
|
||
lm_head: &Tensor,
|
||
) -> Vec<f32> {
|
||
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
|
||
let dim = cfg.dim;
|
||
let kv_dim = num_kv * hd;
|
||
let scale = 1.0 / (hd as f32).sqrt();
|
||
let (theta, eps) = (cfg.rope_theta, cfg.eps);
|
||
let n_layers = cfg.n_layers;
|
||
|
||
// Embedding (fp32 table) → activation stream in the compute dtype.
|
||
let ids = Tensor::from_slice(&[tok], &[1]).to_device(device);
|
||
let mut h = embed.embedding(&ids); // [1, dim] f32
|
||
if cdt == DType::BF16 {
|
||
h = h.to_dtype(DType::BF16);
|
||
}
|
||
|
||
for li in 0..n_layers {
|
||
let base = 1 + li * 11;
|
||
let (attn_norm, wq, wk, wv) =
|
||
(¶ms[base], ¶ms[base + 1], ¶ms[base + 2], ¶ms[base + 3]);
|
||
let (q_norm, k_norm, wo) = (¶ms[base + 4], ¶ms[base + 5], ¶ms[base + 6]);
|
||
let (ffn_norm, w_gate, w_up, w_down) =
|
||
(¶ms[base + 7], ¶ms[base + 8], ¶ms[base + 9], ¶ms[base + 10]);
|
||
|
||
// --- Attention sub-block (pre-norm + cached-KV attention + residual) ---
|
||
let normed = h.rms_norm(&gamma_t(cdt, attn_norm), eps).0; // [1, dim]
|
||
|
||
// Q: project → per-head QK-norm → RoPE at absolute position `pos`.
|
||
let q = linear_t(cdt, &normed, wq).reshape(&[1, nh, hd]); // [1, nh, hd]
|
||
let q = q.reshape(&[nh, hd]).rms_norm(&gamma_t(cdt, q_norm), eps).0;
|
||
let q = q.reshape(&[1, nh, hd]).rope_at(theta, pos);
|
||
let q_bh = q.reshape(&[nh, 1, hd]); // seq=1 ⇒ the head-transpose is a no-op on data
|
||
|
||
// K: same as Q (QK-norm + RoPE); cache token-major. V: project only.
|
||
let k = linear_t(cdt, &normed, wk).reshape(&[1, num_kv, hd]);
|
||
let k = k.reshape(&[num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
|
||
let k_tok = k.reshape(&[1, num_kv, hd]).rope_at(theta, pos); // [1, num_kv, hd]
|
||
let v_tok = linear_t(cdt, &normed, wv).reshape(&[1, num_kv, hd]);
|
||
|
||
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||
cache.append(li, &k_host, &v_host);
|
||
|
||
// Rebuild the full K/V for this layer: token-major [T,num_kv,hd] → [num_kv,T,hd]
|
||
// → repeat_kv to [nh,T,hd].
|
||
let t_len = cache.k[li].len() / kv_dim;
|
||
let build = |flat: &[f32]| -> Tensor {
|
||
let bh_kv = Tensor::from_slice(flat, &[t_len, num_kv, hd])
|
||
.to_device(device)
|
||
.transpose_3d01(); // [num_kv, T, hd], f32
|
||
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
|
||
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, 1) } // [nh, T, hd]
|
||
};
|
||
let k_full = build(&cache.k[li]);
|
||
let v_full = build(&cache.v[li]);
|
||
|
||
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [nh, hd]
|
||
let attn = attn.reshape(&[1, dim]); // concat heads (nh·hd == dim)
|
||
let attn_out = linear_t(cdt, &attn, wo); // [1, dim]
|
||
h = h.add(&attn_out);
|
||
|
||
// --- MLP sub-block (pre-norm + SwiGLU + residual) ---
|
||
let normed = h.rms_norm(&gamma_t(cdt, ffn_norm), eps).0;
|
||
let gate = linear_t(cdt, &normed, w_gate);
|
||
let up = linear_t(cdt, &normed, w_up);
|
||
let act = gate.silu().mul(&up); // swiglu = silu(gate) ∘ up
|
||
let down = linear_t(cdt, &act, w_down);
|
||
h = h.add(&down);
|
||
}
|
||
|
||
let h = h.rms_norm(&gamma_t(cdt, final_norm), eps).0;
|
||
let logits = linear_t(cdt, &h, lm_head); // [1, vocab]
|
||
logits
|
||
.to_dtype(DType::F32)
|
||
.to_device(Device::Cpu)
|
||
.as_slice::<f32>()
|
||
.to_vec()
|
||
}
|
||
|
||
fn argmax(row: &[f32]) -> usize {
|
||
row.iter()
|
||
.enumerate()
|
||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||
.unwrap()
|
||
.0
|
||
}
|
||
|
||
// ===================================================================
|
||
// M2b — batched KV-cache decode (G samples of one prompt, in lockstep)
|
||
// ===================================================================
|
||
|
||
/// Batched K/V cache: `G` sequences advancing together. Per layer, host-accumulates
|
||
/// seq-major `[T, G·num_kv, head_dim]` (one step appends `G·num_kv·hd` f32), rebuilt
|
||
/// to `[G·num_kv, T, hd]` per step. Same host-cache shape as M2a with a G dimension.
|
||
struct BatchKVCache {
|
||
k: Vec<Vec<f32>>,
|
||
v: Vec<Vec<f32>>,
|
||
}
|
||
|
||
impl BatchKVCache {
|
||
fn new(n_layers: usize) -> Self {
|
||
Self {
|
||
k: vec![Vec::new(); n_layers],
|
||
v: vec![Vec::new(); n_layers],
|
||
}
|
||
}
|
||
fn append(&mut self, li: usize, k_tok: &[f32], v_tok: &[f32]) {
|
||
self.k[li].extend_from_slice(k_tok);
|
||
self.v[li].extend_from_slice(v_tok);
|
||
}
|
||
}
|
||
|
||
/// Batched KV-cache decode: roll out `n_samples` (G) completions of the SAME
|
||
/// `prompt` in lockstep — all G share the prompt, so they advance at one common
|
||
/// decode position each step (uniform RoPE via `rope_pos`). Returns G full token
|
||
/// sequences (prompt + sampled continuation). The G-way batching amortises the
|
||
/// per-step kernel launches across G (the rollout long-pole). Token-identical per
|
||
/// row to G independent single-sequence decodes (gated by `tests/decode_batch.rs`).
|
||
///
|
||
/// `temperature == 0` ⇒ greedy (all G identical); `> 0` ⇒ independent samples
|
||
/// (per-row draw from one shared `rng_state`). No finished-mask: all G generate
|
||
/// `max_new` tokens; the caller cuts each at `<|endoftext|>` (a perf-only early
|
||
/// stop is the M2b+ follow-up). Ragged (different-length prompts) is also deferred.
|
||
pub fn generate_cached_batch(
|
||
model: &TinyTransformer,
|
||
device: Device,
|
||
prompt: &[i32],
|
||
n_samples: usize,
|
||
max_new: usize,
|
||
temperature: f32,
|
||
rng_state: &mut u64,
|
||
) -> Vec<Vec<i32>> {
|
||
assert!(!prompt.is_empty(), "prompt must be non-empty");
|
||
assert!(n_samples > 0, "n_samples must be > 0");
|
||
let cfg = model.config();
|
||
let cdt = model.compute_dtype();
|
||
let n_layers = cfg.n_layers;
|
||
let params: Vec<Tensor> = model.params().iter().map(|p| p.value()).collect();
|
||
let embed = ¶ms[0];
|
||
let final_norm = ¶ms[1 + n_layers * 11];
|
||
let lm_head = ¶ms[1 + n_layers * 11 + 1];
|
||
|
||
let g = n_samples;
|
||
let mut cache = BatchKVCache::new(n_layers);
|
||
let mut seqs: Vec<Vec<i32>> = vec![prompt.to_vec(); g];
|
||
|
||
// Prefill: feed each prompt token (identical across G) at its position.
|
||
let mut logits = Vec::new(); // [G, vocab] flattened
|
||
for (pos, &tok) in prompt.iter().enumerate() {
|
||
let toks = vec![tok; g];
|
||
logits = decode_step_batch(¶ms, cfg, cdt, device, &mut cache, &toks, pos, embed, final_norm, lm_head);
|
||
}
|
||
|
||
let vocab = cfg.vocab;
|
||
for _ in 0..max_new {
|
||
let mut next = Vec::with_capacity(g);
|
||
for row in 0..g {
|
||
let lg = &logits[row * vocab..(row + 1) * vocab];
|
||
let t = if temperature <= 0.0 {
|
||
argmax(lg) as i32
|
||
} else {
|
||
sample_temperature(lg, temperature, rng_state) as i32
|
||
};
|
||
next.push(t);
|
||
seqs[row].push(t);
|
||
}
|
||
let pos = seqs[0].len() - 1; // all G are at the same position
|
||
logits = decode_step_batch(¶ms, cfg, cdt, device, &mut cache, &next, pos, embed, final_norm, lm_head);
|
||
}
|
||
seqs
|
||
}
|
||
|
||
/// One batched decode step: `toks` is one current token per sequence (`[G]`), all at
|
||
/// absolute position `pos`. Appends each sequence's K/V and returns logits `[G·vocab]`.
|
||
#[allow(clippy::too_many_arguments)]
|
||
fn decode_step_batch(
|
||
params: &[Tensor],
|
||
cfg: &crate::Config,
|
||
cdt: DType,
|
||
device: Device,
|
||
cache: &mut BatchKVCache,
|
||
toks: &[i32],
|
||
pos: usize,
|
||
embed: &Tensor,
|
||
final_norm: &Tensor,
|
||
lm_head: &Tensor,
|
||
) -> Vec<f32> {
|
||
let (nh, hd, num_kv) = (cfg.n_heads, cfg.head_dim, cfg.num_kv_heads);
|
||
let dim = cfg.dim;
|
||
let g = toks.len();
|
||
let g_kv = g * num_kv; // batch·num_kv heads in the cache
|
||
let scale = 1.0 / (hd as f32).sqrt();
|
||
let (theta, eps) = (cfg.rope_theta, cfg.eps);
|
||
let n_layers = cfg.n_layers;
|
||
// Uniform per-row position (all G at the same decode step).
|
||
let positions = Tensor::from_slice(&vec![pos as i32; g], &[g]).to_device(device);
|
||
|
||
let ids = Tensor::from_slice(toks, &[g]).to_device(device);
|
||
let mut h = embed.embedding(&ids); // [G, dim] f32
|
||
if cdt == DType::BF16 {
|
||
h = h.to_dtype(DType::BF16);
|
||
}
|
||
|
||
for li in 0..n_layers {
|
||
let base = 1 + li * 11;
|
||
let (attn_norm, wq, wk, wv) =
|
||
(¶ms[base], ¶ms[base + 1], ¶ms[base + 2], ¶ms[base + 3]);
|
||
let (q_norm, k_norm, wo) = (¶ms[base + 4], ¶ms[base + 5], ¶ms[base + 6]);
|
||
let (ffn_norm, w_gate, w_up, w_down) =
|
||
(¶ms[base + 7], ¶ms[base + 8], ¶ms[base + 9], ¶ms[base + 10]);
|
||
|
||
let normed = h.rms_norm(&gamma_t(cdt, attn_norm), eps).0; // [G, dim]
|
||
|
||
// Q: project → per-head QK-norm → RoPE at `pos` for every row.
|
||
let q = linear_t(cdt, &normed, wq).reshape(&[g, nh, hd]);
|
||
let q = q.reshape(&[g * nh, hd]).rms_norm(&gamma_t(cdt, q_norm), eps).0;
|
||
let q = q.reshape(&[g, nh, hd]).rope_pos(&positions, theta);
|
||
let q_bh = q.reshape(&[g * nh, 1, hd]); // bh = G·nh
|
||
|
||
let k = linear_t(cdt, &normed, wk).reshape(&[g, num_kv, hd]);
|
||
let k = k.reshape(&[g * num_kv, hd]).rms_norm(&gamma_t(cdt, k_norm), eps).0;
|
||
let k_tok = k.reshape(&[g, num_kv, hd]).rope_pos(&positions, theta);
|
||
let v_tok = linear_t(cdt, &normed, wv).reshape(&[g, num_kv, hd]);
|
||
|
||
let k_host = k_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||
let v_host = v_tok.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||
cache.append(li, &k_host, &v_host);
|
||
|
||
// Rebuild [T, G·num_kv, hd] → [G·num_kv, T, hd] → repeat_kv to [G·nh, T, hd].
|
||
let t_len = cache.k[li].len() / (g_kv * hd);
|
||
let build = |flat: &[f32]| -> Tensor {
|
||
let bh_kv = Tensor::from_slice(flat, &[t_len, g_kv, hd])
|
||
.to_device(device)
|
||
.transpose_3d01(); // [G·num_kv, T, hd], f32
|
||
let bh_kv = if cdt == DType::BF16 { bh_kv.to_dtype(DType::BF16) } else { bh_kv };
|
||
if num_kv == nh { bh_kv } else { bh_kv.repeat_kv(nh, g) } // [G·nh, T, hd]
|
||
};
|
||
let k_full = build(&cache.k[li]);
|
||
let v_full = build(&cache.v[li]);
|
||
|
||
let attn = q_bh.decode_attention(&k_full, &v_full, scale); // [G·nh, hd]
|
||
let attn = attn.reshape(&[g, dim]); // concat heads per sequence
|
||
let attn_out = linear_t(cdt, &attn, wo);
|
||
h = h.add(&attn_out);
|
||
|
||
let normed = h.rms_norm(&gamma_t(cdt, ffn_norm), eps).0;
|
||
let gate = linear_t(cdt, &normed, w_gate);
|
||
let up = linear_t(cdt, &normed, w_up);
|
||
let act = gate.silu().mul(&up);
|
||
let down = linear_t(cdt, &act, w_down);
|
||
h = h.add(&down);
|
||
}
|
||
|
||
let h = h.rms_norm(&gamma_t(cdt, final_norm), eps).0;
|
||
linear_t(cdt, &h, lm_head)
|
||
.to_dtype(DType::F32)
|
||
.to_device(Device::Cpu)
|
||
.as_slice::<f32>()
|
||
.to_vec()
|
||
}
|