Initial project scaffold
This commit is contained in:
75
kernels/triton/flash_attention_fwd.py
Normal file
75
kernels/triton/flash_attention_fwd.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError: # pragma: no cover - depends on local environment
|
||||
triton = None
|
||||
tl = None
|
||||
|
||||
|
||||
TRITON_AVAILABLE = triton is not None
|
||||
|
||||
|
||||
if TRITON_AVAILABLE:
|
||||
|
||||
@triton.jit
|
||||
def flash_attention_fwd_kernel(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
out_ptr,
|
||||
seq_len,
|
||||
head_dim,
|
||||
stride_q_batch,
|
||||
stride_q_head,
|
||||
stride_q_seq,
|
||||
stride_q_dim,
|
||||
stride_k_batch,
|
||||
stride_k_head,
|
||||
stride_k_seq,
|
||||
stride_k_dim,
|
||||
stride_v_batch,
|
||||
stride_v_head,
|
||||
stride_v_seq,
|
||||
stride_v_dim,
|
||||
stride_out_batch,
|
||||
stride_out_head,
|
||||
stride_out_seq,
|
||||
stride_out_dim,
|
||||
causal,
|
||||
block_q: tl.constexpr,
|
||||
block_k: tl.constexpr,
|
||||
block_d: tl.constexpr,
|
||||
):
|
||||
pid_q = tl.program_id(axis=0)
|
||||
pid_bh = tl.program_id(axis=1)
|
||||
# TODO(student): map pid_q and pid_bh to a batch/head/query tile.
|
||||
# TODO(student): load Q, K, and V blocks.
|
||||
# TODO(student): compute scores for the current block pair.
|
||||
# TODO(student): apply optional causal masking.
|
||||
# TODO(student): update online softmax state and accumulate the output block.
|
||||
# TODO(student): store the final output tile.
|
||||
pass
|
||||
|
||||
|
||||
def triton_flash_attention_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
causal: bool = False,
|
||||
block_q: int = 64,
|
||||
block_k: int = 64,
|
||||
) -> torch.Tensor:
|
||||
if not TRITON_AVAILABLE:
|
||||
raise RuntimeError("Triton is not installed in this environment.")
|
||||
if q.shape != k.shape or q.shape != v.shape:
|
||||
raise ValueError(f"q, k, v must match; got {q.shape}, {k.shape}, {v.shape}")
|
||||
if q.ndim != 4:
|
||||
raise ValueError("expected [batch, heads, seq, dim] inputs")
|
||||
if not q.is_cuda or not k.is_cuda or not v.is_cuda:
|
||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||
raise NotImplementedError("TODO(student): implement the FlashAttention forward launch.")
|
||||
|
||||
Reference in New Issue
Block a user