FlashAttention 3 - A Worklog
Hello all. This blog post is about how to implement Flash Attention 3, specifically with CuteDSL, targeting H100 GPUs. The main goal of this work is to create a worklog on incrementally building and adding features to each kernel, to clearly and cleanly understand how we can get to Model FLOPs Utilization(MFU).
Much of this work is inspired by some of the best people whose content I really enjoyed in the GPU kernel optimization and ML performance engineering domain — specifically Simon, Aleksa Gordic, Kapil Sharma, and Hamza for their similar worklogs, which made me want to create one like this in the first place. Since this is going to be my take on such a worklog, I wanted to do something not available on the internet, or at least not in the form of a worklog.
- How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog - Simon Boehm
- Inside NVIDIA GPUs: Anatomy of high performance matmul kernels - Aleksa Gordić
- Learn CUTLASS the hard way! - Kapil Sharma
- Worklog: Optimising GEMM on NVIDIA H100 for cuBLAS-like Performance (WIP) - Hamza
At present, most ML/AI architectures depend on two basic primitives: GEMM and Attention. I’m taking Attention(specifically Flash Attention) and exploring its kernel design and how to push it toward MFU. Below is the rough plan.
Flash Attention 3(FA3) Worklog:
- Attention math
- Online softmax
- Tiling and GPU memory hierarchy
- FA2
- CuTe DSL fundamentals and primitives
- K1 — naive attention
- K2 — tiled softmax
- K3 — using TMA
- K4 — using WGMMA
- K5 — warp specialization
- K6 — ping-pong scheduling
- K7 — FP8
- K8 — FP8 + incoherent processing
- Flash decode
The plan is to follow the above as one section each. Throughout this worklog, I’ll provide good intuition along with the details on how things work or what we need to make them work. Since each of these can be its own topic, I’ll also add separate notes for many of them, which you can refer to if you want to go in-depth. I’ll use Excalidraw to illustrate and detail the mental images we can form to build good intuition.
This is a work in progress, marked with the “WIP” tag. Once it’s complete, I’ll remove that.
I sincerely thank all the many people who planted this idea in me. I hope this will be helpful for anyone who wants to understand FlashAttention 3 (FA3) from first principles. All the code will be made available here: FlashAttention3
I’m not going to use LLMs for most of this content, for a few reasons. That said, I’ll make sure there aren’t any conceptual mistakes. If you find any, please let me know and I’m happy to correct them and make this useful for everyone. If you’re in this domain, feel free to ping me — I’m happy to collaborate on projects around GPU kernels and ML perf engineering.
I’ll add references along each section, instead of tagging them at the end. I’m assuming basic ML/AI and GPU knowledge to get the best out of this blog. All the analysis will be on H100 SXM GPUs, since FA3 was targeted at the Hopper architecture and aims to squeeze out its full utilization through Hopper-specific features like TMA, WGMMA, asynchronous operations, and FP8 support. One point worth mentioning here: there isn’t a huge algorithmic difference between FA2 and FA3, but a lot of engineering work went into FA3 to take advantage of Hopper. FA2 reached around 70% of theoretical max FLOPS on Ampere (A100) but didn’t exploit Hopper-specific features, while FA3 reaches around 75% on H100 in FP16, and with FP8 it gets close to 1.2 PFLOPS with 2.6x smaller numerical error than baseline FP8 attention. More on this in the upcoming sections.
That’s it on the intro. Let’s get started.
Reference of FlashAttention-3 Paper - Link
Attention Math
Attention is the core of transformers, and transformers are the core of large language models (LLMs). Transformers consist of two primitives — attention, followed by feed-forward networks (FFNs) — surrounded by norms and residual connections. Below is an image that captures this.
Attention gives you the context of what is being referred to and meant in a sentence, and the FFN predicts the next token based on this context. Both are critical to each other and together form the backbone of present-day LLMs.
Now, if we look at attention in more detail, it’s composed of three operations:
\[S = Q K^\top\] \[P = \text{softmax}(S) \quad \text{(along the last dimension)}\] \[O = P V\]Refer to Jay Alammar’s transformers blog if you’re not sure of the attention mechanism and what happens under the hood.
Here, (1) and (3) are matrix multiplications — in GPU terminology, we call them GEMM (General Matrix Multiply) — and (2) is a softmax along the last dimension. This matters because softmax was a major bottleneck before, and FlashAttention made algorithmic changes to how it’s computed, while staying numerically exact.
The reason for optimizing attention is that materializing the intermediate \(S\) and \(P\) matrices is expensive. Why? Think in terms of dimensions and how much memory we’d need to get this computation done.
Let’s do a simple calculation. Let \(Q\), \(K\), and \(V\) be of shape \((b, s, d)\) → (batch, sequence length, model dim). Then \(S = QK^\top\) has shape \((b, s, s)\). Now apply softmax. What happens when the sequence length is 1 million and batch size is 1? Assuming FP32 (4 bytes per element):
\[1 \times 10^6 \times 10^6 \times 4 \text{ bytes} = 4 \times 10^{12} \text{ bytes} = 4 \text{ TB}\]And that’s just to materialize one \(S\) matrix. We still need to do the softmax and another matmul on top of that. Attention is also computed many times, since LLMs have hundreds of transformer blocks. Most GPUs don’t have 4 TB of memory, and even if they did, it still wouldn’t be feasible — we need space for intermediate results and bookkeeping too.
This is the major bottleneck FlashAttention addresses. At its core, it does two things: online softmax and tiling. Together, these ensure we never materialize the full \(s \times s\) matrix in GPU memory, while producing the exact same numerical result as if we had. This is what made it possible to train models with longer sequence lengths, which gave us more context and helped LLMs become much more useful.
There are a lot of details in and around this, but we’ll focus on FlashAttention 3 going forward, with all necessary information and intuition along the way for each optimization and topic. With what attention doing in place and why we need to optimize it, we’ll next talk about online softmax.
A side note:
If you want to understand how LLMs work more clearly, try implementing one. Check out the Stanford CS336 course and follow their assignments — they’re self-contained and really push you to have a clear understanding of all the internals of LLMs. I did it as part of my coursework at NYU, and I’m glad to have taken that course and did this assignment. Highly recommend it.
Also, take a look at this: Transformer block accounting and the below image — think through it to understand the memory and FLOPs counts for a transformer block. This gives you a picture of what it takes to run one transformer block, so you can appreciate all the engineering that’s gone into making it fast and efficient.