Compare commits
4 Commits
a67753f516
...
fcf531a9b2
| Author | SHA1 | Date | |
|---|---|---|---|
| fcf531a9b2 | |||
| d96ee0766c | |||
| ce10e4a998 | |||
| 5f060902f6 |
@@ -40,7 +40,7 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
|
||||
let logits_cpu = logits.to_device(Device::Cpu);
|
||||
|
||||
// Extract last row as f32
|
||||
let last_row: Vec<f32> = match logits.dtype() {
|
||||
let mut last_row: Vec<f32> = match logits.dtype() {
|
||||
DType::F32 => {
|
||||
let data = logits_cpu.as_slice::<f32>();
|
||||
data[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec()
|
||||
@@ -60,6 +60,20 @@ pub fn sample(logits: &Tensor, params: &SamplingParams) -> u32 {
|
||||
return argmax(&last_row);
|
||||
}
|
||||
|
||||
// NaN-safe: sampling path uses partial_cmp().unwrap() in top-k/top-p
|
||||
// sorts and softmax; a single NaN logit would panic the engine thread.
|
||||
// Replace NaN with -inf (equivalent to masking) instead.
|
||||
let mut nan_seen = false;
|
||||
for v in last_row.iter_mut() {
|
||||
if v.is_nan() {
|
||||
nan_seen = true;
|
||||
*v = f32::NEG_INFINITY;
|
||||
}
|
||||
}
|
||||
if nan_seen {
|
||||
eprintln!("[sampling] WARNING: NaN logits encountered in sample()");
|
||||
}
|
||||
|
||||
// Apply temperature
|
||||
let mut logits_f32: Vec<f32> = last_row.iter().map(|v| v / params.temperature).collect();
|
||||
|
||||
|
||||
@@ -331,6 +331,10 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
||||
}
|
||||
}
|
||||
|
||||
let fr_value = match normalize_finish_reason(&finish_reason) {
|
||||
Some(s) => serde_json::Value::String(s.to_string()),
|
||||
None => serde_json::Value::Null,
|
||||
};
|
||||
Json(serde_json::json!({
|
||||
"id": id,
|
||||
"object": "chat.completion",
|
||||
@@ -339,7 +343,7 @@ async fn chat_non_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": { "role": "assistant", "content": content },
|
||||
"finish_reason": finish_reason,
|
||||
"finish_reason": fr_value,
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_token_count,
|
||||
@@ -412,8 +416,11 @@ fn chat_stream(state: Arc<AppState>, req: ChatRequest) -> Response {
|
||||
make_chunk(&id, &model_name, created, None, Some("assistant"), None);
|
||||
let _ = sse_tx.send(Ok(Event::default().data(chunk))).await;
|
||||
}
|
||||
let chunk =
|
||||
make_chunk(&id, &model_name, created, None, None, Some(&finish_reason));
|
||||
// Only "stop" and "length" are OpenAI-standard values. Internal
|
||||
// codes like "error" (client-stalled from tp/pp engine) map to
|
||||
// null so SDK clients see a clean stream close.
|
||||
let fr = normalize_finish_reason(&finish_reason);
|
||||
let chunk = make_chunk(&id, &model_name, created, None, None, fr);
|
||||
let _ = sse_tx.send(Ok(Event::default().data(chunk))).await;
|
||||
let _ = sse_tx
|
||||
.send(Ok(Event::default().data("[DONE]".to_string())))
|
||||
@@ -442,6 +449,22 @@ fn validate_request(req: &ChatRequest, model_name: &str) -> Option<Response> {
|
||||
return Some(bad_request("max_tokens must be greater than 0"));
|
||||
}
|
||||
|
||||
if let Some(t) = req.temperature {
|
||||
if !t.is_finite() || t < 0.0 {
|
||||
return Some(bad_request("temperature must be a finite value >= 0"));
|
||||
}
|
||||
}
|
||||
if let Some(p) = req.top_p {
|
||||
if !p.is_finite() || !(0.0..=1.0).contains(&p) {
|
||||
return Some(bad_request("top_p must be in [0, 1]"));
|
||||
}
|
||||
}
|
||||
if let Some(k) = req.top_k {
|
||||
if k > 1_000_000 {
|
||||
return Some(bad_request("top_k must be <= 1_000_000"));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
@@ -453,9 +476,14 @@ fn submit_to_engine(state: &AppState, req: GenerateRequest) -> Result<(), Respon
|
||||
.engine_sender
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
sender
|
||||
.send(req)
|
||||
.map_err(|_| service_unavailable("inference engine is not available"))
|
||||
sender.try_send(req).map_err(|err| match err {
|
||||
std::sync::mpsc::TrySendError::Full(_) => {
|
||||
service_unavailable("inference engine is busy, retry later")
|
||||
}
|
||||
std::sync::mpsc::TrySendError::Disconnected(_) => {
|
||||
service_unavailable("inference engine is not available")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn service_unavailable(message: impl Into<String>) -> Response {
|
||||
@@ -532,3 +560,14 @@ fn sampling_params(req: &ChatRequest) -> SamplingParams {
|
||||
top_p: req.top_p.unwrap_or(1.0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Map engine finish_reason strings to OpenAI-standard values. Any engine-internal
|
||||
/// code (e.g. "error" from tp/pp client-stall) collapses to None so SDK clients see
|
||||
/// a clean null instead of an unknown value.
|
||||
fn normalize_finish_reason(fr: &str) -> Option<&'static str> {
|
||||
match fr {
|
||||
"stop" => Some("stop"),
|
||||
"length" => Some("length"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -396,9 +396,12 @@ fn emit_token(tokenizer: &Tokenizer, seq: &mut Sequence, token_id: u32) {
|
||||
if tokenizer.eos_token_id() == Some(token_id) {
|
||||
let tail = tokenizer.flush_decode_stream(&mut seq.decode_buffer);
|
||||
send_token_if_nonempty(seq, tail);
|
||||
try_send_event(seq, GenerateEvent::Done {
|
||||
finish_reason: "stop".to_string(),
|
||||
});
|
||||
try_send_event(
|
||||
seq,
|
||||
GenerateEvent::Done {
|
||||
finish_reason: "stop".to_string(),
|
||||
},
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -407,9 +410,12 @@ fn emit_token(tokenizer: &Tokenizer, seq: &mut Sequence, token_id: u32) {
|
||||
let tail = tokenizer.flush_decode_stream(&mut seq.decode_buffer);
|
||||
send_token_if_nonempty(seq, text);
|
||||
send_token_if_nonempty(seq, tail);
|
||||
try_send_event(seq, GenerateEvent::Done {
|
||||
finish_reason: "length".to_string(),
|
||||
});
|
||||
try_send_event(
|
||||
seq,
|
||||
GenerateEvent::Done {
|
||||
finish_reason: "length".to_string(),
|
||||
},
|
||||
);
|
||||
} else {
|
||||
send_token_if_nonempty(seq, text);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ mod tp_engine;
|
||||
|
||||
use axum::{
|
||||
Extension, Router,
|
||||
extract::DefaultBodyLimit,
|
||||
routing::{get, post},
|
||||
};
|
||||
use engine::GenerateRequest;
|
||||
@@ -15,7 +16,7 @@ use xserv_model::ModelConfig;
|
||||
pub struct AppState {
|
||||
pub model_name: String,
|
||||
pub chat_template: api::ChatTemplate,
|
||||
pub engine_sender: Mutex<mpsc::Sender<GenerateRequest>>,
|
||||
pub engine_sender: Mutex<mpsc::SyncSender<GenerateRequest>>,
|
||||
pub engine_tokenizer: Mutex<xserv_tokenizer::Tokenizer>,
|
||||
pub max_seq_len: usize,
|
||||
}
|
||||
@@ -104,8 +105,10 @@ async fn main() {
|
||||
|
||||
let tokenizer = xserv_tokenizer::Tokenizer::from_file(&model_dir.join("tokenizer.json"));
|
||||
|
||||
// Unbounded channel: allows multiple requests to queue up
|
||||
let (tx, rx) = mpsc::channel::<GenerateRequest>();
|
||||
// Bounded channel to backpressure incoming requests when the engine falls
|
||||
// behind, instead of letting them pile up in RAM. try_send in the API
|
||||
// handler surfaces this as 503 to the client.
|
||||
let (tx, rx) = mpsc::sync_channel::<GenerateRequest>(256);
|
||||
|
||||
let model_dir_clone = model_dir.clone();
|
||||
std::thread::spawn(move || {
|
||||
@@ -140,6 +143,7 @@ async fn main() {
|
||||
.route("/health", get(api::health))
|
||||
.route("/v1/models", get(api::list_models))
|
||||
.route("/v1/chat/completions", post(api::chat_completions))
|
||||
.layer(DefaultBodyLimit::max(4 * 1024 * 1024))
|
||||
.layer(Extension(state));
|
||||
|
||||
let addr = format!("0.0.0.0:{port}");
|
||||
|
||||
@@ -318,7 +318,12 @@ pub fn run_pp(
|
||||
/// Returns false if the send would block (client too slow) or the client is
|
||||
/// gone — the caller stops generating so the coordinator thread is free to
|
||||
/// admit the next request instead of blocking on one slow consumer.
|
||||
fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: &mut Vec<u8>) -> bool {
|
||||
fn emit_text(
|
||||
tokenizer: &Tokenizer,
|
||||
req: &GenerateRequest,
|
||||
token_id: u32,
|
||||
buf: &mut Vec<u8>,
|
||||
) -> bool {
|
||||
if tokenizer.is_eos(token_id) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -346,7 +346,12 @@ pub fn run_tp(
|
||||
/// Returns false if the send would block (client too slow) or the client is
|
||||
/// gone — the caller stops generating so the serial coordinator thread is free
|
||||
/// to admit the next request instead of blocking on one slow consumer.
|
||||
fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: &mut Vec<u8>) -> bool {
|
||||
fn emit_text(
|
||||
tokenizer: &Tokenizer,
|
||||
req: &GenerateRequest,
|
||||
token_id: u32,
|
||||
buf: &mut Vec<u8>,
|
||||
) -> bool {
|
||||
if tokenizer.is_eos(token_id) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -15,7 +15,10 @@ __global__ void causal_mask_f32(
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (col < cols && col > row + offset) {
|
||||
scores[batch_idx * rows * cols + row * cols + col] = -INFINITY;
|
||||
// 64-bit index: batch * rows * cols overflows int32 at moderate batch
|
||||
// and long context (e.g. batch=128 * heads=28 * seq=32768).
|
||||
long long idx = ((long long)batch_idx * rows + row) * cols + col;
|
||||
scores[idx] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +31,8 @@ __global__ void causal_mask_bf16(
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (col < cols && col > row + offset) {
|
||||
scores[batch_idx * rows * cols + row * cols + col] = __float2bfloat16(-INFINITY);
|
||||
long long idx = ((long long)batch_idx * rows + row) * cols + col;
|
||||
scores[idx] = __float2bfloat16(-INFINITY);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -464,7 +464,7 @@ __global__ void decode_attention_bf16_kernel(
|
||||
// Shared memory for reduction
|
||||
__shared__ float smem_max[32]; // one per warp
|
||||
__shared__ float smem_sum[32];
|
||||
__shared__ float smem_O[HEAD_DIM_MAX]; // final output accumulator
|
||||
__shared__ float smem_O_warp[32][HEAD_DIM_MAX];
|
||||
|
||||
// Step 1: Block-wide max reduction
|
||||
int lane = tid & 31;
|
||||
@@ -513,35 +513,30 @@ __global__ void decode_attention_bf16_kernel(
|
||||
__syncthreads();
|
||||
global_sum = smem_sum[0];
|
||||
|
||||
// Step 4: Reduce O across block (dimension by dimension using shared mem)
|
||||
// Step 4: Reduce O across block, dim by dim. Store one partial per warp
|
||||
// and sum in warp-id order; atomicAdd made greedy decode nondeterministic
|
||||
// when logits were close (same fix pattern as paged_attention.cu / gemv.cu).
|
||||
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
||||
|
||||
// Process head_dim in chunks: each iteration reduces one dimension
|
||||
// Use shared memory accumulator: each warp contributes via warp reduction + atomic
|
||||
// Actually simpler: iterate over dimensions, warp reduce each, then lane0 atomicAdd to smem_O
|
||||
|
||||
// Initialize smem_O
|
||||
for (int d = tid; d < head_dim; d += DECODE_THREADS) {
|
||||
smem_O[d] = 0.0f;
|
||||
for (int i = tid; i < 32 * HEAD_DIM_MAX; i += DECODE_THREADS) {
|
||||
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Each thread adds its local_O contributions via warp reduction + atomicAdd
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
float val = local_O[d];
|
||||
// Warp-level reduction
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
if (lane == 0) {
|
||||
atomicAdd(&smem_O[d], val);
|
||||
}
|
||||
if (lane == 0) smem_O_warp[warp_id][d] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Thread 0..head_dim-1 write final output
|
||||
for (int d = tid; d < head_dim; d += DECODE_THREADS) {
|
||||
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
|
||||
O_ptr[d] = __float2bfloat16(out * inv_sum);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,12 +16,14 @@ __global__ void dequant_fp8e4m3_to_bf16_kernel(
|
||||
__nv_bfloat16* __restrict__ dst,
|
||||
int num_experts, int rows, int cols
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = num_experts * rows * cols;
|
||||
// 64-bit index: num_experts * rows * cols overflows int32 for 32 experts
|
||||
// at ~8k*8k weight matrices, same class as the MoE fix in cfbd64d.
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long total = (long long)num_experts * rows * cols;
|
||||
if (idx >= total) return;
|
||||
|
||||
int expert_stride = rows * cols;
|
||||
int expert = idx / expert_stride;
|
||||
long long expert_stride = (long long)rows * cols;
|
||||
int expert = (int)(idx / expert_stride);
|
||||
float scale = scales[expert];
|
||||
float val = float(src[idx]) * scale;
|
||||
dst[idx] = __float2bfloat16(val);
|
||||
@@ -36,9 +38,9 @@ void launch_dequant_fp8e4m3_to_bf16(
|
||||
int num_experts, int rows, int cols,
|
||||
void* stream
|
||||
) {
|
||||
int total = num_experts * rows * cols;
|
||||
long long total = (long long)num_experts * rows * cols;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
int grid = (int)((total + block - 1) / block);
|
||||
dequant_fp8e4m3_to_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_fp8_e4m3*)src,
|
||||
(const float*)scales,
|
||||
|
||||
Reference in New Issue
Block a user