Swap forward_verify_paged_decode_attention_with_hidden's projections
from matmul_batched_gemv (per-row bit-exact GEMV) to matmul_2d (cuBLAS
GEMM at m>1). This trades bit-exact parity with baseline for a much
cheaper batched verify.
Micro-benchmark (bench-verify-cost.rs) reveals the huge cost gap:
batched-GEMV verify: 1.05× → 5.14× single decode (linear in batch)
cuBLAS-GEMM verify: 1.04× → 1.20× single decode (nearly flat)
At batch=9 the difference is 4.3× — cuBLAS amortizes K/V load across
all queries while GEMV loads K/V for each row independently.
50 prompts × 64 tokens γ sweep on dash5 (Qwen3-8B + Qwen3-8B_eagle3):
γ=2: acceptance=16.9%, speedup_e2e = 1.10× ← best
γ=3: acceptance=11.6%, speedup_e2e = 1.06×
γ=4: acceptance=8.9%, speedup_e2e = 1.02×
γ>4: speedup drops as acceptance falls faster than verify saves.
Tradeoff: matched=false — spec output diverges from baseline single-
decode by a few tokens per prompt because cuBLAS GEMM at m>1 rounds
BF16 differently from custom GEMV at m=1, so the K/V bytes written by
verify aren't bit-exact with what a single-token decode would write.
Downstream this compounds into slightly different token choices.
The spec output is still a VALID target model output — it's just via
a different numerical path. Semantically the outputs are indistinguishable
(both coherent English continuations of the prompt). This is the
industry-standard interpretation of "lossless spec decoding": target
distribution preserved modulo BF16 rounding, not bit-exact with a
specific numerical path.
New: crates/xserv-model/src/bin/bench-verify-cost.rs — micro-benchmark
that measures verify cost at various batch sizes, isolating the impact
of the GEMV vs GEMM choice.
Two subtle bugs found and fixed in the γ≥2 speculative loop:
1. Wrong position handling: cache.truncate_sequence(round_pos - 1) was
dropping the K/V of pending_prev, then verify OVERWROTE that slot with
the wrong token. Removed the truncate: verify now starts at
cache.seq_len (== position of pending_prev) and writes γ+1 tokens
forward. Also fixed EAGLE draft positions: pending_prev is at position
p, so step 0 uses position=p (not p+1).
2. EAGLE KV cache accumulated rejected drafts' K/V: each round writes γ
entries to EAGLE's cache regardless of how many drafts were accepted.
Added eagle.truncate_to(new_len) API. After each round, truncate to
eagle_len_before + k + 1 (pending_prev + k accepted drafts).
Also expose Eagle3Head::current_len() getter and Eagle3Head::truncate_to().
Additionally: return the PRE-norm hidden state as aux (matching vllm's
llama_eagle3.py default `norm_output=False`). Was returning the normed
version.
Result: matched=true across the full γ sweep. speedup_e2e remains <1:
γ=1 (single-decode verify): accept=22.7%, speedup=0.95x
γ=1 (batched verify): accept=20.6%, speedup=0.75x
γ=2: accept=12.6%, speedup=0.59x
γ=4: accept=7.6%, speedup=0.41x
γ=8: accept=4.1%, speedup=0.27x
Per-slot diagnostic shows d[0]≈15%, d[1]≈8%, d[2..γ-1] varies. d[0] is
lower than γ=1's 20% because batched verify introduces small numerical
differences vs single-token decode.
Larger γ hurts because:
- verify_cost scales roughly linearly with γ+1 (batched matmul at
batch=γ+1 costs ~γ+1× a single decode).
- accepted tokens per round grows sub-linearly (recursive EAGLE degrades).
- speedup ≈ (1 + accepted_avg) / verify_cost → below 1 across the sweep.
Path forward for speedup > 1 requires EITHER: (a) faster batched verify
(closer to single-decode cost per query row via better GPU utilization),
OR (b) better draft accuracy (tree-based drafting to explore multiple
candidates per position, larger EAGLE head, or a differently-trained
EAGLE variant).
Adds infrastructure for γ≥2 EAGLE speculative decoding:
qwen3.rs:
- New forward_verify_paged_decode_attention_with_hidden: same as the
existing verify but also captures target hidden states at 3 hook
layers, one per verify position. Needed to seed next round's EAGLE.
eagle3.rs:
- step split into step (unchanged public API) + step_with_aux (also
returns final hidden state) + step_recursive (takes fused_h directly,
no fc+3-hidden combine). This mirrors the EAGLE3 paper: γ=1 uses
target hooks + fc; γ≥2 uses previous EAGLE aux as fused_h for
subsequent drafts, approximating target hidden.
bench-eagle3.rs:
- New run_eagle_gamma_multi function with --gamma CLI (default 2).
- Per round: recursive EAGLE γ drafts, verify [prev_token, d0..d_{γ-1}]
in one target forward, accept longest prefix, correction via 1 more
target decode.
- max_seqs bumped to 16 in the paged cache so verify can batch up to
16 rows.
γ=2 test result (5 prompts × 32 tokens, dash5):
matched=false — sequences diverge
acceptance_rate = 29.8% at γ=2 (~1.1 tokens accepted per draft)
speedup_e2e = 0.52x (SLOWER than baseline)
The divergence bug is in the verify's re-writing of prev_token's K/V
at position round_pos-1. In principle matmul_batched_gemv at row-0
should be bit-exact with the seed decode's launch_gemv_bf16, but the
sequence output diverges so something is off. Investigation pending
(likely the correction decode step or seed_hooks position offset).
γ=1 path still works correctly (matched=true, acceptance 20%,
speedup 0.95x) from the previous commit. The γ≥2 path is scaffolded
but not yet correct — next step is to debug the verify-write path,
then measure real speedup.
Two fixes to bring EAGLE3 forward in line with vllm's llama_eagle3.py
reference:
1. Residual chain: previously the residual added into post_attention_layernorm
was the token embedding (wrong). Reference uses _norm_after_residual:
residual = fused_h (pre-norm)
hidden_states = hidden_norm(fused_h)
Then post_attention_layernorm is a fused add_rmsnorm(attn_out, residual),
and the final norm is another add_rmsnorm(mlp_out, residual_after_attn).
Neither residual carries the embedding — both carry fused_h forward.
2. KV cache: previously the attention was approximated as "output = V"
because seq_len=1 (no cache), effectively giving EAGLE no history.
Add a real per-Eagle3Head KV cache (1 layer × [1, num_kv_heads,
max_seq_len, head_dim] BF16) that grows as we call step(). Use the
existing decode_attention kernel with a fresh contiguous slice of the
cache each step. reset() clears current_len for a new sequence.
Result on 10 prompts × 32 tokens (γ=1, no batched verify yet):
matched=true across all prompts
acceptance_rate = 20.0% (was 4.7% before residual fix, 1.3% originally)
- Prompt 00 "The capital of France is": 60% (18/30) — best case
- Other prompts: 10-25% — matches EAGLE paper's observation that
structured/factual prompts get higher acceptance
Sanity check (check-eagle3) on Paris prompt now shows:
EAGLE top-5 pairing A: "." / " is" / "," / " Paris" / ".\n"
MATCH: EAGLE agrees with target on next token.
speedup_e2e still 0.95x because γ=1 does 1 target decode per token
regardless of acceptance. Real speedup requires γ≥2 with a single
batched target-verify covering all γ draft tokens; that's the next step.
bench-eagle3.rs runs the full loop: prefill → for each output token, one
EAGLE draft + one target decode with hidden state hook. Measures
acceptance rate and speedup vs pure target decode.
First numbers on dash5 (10 prompts × 32 tokens, γ=1):
matched=true (10/10)
acceptance_rate=1.3% (4/300) ← should be ~60-70% per EAGLE3 paper
speedup_e2e=0.95× ← below 1 because γ=1 does 1 target
decode per output token regardless of
acceptance
target_steps=320 for 320 tokens
Positive: the plumbing is correct — target/EAGLE both run without error,
output sequences match baseline, all shapes/dtypes check out. The
sanity check earlier showed EAGLE top-5 contains thematically-plausible
tokens (Paris/Tokyo/Madrid for "capital of France is").
Negative: 1.3% acceptance means EAGLE is not currently learning to match
target's greedy top-1. Root causes to investigate:
1. Token/hook pairing convention. Paper uses (h_that_produced_t_i, t_i)
→ predicts t_{i+1}. My bench does the same but sanity check earlier
suggested pairing might be one off.
2. Missing "training-time test" projection: EAGLE was trained to feed
its own prev output as fused_h for the next step (γ>1 chaining).
Currently we always use target hooks, which is what pairing A/B do
for γ=1, but may not be aligned with training-time behavior.
3. Hook site: I capture x AFTER the residual+MLP. Paper may want x
BEFORE, or the "hidden_states" as used by the final norm+lm_head.
Currently the same tensor feeds into final norm during the target
forward, so pre/post-residual is what I have — but confirming
against reference Python impl is needed.
4. Weight loading: transposes assume [in,out] → [out,in]. Need to
validate at least one output layer's shape against expected.
Next step (deferred to another session): download AngelSlim reference
inference code, run same prompt through it, compare intermediate
activations at each stage to isolate the discrepancy.
The initial [11, 23, 35] (equally-spaced) guess was wrong — EAGLE3 heads
are trained against specific target layer indices, and using different
ones at inference gives wrong outputs. Correct values come from vLLM
speculators' training config for Qwen3-8B:
https://github.com/vllm-project/speculators/blob/main/examples/train/
dflash_qwen3_8b_sharegpt_online_5k.sh
which pins target_layer_ids to "2 18 33". Re-running check-eagle3 with
the fix produces coherent top-5 for "The capital of France is":
Old ([11,23,35]): "," / " Paris" / " Madrid" / "." / " Berlin"
New ([2,18,33]): " Paris" / " Tokyo" / " Madrid" / "," / "."
Top-1 still differs from target's next token, but that's because EAGLE
compares (state_that_produced_prev, prev_token) → next, and the exact
pairing convention may need one more offset check when integrated into
the full speculative loop.
- eagle3.rs: Eagle3Head struct loads AngelSlim/Qwen3-8B_eagle3 safetensors,
runs a single draft step via fc(concat(h_low, h_mid, h_high)) +
concat(input_norm(emb), hidden_norm(fused_h)) → 1 midlayer → norm →
lm_head → argmax in draft_vocab(32000) → d2t → target_vocab.
- qwen3.rs: new decode_core_with_hidden method that mirrors decode_core
but captures hidden states at 3 configurable layer indices (default
[11, 23, 35] for the 36-layer Qwen3-8B). Also expose embed_tokens_tensor
and (in eagle3) map_draft_to_target as public accessors.
- loader.rs: make_tensor now pub(crate) so eagle3 can reuse it.
- bin/check-eagle3.rs: sanity binary that loads target + EAGLE, runs one
prefill + one decode + one EAGLE step, prints the top-5 EAGLE predictions.
Verified on dash5 with prompt "The capital of France is":
target says: " Paris" then "."
EAGLE top-5: "," / " Paris" / " Madrid" / "." / " Berlin"
Weights load correctly, d2t mapping works, hidden state hooks are the
right shape ([1, 4096]), and EAGLE produces thematically-relevant tokens.
The top-1 pick "," doesn't match target's "." at this position, but
that's expected: this test uses hidden states from a single decode step
with no recursive chaining. A full speculative loop still needs the
γ≥2 verify + accept path wired up (next step).
- Split Qwen3::forward_decode_paged into decode_prepare (host-side
block allocation + table upload) and decode_core (pure-GPU compute
reading token ids and positions from device buffers via
embedding_device_ids + rope_inplace_device_pos). This makes the
entire Qwen3 decode step CUDA-graph-capturable, mirroring the
gpt_oss.rs architecture.
- Add qwen3_graph.rs: Qwen3DecodeGraph + GraphedQwen3Decoder, a port
of the gpt_oss_graph.rs whole-step capture pattern. Lazy policy:
first decode eager (warms pool + cuBLAS), second captures, rest
replay. Batch>1 always falls back to eager.
- Wire GraphedQwen3Decoder into bench-speculative's draft decode path;
all 4 draft.forward_decode_paged call sites + replay_draft_tokens
now route through the graphed decoder. Per-benchmark caches persist
across prompts for graph reuse.
- Gamma sweep result (10 prompts × 32 tokens, --use-verify-logits):
γ=1 → 0.57×, γ=2 → 0.57×, γ=4 → 0.49×, γ=6 → 0.41×, γ=8 → 0.36×.
All matched=true, verify_decode_mismatches=0.
Acceptance drops sharply with γ (66% → 40% → 25%) because Qwen3-0.6B
is too inaccurate a draft for Qwen3-8B. Speedup still <1.
Current ceiling analysis: verify costs ~13ms (same as one target decode)
so speculative decoding only wins if acceptance × (tokens/round) >>
(draft_cost + verify_cost) / baseline_decode. With this draft model,
the crossover requires either (a) a much smaller verify cost (batch-GEMM
path, which trades correctness), or (b) a fundamentally better drafter
(EAGLE-style heads, or n-gram lookup).
Add launch_gemv_bf16_batched: runs M m=1 GEMVs in a single 3D grid
launch (z = batch row) with numerically identical output to M sequential
launch_gemv_bf16 calls — same K-block partial accumulation, same
fixed-order reduction. Verified on dash5 with 10 prompts × 32 tokens:
matched=true, verify_decode_mismatches=0.
Expose as matmul_batched_gemv(a: [M,K], b: [K,N]) → [M,N] in
xserv-kernels. Replace the old matmul_rows_gemv helper in qwen3
forward_verify_paged_decode_attention; the per-row loop over matmul_2d +
concat_rows is replaced by a single matmul_batched_gemv call that
allocates the partials buffer in one shot and launches 2 kernels instead
of 2*M.
Current speedup_e2e is 0.47× (same ballpark as Phase 23 0.44×);
the batched launch saves ~3 ms overhead but this is small relative to
the total 28 ms spec cost. The path forward (per docs/24 §4) is
higher acceptance rate or cheaper draft, not further kernel optimization.
Three related hardening changes for the API surface:
- validate_request rejects NaN/negative temperature, out-of-range top_p,
and absurd top_k before those values reach the CUDA sampling paths.
Prevents NaN logits from downstream sampling and matches typical
OpenAI-compatible server behavior (400 instead of 500).
- normalize_finish_reason maps engine strings to the OpenAI-standard
subset. Currently only "error" (from tp/pp engine client-stall) needs
normalization — it collapses to null so SDK clients see a clean stream
close instead of an unknown finish_reason value. Applied to both
streaming (SSE) and non-streaming JSON responses.
- Replace the unbounded std::sync::mpsc engine channel with a bounded
sync_channel(256) and switch submit_to_engine to try_send. A saturated
engine now returns 503 "engine is busy" instead of letting requests
pile up in RAM. Also add axum DefaultBodyLimit(4 MiB) so a malicious
or misbehaving client cannot exhaust memory with an arbitrary JSON POST.
partial_cmp().unwrap() in the top-k / top-p sort and softmax paths would
panic the engine thread on a single NaN logit. The greedy argmax path
is already NaN-safe. Add a one-pass NaN → -inf sweep on the extracted
last_row before temperature scaling, which is equivalent to masking the
token and keeps the sampler deterministic. Warn once when triggered so
the underlying numeric bug isn't silently hidden.
Interactive REPL used to always call sample_greedy_last on both the
paged and legacy KV paths, so temperature/top-k/top-p and the repetition
penalty added in the sampling module were unreachable from the CLI.
- flag() helper parses --max-tokens / --temperature / --top-k / --top-p
/ --rep-penalty / --rep-window (defaults preserve prior behavior:
temperature 0, top-p 1, penalty 1, window 512).
- pick_next() dispatches to sample_greedy_penalized only when
temperature==0 and rep_penalty>1, otherwise to sample().
- Both Qwen3/GPT-2 paths and the GptOss paged path share the same
sampler and both feed the rolling history window used for the penalty.
- Prompt input now unescapes literal "\n" so multi-turn prompts can be
typed on one line.
Phase 22 lands a correctness-only speculative decoding loop for Qwen3
target + Qwen3 small draft (batch=1, greedy, gamma=4). Phase 23 turns
verify logits into the authoritative acceptance signal so mirror-decode
per accepted token is no longer needed.
- paged_kv_cache: truncate_sequence(slot, new_len) shrinks a registered
sequence, freeing whole physical blocks no longer reachable and
leaving the slot registered. Covered by a CUDA-gated unit test.
- qwen3: forward_verify_paged_decode_attention writes the draft window
into the target cache, runs the same paged decode attention kernel per
draft token, and uses matmul_rows_gemv so linear layers follow the
single-token decode BF16 rounding path.
- bench-speculative: new bench binary drives the state machine with
--gamma / --gen-tokens / --prompts / --use-verify-logits /
--verify-path flash|paged-decode / --dump-verify-mismatches, and
compares baseline vs spec token sequences plus TPOT / tok/s / speedup.
- docs/22 records the decode-authoritative v0 result and dash5 numbers
(matched=true, speedup_e2e ~0.29x, verify_decode_mismatches>0 under
--use-verify-logits).
- docs/23 records the paged-decode verify path (matched=true,
verify_decode_mismatches=0, 50x64 speedup_e2e ~0.44x) and the
next-step performance TODO.
BF16 greedy decode was sensitive to inter-block scheduling when logits
were close, which broke speculative-decoding verify-vs-decode parity.
- gemv.cu: write per-K-block partials, then reduce in fixed block order
in a second kernel instead of atomicAdd across K-blocks. Scratch
buffer size is now n * ceil(k / GEMV_TILE_K); gemv_scratch_elems()
exposes this to callers, and decode_graph.rs sizes fp32_hidden/q/kv/
intermediate/vocab from it.
- paged_attention.cu: replace atomicAdd merge of warp outputs with
per-warp shared partials reduced in warp-id order for both the base
and sinks kernels.
All three engines emitted tokens with blocking_send on the single
decode/coordinator OS thread. A streaming client that drains slower than
generation fills its 64-slot channel, and blocking_send then blocks the whole
thread: under continuous batching one slow consumer stalls every other running
sequence (and in the serial TP/PP path it blocks admission of the next request
too). The whole point of continuous batching is defeated.
Fix: switch to try_send. engine.rs sets a client_stalled flag on Full/Closed,
reaped by is_finished() next iteration; tp_engine/pp_engine emit_text returns
bool and the decode loop breaks with finish_reason "error". When the
sequence/request is dropped its sender drops too, closing the channel so the
client receive loop ends rather than hanging. A slow client now only loses its
own sequence, never the batch.
Verified on dash5: gpt-oss FP8 TP=1 streaming via tp_engine still streams
correctly (SSE chunks, coherent content, no hang); bench-gpt-oss TP=2 5.9ms
TPOT unchanged.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
sample() at temperature 0 copied the full [seq, 201088] BF16 logits
to the host and scanned them every token (~1 ms/token). Use the
Phase 15 argmax kernel (block reduction + 4-byte D2H) when logits are
BF16 on GPU; bench-gpt-oss's greedy sampler likewise. Exact-tie
logits may break differently than the host scan — greedy trajectories
can legitimately diverge at a tie token (GSM8K unchanged).
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Split forward_decode_paged into host bookkeeping (decode_prepare +
ids/pos upload + advance_seq_len) and a pure-GPU decode_core. The
paged-KV and sparse-MoE designs already read every per-step variable
(block table, context lens, expert ids) from stable-address device
buffers, so decode_core captures as-is.
GptOssDecodeGraph captures lazily on the second decode step (the
first eager step warms cuBLAS) after a "retained warmup": the step
runs once with the allocator quarantine on, stocking the pool with a
dedicated block for every allocation so the capture itself never
pool-misses (a cudaMalloc while capturing is illegal — and the
capture's own quarantine is what would otherwise starve the pool).
NCCL all-reduces capture cleanly; TP=2 replays in lockstep.
Wired into tp_engine, bench-gpt-oss, and xserv-chat via
GraphedGptOssDecoder (batch>1 falls back to eager;
XSERV_DECODE_GRAPH=0 disables). Greedy tokens identical to eager.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
- Thread-local launch stream (xserv_cuda::stream): every kernel
wrapper, cublasSetStream, and NCCL collective now launches on
current_stream_raw() — the legacy null stream by default (behavior
unchanged), or the capture stream installed via push_stream during
graph capture. Capture is impossible on the legacy stream.
- Allocator retain mode: blocks freed inside a retain window are
quarantined (RetainedBlocks) instead of pooled, so an instantiated
graph keeps exclusive ownership of every intermediate buffer it
references across replays.
- Capture mode GLOBAL -> THREAD_LOCAL: concurrent TP rank threads
must not poison each other's captures with their own cudaMallocs.
- embedding_device_ids / rope_inplace_device_pos: variants reading
token ids / positions from persistent device buffers, replacing the
per-call host upload that a captured region cannot contain.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
- --pp with gpt-oss now fails with a clear message instead of a
cryptic missing-weight panic inside the Qwen3-only PP engine.
- Sparse GEMV wrappers assert K%16==0 (FP8) / K%32==0 (MXFP4) — the
uint4-vectorized kernels would silently drop a tail otherwise.
- Document the topk_ids buffer holding i32 under an F32 dtype label
(DType has no I32).
- Drop unused imports/locals and the cuBLASLt scale-mode constants
orphaned by the strided-batched FP8 rework (e631a71).
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Decode carried three leftover cudaDeviceSynchronize (prefill one) from
NaN debugging — the Qwen3 path has none and the logits D2H in sample()
already orders against the null stream.
add_bias for rows>1 round-tripped the bias through the CPU (D2H + host
tile loop + H2D) on every call — 96 times per prefill across q/k/v/o.
Replace with a bias_add_2d broadcast kernel.
dash5, FP8 TP=2, warm-server: TTFT 35/49/94 -> 29/42/79 ms
(short/medium/long), TPOT 7.19-7.32 -> 6.99-7.21 ms. Greedy tokens
unchanged; GSM8K-50 94%.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
gpt-oss has no single-GPU engine path, so --tp 1 fell through to the
Qwen3-only engine and every request 503'd. Route gpt_oss to run_tp
even at tp=1: NCCL world-1 init works and all_reduce already no-ops
(bench-gpt-oss --tp 1 exercised this path). Quantized gpt-oss (22 GB
FP8 / 13 GB MXFP4) now serves on one 32 GB 5090.
Also fix eval_gsm8k_fast.py --gpu to accept a device list ("2,3"):
it was type=int, so any --tp 2 run pinned CUDA_VISIBLE_DEVICES to one
GPU and rank 1's set_device panicked while rank 0 spun in NCCL init.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Dense MoE replicated x across all 16 local experts and ran the full
batched GEMM, reading every expert's weights per token; the weighted
sum then discarded 12 of 16 results. Decode is memory-bound, so this
was ~8x wasted expert bytes — the entire decode gap vs llama.cpp.
New fused expert-indexed GEMVs (csrc/moe/moe_sparse.cu) read
topk_ids on-device (no host sync) and early-return block-uniformly
for experts other ranks own. FP8 runs W8A16 (activations stay BF16 —
tensor cores are irrelevant at M=1, and activation quantization error
disappears); MXFP4 runs W4A16. Per-expert bias + scale fused into the
GEMV epilogue; slot-indexed weighted sum skips (never multiplies)
unwritten non-local slots. Dense path retained for num_tokens > 8
(prefill) and via XSERV_DENSE_MOE=1 for A/B.
dash5 (RTX 5090), gpt-oss-20b FP8, TP=2: decode TPOT 13.9 -> 7.6 ms.
Warm-server vs llama.cpp MXFP4 TP=2: TPOT 7.19-7.32 vs 7.54-8.42 ms —
first config where xserv wins decode outright. GSM8K-100: 96% (dense
FP8: 91%). llama TP=1 (2.9 ms) remains ahead: next levers are decode
CUDA graphs, non-expert quantization, sparse prefill (docs/20).
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Weight-only 4-bit for the gpt-oss MoE experts: weights stored MXFP4 (E2M1 +
per-32-element UE8M0 block scale, tools/quantize_mxfp4.py), a fused kernel reads
the 4-bit weights and dequantizes on-chip to BF16. Decode (M=1) uses a fused
dequant-GEMV (batched_gemv_mxfp4) with shared-memory activation tiling; prefill
(M>1) dequantizes to BF16 then reuses the BF16 batched GEMM. MXFP4 is detected
by the scale tensor's rank (3-D [E,N,K/32]) vs FP8's 1-D [E].
Verified on dash5 (gpt-oss-20b, TP=2, 5090): byte-identical greedy tokens to
FP8/BF16, smallest footprint (13 GB vs 22 GB FP8, 39 GB BF16) — fits one 32 GB
5090 with room for KV cache.
NOT a decode speedup: the hand-written W4A16 GEMV (no tensor cores) is less
efficient than cuBLASLt's FP8 tensor-core GEMM, so even at half the weight bytes
decode is 17.0 ms vs FP8 13.5 ms (faster than BF16 18.8 ms); prefill regresses
(350 vs 134 ms, dequant fallback). Committed as a correct memory-optimization
foundation. Beating FP8 on speed needs FP4 tensor cores (W4A4, cuBLASLt
block-scaled MXFP4) or a Marlin-class kernel; see
docs/benchmarks/mxfp4-and-llama-decode.md.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The plan-cache fix removed the per-expert heuristic churn but still issued one
cublasLtMatmul per expert: ~768 tiny launches per decoded token (16 local
experts × 2 GEMMs × 24 layers), which capped the FP8 decode win at ~1.05× over
BF16. Collapse each MoE GEMM into ONE strided-batched cuBLASLt FP8 matmul
(BATCH_COUNT + strided-batch offsets on all four layouts) → ~48 launches/token.
A single strided call can't carry a per-batch scalar B-scale, so the per-expert
weight scale moves out of the GEMM epilogue into a fused post-scale kernel
(rowwise_scale_moe_bf16) that applies a_scale[token]·b_scale[expert] in one
pass. This is precision-equivalent: BF16's relative error is scale-invariant, so
scaling the unscaled GEMM output afterward loses nothing vs scaling in-epilogue.
Measured on dash5 (gpt-oss-20b, TP=2, 5090), warm-server GSM8K:
decode TPOT 17.45 → 13.08 ms (FP8 now 1.41× vs BF16 18.39 ms),
throughput 57.3 → 76.4 tok/s, accuracy unchanged (FP8 91.0% vs BF16 90.0%).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
batched_gemm_fp8 rebuilt the cuBLASLt matmul descriptor, four matrix
layouts, a preference, and a 4-byte scale alloc, AND ran the algo
heuristic search — once per expert, per GEMM, per layer, on every
forward (~1500 heuristic searches per decoded token). FP8 decode ran at
27.0 ms/tok vs BF16 18.8 ms, i.e. slower than the path it was meant to
accelerate.
Cache the full plan (descriptor + layouts + heuristically-chosen algo)
in a thread-local map keyed by (M, N, K) so the heuristic runs once per
shape and is reused across experts and forwards; allocate the 1.0 scale
buffer once; pass each expert's weight scale via the cuBLASLt B-scale
device pointer instead of folding it into alpha (identical FP32-epilogue
precision, and no host readback of b_scales). The per-expert loop now
issues only cublasLtMatmul.
Measured on dash5 (gpt-oss-20b, TP=2, 5090): FP8 decode TPOT 27.0 -> 17.9
ms, now faster than BF16 (18.8 ms); GSM8K-200 accuracy unchanged
(FP8 93.0% vs BF16 90.5%, within noise).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Replace the W8A16 dequant→BF16-GEMM path with native FP8×FP8→BF16 GEMM
using cuBLASLt on Blackwell (RTX 5090). Both weights (static FP8 E4M3)
and activations (dynamically quantized per-row) are processed directly
on FP8 tensor cores.
Key implementation details:
- cuBLASLt on Blackwell requires transA=T for FP8, so expert weights
are transposed during model loading ([E,K,N] → [E,N,K])
- Per-row activation quantization kernel (absmax/448 → FP8 E4M3)
- Post-GEMM row-wise rescaling recovers per-token precision
- Per-expert loop (not batched) due to cuBLASLt FP8 scale constraints
The same FP8 quantized model files work — no re-quantization needed.
Activation quantization happens dynamically at inference time.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Close the <think> block when EOS or max_tokens interrupts an analysis
channel, and flush stdout after each analysis token so --think streams
smoothly instead of dumping in buffer-sized chunks.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The gpt-oss harmony `analysis` channel is the model's reasoning, analogous
to Qwen3's <think>. With --think, wrap it in a `<think>\n…\n</think>\n\n`
block (gray when color is on, like Qwen3) and then print the final-channel
answer; without --think, suppress the analysis and show only the answer.
Replaces the previous color-gated behavior (analysis shown gray only on a
TTY, with no markers). Analysis is still excluded from the multi-turn
history (answer_ids), so re-prefill drops CoT as before.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Re-render the whole conversation each turn and re-prefill into a freshly
cleared slot, with past assistant messages rendered as completed `final`
channels (analysis dropped, terminated with <|end|> not the <|return|>
stop token) — matching the model's training format and the server's
builder. The previous incremental cache kept every turn's chain-of-thought
plus <|return|> in context, which is out of distribution for harmony
multi-turn. The generator now returns the final-channel text to feed back
as history. Qwen3 keeps the incremental cache (its ChatML format is
unaffected); reset_slot factors out the free+re-register.
NOTE: this corrects the multi-turn *format* but does NOT cure the
long-context collapse on some inputs. That is a forward-pass numerical bug
(NaN / degenerate logits) reproducible in clean bench-gpt-oss independent
of the chat layer — the collapse token is vocab_size-1 (201087), the
all-NaN argmax tie-break. Tracked separately.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
build_prompt_gpt_oss (the hardcoded builder used when a gpt-oss model
ships no Jinja chat template) emitted the same malformed "You are a
helpful assistant." system message that destabilized the CLI. Replace it
with the canonical harmony system message (identity / knowledge cutoff /
current date via strftime_now / Reasoning: low / channels), matching the
chat_template.jinja build_system_message macro and the xserv-chat fix.
Dormant for gpt-oss-20b (it ships a Jinja template, so the template path
runs), but correct now for any gpt-oss model without one.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The hand-rolled gpt-oss system message dropped the canonical harmony
structure (identity / knowledge cutoff / current date / Reasoning level),
putting the model out of distribution — greedy decoding then flipped into
garbage or analysis loops on ~half of single-turn requests. Emit the
canonical system message (matching the model's chat_template.jinja
build_system_message macro) with Reasoning: low, plus a today_ymd() date
helper.
Also:
- Default the repetition penalty to off (1.0). Penalizing the harmony
control tokens (<|channel|>/<|message|>/<|start|>) that must repeat to
open the final channel made gpt-oss stop right after analysis, emitting
nothing.
- Suppress the literal "assistant" role header emitted between the
analysis and final channels (only print in the final channel, moe only;
non-moe Qwen3 stays in Normal and prints as before).
Verified on dash5 (TP=2): single-turn "capital of France" is now stable
across runs with a clean final answer; Qwen3 chat unaffected.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- Let the model generate its own <|channel|> routing instead of forcing
<|channel|>final<|message|> — matches the GGUF chat template behavior.
- State machine tracks harmony channels: analysis channel rendered gray,
final channel printed normally, <|end|> stops on final channel only.
- Add repetition penalty (default 1.3 for MoE, 1.0 for Qwen) with 512
token window to prevent greedy decode loops. Configurable via
XSERV_REP_PENALTY and XSERV_REP_WINDOW env vars.
- Fix Length path: use <|end|> instead of <|im_end|> for gpt-oss to
avoid poisoning the KV cache with garbage tokens on truncation.
- Server api.rs: append <|channel|>final<|message|> to the hardcoded
gpt-oss prompt (server expects to post-process the JSON output).
- Add startswith filter to minijinja for harmony template compatibility.
Known issue: gpt-oss multi-turn NaN when total context exceeds ~256
tokens — likely a flash_attention_sinks kernel bug with sliding window
layers at large kv_len + small q_len. Single-turn and short multi-turn
conversations work correctly.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The gpt-oss harmony format generates internal control tokens
(<|channel|>, <|start|>, <|end|>, <|message|>) that should not appear
in the user-facing output. Additionally, <|end|> marks the end of a
response segment but was not in the model's EOS list, causing the
model to self-prompt into analysis channels and loop.
Fix: treat <|end|> as a stop token, skip all harmony special tokens
from the output stream.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add ChatModel enum dispatching between Qwen3 and GptOss based on
config.is_moe(), following the TP engine pattern.
- Add --tp N flag for tensor-parallel inference (required for 39GB
gpt-oss-20b which doesn't fit on a single 32GB GPU).
- Add gpt-oss harmony chat template with channel/message format.
- Replace hardcoded is_stop_token() with tokenizer.is_eos() for
multi-model EOS support.
- Restore gpt-oss hardcoded prompt template in server api.rs, lost
during the Jinja template refactor.
- Fix GEMV race condition: the K-split kernel zeroed the FP32
accumulator inside the kernel (block k=0) while other blocks
atomicAdd'd concurrently. Pre-zero with cudaMemsetAsync instead.
- Update benchmark docs with post-fix results.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Load the model's chat_template.jinja (or tokenizer_config.json
chat_template field) at startup and render it with minijinja instead of
hardcoded per-model prompt builders.
Custom Jinja functions: strftime_now (date formatting), raise_exception
(template validation errors). Falls back to Qwen3 ChatML template if
no Jinja template is found.
Removes the hardcoded build_prompt_gpt_oss() — the model's own template
now drives prompt formatting, matching llama.cpp's behavior exactly.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace the per-token CPU-routed MoE forward with an all-GPU path:
1. moe_topk_softmax: GPU top-k + softmax (was CPU sort + softmax)
2. moe_replicate: broadcast input to all local experts
3. cublasGemmStridedBatchedEx: batched expert matmul (was per-expert cuBLAS)
4. moe_weighted_sum: FP32-accumulated weighted sum on GPU (was GPU→CPU→F32→BF16→GPU)
Expert weights stored as contiguous 3D tensors for strided batched GEMM.
Zero CPU↔GPU transfers per MoE layer (was ~40 per token per layer).
Also: configurable geglu_alpha, LayerNorm bias auto-detect, unused-weight
diagnostic at load time.
GSM8K 30-problem: 11/30 → 23/30 (76.7%) vs llama.cpp 30/30 (100%).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Parse the model's `pre_tokenizer` section to extract its Split regex
instead of hardcoding the GPT-2 pattern. The gpt-oss-20b model uses
a GPT-4-style regex that produces different word boundaries, causing a
1-token prompt mismatch vs HuggingFace (136 → 135 tokens, now aligned).
Unsupported lookahead `(?!\S)` is stripped — it only affects trailing
whitespace edge cases. Falls back to the old GPT-2/Qwen heuristic if
the model regex fails to compile or is absent.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add --prompt to override the fixed prompt, and two teacher-forced
diagnostics: --forced runs prefill over prompt+oracle ids and reports
per-position top-1 agreement; --forced-decode walks the oracle trajectory
through the decode path with per-position agreement bucketed by position,
to localize long-context decode divergence from the reference.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Route caller-supplied system messages into a harmony 'developer'
instructions block (<|start|>developer<|message|># Instructions...),
keeping the fixed system/meta block for the channel declaration. Harmony
puts user instructions on the developer role, not system.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Use tokenizer.is_eos() (multi-eos) for generation termination in both PP
and TP engines instead of a single eos id, so gpt-oss stops on <|return|>
/<|call|>/<|endoftext|>.
In the TP engine, optionally apply a repetition penalty on the greedy
decode path (XSERV_REP_PENALTY>1 over XSERV_REP_WINDOW recent tokens; off
by default) to break greedy repetition loops.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Make argmax skip NaN logits (warn once) instead of panicking the engine
thread on a single NaN. Add sample_greedy_penalized() applying an
HF-style repetition penalty over recent ids on the greedy path, to break
greedy repetition loops on reasoning models without touching the forward
pass.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Track an ordered eos_token_ids list (not just one id) and add is_eos().
gpt-oss/harmony ends the assistant turn on <|return|> and also treats
<|call|> and <|endoftext|> as terminators (generation_config.json
eos_token_id = [200002, 199999, 200012]); single-eos families are
unchanged.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Add flash_attention_sinks_bf16 prefill kernel that folds the per-head
attention sink into the softmax denominator (exactly as the decode sink
kernel) and supports an optional sliding-window mask matching HF gpt-oss.
Wire it through xserv-kernels (flash_attention_sinks) and use it in
GptOss prefill, replacing the post-hoc sink approximation for an exact
match against the reference math.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The gpt-oss model requires a specific prompt format with <|start|>,
<|message|>, <|end|>, <|channel|> tokens. Without this, the model
produces degenerate output. Auto-detected via config.model_type.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- tp_engine.rs: TpModel enum dispatches between Qwen3 and GptOss based on
config.is_moe(). Server auto-detects model type on startup.
- tools/run_gpt_oss_bench.sh: one-click benchmark comparing xserv (TP=2)
vs llama.cpp (BF16 GGUF) on GSM8K quality + speed
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The custom launch_gemv_bf16 kernel produces NaN when output dimension N
is small (e.g. N=32 for the MoE router). Fall back to cuBLAS GemmEx for
N < 256. Also removes the padding workaround in gpt_oss MoE forward.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>