FlashAttention 3 - A Worklog[WIP]
Hello all. This blog post is about how to implement Flash Attention 3, specifically with CuteDSL, targeting H100 GPUs. The main goal of this work is to create a worklog on incrementally building and adding features to each kernel, to clearly and cleanly understand how we can get to Model FLOPs Utilization(MFU).
Much of this work is inspired by some of the best people whose content I really enjoyed in the GPU kernel optimization and ML performance engineering domain — specifically Simon, Aleksa Gordic, Kapil Sharma, and Hamza for their similar worklogs, which made me want to create one like this in the first place. Since this is going to be my take on such a worklog, I wanted to do something not available on the internet, or at least not in the form of a worklog.
- How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog - Simon Boehm
- Inside NVIDIA GPUs: Anatomy of high performance matmul kernels - Aleksa Gordić
- Learn CUTLASS the hard way! - Kapil Sharma
- Worklog: Optimising GEMM on NVIDIA H100 for cuBLAS-like Performance (WIP) - Hamza
At present, most ML/AI architectures depend on two basic primitives: GEMM and Attention. I’m taking Attention(specifically Flash Attention) and exploring its kernel design and how to push it toward MFU. Below is the rough plan.
Flash Attention 3(FA3) Worklog:
- Attention math
- Online softmax
- GPU memory and compute hierarchy
- Tiling
- FA2
- CuTe DSL fundamentals and primitives
- K1 — naive attention
- K2 — tiled softmax
- K3 — using TMA
- K4 — using WGMMA
- K5 — warp specialization
- K6 — ping-pong scheduling
- K7 — FP8
- K8 — FP8 + incoherent processing
- Flash decode
The plan is to follow the above as one section each. Throughout this worklog, I’ll provide good intuition along with the details on how things work or what we need to make them work. Since each of these can be its own topic, I’ll also add separate notes for many of them, which you can refer to if you want to go in-depth.
This is a work in progress, marked with the “WIP” tag. Once it’s complete, I’ll remove that.
I sincerely thank all the many people who planted this idea in me. I hope this will be helpful for anyone who wants to understand FlashAttention 3 (FA3) from first principles. All the code will be made available here: FlashAttention3
I’ve made sure there aren’t any conceptual mistakes. If you find any, please let me know and I’m happy to correct them and make this useful for everyone. If you’re in this domain, feel free to ping me; I’m happy to collaborate on projects around GPU kernels and ML perf engineering.
I’ll add references along each section, instead of tagging them at the end. I’m assuming basic ML/AI and GPU knowledge to get the best out of this blog. All the analysis will be on H100 SXM GPUs, since FA3 was targeted at the Hopper architecture and aims to squeeze out its full utilization through Hopper-specific features like TMA, WGMMA, asynchronous operations, and FP8 support. One point worth mentioning here: there isn’t a huge algorithmic difference between FA2 and FA3, but a lot of engineering work went into FA3 to take advantage of Hopper. FA2 reached around 70% of theoretical max FLOPS on Ampere (A100) but didn’t exploit Hopper-specific features, while FA3 reaches around 75% on H100 in FP16, and with FP8 it gets close to 1.2 PFLOPS with 2.6x smaller numerical error than baseline FP8 attention. More on this in the upcoming sections.
Also, you’ll come across a lot of acronyms throughout this blog. I expand each term on its first use, so if you hit one you’re unsure about, a quick Ctrl+F will take you back to where it was introduced.
That’s it on the intro. Let’s get started.
Reference of FlashAttention-3 Paper - Link
Attention Math
Attention is the core of transformers, and transformers are the core of large language models (LLMs). Transformers consist of two primitives — attention, followed by feed-forward networks (FFNs) — surrounded by norms and residual connections. Below is an image that captures this.
Attention gives you the context of what is being referred to and meant in a sentence, and the FFN predicts the next token based on this context. Both are critical to each other and together form the backbone of present-day LLMs.
Now, if we look at attention in more detail, it’s composed of three operations:
\[S = Q K^\top \tag{1}\] \[P = \text{softmax}(S) \quad \text{(along the last dimension)} \tag{2}\] \[O = P V \tag{3}\]Refer to Jay Alammar’s transformers blog if you’re not sure of the attention mechanism and what happens under the hood.
Here, (1) and (3) are matrix multiplications. In GPU terminology, we call them GEMM (General Matrix Multiply), and (2) is a softmax along the last dimension. This matters because softmax was a major bottleneck before, and FlashAttention made algorithmic changes to how it’s computed, while staying numerically exact.
Note: There is a scale factor of \(1/\sqrt{d}\) applied to \(S\), which we’re ignoring here for simplicity. The scale keeps the dot products from growing too large as \(d\) increases, which would otherwise push the softmax into saturated, low-gradient regions. In code, we’ll definitely include it. There is also a causal mask that needs to be applied, to make sure each query attends only to keys prior to it and not after. This is needed for next-token prediction.
The reason for optimizing attention is that materializing the intermediate \(S\) and \(P\) matrices is expensive. Why? Think in terms of dimensions and how much memory we’d need to get this computation done.
Let’s do a simple calculation. Let \(Q\), \(K\), and \(V\) be of shape \((b, s, d)\) → (batch, sequence length, model dim). Then \(S = QK^\top\) has shape \((b, s_1, s_2)\). Now apply softmax. What happens when the sequence length is 1 million and batch size is 1? Assuming FP32 (4 bytes per element):
\[1 \times 10^6 \times 10^6 \times 4 \text{ bytes} = 4 \times 10^{12} \text{ bytes} = 4 \text{ TB}\]And that’s just to materialize one \(S\) matrix. We still need to do the softmax and another matmul on top of that. Attention is also computed many times, since LLMs have hundreds of transformer blocks. Most GPUs don’t have 4 TB of memory, and even if they did, it still wouldn’t be feasible — we need space for intermediate results and bookkeeping too.
This is the major bottleneck FlashAttention addresses. At its core, it does two things: online softmax and tiling. Together, these ensure we never materialize the full \(s \times s\) matrix in GPU memory, while producing the exact same numerical result as if we had. This is what made it possible to train models with longer sequence lengths, which gave us more context and helped LLMs become much more useful.
There are a lot of details in and around this, but we’ll focus on FlashAttention 3 going forward, with all necessary information and intuition along the way for each optimization and topic. With a sense of what attention does and why we need to optimize it, we’ll next talk about online softmax.
A side note:
If you want to understand how LLMs work more clearly, try implementing one. Check out the Stanford CS336 course and follow their assignments — they’re self-contained and really push you to have a clear understanding of all the internals of LLMs. I did it as part of my coursework at NYU, and I’m glad to have taken that course. Highly recommend it.
Also, take a look at this: Transformer block accounting and the below image — think through it to understand the memory and FLOPs counts for a transformer block. This gives you a picture of what it takes to run one transformer block, so you can appreciate all the engineering that’s gone into making it fast and efficient.
Online Softmax
First, let’s revisit softmax and where it’s used.
Given a list of numbers \(\{x_1, x_2, \ldots, x_N\}\), for each element we compute \(e^{x_i} / \sum_j e^{x_j}\) and that produces a new list. Why do we do this? To make sure the new list sums to 1. Softmax is applied as the final step so that the probabilities over all choices add up to one. In the case of attention, we apply softmax to \(QK^\top\). These are the scores given to each key for a given query, and all the scores for a given query must add up to 1. We then use these scores to weight the value matrix through the \(PV\) matmul.
Now, one more important aspect: along which dimension of \(S = QK^\top\) do we apply softmax? The shape of \(S\) is \((b, s_1, s_2)\), where \(s_1\) is the query sequence length and \(s_2\) is the key sequence length. The last dimension is the key’s sequence length, and softmax is applied along this one, so that for each query we get the importance of every key and can extract the context.
With the softmax concept in place, we’ll proceed to online softmax. Here, instead of processing the full list at once, we process it in chunks and rescale the partial results as and when needed. At the end, we get the exact same result as if we had processed the whole list at once. To process in chunks and rescale, we need to keep a few temporary variables for bookkeeping. The reason we want this: our bottleneck is that we can’t hold the full list of numbers in memory, and chunk-by-chunk processing avoids that, at the cost of a little bookkeeping. That’s the gist of online softmax. Now let’s look at the actual math.
The math
Before getting to online softmax, we need to address an overflow issue with exponentials. The largest representable value in float16 (fp16) is 65504. For large \(x\), \(e^x\) overflows this range. To mitigate this, we use a trick known as safe softmax: subtract the max before exponentiating.
\[\frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}} = \frac{e^{x_i - m}}{\sum_{j=1}^{N} e^{x_j - m}}, \quad \text{where } m = \max_{j=1}^{N} x_j \tag{4}\]Subtracting \(m\) guarantees \(x_i - m \le 0\), so every exponent is non-positive, which is the safe region for the exponential.
Note: this can make some terms very small. If \(x_i\) is far below the max, \(e^{x_i - m}\) can underflow to 0 in fp16. Unlike overflow, this is harmless: a term that underflows was already negligibly small, so it contributes essentially nothing to the sum or to the final probability. Overflow produces inf and corrupts the whole result, whereas underflow just drops a value that rounds to zero anyway.
3-pass safe softmax
Notation: \(m_i = \max_{j=1}^{i} x_j\) with \(m_0 = -\infty\); \(d_i = \sum_{j=1}^{i} e^{x_j - m_N}\) with \(d_0 = 0\) (\(d_N\) is the safe-softmax denominator); \(a_i\) is the final softmax value.
\[m_i = \max(m_{i-1},\, x_i) \tag{5}\] \[d_i = d_{i-1} + e^{x_i - m_N} \tag{6}\] \[a_i = \frac{e^{x_i - m_N}}{d_N} \tag{7}\] \[\begin{aligned} &\textbf{for } i = 1 \ldots N: \\ &\quad m_i = \max(m_{i-1},\, x_i) \\[6pt] &\textbf{for } i = 1 \ldots N: \\ &\quad d_i = d_{i-1} + e^{x_i - m_N} \\[6pt] &\textbf{for } i = 1 \ldots N: \\ &\quad a_i = \dfrac{e^{x_i - m_N}}{d_N} \end{aligned}\]This algorithm iterates over \([1, N]\) three times. In the context of self-attention, the \(x_i\) are the pre-softmax logits computed by \(QK^\top\). When \(N\) is large, all \(N\) logits don’t fit in memory at once, so we need to break the computation into chunks rather than processing the full row in one shot. That is exactly what online softmax does (the GPU memory hierarchy that makes this necessary is covered in the next section).
From safe softmax to online softmax
If we fuse equations (5), (6), and (7) into a single loop, we reduce the number of global memory passes from 3 to 1. Unfortunately, we cannot fuse (5) and (6) directly, because (6) depends on \(m_N\), which isn’t known until the first loop finishes.
The fix: define a surrogate sequence
\[d_i' := \sum_{j=1}^{i} e^{x_j - m_i} \tag{8}\]as a stand-in for the original \(d_i := \sum_{j=1}^{i} e^{x_j - m_N}\). This removes the dependency on \(m_N\), and the \(N\)-th terms of both sequences are identical: \(d_N = d_N'\). So we can safely replace \(d_N\) in equation (7) with \(d_N'\).
We can also find a recurrence between \(d_i'\) and \(d_{i-1}'\):
\[d_i' = \sum_{j=1}^{i} e^{x_j - m_i} = \left( \sum_{j=1}^{i-1} e^{x_j - m_i} \right) + e^{x_i - m_i} = \left( \sum_{j=1}^{i-1} e^{x_j - m_{i-1}} \right) e^{m_{i-1} - m_i} + e^{x_i - m_i} = d_{i-1}'\, e^{m_{i-1} - m_i} + e^{x_i - m_i} \tag{9}\]This recurrence only depends on \(m_i\) and \(m_{i-1}\), so we can compute \(m_i\) and \(d_i'\) together in the same loop.
2-pass online softmax
\[\begin{aligned} &\textbf{for } i = 1 \ldots N: \\ &\quad m_i = \max(m_{i-1},\, x_i) \\ &\quad d_i' = d_{i-1}' \cdot e^{m_{i-1} - m_i} + e^{x_i - m_i} \\[6pt] &\textbf{for } i = 1 \ldots N: \\ &\quad a_i = \dfrac{e^{x_i - m_N}}{d_N'} \end{aligned}\]It still takes two passes. And we can reduce it to a single pass to minimize global I/O. We’ll see that in the FA2 section, after we’ve covered the GPU hierarchy and tiling.
References
- From Online Softmax to FlashAttention - by Zihao Ye (notation used in this section is from here)
GPU Memory and Compute Hierarchy
We left online softmax at two passes. The piece still missing is where those chunks live and how we move them between memory levels. To understand that, and to understand why we process attention in chunks at all, we first need a mental model of how a GPU stores and computes data. That is what this section builds.
We’ll build the memory hierarchy first, then the compute side. Tiling, which combines the two, follows in the next section.
The memory hierarchy
Why a hierarchy? Memory on a GPU is layered based on functionality, bandwidth, size, and access pattern. It is very similar to a CPU’s hierarchy of HDD, RAM, cache, and registers. On a GPU, the levels are: global memory (GMEM, also called HBM, high-bandwidth memory, or off-chip memory), L2 cache, shared memory (SMEM, on-chip memory) and L1 cache, and registers. This ordering goes from largest to smallest in size, lowest to highest in bandwidth, and highest to lowest in latency.
Some rough numbers to keep in mind, for the H100 SXM specifically:
- GMEM / HBM3: ~80 GB capacity, ~3.35 TB/s bandwidth, hundreds of cycles of latency.
- L2 cache: ~50 MB, shared across the whole GPU.
- SMEM / L1: up to 228 KB per SM, configurable as a split between shared memory and L1 cache in code. Bandwidth is roughly an order of magnitude higher than HBM, with latency in the tens of cycles.
- Registers: 65536 32-bit registers per SM (often written as 64K), the fastest storage, allocated per thread.
The point of these numbers is the gap: each step down the hierarchy is dramatically faster but dramatically smaller. That tension is what every kernel optimization has to navigate.
There is a clear tradeoff in designing kernels around these hardware characteristics. This is one of the reasons I love this domain: you can clearly see what you gain and what you give up. One example is occupancy. We want to maximize the number of warps resident per SM, but to do so we need enough resources (registers, shared memory) available on that SM, and resource usage depends on how you framed the algorithm and how you use memory and registers. Balancing this is what it takes to achieve high utilization.
I highly recommend Aleksa Gordic’s writeup on NVIDIA GPU anatomy if you want more depth before proceeding.
The compute hierarchy
We’ve summarized the memory hierarchy. Let’s do a similar one for compute.
Kernels are launched as a grid, which contains multiple blocks, each of which is composed of multiple warps, and each warp is 32 threads. The number of threads per warp is fixed by hardware; warps are not directly exposed as a software unit you tune. Blocks and grids are the software-accessible, tunable units. Multiple blocks can be resident within an SM, depending on the block’s resource requirements and the SM’s available resources. All threads within a block can share that SM’s shared memory to communicate and exchange data.
That covers the structure of parallel compute. The compute units themselves, inside each SM, are: CUDA cores, Tensor cores, and Special Function Units (SFUs). Each is built for specific use cases: CUDA cores for general arithmetic, Tensor cores for matrix multiply and fused multiply-add (FMA), and SFUs for transcendental functions like exponentials.
Because these are separate hardware units, they can run in parallel. I mention this because it is one of the optimizations we’ll apply to make attention faster: use the Tensor cores for the matmuls while the CUDA cores or SFUs simultaneously handle the softmax. This is just to plant the idea; we’ll return to it in the warp specialization and ping-pong sections.
You can reference the NVIDIA H100 GPU Whitepaper for more details.
SIMT: the execution model
One more foundational idea, and arguably the foundation of GPU programming: SIMD/SIMT, Single Instruction Multiple Data / Single Instruction Multiple Threads. We issue a single instruction and apply it across many data elements at once. This is the core reason GPUs accelerate deep learning so well.
Matrix multiplication is a natural fit for this model. Each element of the output matrix is an independent dot product: element \((0,0)\) of the result is the dot product of row 0 of \(A\) with column 0 of \(B\), element \((0,1)\) is row 0 of \(A\) with column 1 of \(B\), and so on. In a basic model where one thread computes one output element, every thread runs the same instruction (a dot product) but operates on different data. That is SIMT.
This is also the bridge to the next section. The block/warp/thread structure above is how work is grouped; tiling is how we shape the data so that grouping maps cleanly onto the memory hierarchy.
Tiling
Before defining tiling, let’s see why we need it, and why it is the thing that lets us combine the memory and compute hierarchies. There are two reasons.
First, we can’t fit all the data in fast memory at once, so we break it into chunks and process them piece by piece. (This is exactly the constraint that forced online softmax into chunked processing.)
Second, we want to avoid repeated transfers from GMEM to SMEM. Moving data is where much of the time goes. HBM bandwidth, while large in absolute terms, is far lower than the rate at which the compute units can consume data, so without care the compute units sit idle waiting for data to arrive. These are memory stalls, and we want to minimize them. Since SMEM is small, we can only hold a small portion of the data at a time, which again pushes us toward breaking the data into small pieces.
Why “tiles”? Because we chunk in 2D, along both rows and columns, so each piece is a tile. The Ampere and Hopper Tensor cores consume matrix fragments directly, so tiling has become the norm, and modern libraries like Triton and CuTe expose tiles rather than scalar elements as their basic primitive.
How tiling works
Here is the standard framing, on a generic GEMM \(C = A \times B\) where \(A\) is \(M \times K\), \(B\) is \(K \times N\), and \(C\) is \(M \times N\).
Instead of one thread computing one full output element (one full dot product over the entire \(K\) dimension), we organize the work as:
- An outer loop over output tiles: the \(M \times N\) output is divided into tiles of size \(T_M \times T_N\). Each block of threads is responsible for one output tile.
- An inner loop over the \(K\) dimension: the shared dimension \(K\) is also split into chunks of size \(T_K\). For each step, the block loads one \(T_M \times T_K\) tile of \(A\) and one \(T_K \times T_N\) tile of \(B\) from GMEM into SMEM, multiplies them, and accumulates the partial result into registers.
After the inner loop finishes walking the full \(K\) dimension, the accumulator in registers holds the complete output tile, which is then written back to GMEM once.
In pseudocode:
# C = A @ B
# A: (M, K), B: (K, N), C: (M, N)
# tile sizes: T_M, T_N, T_K
for each output tile (i, j) in the (M/T_M) x (N/T_N) grid: # outer loop, parallel across blocks
acc = zeros(T_M, T_N) # lives in registers
for k_step in range(0, K, T_K): # inner loop over K
A_tile = load A[i*T_M : (i+1)*T_M, k_step : k_step+T_K] # GMEM -> SMEM
B_tile = load B[k_step : k_step+T_K, j*T_N : (j+1)*T_N] # GMEM -> SMEM
acc += A_tile @ B_tile # compute on SMEM data
C[i*T_M : (i+1)*T_M, j*T_N : (j+1)*T_N] = acc # SMEM/registers -> GMEM
The inner \(K\)-loop is pure accumulation: each \(K\)-step loads a fresh pair of tiles and adds its partial product into acc. The loop walks the shared dimension in \(K/T_K\) steps, and only after the last step does acc hold a complete output tile.
Compare this with the naive version, where each output element is computed independently and every access goes to GMEM:
# naive: one thread, one output element
for i in range(M):
for j in range(N):
acc = 0
for k in range(K):
acc += A[i, k] * B[k, j] # every read hits GMEM
C[i, j] = acc
To tie this back to the compute hierarchy from the previous section, here is how the tiling concepts map onto the GPU’s parallel execution units:
| Tiling concept | Hardware unit | Role |
|---|---|---|
| The full output grid of tiles | grid | The whole launch. One grid covers the entire \(M \times N\) output. |
| One output tile \((i, j)\) | thread block | One block owns one output tile and runs the inner \(K\)-loop for it. Blocks are independent, so all output tiles can be computed in parallel. |
| A sub-tile within the block’s tile | warp (32 threads) | The block’s tile is further divided among its warps, with each warp cooperatively computing a portion. |
| One output element (or small group) | thread | The finest unit of work, executing the same instruction as its warp-mates on different data (SIMT). |
A_tile, B_tile | shared memory (SMEM) | Loaded once from GMEM per \(K\)-step, then read many times by all threads in the block. This shared reuse is why the load is amortized. |
acc (the accumulator) | registers | Per-thread, holds the partial result across the entire inner loop. Written to GMEM only once, at the end. |
The key thing this mapping makes clear: the grid/block split is what gives us parallelism across output tiles, while the SMEM/register split is what gives us data reuse within a tile. Tiling works because it exploits both at once. (The warp-level detail, how a warp’s threads cooperate with the Tensor cores on a fragment, is deferred to the WGMMA section.)
We’ve described what the loops do and how the work maps to hardware. The next subsection makes the payoff concrete: it counts the memory traffic and shows exactly why this structure is faster.
Why tiling reduces memory traffic
The intuition that tiling helps becomes concrete once we count. We’ll count GMEM traffic in elements rather than bytes; the bytes-per-element factor is the same for both versions, so it cancels in the final ratio. For simplicity, assume square tiles of size \(T \times T\) (so \(T_M = T_N = T_K = T\)).
Naive version. Each output element of \(C\) is one dot product over the full \(K\) dimension. Computing it reads \(K\) elements of \(A\) and \(K\) elements of \(B\) from GMEM, so \(2K\) reads per output element. There are \(M \cdot N\) output elements, so total GMEM traffic is:
\[Q_{\text{naive}} = 2 \cdot M \cdot N \cdot K \quad \text{elements}\]The useful compute is one multiply and one add per term, over \(K\) terms, for each of the \(M \cdot N\) outputs:
\[W = 2 \cdot M \cdot N \cdot K \quad \text{FLOPs}\]The arithmetic intensity is the ratio of compute to memory traffic:
\[I_{\text{naive}} = \frac{W}{Q_{\text{naive}}} = \frac{2 \cdot M \cdot N \cdot K}{2 \cdot M \cdot N \cdot K} = 1 \quad \text{FLOP per element}\]One FLOP per element moved. The kernel is entirely memory-bound: the compute units do a single operation for every value dragged across the slow GMEM boundary.
Tiled version. Now the output is divided into \((M/T) \times (N/T)\) tiles. For each output tile, the inner loop walks the \(K\) dimension in \(K/T\) steps. At each step it loads one \(T \times T\) tile of \(A\) and one \(T \times T\) tile of \(B\), which is \(2T^2\) elements. So the GMEM traffic for a single output tile is:
\[\frac{K}{T} \cdot 2T^2 = 2 \cdot T \cdot K \quad \text{elements}\]There are \(\dfrac{M \cdot N}{T^2}\) output tiles in total, so:
\[Q_{\text{tiled}} = \frac{M \cdot N}{T^2} \cdot 2 \cdot T \cdot K = \frac{2 \cdot M \cdot N \cdot K}{T} \quad \text{elements}\]The compute \(W\) is unchanged, \(2 \cdot M \cdot N \cdot K\) FLOPs, since tiling reorganizes the work but does not remove any of it. The arithmetic intensity becomes:
\[I_{\text{tiled}} = \frac{W}{Q_{\text{tiled}}} = \frac{2 \cdot M \cdot N \cdot K}{\,2 \cdot M \cdot N \cdot K \,/\, T\,} = T \quad \text{FLOPs per element}\]The ratio. Comparing the two:
\[\frac{I_{\text{tiled}}}{I_{\text{naive}}} = \frac{T}{1} = T\]Tiling improves arithmetic intensity by a factor of \(T\), the tile size. Equivalently, it cuts GMEM traffic by a factor of \(T\) for the same amount of compute.
Where the factor of \(T\) comes from. It is worth seeing why the ratio is exactly \(T\), because the reuse happens at a finer level than “load tile A, load tile B.” The line acc += A_tile @ B_tile looks like a single operation, but multiplying two \(T \times T\) tiles is itself a \(T \times T \times T\) triple loop: every row of A_tile must be dotted with every column of B_tile. So a single row of A_tile, once it sits in SMEM, is read \(T\) times, once for each output column, rather than being fetched from GMEM \(T\) separate times. Likewise each column of B_tile is reused across all \(T\) rows. In numbers: the tile-tile multiply does \(2T^3\) FLOPs of compute on only \(2T^2\) loaded elements, so each loaded element is touched \(T\) times. That is the factor of \(T\). Note that this reuse lives inside the tile-tile multiply, not across \(K\)-steps: each \(K\)-step still loads a fresh pair of tiles, but each loaded tile, once in SMEM, earns \(T\) uses before being discarded.
What the factor of \(T\) buys you. To see why this matters, compare arithmetic intensity against real hardware. For H100 SXM in BF16, the Tensor cores deliver roughly 990 TFLOP/s of compute, while HBM3 delivers roughly 3.35 TB/s of bandwidth. To stay consistent with our element-based traffic count, we convert that bandwidth into elements per second: at 2 bytes per BF16 element, 3.35 TB/s is about \(1.675 \times 10^{12}\) elements/s. Dividing compute by bandwidth gives the hardware’s break-even arithmetic intensity:
\[I_{\text{ridge}} = \frac{990 \times 10^{12} \text{ FLOP/s}}{1.675 \times 10^{12} \text{ elem/s}} \approx 590 \text{ FLOPs per element}\]This is the roofline model: a kernel’s attainable performance is capped by whichever constraint binds first, memory bandwidth or peak compute. The crossover is the ridge point: any kernel below it is memory-bound, any kernel above it is compute-bound. A naive GEMM sits at \(I = 1\), nearly 600× below the threshold. The Tensor cores can compute almost 600 times faster than HBM can feed them, so they spend the overwhelming majority of their time idle, waiting on data. Tiling raises \(I\) from 1 toward \(T\), moving the kernel up the roofline, toward the regime where compute, not memory, is the limit.
This is the whole game. Tiling does not reduce compute; it reduces how many times each value crosses the slow GMEM boundary, moving the kernel from memory-bound (\(I = 1\)) toward compute-bound (\(I = T\)), which is the regime where the Tensor cores can actually be kept busy.
We’ll cover how this same tiling idea applies to attention, where the second matmul (\(P \times V\)) depends on the softmax output of the first, in the FA2 section. The attention case has an extra wrinkle that plain GEMM does not, and online softmax is what resolves it.
For a deeper, code-level walkthrough of GEMM tiling, see Simon’s GEMM worklog.
References
- How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog - Simon Boehm
- Inside NVIDIA GPUs: Anatomy of high performance matmul kernels - Aleksa Gordić
- NVIDIA H100 GPU Whitepaper
FlashAttention-2 (FA2)
Now that we have all the basic building blocks we need to construct FlashAttention, let’s do it.
First and foremost: FA1 and FA2 are very similar. They differ only in loop ordering and in what is kept in shared memory. We’ll look at FA2 directly. The algorithm at its core is also the same for FA3, so getting it clear here lets us relate to and build on it for the rest of the blog.
Before we get into the GPU-specific details, let’s do the math-heavy part first. The UWashington paper by Zihao Ye does a very good job of this, so we’ll use it as our reference here.
A quick note on notation before the derivation begins. We adopt the paper’s bracket-index style for this section, because shapes will matter at every step and bracket notation makes them self-documenting:
- \(Q[k, :]\): the \(k\)-th row of the \(Q\) matrix, shape \((1, d)\).
- \(K^\top[:, i]\): the \(i\)-th column of \(K^\top\), shape \((d, 1)\).
- \(V[i, :]\): the \(i\)-th row of the \(V\) matrix, shape \((1, d)\).
- \(O[k, :]\): the \(k\)-th row of the output matrix \(O\), shape \((1, d)\).
The running scalars and row vectors carry over from the online-softmax section: \(m_i\) is the running max, \(d_i'\) is the running (surrogate) softmax denominator (both scalars), and we now add \(o_i'\), the running (surrogate) output accumulator (a row vector of shape \((1, d)\)). The computation of every output row is independent, so we explain it for a single row.
Also, as in the online-softmax section, we drop the \(1/\sqrt{d}\) softmax scale and the causal mask here for simplicity. Both are present in real code: the scale keeps the dot products from saturating the softmax as \(d\) grows, and the causal mask ensures each query attends only to keys at or before its position. We omit them only to keep the recurrence readable.
From multi-pass to single-pass attention
In the online-softmax section we asked whether softmax could be done in one pass and the answer was no. But in self-attention our final target is not the attention-score matrix \(A\). It is the output matrix \(O = A V\). So the better question is: can we find a one-pass recurrence for \(O\) instead?
With the notation in hand from the section intro, we add three more symbols specific to this derivation:
- \(x_i\): a scalar, the \(i\)-th attention score. Comes from a row times a column: \(x_i = Q[k, :]\, K^\top[:, i]\).
- \(a_i\): a scalar, the \(i\)-th softmax weight.
- \(o_i\): a row vector of shape \((1, d)\), the partial aggregation \(\sum_{j=1}^{i} a_j\, V[j, :]\). This is the running output state.
We can now write the per-row attention computation as a recurrence. We reuse the 2-pass online softmax from before, and add a second loop that accumulates the output row.
Multi-pass self-attention.
\[\begin{aligned} &\textbf{for } i = 1 \ldots N: \\ &\quad x_i \leftarrow Q[k, :]\, K^\top[:, i] \\ &\quad m_i \leftarrow \max(m_{i-1},\, x_i) \\ &\quad d_i' \leftarrow d_{i-1}'\, e^{m_{i-1} - m_i} + e^{x_i - m_i} \\[6pt] &\textbf{for } i = 1 \ldots N: \\ &\quad a_i \leftarrow \dfrac{e^{x_i - m_N}}{d_N'} &(10)\\ &\quad o_i \leftarrow o_{i-1} + a_i\, V[i, :] &(11)\\[6pt] &O[k, :] \leftarrow o_N \end{aligned}\]A few things to read off this:
- In equation (10), \(a_i\) is scalar: \(x_i, m_N, d_N'\) are all scalars.
- In equation (11), \(a_i\, V[i, :]\) is a scalar times a row vector, producing a row vector of shape \((1, d)\). Adding it to \(o_{i-1}\) (also \((1, d)\)) gives \(o_i\) of shape \((1, d)\). So \(o_i\) accumulates a sum of \(d\)-dimensional row vectors, one per loop iteration.
- After the loop, \(o_N = \sum_{j=1}^{N} a_j\, V[j, :]\), which is the \(k\)-th row of \(O\).
This is multi-pass: the second loop cannot start until the first finishes, because \(a_i\) in equation (10) depends on \(m_N\) and \(d_N'\), which are only known after the first loop completes.
Fusing into a single pass
To fuse the two loops we have to remove the dependency on \(m_N\) and \(d_N'\). Start by substituting the definition of \(a_i\) from equation (10) into equation (11):
\[o_i := \sum_{j=1}^{i} \frac{e^{x_j - m_N}}{d_N'}\, V[j, :] \quad (12)\]Same dependency problem: \(m_N\) and \(d_N'\) aren’t available inside the loop. We play the same surrogate trick we used for \(d_i'\) in the online-softmax section: define a surrogate sequence \(o_i'\) that uses the running values \(m_i\) and \(d_i'\) instead of the final \(m_N\) and \(d_N'\):
\[o_i' := \sum_{j=1}^{i} \frac{e^{x_j - m_i}}{d_i'}\, V[j, :]\]The \(N\)-th terms of the two sequences agree. Setting \(i = N\) in the surrogate definition, the running \(m_i\) and \(d_i'\) become the final \(m_N\) and \(d_N'\):
\[o_N' = \sum_{j=1}^{N} \frac{e^{x_j - m_N}}{d_N'}\, V[j, :] = o_N\]The right-hand side is exactly the definition of \(o_N\) from equation (12). So computing the surrogate \(o_N'\) gives us the true output. The reason this works is the same as for \(d_i'\): only the final term of the sequence is needed, and at the final term the running statistics have caught up to the global ones. Revisit the online-softmax section for the fuller argument.
Next we find a recurrence between \(o_i'\) and \(o_{i-1}'\). The goal is to write \(o_i'\) in terms of \(o_{i-1}'\), a constant, and the data of step \(i\). We walk it step by step.
Start from the definition of \(o_i'\) and split off the last term (\(j = i\)):
\[o_i' = \sum_{j=1}^{i} \frac{e^{x_j - m_i}}{d_i'}\, V[j, :] = \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_i}}{d_i'}\, V[j, :] \right) + \frac{e^{x_i - m_i}}{d_i'}\, V[i, :]\]Now we need to massage the first sum so it matches the definition of \(o_{i-1}'\). By definition, \(o_{i-1}' = \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d_{i-1}'} V[j, :]\). The first sum above has \(m_i\) and \(d_i'\) where we’d like \(m_{i-1}\) and \(d_{i-1}'\). We change them one at a time.
First, rewrite the exponent. Using \(e^{x_j - m_i} = e^{x_j - m_{i-1} + m_{i-1} - m_i} = e^{x_j - m_{i-1}}\, e^{m_{i-1} - m_i}\):
\[\sum_{j=1}^{i-1} \frac{e^{x_j - m_i}}{d_i'}\, V[j, :] = \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}\, e^{m_{i-1} - m_i}}{d_i'}\, V[j, :]\]The factor \(e^{m_{i-1} - m_i}\) doesn’t depend on \(j\), so it can come out of the sum:
\[= e^{m_{i-1} - m_i} \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d_i'}\, V[j, :]\]Now we still have \(d_i'\) where we’d like \(d_{i-1}'\). Multiply and divide by \(d_{i-1}'\):
\[= e^{m_{i-1} - m_i}\, \frac{d_{i-1}'}{d_i'} \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d_{i-1}'}\, V[j, :]\]The sum is now exactly \(o_{i-1}'\). Substituting:
\[\sum_{j=1}^{i-1} \frac{e^{x_j - m_i}}{d_i'}\, V[j, :] = o_{i-1}'\, \frac{d_{i-1}'}{d_i'}\, e^{m_{i-1} - m_i}\]Plug back into the split form of \(o_i'\):
\[o_i' = o_{i-1}'\, \frac{d_{i-1}'}{d_i'}\, e^{m_{i-1} - m_i} + \frac{e^{x_i - m_i}}{d_i'}\, V[i, :] \quad (13)\]The recurrence (13) depends only on \(d_i', d_{i-1}', m_i, m_{i-1}, x_i,\) and \(V[i, :]\). Every one of these is available within a single loop iteration, so we can fuse the entire attention computation into one pass:
FlashAttention (single-pass).
\[\begin{aligned} &\textbf{for } i = 1 \ldots N: \\ &\quad x_i \leftarrow Q[k, :]\, K^\top[:, i] \\ &\quad m_i \leftarrow \max(m_{i-1},\, x_i) \\ &\quad d_i' \leftarrow d_{i-1}'\, e^{m_{i-1} - m_i} + e^{x_i - m_i} \\ &\quad o_i' \leftarrow o_{i-1}'\, \dfrac{d_{i-1}'}{d_i'}\, e^{m_{i-1} - m_i} + \dfrac{e^{x_i - m_i}}{d_i'}\, V[i, :] \\[6pt] &O[k, :] \leftarrow o_N' \end{aligned}\]A note on the shapes inside the loop: every operation on the right side of the \(o_i'\) update is a scalar coefficient multiplying either \(o_{i-1}'\) (a \((1, d)\) row) or \(V[i, :]\) (a \((1, d)\) row). The result is a \((1, d)\) row, matching \(o_i'\)’s shape. The expensive part of one iteration is the row addition; the rest is scalar arithmetic.
The states \(x_i, m_i, d_i',\) and \(o_i'\) all have small footprints that fit comfortably in GPU shared memory: three scalars and one \((1, d)\) row per query. This is the property that makes FlashAttention work. We never materialize the full \(N \times N\) score matrix, only a handful of running states per row.
Tiling the algorithm
The single-pass algorithm processes the entire sequence as a loop over individual key/value indices. Tiling changes how many rows we process per loop step, in two independent ways. We do them one at a time. Step 1 follows the paper exactly: tile the key/value sequence dimension. Step 2 extends the same idea to query rows, which lifts us into the form the real kernel uses (Algorithm 1).
Step 1: tiling the key/value rows.
The paper does this step first because all operations in the single-pass algorithm are associative: nothing forces us to process one \(K^\top\) column and one \(V\) row at a time. We can process a block of them in one loop step, as long as the recurrence’s max-tracking and rescaling still work.
We do not touch the query row. \(Q[k, :]\) stays whole. What we tile is the key/value sequence dimension: instead of walking it one row at a time, we walk it \(b\) rows at a time, where \(b\) is the tile size.
New notations for this step (matching the paper):
- \(b\): the block size of the tile.
- \(\#\text{tiles} = N / b\): number of tiles in one row of the score matrix.
- \(x_i\): now a row vector of shape \((1, b)\), holding the dot products of \(Q[k, :]\) against the \(i\)-th tile of \(K^\top\). So \(x_i = Q[k, :]\, K^\top[:, (i-1)b : ib]\), covering columns \((i-1)b + 1\) through \(ib\).
- \(x_i[j]\): the \(j\)-th entry of \(x_i\) (scalar). It’s the score of \(Q[k, :]\) against the \(j\)-th key row inside tile \(i\).
- \(m_i^{(\text{local})}\): scalar, the local maximum within \(x_i\).
- \(m_i, d_i', o_i'\): same running states as before (two scalars and one row vector of shape \((1, d)\)). They now update per tile instead of per row.
With those, the tiled algorithm is:
\[\begin{aligned} &\textbf{for } i = 1 \ldots \#\text{tiles}: \\ &\quad x_i \leftarrow Q[k, :]\, K^\top[:, (i-1)b : ib] \\ &\quad m_i^{(\text{local})} \leftarrow \max_{j = 1 \ldots b}\, x_i[j] \\ &\quad m_i \leftarrow \max\!\left(m_{i-1},\, m_i^{(\text{local})}\right) \\ &\quad d_i' \leftarrow d_{i-1}'\, e^{m_{i-1} - m_i} + \sum_{j=1}^{b} e^{x_i[j] - m_i} \\ &\quad o_i' \leftarrow o_{i-1}'\, \dfrac{d_{i-1}'}{d_i'}\, e^{m_{i-1} - m_i} + \sum_{j=1}^{b} \dfrac{e^{x_i[j] - m_i}}{d_i'}\, V[j + (i-1)b, :] \\[6pt] &O[k, :] \leftarrow o_{N/b}' \end{aligned}\]Reading it: \(m_i^{(\text{local})}\) is the max within the current tile only, and the regular \(m_i\) folds that local max into the running global max. The \(d_i'\) update is the same recurrence as before, except the new contribution is a sum over the \(b\) entries of \(x_i\) instead of a single \(e^{x_i - m_i}\). The \(o_i'\) update similarly: the per-step contribution is now a sum of \(b\) scaled value rows. Nothing in the algorithm structure changed; we just process \(b\) entries per loop step instead of one.
This is the form the paper stops at. The states \(x_i, m_i, d_i', o_i'\) all still have small footprints that fit comfortably in GPU shared memory (the largest is \(o_i'\), one row vector of \(d\) elements). One loop step now consumes a whole \((b, d)\) tile of \(K^\top\) and \(V\), which is what lets each tile be loaded into SMEM once and fully consumed there.
Step 2: tiling the query rows.
The paper’s algorithm stops at K/V tiling, but the real kernel (Algorithm 1 below) also tiles query rows. So we take one more step.
So far we have processed one query row \(Q[k, :]\). But every output row is computed independently: row \(k\) of \(O\) never depends on row \(k'\). That independence means we don’t need a new recurrence to handle many query rows. We just run the Step 1 recurrence for a block of query rows at once, and every per-row quantity gains a row dimension. The paper’s \(b\) becomes \(B_c\) (the K/V tile size, “c” for the column dimension of the score matrix), and we introduce \(B_r\) (the Q tile size, “r” for the row dimension).
Concretely, lift each shape:
- The query block becomes \(Q_{[\,\cdot\,]}\) of shape \((B_r, d)\) instead of one row.
- The score \(x_i\) becomes a matrix of shape \((B_r, B_c)\), where row \(r\) contains the scores of query row \(r\) against all \(B_c\) key rows of tile \(i\).
- The running statistics \(m_i\) and \(d_i'\) become vectors of shape \((B_r,)\), one value per query row.
- The running output \(o_i'\) becomes a block of shape \((B_r, d)\).
- \(V_{[(i-1)B_c : iB_c]}\) is the \(i\)-th K/V tile of shape \((B_c, d)\), same as Step 1.
The scalar reductions inside the loop become per-row reductions: the max over a tile becomes a rowmax (reduce the \(B_c\) axis, one max per query row), and the sum becomes a rowsum. The recurrence is otherwise identical:
A note on the shapes inside the loop: \(m_i\) is a \((B_r,)\) vector, so in \(x_i - m_i\) it broadcasts across the \(B_c\) columns of the \((B_r, B_c)\) score matrix. In the \(o_i'\) update, the per-row scalars \(\frac{d_{i-1}'}{d_i'} e^{m_{i-1} - m_i}\) and \(\frac{1}{d_i'}\) multiply each query row of the output block. \(\text{diag}(\cdot)\) is the cleanest notation for this: \(\text{diag}(v)\) turns a length-\(B_r\) vector \(v\) into a \(B_r \times B_r\) diagonal matrix, so that \(\text{diag}(v)\, M\) scales row \(r\) of matrix \(M\) by entry \(v[r]\). There is no new algebra here, just the Step 1 recurrence with a row index attached.
With both tilings in place, one loop step now processes a full \((B_r, B_c)\) score tile: a block of query rows against a block of key/value rows, with all running state held as small \((B_r,)\) vectors and one \((B_r, d)\) output block. This is the form that maps directly onto the GPU, which is the next subsection.
Mapping to hardware
The general mechanics of tiling, the grid/block/warp split, and SMEM-versus-register reuse were covered in the tiling section. Here we cover what is specific to attention, and one important point up front: the hardware mapping does not line up one-to-one with the clean mental model from the tiled algorithm above. The algorithm is a sequential recurrence over rows and tiles; the hardware is a mesh of warps and Tensor cores that only goes fast when work is shaped to fit them. Bridging that gap is exactly what the kernel sections (K1–K8) are about. Here we set up the correct picture so the algorithm walkthrough that follows is straightforward.
The outer loop: blocks over query tiles.
This part does match the mental model. The work is parallelized over query blocks: each thread block takes one \(B_r \times d\) block of query rows and computes the corresponding \(B_r \times d\) block of output rows. Output rows are independent (the property we used to tile Q), so blocks are fully independent and the GPU schedules them across SMs with no inter-block communication. This is the outer loop, line 3 of Algorithm 1.
The inner loop: where the mental model and the hardware diverge.
In the tiled algorithm, the inner loop reads as a simple sequential sweep: for each key/value tile, compute the score tile, update \(m\), \(d'\), \(o'\), move on. It is tempting to map this as “one thread (or one warp) owns one query row and runs the recurrence for it.” That model is intuitive but wrong, and it is worth seeing why, because the real mapping is what every later optimization builds on.
The inner loop’s heavy work is two matmuls:
- \(S = Q_i K_j^\top\), producing a \(B_r \times B_c\) score tile (the first matmul),
- \(\tilde{P} V_j\), producing the \(B_r \times d\) output contribution (the second matmul).
A matmul tile on a modern GPU is not computed by independent per-row threads. It is computed cooperatively by warps through the Tensor cores. A Tensor core consumes a fragment: a small matrix shard whose elements are spread across all 32 lanes of a warp. A single thread does not hold a full row of the result, it holds a few scattered elements of the fragment. So the \(B_r\) query rows of the block are partitioned across the block’s warps as a byproduct of the matmul fragment layout, not by an explicit “row \(r\) → thread \(r\)” assignment.
This has a direct consequence for the softmax. The rowmax and rowsum in the recurrence are reductions along a row, over the \(B_c\) score values for a given query row. But after the first matmul, those \(B_c\) values for one row are spread across multiple lanes of a warp. Computing the max or sum for that row therefore requires warp-level coordination, lanes exchanging partial results through warp-shuffle instructions, not a single thread looping over the row. So the inner loop is genuinely cooperative: warps work together on the matmul tiles, and again on the cross-lane reductions for the softmax statistics, and again on the rescaling of the running output.
So the corrected picture of one inner-loop step is: the block’s warps cooperatively compute the \(S\) tile on the Tensor cores, cooperatively reduce it row-wise (warp shuffles) to get \(m^{(\text{local})}\) and the denominator update, cooperatively rescale the running output, and cooperatively compute the \(\tilde{P} V\) matmul on the Tensor cores. The sequential recurrence is still exactly what executes, it is just executed by a team of warps per step rather than by one thread per row.
The precise fragment layout, how many warps cooperate, and how the two matmuls are scheduled relative to the softmax, depend on the GPU generation. On Hopper this is the warpgroup-level WGMMA instruction, and getting it right (overlapping the matmuls with the softmax, keeping the Tensor cores fed) is the substance of the later kernel sections. We will cover WGMMA and that scheduling in the K4 section; here it is enough to hold the model that the inner loop is warp-cooperative, not per-thread.
The sweep direction itself is simple: a block walks the \(K^\top\) tiles left to right and the \(V\) tiles top to bottom, updating the running state \(m\), \(d'\), \(O\) after each tile.
Why the footprint is what makes this work.
The SRAM footprint of a block depends only on the block sizes (\(B_r\), \(B_c\)) and the head dimension \(d\), never on the sequence length \(L\). At any moment a block holds only its current Q tile, one K tile, one V tile, and the small \((B_r,)\) running-state vectors, all sized by the block parameters and \(d\). The \(N \times N\) score matrix is never materialized. Since shared memory is small (228 KB/SM on H100) and \(L\) can be very large (16k or more), this \(L\)-independence is exactly what lets FlashAttention scale to long context without running out of memory.
This connects back to the occupancy tradeoff from the memory and compute hierarchy section: because the footprint is fixed by the block parameters and \(d\), the block size becomes a tuning knob. Larger blocks mean more reuse per tile but more SRAM and registers per block, which lowers how many blocks can be resident per SM. That balance is exactly the kind of tradeoff we’ll revisit when we profile real kernels.
With this mapping in mind, blocks over query tiles, warp-cooperative inner loop over key/value tiles, two Tensor-core matmuls bridged by a warp-reduced softmax, the full algorithm reads directly off the hardware.
The full algorithm: FlashAttention-2 forward pass
We have now built the algorithm twice: as a per-row recurrence, and as the \(B_r \times B_c\) tiled form that maps onto the hardware. Algorithm 1 below is that same algorithm exactly as the FA2 paper states it. It is kept in the paper’s own notation so it can serve as a faithful reference; the mapping back to the blog’s notation follows.
Algorithm 1 — FlashAttention-2 forward pass
Require: Matrices \(Q, K, V \in \mathbb{R}^{N \times d}\) in HBM, block sizes \(B_c\), \(B_r\).
- Divide \(Q\) into \(T_r = \lceil N / B_r \rceil\) blocks \(Q_1, \ldots, Q_{T_r}\) of size \(B_r \times d\) each, and divide \(K, V\) into \(T_c = \lceil N / B_c \rceil\) blocks \(K_1, \ldots, K_{T_c}\) and \(V_1, \ldots, V_{T_c}\), of size \(B_c \times d\) each.
- Divide the output \(O \in \mathbb{R}^{N \times d}\) into \(T_r\) blocks \(O_1, \ldots, O_{T_r}\) of size \(B_r \times d\) each, and divide the logsumexp \(L\) into \(T_r\) blocks \(L_1, \ldots, L_{T_r}\) of size \(B_r\) each.
- for \(1 \le i \le T_r\) do
- Load \(Q_i\) from HBM to on-chip SRAM.
- On chip, initialize \(O_i^{(0)} = (0)_{B_r \times d}\), \(\ell_i^{(0)} = (0)_{B_r}\), \(m_i^{(0)} = (-\infty)_{B_r}\).
- for \(1 \le j \le T_c\) do
- Load \(K_j, V_j\) from HBM to on-chip SRAM.
- On chip, compute \(S_i^{(j)} = Q_i K_j^\top \in \mathbb{R}^{B_r \times B_c}\).
- On chip, compute \(m_i^{(j)} = \max\!\left(m_i^{(j-1)},\ \text{rowmax}(S_i^{(j)})\right) \in \mathbb{R}^{B_r}\), \(\tilde{P}_i^{(j)} = \exp\!\left(S_i^{(j)} - m_i^{(j)}\right) \in \mathbb{R}^{B_r \times B_c}\) (pointwise), \(\ell_i^{(j)} = e^{\,m_i^{(j-1)} - m_i^{(j)}} \ell_i^{(j-1)} + \text{rowsum}(\tilde{P}_i^{(j)}) \in \mathbb{R}^{B_r}\).
- On chip, compute \(O_i^{(j)} = \text{diag}\!\left(e^{\,m_i^{(j-1)} - m_i^{(j)}}\right) O_i^{(j-1)} + \tilde{P}_i^{(j)} V_j\).
- end for
- On chip, compute \(O_i = \text{diag}\!\left(\ell_i^{(T_c)}\right)^{-1} O_i^{(T_c)}\).
- On chip, compute \(L_i = m_i^{(T_c)} + \log\!\left(\ell_i^{(T_c)}\right)\).
- Write \(O_i\) to HBM as the \(i\)-th block of \(O\).
- Write \(L_i\) to HBM as the \(i\)-th block of \(L\).
- end for
- Return the output \(O\) and the logsumexp \(L\).
Reading the algorithm with what we’ve built. Every line maps onto something we’ve already derived. The paper’s notation lines up with the blog’s as follows: \(S_i^{(j)}\) is our score tile \(x_i\), \(\ell\) is our denominator \(d_i'\), \(m_i^{(j)}\) is our running max \(m_i\), \(\tilde{P}\) is the unnormalized \(e^{x_i - m_i}\) term, and \(O_i^{(j)}\) is our running output \(o_i'\). The superscript \((j)\) is the inner-loop step.
- Lines 1–2 are the tiling setup: \(Q\) is cut into \(T_r\) row blocks of size \(B_r\), and \(K, V\) into \(T_c\) row blocks of size \(B_c\), the two tile sizes from the Q-tiling and K/V-tiling steps above.
- Line 3 iterates over query blocks; line 6 iterates over key/value tiles. The hardware actors behind these loops, one independent thread block per query block, a warp-cooperative sweep inside, were covered in the mapping subsection.
- Lines 4, 7 are the GMEM → SRAM loads: one Q block once per outer iteration, one K and V block per inner step.
- Line 8 is the first matmul, \(S = QK^\top\), producing a \(B_r \times B_c\) score tile.
- Line 9 is the online softmax, batched over \(B_r\) rows. The
rowmaxandrowsumhere are exactly the warp-reduced steps from the hardware mapping above: a row’s \(B_c\) score values are spread across warp lanes, so these reductions run as warp shuffles. \(\ell\) is the running denominator \(d_i'\), and its update is the same surrogate recurrence we derived. - Line 10 is the output recurrence. \(O_i^{(j)}\) is the running output \(o_i'\). The \(\text{diag}(e^{m^{(j-1)} - m^{(j)}})\) factor rescales the previous partial output when the max changes, and \(\tilde{P}_i^{(j)} V_j\) is the second matmul, \(P \times V\), adding the current tile’s contribution.
The one real difference: rescaling is deferred. Look closely at line 10 versus our earlier \(o_i'\) recurrence. Our recurrence divided by \(d_i'\) at every step. Algorithm 1 does not divide inside the inner loop. \(O_i^{(j)}\) is kept unnormalized throughout the sweep, rescaled only by the max-correction factor \(e^{m^{(j-1)} - m^{(j)}}\). The division by the denominator happens exactly once, on line 12, after the inner loop finishes: \(O_i = \text{diag}(\ell_i^{(T_c)})^{-1} O_i^{(T_c)}\).
This is the key FA1 → FA2 change. FA1 rescaled the output by the denominator on every inner step; FA2 moves that division out of the inner loop and applies it a single time at the end. The inner loop then contains only matmuls and the cheap max-correction, with the expensive per-step normalization removed. Fewer non-matmul operations on the hot path means the Tensor cores stay busier, and combined with the loop reordering (outer loop over query blocks, which parallelizes cleanly across thread blocks), this is most of where FA2’s speedup over FA1 comes from.
Line 13, the logsumexp. \(L_i = m_i^{(T_c)} + \log(\ell_i^{(T_c)})\) is only needed for the backward pass, where it lets the gradient recompute the softmax without storing the full probability matrix. Since this blog covers only the forward pass, we include \(L\) for completeness, to keep Algorithm 1 faithful to the paper, but we won’t use it further.
Note: in this section we only cover the forward pass, since inference is the main thing we want to cover. Training additionally needs the backward pass to propagate gradients and update parameters; for that, please refer to the paper.
References
CuTe DSL fundamentals and primitives
We’ve built the FlashAttention algorithm and tiled it for the GPU. From here on, every kernel we write (K1 through K8) will be written in CuTe Python DSL, the surface we use to express layouts, tensors, and Tensor-Core operations on Hopper. This section is the foundation: enough CuTe to read and understand the kernels that follow, built from first principles. Deep, kernel-specific mechanics (TMA hardware behavior, WGMMA fragment internals, swizzle and core-matrix details) are deferred to the kernel sections where they actually matter. Here we cover the abstractions.
We’ll go: what CuTe is and why it exists, then layouts, then tensors, then partitioning, then cute.gemm and cute.copy at a high level. Note: I initially added this part here, but it was very long, so I had to move it to a separate blog post to keep the content structured while still explaining most of it. So, pls refer to this CuTe DSL for the full in-depth ideas and understanding. Very quick one is below:
TL;DR. This section builds the CuTe Python DSL from first principles, ending with everything we need to write the FA3 kernels. CuTe is an embedded DSL inside Python, layered on CUTLASS 3.x (NVIDIA’s C++ template library for high-performance GEMM and convolution kernels). CUTLASS uses templates to specialize one generic kernel into many shape-specific compiled versions; CuTe Python gives us the same specialization power with a much cleaner Python surface. The DSL compiles just-in-time:
@cute.jitfunctions get lowered to PTX/SASS during the Python program’s run, not ahead of time, which is what lets the kernel be specialized for the exact shapes and dtypes the user supplies without paying runtime overhead. The catch: many values must be JIT-time constants (cutlass.const_expr,cutlass.range_constexpr) so the compiler can fold offsets into immediates and unroll loops over fixed shapes. The most important object in CuTe is the layout, a pair of shape and stride that defines a pure function from a coordinate tuple to an integer offset; layouts can be hierarchical (modes nested into tuples), which is what lets them describe non-trivial hardware geometries like the scatter pattern of a WGMMA output fragment. Layouts touch no data, they’re pure compile-time descriptions of how to walk an indexed grid, and a handful of operations (cute.size,cute.slice_,cute.append,cute.flat_divide) lets us compose, slice, and tile them. A tensor is an(iterator, layout)pair: a pointer-like handle attached to a layout, where indexing at coord \(c\) computes the offset via the layout and dereferences; two tensors can share an iterator but use different layouts (aliasing, a no-op at runtime), and identity tensors are a special case whose “values” are just the coordinates themselves, used as probes for figuring out where partitioned data lives in the original tile. To run a matmul, we build a TiledMMA, which bundles an MMA atom (the specific WGMMA instruction we want, with its shape and dtype), an atom layout (how many atom copies tile the work region), and the TV layouts (tv_layout_A/B/C) that encode how threads and per-thread values map to positions in the operand and accumulator tiles; per-thread tensors come frompartition_A/B/C(re-layouts of source tiles, no data motion) andmake_fragment_A/B/C(operand encoding for the Tensor Core;make_fragment_Cis the only one that actually allocates registers). For a \(64 \times 64\) atom, each of the warpgroup’s 128 threads owns 32 output values, scattered across the tile in a hardware-fixed pattern (which is why the C TV layout is hierarchical), and SMEM-sourced operands feed WGMMA through a 64-bit SMEM descriptor that the Tensor Core uses to read SMEM directly with no per-thread loads. The naming convention for partitioned tensors ist<X><mem><Operand>(e.g.,tSsQ= “QK matmul, thread-partitioned, SMEM, Q-operand”), and once internalized it makes kernel code readable at a glance. The dispatch primitives arecute.gemm(issues the WGMMA instructions for one matmul; loops over MMA_M, MMA_N, MMA_K and emits one instruction per atom, fully unrolled at JIT time) andcute.copy(issues a TMA tile copy from GMEM to SMEM asynchronously, with completion signaled on a memory barrier), with the TMA side using TMA atoms (built once at JIT time, packaging the GMEM descriptor and SMEM destination layout),tma_partition(per-thread partitioning for TMA dispatch), and multiple SMEM slots called stages arranged in a circular buffer to keep loads in flight while WGMMA is consuming the previous load. Deferred to the kernel sections: the exact WGMMA fragment layout and Z-pattern, SMEM-descriptor internals, inline PTX for WGMMA issue, the same-accumulator pipelining exception (all K4), TMA descriptors and mbarriers, the producer-consumer pipeline mechanics (K3), and the SMEM swizzle layouts in depth (K3 and K4). Here we cover only the abstractions, and the rest of this section unpacks each piece above.
Experimental setup
Before getting into K1, two practical notes. All measurements in this worklog run on the same hardware and software stack (H100 SXM, fixed CuTe DSL and PyTorch versions, etc.), and all kernels are launched through a single shared harness with the same benchmark and correctness-check conventions. Rather than recapping that in the blog, the details live in the project README: hardware and software versions, the repo layout, the workload sweep (sequence lengths, head dims, causal flag), the tensor layout convention, and how to run any kernel from the command line.
The repo is at: GitHub.
We’ll introduce just enough of the harness inline when K1 lands (one example command, a one-line note on what CHECK=1 does, and what TFLOPS gets reported). From K2 onward, the same conventions apply unchanged. If anything in the setup is unclear while reading a kernel section, the README is the reference.
With that, on to K1.
K1: Naive materialized attention
Refer this blog for full K1 CuTe walkthrough: Link
What’s new
This is the baseline. No fusion, no tiling, no online softmax, no Tensor Cores. Just the three operations of attention done as three separate kernels, with the full \((B, N, N)\) score matrix materialized in GMEM between them. The point isn’t to be fast. The point is to have a “before” picture: a working implementation that we can correctness-check against SDPA, profile, and then start optimizing.
Reading this section, you should expect: the slowest kernel in the worklog, modest tile sizes, no special hardware features used, and a hard wall at large sequence lengths because the materialized matrix doesn’t fit.
The idea
Attention is three operations. K1 implements them in the most direct way possible, as three separately-launched kernels with the intermediate \(S\) matrix written to GMEM and read back:
\[\begin{aligned} S &= Q K^\top \cdot \text{scale} \quad (\text{plus causal mask if enabled}) \\ P &= \text{softmax}(S) \quad (\text{in place, S becomes P}) \\ O &= P V \end{aligned}\]Each kernel uses one thread per output element (one thread per \(S_{i,j}\), one thread per row for the softmax, one thread per \(O_{i,j}\)). No SMEM staging, no Tensor Cores. The whole computation reads from and writes to GMEM directly.
This is exactly the configuration we used to motivate FlashAttention earlier in the blog. From the attention-math section: at \(B=1, N=10^6, d\) moderate, the materialized \(S\) matrix alone needs about 4 TB. For the smaller sweep configs we actually run (\(N \le 16384\), batch folded with heads), the matrix fits, but it dominates GMEM traffic and the arithmetic intensity is essentially \(I = 1\) (every FLOP is preceded by a load). Naive GEMM, naive softmax, naive output projection, three times.
Why bother implementing it? Three reasons. First, it’s the right “before” picture for the worklog. Every later kernel either removes a piece of K1’s badness or replaces a piece with a hardware-specific fast path; you can’t appreciate the gains without seeing the baseline. Second, this is the first time we run CuTe Python end to end, so it exercises the harness (bench.py, ref_check.py, the compile_kernel / run_kernel contract) on something simple. Third, even at this level you can verify the algorithm is right by running CHECK=1 and comparing against SDPA; this gives us confidence the layout and indexing conventions are wired up correctly before we start adding optimizations.
Implementation
The kernel file kernels/k1.py contains three CuTe classes plus the compile_kernel / run_kernel dispatchers that the harness calls.
-
K1Scorecomputes \(S_{i,j,b} = \sum_k Q_{i,k,b} K_{j,k,b} \cdot \text{scale}\), with the causal mask applied as a write of \(-\infty\) when \(j > i\). One thread per \((i, j, b)\) output element. Grid: \((\lceil N / 16 \rceil, \lceil N / 16 \rceil, B)\) blocks, each \(16 \times 16 \times 1\) threads. Reads Q and K from GMEM in FP32, accumulates in FP32, writes the result to the \(S\) buffer. -
K1Softmaxdoes the per-row softmax. One thread per row. The thread reads the entire row of \(S\) from GMEM to find the max, reads it again to compute exponentials and the sum, then reads it a third time to normalize. Grid: \((\lceil N / 256 \rceil, B, 1)\) blocks, each \(256 \times 1 \times 1\) threads. The materialized \(S\) is reused in place: the kernel reads \(S\) and writes \(P\) to the same buffer. -
K1Outputcomputes \(O_{i,j,b} = \sum_k P_{i,k,b} V_{k,j,b}\). One thread per \((i, j, b)\) output element. Grid: \((\lceil d / 16 \rceil, \lceil N / 16 \rceil, B)\) blocks, each \(16 \times 16 \times 1\) threads. Reads \(P\) (still in the \(S\) buffer) and \(V\) from GMEM, accumulates in FP32, writes the result to \(O\) in BF16.
The three kernels share one \(S\) buffer of size \(B \times N \times N \times 4\) bytes, allocated once in compile_kernel and reused across benchmark iterations. run_kernel launches the three compiled kernels in sequence on the current CUDA stream. The full CuTe details (the inner kernel code, primitives used, dispatcher mechanics) are in the standalone K1 walkthrough.
Running K1
From the project root:
CHECK=1 SEQLEN=512 HEADDIM=64 CAUSAL=0 python bench.py k1
Results
Profiling
What’s next
The first thing to fix is the materialized \(S\) matrix. Everything else (Tensor Cores, async loads, warp specialization) is downstream of that decision; as long as we materialize \(S\) we’re paying \(\mathcal{O}(B N^2)\) in memory and \(\mathcal{O}(B N^2)\) in GMEM traffic for it. K2 fuses the three kernels into one and uses the tiled online softmax we derived in the math sections, so \(S\) never leaves SMEM. This is also the first kernel where we’ll see the basic FlashAttention structure (per-row state, max-tracking recurrence, output rescaling) in actual code.
K2: Tiled online softmax (FA2)
Refer this blog for full K2 CuTe walkthrough: Link
What’s new
K1 materialized \(S = QK^\top\) in GMEM and ran three sequential kernels. K2 collapses everything into a single fused kernel that follows the FA2 algorithm directly: per Q-tile we stream KV tiles, recompute \(S\) in SMEM, run online softmax across rows, and accumulate \(O\) in registers. The algorithm is exactly the one derived in the FA2 section; the only thing this kernel adds is the threading and SMEM layout to make it run.
We are still on CUDA cores. The two GEMMs (\(QK^\top\) and \(PV\)) are hand-rolled FP32 dot products. No tensor cores, no TMA, no async pipelining; those land in K3 and K4. The point of K2 is to verify the FA2 algorithm end-to-end in a form that reads top to bottom.
The algorithm in code, at a glance
One CTA produces one Q-tile of \(O\). From the FA2 section: initialise \(m_i = -\infty\), \(\ell_i = 0\), \(O_i = 0\) in registers; load the Q-tile to SMEM; loop over KV tiles in \(j\), each iteration updating \((m_i, \ell_i, O_i)\) via the online softmax recurrence; finalize by dividing by \(\ell_i\).
The kernel body follows that structure literally, in three phases:
- Load Q once. \(Q_i\) is loaded to SMEM at the top of the kernel and reused across all KV iterations. Initialise the softmax state and the output accumulator.
- KV mainloop. For each KV tile \(j\): load \(K_j\) to SMEM, build \(S_{ij} = Q_i K_j^\top \cdot \text{scale}\) in SMEM (with the causal mask applied per-element), run the online softmax row update (which writes \(P_{ij}\) in place over \(S_{ij}\)), load \(V_j\) over \(K_j\) in the same SMEM buffer, accumulate \(O_i \mathrel{+}= P_{ij} V_j\) in registers.
- Finalize. Each row group divides its \(O_i\) slice by \(\ell_i\) and writes the result to GMEM.
The threading uses 256 threads per CTA, organised as 64 row groups of 4 threads each. Each row group is jointly responsible for one row of \(O\) across all KV iterations: the 4 threads split the \(d\) axis of \(O\) four ways in registers, split the \(B_c\) axis of the score row for the softmax reductions four ways, and combine partials with a sub-warp shuffle. The per-row softmax state \((m_i, \ell_i)\) lives in registers replicated across the 4 threads of the group; only the output accumulator is sharded.
What FA2 looks like in memory
Three SMEM tiles, one CTA:
-
sQholds the Q-tile, shape \((B_r, d)\) in BF16. Loaded once, kept across all KV iterations. 16 KB at \(B_r = 64\), \(d = 128\). -
sKVis the rolling KV buffer, shape \((B_c, d)\) in BF16. Used for \(K_j\), then overwritten with \(V_j\) inside the same iteration. The overwrite is safe because \(K_j\) is finished with the moment \(S_{ij}\) has been exponentiated into \(P_{ij}\). One buffer, two roles. 16 KB. -
sSholds the score block \(S_{ij}\), shape \((B_r, B_c)\) in FP32. Built by the QK matmul, overwritten in place by the softmax to hold \(P_{ij}\) before the PV matmul reads it. 16 KB at \(B_r = B_c = 64\).
Total: 48 KB per CTA, comfortably under H100’s 228 KB SMEM/SM. Layouts are plain row-major; no swizzling yet, which means we will pay some SMEM bank-conflict tax that K3 and K4 fix.
The in-place rewriting (sKV holding K then V; sS holding S then P) is the single trick that keeps K2’s SMEM footprint small. Fully-masked rows (rows of a Q-tile that begin below the diagonal of a KV-tile and see no valid keys) produce \(\ell_i = 0\); a guard at finalize substitutes \(1 / \ell_i = 1\) to avoid the \(0 \cdot \infty\) NaN.
Implementation
The kernel file kernels/k2.py contains one CuTe class K2FMHA plus the compile_kernel / run_kernel dispatchers that the harness calls. The kernel uses one CTA per (Q-tile, batch), grid \((\lceil N / B_r \rceil, B, 1)\), block shape \((256, 1, 1)\).
The CuTe machinery K2 introduces relative to K1: a cute.struct to declare and align the SMEM region, cute.arch.warp_reduction_max / _sum for the sub-warp shuffle reductions, cute.math.exp and cute.arch.rcp_approx for the SFU-backed softmax math, and the cutlass.range vs cutlass.range_constexpr distinction for the runtime KV mainloop vs the JIT-unrolled inner loops. The full code walkthrough, variable tables, and per-phase unwrapping live in the standalone K2 walkthrough.
Running K2
CHECK=1 SEQLEN=512 HEADDIM=64 CAUSAL=0 python bench.py k2
Results
Profiling
What’s next
The scalar FP32 GEMMs (\(QK^\top\) and \(PV\)) are the obvious bottleneck. K3 brings TMA to lift the GMEM→SMEM copies off the critical path with async bulk transfers, and introduces the multi-stage pipeline that lets the next tile’s load overlap with the current tile’s compute. K4 then replaces the two GEMMs with WGMMA on the tensor cores.
K3: TMA
Refer this blog for full K3 CuTe walkthrough: Link
What’s new
K2 spent every KV iteration with 256 threads acting as memcpy units, each issuing per-element thread-strided loads to drag Q, K, and V into SMEM. K3 hands all of that off to TMA, the SM90 hardware unit that performs async bulk copies of multidimensional tiles between GMEM and SMEM. A single thread issues the instruction; hardware walks the tile and writes the bytes into SMEM in the destination layout. Completion is tracked by a hardware mbarrier, not by __syncthreads.
On its own, this is a load-path cleanup. But TMA also sets up the producer/consumer pipeline that K4 will lean on: one thread (the producer) issues loads into a circular buffer of SMEM stages while the other threads (the consumer) compute on previously-loaded stages. Multiple stages mean the next tile’s load overlaps with the current tile’s compute.
K3 in this worklog covers three substages of that machinery: K3.1 with one pipeline stage (no overlap, just the protocol), K3.2 with two stages (first real overlap), K3.3 with five (the depth K4 will inherit). The compute body (S = QK^T, online softmax, O += PV) is unchanged from K2 across all three; this section is purely about the load layer.
The idea
TMA’s contract is straightforward: build a tensor descriptor on the host describing the source tensor’s shape, strides, and dtype; pass it to one issuing thread inside the kernel along with a tile coordinate; hardware does the rest. The instruction returns immediately and arrives on an mbarrier when the bytes are in SMEM. The consumer side waits on that barrier before reading.
This buys two things that matter for K3 and beyond. First, the 255 threads that K2 spent on memcpy work are now free; K4 will fill them with WGMMA. Second, TMA can saturate HBM bandwidth from a single issuer more efficiently than 256 threads each issuing per-element loads. Even at one pipeline stage with no overlap, K3.1 has a cleaner load profile than K2.
The actual speedup over K2 is small in K3 because the consumer side (scalar FP32 GEMMs) is far slower than the load side, so the load is already not on the critical path. The wins from TMA and from staging only materialise in K4 once the consumer is on tensor cores.
K and V get separate SMEM regions
K2 used one SMEM buffer for \(K_j\) and \(V_j\), reusing it within an iteration (load \(K_j\), build \(S_{ij}\), exponentiate to \(P_{ij}\) in sS, then overwrite the same buffer with \(V_j\) for PV). K3 cannot do that: with \(N\) pipeline stages, the producer may be loading stage \(j+1\)’s \(K\) at the same time the consumer is reading stage \(j\)’s \(V\). The two tiles must coexist in different SMEM regions.
So in K3, \(K\) and \(V\) get their own SMEM regions (sK, sV), both staged \(N\) deep. They share a stage index and a single mbarrier whose tx_count is the sum of K-bytes and V-bytes, so the barrier fires only when both TMAs of a stage have arrived. But the bytes live in separate buffers.
The stage progression
K3.1 — one stage. Producer and consumer are serialized through the single slot: producer fills, consumer drains, producer refills. No overlap. The point of K3.1 is correctness of the async protocol and the API surface; nothing structural to gain.
K3.2 — two stages. First real overlap: while the consumer computes on stage 0, the producer can issue the TMA for stage 1. The producer/consumer protocol splits into a prefetch phase (fill the first \(N\) stages upfront) and a steady-state phase (consume + reissue per iteration). Same structure K3.3 uses.
K3.3 — five stages. Same protocol, deeper buffer. At steady state the producer is up to four iterations ahead of the consumer. We pick 5 because:
- Latency-hiding bound. TMA round-trip latency on H100 is on the order of hundreds of cycles per tile; one or two stages would already hide the load against K3’s scalar-GEMM consumer. The point of 5 isn’t to hide more latency in K3 but to lock in the depth K4 will use, since K4’s WGMMA consumer is fast enough that several stages of headroom become useful.
- SMEM bound. Each KV stage costs 32 KB at \(B_c = 64, d = 128\) (16 KB for K plus 16 KB for V, separate regions). 5 stages = 160 KB. Adding
sQ(16 KB) andsS(16 KB) brings the K3.3 total to 192 KB, which fits in H100’s 228 KB/SM with one CTA per SM.
Beyond the latency-hiding requirement, extra stages just consume SMEM that could go to occupancy or larger tiles. For K3 specifically, K3.3 may even be slightly slower than K3.1 at higher \(d\) because the larger SMEM footprint cuts occupancy without reducing iteration time. That’s expected; K3.3 is structural setup for K4, not a kernel that’s supposed to win in isolation.
Implementation
The kernel file kernels/k3.py contains one CuTe class K3FMHA and the harness dispatchers. The kernel is parameterised by q_stage and kv_stage; the three substages correspond to kv_stage ∈ {1, 2, 5} (Q always at 1, since it’s loaded once per CTA). Grid \((\lceil N / B_r \rceil, B, 1)\), block \((256, 1, 1)\).
Two pipelines: Q is a one-stage PipelineTmaAsync (Q is loaded once and reused across all KV iterations); KV is an \(N\)-stage PipelineTmaAsync with tx_count = sK_bytes + sV_bytes so a stage’s barrier fires only when both K and V have arrived. SMEM layouts are plain row-major across K3 (no swizzle yet; that arrives in K4 along with WGMMA’s tensor-core access pattern, which is the only thing swizzle actually helps).
The full details — TMA descriptor mechanics, mbarrier internals, the producer/consumer protocol, the CuTe API surface for make_tiled_tma_atom / tma_partition / PipelineTmaAsync.create, the K3.1 walkthrough, the K3.2 prefetch + steady-state pattern — are in the standalone K3 walkthrough.
Running K3
CHECK=1 SEQLEN=512 HEADDIM=64 CAUSAL=0 python bench.py k3
The substage is selected by kv_stage in the kernel class; the harness exposes the three variants as k3_1, k3_2, k3_3 (see the standalone for the exact dispatch).
Results
Profiling
What’s next
K4 replaces the two scalar FP32 GEMMs with WGMMA on the tensor cores. The consumer side becomes fast enough that load latency genuinely matters, and the 5-stage pipeline from K3.3 starts earning its SMEM. K4 also brings the SMEM descriptor model, the swizzled SMEM layouts the tensor cores actually need, and (in K4.2) the QK→PV register handoff that keeps \(P\) in registers between the two matmuls.
K4: WGMMA
Refer this blog for full K4 CuTe walkthrough: Link
What’s new
K3 set up the load layer: TMA bulk copies, 5-stage pipeline, async barriers. But the actual math stayed on the CUDA cores as scalar FP32 dot products. K4 finally brings tensor cores into the picture by replacing both GEMMs (\(QK^\top\) and \(PV\)) with WGMMA, the SM90 warpgroup-level MMA instruction.
K4 splits into two substages, because the move from “no tensor cores” to “both matmuls on tensor cores with register handoff” is a meaningful jump that benefits from being staged:
- K4.1 puts QK on WGMMA, leaves PV as a hand-rolled scalar loop. The QK output
acc_qkis a per-thread register fragment under WGMMA’s C TV layout; K4.1 scatters it back to SMEM as a flat \((B_r, B_c)\) tile so the K3 softmax can run unchanged. - K4.2 puts PV on WGMMA too, which forces \(P\) to flow through registers (not SMEM) for the QK→PV handoff. That in turn forces the softmax to operate directly on the WGMMA C-fragment in registers, and forces us to reckon with \(V\)’s mode-order ambiguity in the WGMMA descriptor.
The threading model, the pipeline, and the load layer are otherwise unchanged from K3.3.
The idea
WGMMA in one paragraph. One WGMMA call computes \(C \mathrel{+}= A \cdot B\) at warpgroup granularity (128 threads, 4 warps) on a fixed atom shape, with the tensor cores doing the actual math. The instruction is asynchronous: it returns immediately, the warpgroup keeps going, and you fence and wait when you need the result. For SM90 BF16 the atom shape is \(64 \times N \times 16\) for \(N \in \{8, 16, \ldots, 256\}\). Larger matmuls iterate WGMMAs over the K axis. Operands come from SMEM with a 64-bit descriptor encoding tile address, swizzle, and stride; the C accumulator lives in registers as a hardware-defined per-thread fragment.
Why 128 threads. A WGMMA is a warpgroup-collective instruction: all 128 threads of a warpgroup must issue together, and hardware reads operand fragments from those 128 threads’ registers and writes the output back to them. We could keep 256 threads in the CTA and have the second warpgroup do separate work, but for K4 the simplest design is one warpgroup as the entire CTA. K5 (warp specialization) is where we’ll start using the second warpgroup.
This drops threads_per_row from 4 to 2, which means each row of \(O\) is now owned by 2 threads splitting the \(d\) axis. For \(d = 128\) each thread holds 64 FP32 values of \(O\), double K3’s 32.
K4.1: the SMEM-laundering half-step
After the WGMMA, acc_qk is a per-thread register fragment under the WGMMA C TV layout. Each of the 128 threads holds 32 FP32 values at tile positions decided by hardware. K4.1 writes this fragment back to SMEM into a plain \((B_r, B_c)\) tile sS, using an identity tensor probe (cute.make_identity_tensor) that reports the \((M, N)\) tile coordinate of each per-thread fragment slot. Each thread sweeps its own fragment, asks the identity tensor where each slot lives, and writes to sS[m, n] accordingly.
The K3 softmax then runs against sS unchanged. PV stays scalar (each thread does \(d / 2\) output columns by walking \(B_c\)). This launders the WGMMA C-layout through SMEM so the rest of the kernel doesn’t have to think about it.
K4.1 is what lets us validate the WGMMA path in isolation: the QK matmul accelerates by roughly two orders of magnitude over the K3 scalar version, but PV stays slow, so the overall result is “QK is now fast; PV is now the dominant cost.” That’s the right intermediate state before committing to K4.2’s harder changes.
K4.2: the register handoff
The natural next step is “put PV on WGMMA too.” That sounds like one change, but it forces a cascade:
-
PV’s A operand must come from registers or SMEM. WGMMA constrains B to SMEM, but A can be either. Re-using K4.1’s “write to SMEM, read back” trick for \(P\) does not work cleanly here, because the WGMMA C layout (where the softmax output sits) and the WGMMA A layout (where PV expects its A operand) are different register-level patterns, and converting between them through SMEM doesn’t line up byte-for-byte. So K4.2 keeps \(P\) in the registers
acc_qkalready lives in and hands them off directly to PV. -
The softmax must run on the register fragment. With no SMEM round-trip on \(P\), the online softmax has to operate directly on
acc_qk’s per-thread fragment. That means knowing the WGMMA C TV layout in full detail: which fragment slots correspond to which \((M, N)\) tile positions, how many distinct M rows each thread owns (2 for the 64×64 atom, not 1), and which threads share a row for the cross-thread reduction (4 sibling threads per row, not the K3 row group of 4 threads doing other things). -
The QK→PV handoff is a register reinterpretation. A small utility
make_acc_into_opallocates a register tensor in the PV A-layout, builds an aliasing view of those same registers in the QK C-layout, and writes the FP32→BF16 cast through the C view. The bytes are reinterpreted in place; no SMEM, no inter-thread shuffles (BF16 only; FP8 in K7 needs shuffles on top of this). -
The persistent accumulator
acc_pvis a register fragment allocated once before the mainloop, zero-initialised so the unified rescale is benign on iter 0, and updated each KV iteration. Cross-thread reduction of the row suma_sumis deferred to the finalize stage and done once after all KV iterations are complete. -
The output scatter writes
acc_pvto GMEM via the same identity-tensor coord-probe pattern from K4.1, but now against the PV tiled MMA’s C-layout.
V’s mode-order trap (callout)
\(V\) stores the same physical bytes as \(K\) (row-major \((N, d, B)\) with \(d\) contiguous), but plays a different role in PV than \(K\) plays in QK. For QK, \(K\) is the B operand and the contraction is \(d\), so \(K\) is K-major. For PV, \(V\) is the B operand and the contraction is \(B_c\); \(V\)’s contiguous axis is still \(d\), but for PV that’s the N axis, not the K axis. So \(V\) needs to be MN-major for PV.
The cleanest way to communicate this to WGMMA is to re-view \(V\) at kernel entry with mode order \((d, k, l)\) instead of \((k, d, l)\). The bytes don’t move; CuTe now sees \(d\) as the leading axis,
sm90_mma_major_mode()returns MN, and both the TMA atom and the WGMMA descriptor derive from this single view consistently.
Why 5 stages finally earns itself
K3.3 carried kv_stage = 5 as structural preparation, not as something the K3 kernel actually benefited from. K4.2 changes the calculus. The PV WGMMA on the consumer side is roughly an order of magnitude faster (in wall time) than the K3 scalar PV; the QK WGMMA is similarly fast; the register softmax removes the SMEM round-trip. One full KV iteration of compute is now genuinely fast, while TMA latency hasn’t changed.
With \(N\) stages the producer can run \(N - 1\) iterations ahead. Three to four iterations of headroom is typically enough to hide TMA latency against K4.2’s compute rate; 5 gives margin. The SMEM bound is the same as K3.3 (192 KB at \(d = 128\)), single CTA per SM. K5 onwards will revisit pipeline depth as warp specialization changes the rate ratio further.
Implementation
K4 ships as two files. kernels/k4_qk.py is the K4.1 half-step kernel (K4QKFMHA), with WGMMA on QK and the coord-tensor scatter to sS. kernels/k4.py is the K4.2 full kernel (K4FMHA) with both matmuls on WGMMA, register softmax, register handoff, and V mode-order fix. Both files share the K3.3 pipeline structure and several utility helpers (layout_separate, layout_acc_mn, reduction_target_n for K4.2 only).
The CuTe machinery K4 introduces relative to K3: the WGMMA atom (sm90_utils.make_trivial_tiled_mma), per-operand TV layouts and partitioning (partition_A/B/C, make_fragment_*), the async pattern (fence/commit_group/wait_group), cute.gemm for issuing one WGMMA, swizzled SMEM layouts (make_smem_layout_a/b), the identity-tensor coord probe (make_identity_tensor), and for K4.2 the helpers above plus the register handoff (make_acc_into_op) and the V re-view (cute.select). All of these get a proper write-up in the standalone K4 walkthrough.
Running K4
# K4.1: WGMMA for QK only
CHECK=1 SEQLEN=512 HEADDIM=64 CAUSAL=0 python bench.py k4_qk
# K4.2: full WGMMA
CHECK=1 SEQLEN=512 HEADDIM=64 CAUSAL=0 python bench.py k4
Results
Profiling
What’s next
K4.2 leaves obvious things on the table. The QK and PV WGMMAs issue serially in each KV iteration, with the softmax sequenced strictly between them. The same 128 threads do load issuance (warp 0), compute, and store — they spend cycles on pipeline housekeeping instead of math. K5 introduces warp specialization: split the warpgroup into “load” warps that just issue TMAs and “compute” warps that just run WGMMA + softmax. K6 then introduces ping-pong scheduling where two compute warpgroups alternate so one runs WGMMA while the other runs softmax, hiding the softmax cost behind GEMM.