sft: assistant-only SFT (ignore-index CE) + chat-prompt greedy eval

Enable assistant-only supervised fine-tuning and a fixed chat-prompt eval path
used by the v12 SFT runs:

- cross_entropy ignores negative targets (-100 ignore-index), normalizing by
  valid rows instead of all rows; CUDA fwd/bwd skip t<0 (ops.rs, nn.cu).
- Corpus gains optional labels + load_sft_tsv_cached: two-column TSV is
  formatted as 'User: .. \nAssistant:' + answer + <|endoftext|>, prompt tokens
  masked to -100 while answer+EOS are supervised; i32 label cache alongside the
  u16 token cache; sample() retries windows that are fully masked; eval uses
  target_window so masking applies to val loss too (data.rs, train_loop.rs).
- train + train_ddp: --sft-tsv selects the TSV loader, --init-ckpt continues
  training from a base checkpoint.
- greedy_sample: --prompts-file/--prompt/--temperature for fixed chat-prompt
  generation eval.

Test fixtures updated for the new Corpus.labels field; dropout.rs carries
incidental rustfmt. Not rebuilt locally (no CUDA toolchain on this checkout);
correctness rests on the documented v12 base+SFT runs on the GPU box.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-29 16:19:02 +08:00
parent 5c27493a90
commit fbf4ac2917
11 changed files with 327 additions and 30 deletions

View File

@@ -338,7 +338,7 @@ __global__ void cross_entropy_fwd_k(const float* x, const int* target,
for (int c = threadIdx.x; c < cols; c += blockDim.x) pr[c] *= inv;
if (threadIdx.x == 0) {
int t = target[r];
loss[r] = -logf(pr[t]);
loss[r] = t < 0 ? 0.0f : -logf(pr[t]);
}
}
void launch_cross_entropy_fwd_f32(const float* x, const int* target,
@@ -354,8 +354,13 @@ __global__ void cross_entropy_dx_k(const float* probs, const int* target,
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= rows * cols) return;
int r = i / cols, c = i % cols;
float g = probs[i] - (c == target[r] ? 1.0f : 0.0f);
dx[i] = g * scale;
int t = target[r];
if (t < 0) {
dx[i] = 0.0f;
} else {
float g = probs[i] - (c == t ? 1.0f : 0.0f);
dx[i] = g * scale;
}
}
void launch_cross_entropy_dx_f32(const float* probs, const int* target,
float* dx, int rows, int cols, float scale, void* s) {