Initial project scaffold

This commit is contained in:
wjh
2026-04-10 13:15:06 +00:00
commit a4a6b1f1c8
94 changed files with 3964 additions and 0 deletions

View File

@@ -0,0 +1,75 @@
from __future__ import annotations
import torch
try:
import triton
import triton.language as tl
except ImportError: # pragma: no cover - depends on local environment
triton = None
tl = None
TRITON_AVAILABLE = triton is not None
if TRITON_AVAILABLE:
@triton.jit
def flash_attention_fwd_kernel(
q_ptr,
k_ptr,
v_ptr,
out_ptr,
seq_len,
head_dim,
stride_q_batch,
stride_q_head,
stride_q_seq,
stride_q_dim,
stride_k_batch,
stride_k_head,
stride_k_seq,
stride_k_dim,
stride_v_batch,
stride_v_head,
stride_v_seq,
stride_v_dim,
stride_out_batch,
stride_out_head,
stride_out_seq,
stride_out_dim,
causal,
block_q: tl.constexpr,
block_k: tl.constexpr,
block_d: tl.constexpr,
):
pid_q = tl.program_id(axis=0)
pid_bh = tl.program_id(axis=1)
# TODO(student): map pid_q and pid_bh to a batch/head/query tile.
# TODO(student): load Q, K, and V blocks.
# TODO(student): compute scores for the current block pair.
# TODO(student): apply optional causal masking.
# TODO(student): update online softmax state and accumulate the output block.
# TODO(student): store the final output tile.
pass
def triton_flash_attention_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False,
block_q: int = 64,
block_k: int = 64,
) -> torch.Tensor:
if not TRITON_AVAILABLE:
raise RuntimeError("Triton is not installed in this environment.")
if q.shape != k.shape or q.shape != v.shape:
raise ValueError(f"q, k, v must match; got {q.shape}, {k.shape}, {v.shape}")
if q.ndim != 4:
raise ValueError("expected [batch, heads, seq, dim] inputs")
if not q.is_cuda or not k.is_cuda or not v.is_cuda:
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
raise NotImplementedError("TODO(student): implement the FlashAttention forward launch.")