post-train: M2d — batch the GRPO training-side forwards (op + module + wiring)
After M2b/M2c made the rollout cheap, the GRPO step is dominated by the per-sample
single-sequence training-side forwards: the per_token_logp captures (policy +
reference) and the inner clipped-PG forward/backwards. M2d packs all N=B·G ragged
samples of a step into ONE forward_batched.
Enabling property — right-padding is free under causal attention: a real completion
row sits at an earlier position than the trailing pad, and causal masking forbids
attending forward, so its logits equal the unpadded single-sequence forward; pad
rows are masked out (target=-100).
- ops::clipped_pg_loss_batched: like clipped_pg_loss but takes per-row advantage[t]
(the owning sample's A) and per-row weight[t] (the full normaliser). It does NOT
compute its own 1/n_tokens, so the caller passing weight=1/(N·n_s) reproduces the
looped Σ_s (1/N)(1/n_s)·clipped_pg_loss_s bit-for-bit (per-row CE backward is
row-local).
- grpo_batch.rs (shared module): per_token_logp_batched (right-pad → one
forward_batched(N) → slice back to real length) + looped baselines +
inner_pg_step_{looped,batched}. A --micro knob chunks the pack to bound the
[chunk·Lmax, vocab] logits memory; weight uses the GLOBAL N so chunked
grad-accumulation stays exact.
- train_grpo restructured to collect-all-samples-then-batch; per-window phase timers
(rollout / capture / inner) to keep the step decomposition honest. Default micro =
B·G; bench-measured 9× on the training forwards.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -597,3 +597,87 @@ pub fn clipped_pg_loss(
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Batched GRPO clipped-PG loss over `N` ragged completions packed into ONE
|
||||
/// `forward_batched` (M2d): `logits` is `[R, vocab]` with `R = N·Lmax` rows in
|
||||
/// sequence-major order (sample 0's `Lmax` rows, then sample 1's, …), each ragged
|
||||
/// completion right-padded to the batch's `Lmax`. Prompt AND pad rows are masked
|
||||
/// (`target < 0`), so they contribute nothing and carry no gradient — the
|
||||
/// **right-pad-is-free-under-causal-attention** property (a real completion row
|
||||
/// never attends to the trailing pad rows, so its logits equal the unpadded
|
||||
/// single-sequence forward's).
|
||||
///
|
||||
/// Unlike the per-sample [`clipped_pg_loss`] (which folds a single scalar
|
||||
/// `advantage` and a global `1/N_tokens` normaliser), this op takes **per-row**
|
||||
/// `advantage[t]` (the owning sample's group-relative `A`) and **per-row**
|
||||
/// `weight[t]` (the full normaliser, e.g. `1/(N_samples · n_s)` for sample `s`'s
|
||||
/// completion rows, `0` at masked rows). It does NOT compute its own `inv_n`. With
|
||||
/// `weight[t] = 1/(N_samples·n_s)` and `advantage[t] = A_s` this is **bit-equivalent
|
||||
/// to the looped path** `Σ_s scale·(1/n_s)·clipped_pg_loss_s` (`scale = 1/N_samples`):
|
||||
/// the per-row backward is local (`cross_entropy_backward` is row-wise), so the
|
||||
/// batched row-`t` gradient equals the looped sample-`s` row-`t` gradient, and the
|
||||
/// scalar loss equals the looped weighted sum. (`tests/autograd.rs`:
|
||||
/// `clipped_pg_loss_batched_matches_looped`.) Degenerate points match
|
||||
/// [`clipped_pg_loss`] (`A=0` ⇒ KL only; `ε→∞` ⇒ vanilla PG; `β=0` ⇒ no KL).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn clipped_pg_loss_batched(
|
||||
logits: &Var,
|
||||
target: &Tensor,
|
||||
logp_old: &[f32],
|
||||
logp_ref: &[f32],
|
||||
advantage: &[f32],
|
||||
weight: &[f32],
|
||||
eps: f32,
|
||||
beta: f32,
|
||||
) -> Var {
|
||||
use xtrain_tensor::Device;
|
||||
let logit_dtype = logits.value().dtype();
|
||||
let (probs, per_row) = logits.value().cross_entropy(target);
|
||||
let rows = per_row.shape()[0];
|
||||
let per_row_h = per_row.to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
let target_h = target.to_device(Device::Cpu).as_slice::<i32>().to_vec();
|
||||
assert_eq!(logp_old.len(), rows, "logp_old must have one entry per row");
|
||||
assert_eq!(logp_ref.len(), rows, "logp_ref must have one entry per row");
|
||||
assert_eq!(advantage.len(), rows, "advantage must have one entry per row");
|
||||
assert_eq!(weight.len(), rows, "weight must have one entry per row");
|
||||
|
||||
let mut s = vec![0f32; rows]; // per-row scale for cross_entropy_backward(·,·,1.0)
|
||||
let mut loss_val = 0f32;
|
||||
for t in 0..rows {
|
||||
if target_h[t] < 0 {
|
||||
continue; // masked (prompt or pad) row — no contribution, no gradient
|
||||
}
|
||||
let (a, w) = (advantage[t], weight[t]);
|
||||
let lp = -per_row_h[t]; // logπθ_t
|
||||
let ratio = (lp - logp_old[t]).exp();
|
||||
let clipped = ratio.clamp(1.0 - eps, 1.0 + eps);
|
||||
let (unclipped_term, clipped_term) = (ratio * a, clipped * a);
|
||||
let pg_t = unclipped_term.min(clipped_term);
|
||||
let active = unclipped_term <= clipped_term; // min picks unclipped ⇒ grad flows
|
||||
let d = logp_ref[t] - lp;
|
||||
let kl_t = d.exp() - d - 1.0;
|
||||
let pg_grad = if active { -a * ratio } else { 0.0 };
|
||||
let kl_grad = beta * (1.0 - d.exp());
|
||||
// The full per-row normaliser is folded into s (no global inv_n here).
|
||||
s[t] = -(pg_grad + kl_grad) * w;
|
||||
loss_val += (-pg_t + beta * kl_t) * w;
|
||||
}
|
||||
let dev = logits.value().device();
|
||||
let out = Tensor::from_slice(&[loss_val], &[1]).to_device(dev);
|
||||
let s_dev = Tensor::from_slice(&s, &[rows]).to_device(dev);
|
||||
|
||||
let target = target.clone();
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![logits.clone()],
|
||||
Box::new(move |d, parents| {
|
||||
let up = d.to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
let ce = Tensor::cross_entropy_backward(&probs, &target, 1.0);
|
||||
let mut dx = ce.scale_rows(&s_dev);
|
||||
if up != 1.0 {
|
||||
dx = dx.scale(up);
|
||||
}
|
||||
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -23,15 +23,15 @@ fn main() {
|
||||
eprintln!("train_grpo: built without CUDA (no_cuda); run on a GPU host.");
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_autodiff::ops;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_cuda::device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_model::{Config, TinyTransformer, generate_cached_batch, ids_tensor};
|
||||
use xtrain_model::{Config, TinyTransformer, generate_cached_batch};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_tensor::{DType, Device};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::grpo_batch::{PgSample, inner_pg_step_batched, per_token_logp_batched};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::task::{check_answer, gen_problem, GenConfig, Op};
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
@@ -117,20 +117,6 @@ fn frame(tok: &xserv_tokenizer::Tokenizer, question: &str, completion: &str) ->
|
||||
(tokens[..l - 1].to_vec(), labels[1..l].to_vec())
|
||||
}
|
||||
|
||||
/// Per-position logprob `logπ(target_t)` of a framed (input, target) pair (= −per_row
|
||||
/// of cross_entropy; masked positions are 0 and unused). No grad kept.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn per_token_logp(model: &TinyTransformer, device: Device, input: &[i32], target: &[i32]) -> Vec<f32> {
|
||||
let logits = model.forward(&ids_tensor(input, device)).value();
|
||||
let (_, per_row) = logits.cross_entropy(&ids_tensor(target, device));
|
||||
per_row
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.map(|p| -p)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
@@ -149,6 +135,9 @@ fn main() {
|
||||
let group: usize = flag(&args, "--group", 6);
|
||||
let n_prompts: usize = flag(&args, "--prompts", 8);
|
||||
let inner: usize = flag(&args, "--inner", 1);
|
||||
// M2d: pack the step's N=B·G ragged samples into forward_batched chunks of this
|
||||
// many samples (bounds the [chunk·Lmax, vocab] logits memory). Default = whole batch.
|
||||
let micro: usize = flag(&args, "--micro", n_prompts * group.max(1));
|
||||
let temp: f32 = flag(&args, "--temp", 1.0);
|
||||
let beta: f32 = flag(&args, "--beta", 0.04);
|
||||
let eps: f32 = flag(&args, "--eps", 0.2);
|
||||
@@ -188,16 +177,17 @@ fn main() {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let (mut win_reward, mut win_solved, mut win_n) = (0f32, 0usize, 0usize);
|
||||
// Per-window phase timers (ms): rollout / capture / inner — to keep the step
|
||||
// decomposition honest (M2d cut the training-side forwards 9×, so the question is
|
||||
// what now dominates the step).
|
||||
let (mut t_roll, mut t_cap, mut t_inner) = (0f32, 0f32, 0f32);
|
||||
for step in 0..steps {
|
||||
// ---- Rollout: B prompts × G completions, scored, group-advantage ----
|
||||
struct Sample {
|
||||
input: Vec<i32>,
|
||||
target: Vec<i32>,
|
||||
adv: f32,
|
||||
logp_old: Vec<f32>,
|
||||
logp_ref: Vec<f32>,
|
||||
}
|
||||
let mut batch: Vec<Sample> = Vec::new();
|
||||
// Collect ALL the step's framed samples first (input, target, adv), so the
|
||||
// training-side forwards can be batched across the whole step (M2d) instead of
|
||||
// run one ragged sequence at a time.
|
||||
let t0 = std::time::Instant::now();
|
||||
let mut raw: Vec<(Vec<i32>, Vec<i32>, f32)> = Vec::new();
|
||||
for _ in 0..n_prompts {
|
||||
let p = gen_problem(&mut rng, &gcfg);
|
||||
let prompt_ids: Vec<i32> = tok
|
||||
@@ -230,53 +220,69 @@ fn main() {
|
||||
for (seg, r) in &comps {
|
||||
let adv = (r - mean) / (std + 1e-4);
|
||||
let (input, target) = frame(&tok, &p.question(), seg);
|
||||
let logp_old = per_token_logp(&policy, device, &input, &target);
|
||||
// β=0 ⇒ KL term drops ⇒ logp_ref unused; pass zeros (no reference model).
|
||||
let logp_ref = match &reference {
|
||||
Some(r) => per_token_logp(r, device, &input, &target),
|
||||
None => vec![0.0; logp_old.len()],
|
||||
};
|
||||
batch.push(Sample { input, target, adv, logp_old, logp_ref });
|
||||
raw.push((input, target, adv));
|
||||
}
|
||||
}
|
||||
|
||||
// ---- K inner clipped-PG epochs over the captured batch ----
|
||||
if !batch.is_empty() {
|
||||
let scale = 1.0 / batch.len() as f32;
|
||||
t_roll += t0.elapsed().as_secs_f32() * 1e3;
|
||||
|
||||
// ---- Batched capture (M2d): logπ_old (policy) + logπ_ref (frozen) over ALL
|
||||
// samples in forward_batched chunks, instead of one forward per sample. ----
|
||||
if !raw.is_empty() {
|
||||
let t1 = std::time::Instant::now();
|
||||
let io: Vec<(Vec<i32>, Vec<i32>)> = raw.iter().map(|(i, t, _)| (i.clone(), t.clone())).collect();
|
||||
let logp_old = per_token_logp_batched(&policy, device, &io, micro);
|
||||
// β=0 ⇒ KL term drops ⇒ logp_ref unused; pass zeros (no reference model).
|
||||
let logp_ref = match &reference {
|
||||
Some(r) => per_token_logp_batched(r, device, &io, micro),
|
||||
None => raw.iter().map(|(i, _, _)| vec![0.0; i.len()]).collect(),
|
||||
};
|
||||
let batch: Vec<PgSample> = raw
|
||||
.iter()
|
||||
.zip(logp_old)
|
||||
.zip(logp_ref)
|
||||
.map(|(((input, target, adv), lo), lr)| PgSample {
|
||||
input: input.clone(),
|
||||
target: target.clone(),
|
||||
adv: *adv,
|
||||
logp_old: lo,
|
||||
logp_ref: lr,
|
||||
})
|
||||
.collect();
|
||||
t_cap += t1.elapsed().as_secs_f32() * 1e3;
|
||||
|
||||
// ---- K inner clipped-PG epochs, batched over the captured samples ----
|
||||
let t2 = std::time::Instant::now();
|
||||
for _ in 0..inner {
|
||||
for s in &batch {
|
||||
let logits = policy.forward(&ids_tensor(&s.input, device));
|
||||
let loss = ops::clipped_pg_loss(
|
||||
&logits,
|
||||
&ids_tensor(&s.target, device),
|
||||
&s.logp_old,
|
||||
&s.logp_ref,
|
||||
s.adv,
|
||||
eps,
|
||||
beta,
|
||||
);
|
||||
ops::scale(&loss, scale).backward();
|
||||
}
|
||||
inner_pg_step_batched(&policy, device, &batch, eps, beta, micro);
|
||||
let _ = xtrain_train::clip::clip_grad_norm_gpu(¶ms, clip, 1.0);
|
||||
opt.step(lr, ¶ms);
|
||||
for p in ¶ms {
|
||||
p.zero_grad();
|
||||
}
|
||||
}
|
||||
t_inner += t2.elapsed().as_secs_f32() * 1e3;
|
||||
}
|
||||
|
||||
if (step + 1) % log_every == 0 || step == steps - 1 {
|
||||
let w = log_every.min(step + 1) as f32; // steps in this window
|
||||
println!(
|
||||
"step {:5}/{steps}: mean-reward {:.3} | solved {}/{} | {:.0}s",
|
||||
"step {:5}/{steps}: mean-reward {:.3} | solved {}/{} | {:.0}s | ms/step roll {:.0} cap {:.0} inner {:.0}",
|
||||
step + 1,
|
||||
win_reward / win_n.max(1) as f32,
|
||||
win_solved,
|
||||
win_n,
|
||||
start.elapsed().as_secs_f32(),
|
||||
t_roll / w,
|
||||
t_cap / w,
|
||||
t_inner / w,
|
||||
);
|
||||
win_reward = 0.0;
|
||||
win_solved = 0;
|
||||
win_n = 0;
|
||||
t_roll = 0.0;
|
||||
t_cap = 0.0;
|
||||
t_inner = 0.0;
|
||||
// Periodic save so a later OOM (naive rollout fragments the allocator —
|
||||
// the long-pole the design doc flagged) still leaves an evaluatable ckpt.
|
||||
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), ¶ms).expect("save");
|
||||
|
||||
162
crates/xtrain-train/src/grpo_batch.rs
Normal file
162
crates/xtrain-train/src/grpo_batch.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
//! Batched GRPO training-side forwards (post-training M2d). After M2b/M2c made the
|
||||
//! rollout cheap, the GRPO **step** is dominated by the per-sample full-sequence
|
||||
//! forwards: the `per_token_logp` captures (policy + reference) and the inner
|
||||
//! clipped-PG `forward`/`backward`s — each a single-sequence `forward` over a short
|
||||
//! ragged completion. This module packs the `N = B·G` ragged samples of a step into
|
||||
//! ONE `forward_batched`, amortising the per-launch overhead across N (the same win
|
||||
//! M2b gave the rollout).
|
||||
//!
|
||||
//! The enabling property: **right-padding is free under causal attention.** Pad each
|
||||
//! ragged completion on the RIGHT to the batch's `Lmax`; a real completion row is at
|
||||
//! an earlier position than the trailing pad, and causal masking forbids attending
|
||||
//! forward, so its logits are bit-identical to the unpadded single-sequence forward.
|
||||
//! The pad rows' own outputs are garbage but are masked out (`target = -100`).
|
||||
//!
|
||||
//! Both the looped (baseline) and batched paths live here so they share one source of
|
||||
//! truth — `bin/bench_grpo_batch` A/Bs them (timing + a closeness gate), and the
|
||||
//! per-row equivalence of the loss op is pinned by `clipped_pg_loss_batched_matches_looped`
|
||||
//! in `xtrain-autodiff/tests/autograd.rs`.
|
||||
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_autodiff::ops;
|
||||
use xtrain_model::{TinyTransformer, ids_tensor};
|
||||
use xtrain_tensor::{Device, Tensor};
|
||||
|
||||
/// One framed completion of a GRPO step: the next-token `(input, target)` pair
|
||||
/// (prompt positions masked to `-100` in `target`), its group-relative `adv`, and the
|
||||
/// per-position rollout-time / reference logprobs the clipped-PG loss needs.
|
||||
pub struct PgSample {
|
||||
pub input: Vec<i32>,
|
||||
pub target: Vec<i32>,
|
||||
pub adv: f32,
|
||||
pub logp_old: Vec<f32>,
|
||||
pub logp_ref: Vec<f32>,
|
||||
}
|
||||
|
||||
// ------------------------------- looped (baseline) -------------------------------
|
||||
|
||||
/// Per-position `logπ(target_t)` of one framed `(input, target)` pair (= `−per_row`
|
||||
/// of cross_entropy; masked positions are 0). One single-sequence forward, no grad.
|
||||
pub fn per_token_logp(model: &TinyTransformer, device: Device, input: &[i32], target: &[i32]) -> Vec<f32> {
|
||||
let logits = model.forward(&ids_tensor(input, device)).value();
|
||||
let (_, per_row) = logits.cross_entropy(&ids_tensor(target, device));
|
||||
per_row
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.map(|p| -p)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// One inner clipped-PG epoch the looped way: per sample, a single-sequence forward +
|
||||
/// [`ops::clipped_pg_loss`] scaled by `1/N` + backward (grads accumulate on `model`'s
|
||||
/// params). Returns the summed scaled loss. Caller does clip + opt.step + zero_grad.
|
||||
pub fn inner_pg_step_looped(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
batch: &[PgSample],
|
||||
eps: f32,
|
||||
beta: f32,
|
||||
) -> f32 {
|
||||
let scale = 1.0 / batch.len() as f32;
|
||||
let mut total = 0f32;
|
||||
for s in batch {
|
||||
let logits = model.forward(&ids_tensor(&s.input, device));
|
||||
let loss = ops::clipped_pg_loss(&logits, &ids_tensor(&s.target, device), &s.logp_old, &s.logp_ref, s.adv, eps, beta);
|
||||
let scaled = ops::scale(&loss, scale);
|
||||
total += scaled.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
scaled.backward();
|
||||
}
|
||||
total
|
||||
}
|
||||
|
||||
// ------------------------------- batched (M2d) -----------------------------------
|
||||
|
||||
/// Right-pad `m` ragged `i32` rows (each `< lmax` long) to `[m*lmax]` sequence-major,
|
||||
/// filling with `pad`. Used for both the id stream (pad = 0, arbitrary) and the target
|
||||
/// stream (pad = −100, ignored by cross_entropy).
|
||||
fn pack_i32(rows: &[&[i32]], lmax: usize, pad: i32) -> Vec<i32> {
|
||||
let mut flat = vec![pad; rows.len() * lmax];
|
||||
for (i, r) in rows.iter().enumerate() {
|
||||
flat[i * lmax..i * lmax + r.len()].copy_from_slice(r);
|
||||
}
|
||||
flat
|
||||
}
|
||||
|
||||
/// Batched [`per_token_logp`]: pack `samples` (each `(input, target)`) right-padded to
|
||||
/// `Lmax`, run ONE `forward_batched(batch = N)`, and slice each sample's `logπ` back to
|
||||
/// its real length. Equal to looping [`per_token_logp`] (right-pad is free under causal
|
||||
/// attention), to bf16 batch-reduction tolerance. `samples` are processed in chunks of
|
||||
/// `micro` (≥1) to bound the `[chunk*Lmax, vocab]` logits memory.
|
||||
pub fn per_token_logp_batched(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
samples: &[(Vec<i32>, Vec<i32>)],
|
||||
micro: usize,
|
||||
) -> Vec<Vec<f32>> {
|
||||
let mut out = Vec::with_capacity(samples.len());
|
||||
for chunk in samples.chunks(micro.max(1)) {
|
||||
let m = chunk.len();
|
||||
let lmax = chunk.iter().map(|(i, _)| i.len()).max().unwrap();
|
||||
let ins: Vec<&[i32]> = chunk.iter().map(|(i, _)| i.as_slice()).collect();
|
||||
let tgs: Vec<&[i32]> = chunk.iter().map(|(_, t)| t.as_slice()).collect();
|
||||
let ids = Tensor::from_slice(&pack_i32(&ins, lmax, 0), &[m * lmax]).to_device(device);
|
||||
let tgt = Tensor::from_slice(&pack_i32(&tgs, lmax, -100), &[m * lmax]).to_device(device);
|
||||
let logits = model.forward_batched(&ids, m).value();
|
||||
let (_, per_row) = logits.cross_entropy(&tgt);
|
||||
let pr = per_row.to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
for (i, (inp, _)) in chunk.iter().enumerate() {
|
||||
let b = i * lmax;
|
||||
out.push((0..inp.len()).map(|r| -pr[b + r]).collect());
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// One inner clipped-PG epoch, batched: pack the batch (in `micro`-sized chunks) and run
|
||||
/// ONE `forward_batched` + [`ops::clipped_pg_loss_batched`] + backward per chunk. The
|
||||
/// per-row `weight = 1/(N·n_s)` uses the GLOBAL `N = batch.len()` (not the chunk size),
|
||||
/// so chunked grad-accumulation reproduces the looped `Σ_s (1/N)(1/n_s)…` exactly.
|
||||
/// Returns the summed loss. Caller does clip + opt.step + zero_grad.
|
||||
pub fn inner_pg_step_batched(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
batch: &[PgSample],
|
||||
eps: f32,
|
||||
beta: f32,
|
||||
micro: usize,
|
||||
) -> f32 {
|
||||
let inv_n = 1.0 / batch.len() as f32;
|
||||
let mut total = 0f32;
|
||||
for chunk in batch.chunks(micro.max(1)) {
|
||||
let m = chunk.len();
|
||||
let lmax = chunk.iter().map(|s| s.input.len()).max().unwrap();
|
||||
let ins: Vec<&[i32]> = chunk.iter().map(|s| s.input.as_slice()).collect();
|
||||
let tgs: Vec<&[i32]> = chunk.iter().map(|s| s.target.as_slice()).collect();
|
||||
let ids = Tensor::from_slice(&pack_i32(&ins, lmax, 0), &[m * lmax]).to_device(device);
|
||||
let tgt = Tensor::from_slice(&pack_i32(&tgs, lmax, -100), &[m * lmax]).to_device(device);
|
||||
|
||||
let mut logp_old = vec![0f32; m * lmax];
|
||||
let mut logp_ref = vec![0f32; m * lmax];
|
||||
let mut advantage = vec![0f32; m * lmax];
|
||||
let mut weight = vec![0f32; m * lmax];
|
||||
for (i, s) in chunk.iter().enumerate() {
|
||||
let b = i * lmax;
|
||||
let li = s.input.len();
|
||||
logp_old[b..b + li].copy_from_slice(&s.logp_old);
|
||||
logp_ref[b..b + li].copy_from_slice(&s.logp_ref);
|
||||
let n_s = s.target.iter().filter(|&&t| t >= 0).count().max(1) as f32;
|
||||
let w = inv_n / n_s; // = 1/(N · n_s)
|
||||
for r in 0..lmax {
|
||||
advantage[b + r] = s.adv;
|
||||
weight[b + r] = w;
|
||||
}
|
||||
}
|
||||
let logits = model.forward_batched(&ids, m);
|
||||
let loss = ops::clipped_pg_loss_batched(&logits, &tgt, &logp_old, &logp_ref, &advantage, &weight, eps, beta);
|
||||
total += loss.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
loss.backward();
|
||||
}
|
||||
total
|
||||
}
|
||||
@@ -15,6 +15,8 @@ pub mod task;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod checkpoint;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod grpo_batch;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod sample;
|
||||
#[cfg(not(no_cuda))]
|
||||
mod train_loop;
|
||||
|
||||
Reference in New Issue
Block a user