#include #include "../common.cuh" // Embedding lookup: out[seq_idx] = table[token_ids[seq_idx]] // Grid: num_tokens, Block: handles hidden_size elements per token. __global__ void embedding_f32( const float* __restrict__ table, // [vocab_size, hidden_size] const int* __restrict__ token_ids, // [num_tokens] float* __restrict__ out, // [num_tokens, hidden_size] int hidden_size, int vocab_size ) { int token_idx = blockIdx.x; int tid = token_ids[token_idx]; if (tid < 0 || tid >= vocab_size) return; const float* row = table + tid * hidden_size; float* dst = out + token_idx * hidden_size; for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { dst[i] = row[i]; } } __global__ void embedding_bf16( const __nv_bfloat16* __restrict__ table, const int* __restrict__ token_ids, __nv_bfloat16* __restrict__ out, int hidden_size, int vocab_size ) { int token_idx = blockIdx.x; int tid = token_ids[token_idx]; if (tid < 0 || tid >= vocab_size) return; const __nv_bfloat16* row = table + tid * hidden_size; __nv_bfloat16* dst = out + token_idx * hidden_size; for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { dst[i] = row[i]; } } extern "C" { void launch_embedding_f32(const void* table, const void* token_ids, void* out, int num_tokens, int hidden_size, int vocab_size, void* stream) { int block = (hidden_size < 256) ? hidden_size : 256; embedding_f32<<>>( (const float*)table, (const int*)token_ids, (float*)out, hidden_size, vocab_size); CUDA_CHECK_LAST_ERROR(); } void launch_embedding_bf16(const void* table, const void* token_ids, void* out, int num_tokens, int hidden_size, int vocab_size, void* stream) { int block = (hidden_size < 256) ? hidden_size : 256; embedding_bf16<<>>( (const __nv_bfloat16*)table, (const int*)token_ids, (__nv_bfloat16*)out, hidden_size, vocab_size); CUDA_CHECK_LAST_ERROR(); } }