sft: assistant-only SFT (ignore-index CE) + chat-prompt greedy eval
Enable assistant-only supervised fine-tuning and a fixed chat-prompt eval path used by the v12 SFT runs: - cross_entropy ignores negative targets (-100 ignore-index), normalizing by valid rows instead of all rows; CUDA fwd/bwd skip t<0 (ops.rs, nn.cu). - Corpus gains optional labels + load_sft_tsv_cached: two-column TSV is formatted as 'User: .. \nAssistant:' + answer + <|endoftext|>, prompt tokens masked to -100 while answer+EOS are supervised; i32 label cache alongside the u16 token cache; sample() retries windows that are fully masked; eval uses target_window so masking applies to val loss too (data.rs, train_loop.rs). - train + train_ddp: --sft-tsv selects the TSV loader, --init-ckpt continues training from a base checkpoint. - greedy_sample: --prompts-file/--prompt/--temperature for fixed chat-prompt generation eval. Test fixtures updated for the new Corpus.labels field; dropout.rs carries incidental rustfmt. Not rebuilt locally (no CUDA toolchain on this checkout); correctness rests on the documented v12 base+SFT runs on the GPU box. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -338,7 +338,7 @@ __global__ void cross_entropy_fwd_k(const float* x, const int* target,
|
||||
for (int c = threadIdx.x; c < cols; c += blockDim.x) pr[c] *= inv;
|
||||
if (threadIdx.x == 0) {
|
||||
int t = target[r];
|
||||
loss[r] = -logf(pr[t]);
|
||||
loss[r] = t < 0 ? 0.0f : -logf(pr[t]);
|
||||
}
|
||||
}
|
||||
void launch_cross_entropy_fwd_f32(const float* x, const int* target,
|
||||
@@ -354,8 +354,13 @@ __global__ void cross_entropy_dx_k(const float* probs, const int* target,
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i >= rows * cols) return;
|
||||
int r = i / cols, c = i % cols;
|
||||
float g = probs[i] - (c == target[r] ? 1.0f : 0.0f);
|
||||
dx[i] = g * scale;
|
||||
int t = target[r];
|
||||
if (t < 0) {
|
||||
dx[i] = 0.0f;
|
||||
} else {
|
||||
float g = probs[i] - (c == t ? 1.0f : 0.0f);
|
||||
dx[i] = g * scale;
|
||||
}
|
||||
}
|
||||
void launch_cross_entropy_dx_f32(const float* probs, const int* target,
|
||||
float* dx, int rows, int cols, float scale, void* s) {
|
||||
|
||||
Reference in New Issue
Block a user