from __future__ import annotations import torch try: import triton import triton.language as tl except ImportError: # pragma: no cover - depends on local environment triton = None tl = None TRITON_AVAILABLE = triton is not None if TRITON_AVAILABLE: @triton.jit def flash_attention_fwd_kernel( q_ptr, k_ptr, v_ptr, out_ptr, seq_len, head_dim, stride_q_batch, stride_q_head, stride_q_seq, stride_q_dim, stride_k_batch, stride_k_head, stride_k_seq, stride_k_dim, stride_v_batch, stride_v_head, stride_v_seq, stride_v_dim, stride_out_batch, stride_out_head, stride_out_seq, stride_out_dim, causal, block_q: tl.constexpr, block_k: tl.constexpr, block_d: tl.constexpr, ): pid_q = tl.program_id(axis=0) pid_bh = tl.program_id(axis=1) # TODO(student): map pid_q and pid_bh to a batch/head/query tile. # TODO(student): load Q, K, and V blocks. # TODO(student): compute scores for the current block pair. # TODO(student): apply optional causal masking. # TODO(student): update online softmax state and accumulate the output block. # TODO(student): store the final output tile. pass def triton_flash_attention_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False, block_q: int = 64, block_k: int = 64, ) -> torch.Tensor: if not TRITON_AVAILABLE: raise RuntimeError("Triton is not installed in this environment.") if q.shape != k.shape or q.shape != v.shape: raise ValueError(f"q, k, v must match; got {q.shape}, {k.shape}, {v.shape}") if q.ndim != 4: raise ValueError("expected [batch, heads, seq, dim] inputs") if not q.is_cuda or not k.is_cuda or not v.is_cuda: raise ValueError("Triton kernels in this lab expect CUDA tensors.") raise NotImplementedError("TODO(student): implement the FlashAttention forward launch.")