Flash Attention Tiling Principles
Updated 2026-04-06
Introduction: Why Memory Is the Bottleneck in Standard Attention
In previous articles, we learned about the computation process of Scaled Dot-Product Attention:
The standard implementation requires three steps:
- Compute — store to HBM
- Compute — store to HBM
- Compute — store to HBM
The problem lies in the intermediate matrices and , both of size . When the sequence length is large (e.g., ), these two matrices require of memory (in fp16). More critically, these matrices need to be repeatedly read from and written to the GPU’s HBM (High Bandwidth Memory), whose bandwidth is far lower than the GPU’s compute speed.
Flash Attention (Dao et al., 2022) introduced a key insight: through tiled computation and Online Softmax, we can completely avoid storing the intermediate matrices, reducing memory from to while dramatically reducing the number of HBM accesses.
GPU Memory Hierarchy: SRAM vs HBM
To understand Flash Attention’s design motivation, we must first understand the GPU’s memory hierarchy.
Two-Level Storage
| Storage Level | Type | A100 Capacity | Bandwidth | Characteristics |
|---|---|---|---|---|
| SRAM (on-chip cache) | Registers + Shared Memory | ~20MB (192KB per SM) | ~19 TB/s | Extremely fast, but very small capacity |
| HBM (High Bandwidth Memory) | GPU memory | 40-80 GB | ~1.5-2.0 TB/s | Large capacity, but limited bandwidth |
Key data: SRAM bandwidth is ~10x that of HBM, but its capacity is only ~1/2000 of HBM.
GPU Memory Hierarchy & Data Transfer Comparison
Standard Attention(6 HBM transfers)
Flash Attention(2 HBM transfers)
Standard Attention requires 6 HBM transfers (3 reads + 3 writes), Flash Attention only needs 2 (1 read + 1 write)
Standard Attention’s Memory Access Pattern
The problem with the standard implementation is not the computation (FLOPs), but the memory access volume (IO):
Step 1: Read Q, K from HBM → Compute S = QK^T → Write S to HBM (read 2Nd, write N²)
Step 2: Read S from HBM → Compute P = softmax(S) → Write P to HBM (read N², write N²)
Step 3: Read P, V from HBM → Compute O = PV → Write O to HBM (read N²+Nd, write Nd)
Total HBM accesses: . When , the term dominates.
Flash Attention’s goal: Through tiled computation, keep all intermediate results in SRAM, reducing HBM accesses to , where is the SRAM size.
Tiling Strategy: Tiling Q, K, V
Flash Attention’s first technique is Tiling: splitting Q, K, V into appropriately-sized blocks so that each block fits entirely in SRAM.
Block Size Selection
Given SRAM size and head dimension :
This ensures that a Q block, a K block, a V block, and a local score matrix all fit in SRAM.
Block Size Calculator
Larger SRAM → larger blocks → fewer outer loops → less HBM access (currently 16 block computations)
Nested Loop Structure
Flash Attention uses nested loops:
Outer loop (j = 1 to T_c): // iterate over K, V blocks
Load K_j, V_j from HBM to SRAM
Inner loop (i = 1 to T_r): // iterate over Q blocks
Load Q_i, O_i, l_i, m_i from HBM to SRAM
Compute local attention in SRAM
Update O_i, l_i, m_i
Write back to HBM
Where is the number of K/V blocks and is the number of Q blocks.
Key: The attention matrix is never fully materialized. Only a tile is computed at a time, then discarded.
Online Softmax: Detailed Derivation of the Core Innovation
The biggest challenge with tiled computation is that softmax needs to see an entire row of data to normalize. If you only see part of the columns, how can you compute the correct softmax?
This is exactly what Online Softmax solves.
Standard Softmax Review
For a row vector , the numerically stable softmax is:
Where is the maximum value (for numerical stability), is the shifted exponential vector, and is the normalization constant.
Splitting into Two Blocks
Suppose the vector is split into two parts , where . We want to show that the global softmax can be derived from the local statistics of the two parts.
The global maximum can be obtained from local maxima:
The global shifted exponential vector:
The global normalization constant:
Key insight: The exponential correction factor compensates for the difference between the local max and the global max. If the new block has a larger maximum (), all previous values need to be multiplied by to correct them.
Recurrence Algorithm
This decomposition can be applied recursively to any number of blocks. Let , , be the statistics after processing the -th block. When the -th block arrives:
Step 1: Compute local scores
Step 2: Compute local statistics
Step 3: Update global statistics
Step 4: Correct and update the output
What this formula means:
- : “Un-normalizes” the previous output back to the state before dividing by
- : Correction factor compensating for the difference between old max and new max
- : The new block’s contribution (also corrected to the new max)
- : Re-normalizes with the new normalization constant
Why Is It Exact?
Online Softmax is not an approximation — it is mathematically exactly equivalent to standard softmax. The entire derivation is based on a simple algebraic identity:
Regardless of how many blocks the data is split into or what order they are processed in, the final result is exactly the same.
s₁ = [2.1, 3.2] → m₁ = max(2.1, 3.2) = 3.2 → exp(s₁ - m₁) = [0.3329, 1.0000] → l₁ = 1.3329
No need to store full N×N matrix, only maintain m, l, O accumulators
Interactive Demo: Flash Attention Tiled Computation
Below is a small example with , demonstrating step by step how Flash Attention processes the first Q block (), interacting with two K/V blocks sequentially, and using Online Softmax correction to obtain exact results.
Standard Attention requires storing the full N×N attention matrix in HBM, memory is O(N²)。Flash Attention core idea: split Q, K, V into blocks, compute in SRAM blockwise,never store the full N×N matrix。
Memory Reduction from to : Derivation
Standard Attention Memory
The standard implementation needs to store intermediate matrices and :
When , the term dominates.
Flash Attention Memory
Flash Attention only needs to store inputs, outputs, and auxiliary statistics:
No terms at all! The local score matrix exists only temporarily in SRAM, not in HBM.
Theorem 1 (Dao et al., 2022): The Flash Attention algorithm returns , uses FLOPs, and requires only additional memory.
IO Complexity Analysis: Why It’s Faster
Flash Attention not only saves memory but also saves time, because the bottleneck of Attention on GPUs is not computation but memory access.
Standard Attention IO Complexity
Flash Attention IO Complexity
Theorem 2 (Dao et al., 2022): Let be the sequence length, the head dimension, and the SRAM size (). Standard Attention requires HBM accesses; Flash Attention requires .
Intuitive understanding:
- The outer loop iterates over K/V blocks, each loading data
- The inner loop iterates over Q blocks, each loading and writing back data
- Total accesses:
- Since , we get
For typical parameters (, ), is much smaller than , so . In experiments, Flash Attention is 2-4x faster than the standard implementation.
IO Complexity Comparison: Standard vs Flash v1 vs v2
Long sequences: standard IO explosion vs Flash Attention sub-quadratic growth
Lower Bound
Proposition 3 (Dao et al., 2022): No exact Attention algorithm can achieve HBM accesses for all .
This means Flash Attention is asymptotically optimal in terms of IO complexity.
Flash Attention v1 vs v2
In 2023, Tri Dao released Flash Attention v2, which further optimized GPU parallelism on top of v1.
| Comparison | Flash Attention v1 | Flash Attention v2 |
|---|---|---|
| Outer loop | Iterates over K/V blocks | Iterates over Q blocks |
| Inner loop | Iterates over Q blocks | Iterates over K/V blocks |
| Inter-block parallelism | Different heads & batch parallel | Additional parallelism on Q block dimension |
| Non-matmul FLOPs | More | Reduced, better Tensor Core utilization |
| Inter-warp communication | Via shared memory | Reduced inter-warp communication |
| A100 utilization | 25-40% of theoretical peak | 50-73% of theoretical peak |
| Relative speedup | Baseline | ~2x over v1 |
Key Improvements in v2
1. Swapped Loop Order
v1’s outer loop iterates over K/V blocks, inner loop over Q blocks. v2 reverses this: the outer loop iterates over Q blocks, the inner loop over K/V blocks. This way, each thread block is responsible for only one Q block’s output, reducing synchronization overhead, and allows parallelism across the Q block dimension by distributing to different thread blocks (streaming multiprocessors).
2. Reduced Non-matmul FLOPs
GPU Tensor Cores have extremely high throughput for matrix multiplication, but the rescaling, max, and sum operations in Online Softmax are non-matmul FLOPs. v2 reduces the proportion of these operations through algorithmic adjustments.
3. Better Intra-warp Work Distribution
v2 optimizes task partitioning between warps, reducing the number of synchronizations through shared memory, further improving parallel efficiency.
Summary
Flash Attention solves the memory and speed bottlenecks of standard Attention through three core techniques:
| Technique | Problem Solved | Effect |
|---|---|---|
| Tiling | matrix not stored in HBM | Memory |
| Online Softmax | Correct normalization in tiled computation | Mathematically exact, zero approximation error |
| IO-aware design | Reduces HBM access count | 2-4x speed improvement |
Core formula quick reference:
Flash Attention has become a standard component in modern large model inference and training. Starting from PyTorch 2.0, torch.nn.functional.scaled_dot_product_attention uses the Flash Attention backend by default. Understanding its tiling principles is an important foundation for deeply understanding LLM system optimization.