76 lines
2.1 KiB
Python
76 lines
2.1 KiB
Python
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
try:
|
|
import triton
|
|
import triton.language as tl
|
|
except ImportError: # pragma: no cover - depends on local environment
|
|
triton = None
|
|
tl = None
|
|
|
|
|
|
TRITON_AVAILABLE = triton is not None
|
|
|
|
|
|
if TRITON_AVAILABLE:
|
|
|
|
@triton.jit
|
|
def flash_attention_fwd_kernel(
|
|
q_ptr,
|
|
k_ptr,
|
|
v_ptr,
|
|
out_ptr,
|
|
seq_len,
|
|
head_dim,
|
|
stride_q_batch,
|
|
stride_q_head,
|
|
stride_q_seq,
|
|
stride_q_dim,
|
|
stride_k_batch,
|
|
stride_k_head,
|
|
stride_k_seq,
|
|
stride_k_dim,
|
|
stride_v_batch,
|
|
stride_v_head,
|
|
stride_v_seq,
|
|
stride_v_dim,
|
|
stride_out_batch,
|
|
stride_out_head,
|
|
stride_out_seq,
|
|
stride_out_dim,
|
|
causal,
|
|
block_q: tl.constexpr,
|
|
block_k: tl.constexpr,
|
|
block_d: tl.constexpr,
|
|
):
|
|
pid_q = tl.program_id(axis=0)
|
|
pid_bh = tl.program_id(axis=1)
|
|
# TODO(student): map pid_q and pid_bh to a batch/head/query tile.
|
|
# TODO(student): load Q, K, and V blocks.
|
|
# TODO(student): compute scores for the current block pair.
|
|
# TODO(student): apply optional causal masking.
|
|
# TODO(student): update online softmax state and accumulate the output block.
|
|
# TODO(student): store the final output tile.
|
|
pass
|
|
|
|
|
|
def triton_flash_attention_fwd(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
causal: bool = False,
|
|
block_q: int = 64,
|
|
block_k: int = 64,
|
|
) -> torch.Tensor:
|
|
if not TRITON_AVAILABLE:
|
|
raise RuntimeError("Triton is not installed in this environment.")
|
|
if q.shape != k.shape or q.shape != v.shape:
|
|
raise ValueError(f"q, k, v must match; got {q.shape}, {k.shape}, {v.shape}")
|
|
if q.ndim != 4:
|
|
raise ValueError("expected [batch, heads, seq, dim] inputs")
|
|
if not q.is_cuda or not k.is_cuda or not v.is_cuda:
|
|
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
|
raise NotImplementedError("TODO(student): implement the FlashAttention forward launch.")
|
|
|