Content on this site is AI-generated and may contain errors. If you find issues, please report at GitHub Issues .

Flash Attention Tiling Principles

Flash Attention Tiling Principles

Updated 2026-04-06

Introduction: Why Memory Is the Bottleneck in Standard Attention

In previous articles, we learned about the computation process of Scaled Dot-Product Attention:

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V

The standard implementation requires three steps:

  1. Compute S=QKTRN×NS = QK^T \in \mathbb{R}^{N \times N} — store to HBM
  2. Compute P=softmax(S)RN×NP = \text{softmax}(S) \in \mathbb{R}^{N \times N} — store to HBM
  3. Compute O=PVRN×dO = PV \in \mathbb{R}^{N \times d} — store to HBM

The problem lies in the intermediate matrices SS and PP, both of size N×NN \times N. When the sequence length NN is large (e.g., N=4096N = 4096), these two matrices require 40962×264MB4096^2 \times 2 \approx 64\text{MB} of memory (in fp16). More critically, these matrices need to be repeatedly read from and written to the GPU’s HBM (High Bandwidth Memory), whose bandwidth is far lower than the GPU’s compute speed.

Flash Attention (Dao et al., 2022) introduced a key insight: through tiled computation and Online Softmax, we can completely avoid storing the N×NN \times N intermediate matrices, reducing memory from O(N2)O(N^2) to O(N)O(N) while dramatically reducing the number of HBM accesses.

GPU Memory Hierarchy: SRAM vs HBM

To understand Flash Attention’s design motivation, we must first understand the GPU’s memory hierarchy.

Two-Level Storage

Storage LevelTypeA100 CapacityBandwidthCharacteristics
SRAM (on-chip cache)Registers + Shared Memory~20MB (192KB per SM)~19 TB/sExtremely fast, but very small capacity
HBM (High Bandwidth Memory)GPU memory40-80 GB~1.5-2.0 TB/sLarge capacity, but limited bandwidth

Key data: SRAM bandwidth is ~10x that of HBM, but its capacity is only ~1/2000 of HBM.

GPU Memory Hierarchy & Data Transfer Comparison

SRAM~20MB · 19 TB/sHBM80GB · 2 TB/sbandwidth bottleneck

Standard Attention(6 HBM transfers)

Step 1HBM → SRAMRead Q, K
Step 1SRAM → HBMWrite S = QKᵀ
Step 2HBM → SRAMRead S
Step 2SRAM → HBMWrite P = softmax(S)
Step 3HBM → SRAMRead P, V
Step 3SRAM → HBMWrite O = PV

Flash Attention(2 HBM transfers)

LoadHBM → SRAMRead Q, K, V blocks
Compute⟳ SRAMQKᵀ → scale → mask → softmax → ×V (all in SRAM)
WriteSRAM → HBMWrite final O only

Standard Attention requires 6 HBM transfers (3 reads + 3 writes), Flash Attention only needs 2 (1 read + 1 write)

Standard Attention’s Memory Access Pattern

The problem with the standard implementation is not the computation (FLOPs), but the memory access volume (IO):

Step 1: Read Q, K from HBM → Compute S = QK^T → Write S to HBM     (read 2Nd, write N²)
Step 2: Read S from HBM     → Compute P = softmax(S) → Write P to HBM (read N², write N²)
Step 3: Read P, V from HBM  → Compute O = PV → Write O to HBM        (read N²+Nd, write Nd)

Total HBM accesses: Θ(Nd+N2)\Theta(Nd + N^2). When NdN \gg d, the N2N^2 term dominates.

Flash Attention’s goal: Through tiled computation, keep all intermediate results in SRAM, reducing HBM accesses to Θ(N2d2M1)\Theta(N^2 d^2 M^{-1}), where MM is the SRAM size.

Tiling Strategy: Tiling Q, K, V

Flash Attention’s first technique is Tiling: splitting Q, K, V into appropriately-sized blocks so that each block fits entirely in SRAM.

Block Size Selection

Given SRAM size MM and head dimension dd:

Bc=M4d,Br=min ⁣(M4d,d)B_c = \left\lceil \frac{M}{4d} \right\rceil, \quad B_r = \min\!\left(\left\lceil \frac{M}{4d} \right\rceil, d\right)

This ensures that a Br×dB_r \times d Q block, a Bc×dB_c \times d K block, a Bc×dB_c \times d V block, and a Br×BcB_r \times B_c local score matrix all fit in SRAM.

Block Size Calculator

Bc = ⌈M/(4d)⌉
400
Br = min(Bc, d)
64
Q blocks (Tr)
8
K/V blocks (Tc)
2
Q Matrix (512×64)
64×64
64×64
64×64
64×64
64×64
64×64
64×64
64×64
K, V Matrix (512×64)
400×64
400×64

Larger SRAM → larger blocks → fewer outer loops → less HBM access (currently 16 block computations)

Nested Loop Structure

Flash Attention uses nested loops:

Outer loop (j = 1 to T_c):     // iterate over K, V blocks
  Load K_j, V_j from HBM to SRAM
  Inner loop (i = 1 to T_r):   // iterate over Q blocks
    Load Q_i, O_i, l_i, m_i from HBM to SRAM
    Compute local attention in SRAM
    Update O_i, l_i, m_i
    Write back to HBM

Where Tc=N/BcT_c = \lceil N / B_c \rceil is the number of K/V blocks and Tr=N/BrT_r = \lceil N / B_r \rceil is the number of Q blocks.

Key: The N×NN \times N attention matrix is never fully materialized. Only a Br×BcB_r \times B_c tile is computed at a time, then discarded.

Online Softmax: Detailed Derivation of the Core Innovation

The biggest challenge with tiled computation is that softmax needs to see an entire row of data to normalize. If you only see part of the columns, how can you compute the correct softmax?

This is exactly what Online Softmax solves.

Standard Softmax Review

For a row vector xRBx \in \mathbb{R}^B, the numerically stable softmax is:

m(x)=maxixi,f(x)=[ex1m(x)exBm(x)],(x)=if(x)i,softmax(x)=f(x)(x)m(x) = \max_i x_i, \quad f(x) = \begin{bmatrix} e^{x_1 - m(x)} & \cdots & e^{x_B - m(x)} \end{bmatrix}, \quad \ell(x) = \sum_i f(x)_i, \quad \text{softmax}(x) = \frac{f(x)}{\ell(x)}

Where m(x)m(x) is the maximum value (for numerical stability), f(x)f(x) is the shifted exponential vector, and (x)\ell(x) is the normalization constant.

Splitting into Two Blocks

Suppose the vector xx is split into two parts x=[x(1),x(2)]x = [x^{(1)}, x^{(2)}], where x(1),x(2)RBx^{(1)}, x^{(2)} \in \mathbb{R}^B. We want to show that the global softmax can be derived from the local statistics of the two parts.

The global maximum can be obtained from local maxima:

m(x)=max ⁣(m(x(1)),m(x(2)))m(x) = \max\!\big(m(x^{(1)}), m(x^{(2)})\big)

The global shifted exponential vector:

f(x)=[em(x(1))m(x)f(x(1))em(x(2))m(x)f(x(2))]f(x) = \begin{bmatrix} e^{m(x^{(1)}) - m(x)} f(x^{(1)}) & e^{m(x^{(2)}) - m(x)} f(x^{(2)}) \end{bmatrix}

The global normalization constant:

(x)=em(x(1))m(x)(x(1))+em(x(2))m(x)(x(2))\ell(x) = e^{m(x^{(1)}) - m(x)} \ell(x^{(1)}) + e^{m(x^{(2)}) - m(x)} \ell(x^{(2)})

Key insight: The exponential correction factor em(x(1))m(x)e^{m(x^{(1)}) - m(x)} compensates for the difference between the local max and the global max. If the new block has a larger maximum (m(x(2))>m(x(1))m(x^{(2)}) > m(x^{(1)})), all previous eximolde^{x_i - m_{\text{old}}} values need to be multiplied by emoldmnewe^{m_{\text{old}} - m_{\text{new}}} to correct them.

Recurrence Algorithm

This decomposition can be applied recursively to any number of blocks. Let mjm_j, j\ell_j, OjO_j be the statistics after processing the jj-th block. When the (j+1)(j+1)-th block arrives:

Step 1: Compute local scores

S~=QiKj+1T/d\tilde{S} = Q_i K_{j+1}^T / \sqrt{d}

Step 2: Compute local statistics

m~=rowmax(S~),P~=exp(S~m~),~=rowsum(P~)\tilde{m} = \text{rowmax}(\tilde{S}), \quad \tilde{P} = \exp(\tilde{S} - \tilde{m}), \quad \tilde{\ell} = \text{rowsum}(\tilde{P})

Step 3: Update global statistics

mnew=max(mj,m~)m^{\text{new}} = \max(m_j, \tilde{m}) new=emjmnewj+em~mnew~\ell^{\text{new}} = e^{m_j - m^{\text{new}}} \cdot \ell_j + e^{\tilde{m} - m^{\text{new}}} \cdot \tilde{\ell}

Step 4: Correct and update the output

Onew=diag(new)1 ⁣(diag(j)emjmnewOj+em~mnewP~Vj+1)O^{\text{new}} = \text{diag}(\ell^{\text{new}})^{-1} \!\left( \text{diag}(\ell_j) \cdot e^{m_j - m^{\text{new}}} \cdot O_j + e^{\tilde{m} - m^{\text{new}}} \cdot \tilde{P} \cdot V_{j+1} \right)

What this formula means:

  • diag(j)Oj\text{diag}(\ell_j) \cdot O_j: “Un-normalizes” the previous output back to the state before dividing by \ell
  • emjmnewe^{m_j - m^{\text{new}}}: Correction factor compensating for the difference between old max and new max
  • em~mnewP~Vj+1e^{\tilde{m} - m^{\text{new}}} \cdot \tilde{P} \cdot V_{j+1}: The new block’s contribution (also corrected to the new max)
  • diag(new)1\text{diag}(\ell^{\text{new}})^{-1}: Re-normalizes with the new normalization constant

Why Is It Exact?

Online Softmax is not an approximation — it is mathematically exactly equivalent to standard softmax. The entire derivation is based on a simple algebraic identity:

eximoldemnewmold=eximnew\frac{e^{x_i - m_{\text{old}}}}{e^{m_{\text{new}} - m_{\text{old}}}} = e^{x_i - m_{\text{new}}}

Regardless of how many blocks the data is split into or what order they are processed in, the final result is exactly the same.

Block 1: Initialize
B1:[2.1, 3.2]
B2:[4.1, 1.5]
B3:[2.8, 3]

s₁ = [2.1, 3.2] → m₁ = max(2.1, 3.2) = 3.2 → exp(s₁ - m₁) = [0.3329, 1.0000] → l₁ = 1.3329

m = 3.2l = 1.3329O = [0.7251, 0.1499]

No need to store full N×N matrix, only maintain m, l, O accumulators

Interactive Demo: Flash Attention Tiled Computation

Below is a small example with N=4,d=3,B=2N=4, d=3, B=2, demonstrating step by step how Flash Attention processes the first Q block (t1,t2t_1, t_2), interacting with two K/V blocks sequentially, and using Online Softmax correction to obtain exact results.

Q, K, V matrices and blocking

Standard Attention requires storing the full N×N attention matrix in HBM, memory is O(N²)Flash Attention core idea: split Q, K, V into blocks, compute in SRAM blockwise,never store the full N×N matrix

Q ∈ ℝ^(4×3)
d₁
d₂
d₃
t₁
0.05
0.11
0.42
t₂
0.03
0.89
0.59
t₃
0.63
0.06
0.25
t₄
-0.56
0.56
0.76
(4, 3)
K ∈ ℝ^(4×3)
d₁
d₂
d₃
t₁
0.99
-0.13
0.51
t₂
-0.54
-0.85
0.13
t₃
0.17
-0.34
0.28
t₄
0.42
-0.63
-0.28
(4, 3)
V ∈ ℝ^(4×3)
d₁
d₂
d₃
t₁
-0.07
0.10
0.13
t₂
0.89
-0.59
0.14
t₃
-0.29
0.79
0.78
t₄
-0.13
0.65
0.68
(4, 3)
Blocking:block size Br = Bc = 2。Highlighted rows = first block (t₁, t₂), non-highlighted = second block (t₃, t₄). We show processing two K/V blocks using Q's first block as example.

Memory Reduction from O(N2)O(N^2) to O(N)O(N): Derivation

Standard Attention Memory

The standard implementation needs to store intermediate matrices SS and PP:

Memory=NdQ+NdK+NdV+N2S+N2P+NdO=Θ(Nd+N2)\text{Memory} = \underbrace{Nd}_Q + \underbrace{Nd}_K + \underbrace{Nd}_V + \underbrace{N^2}_S + \underbrace{N^2}_P + \underbrace{Nd}_O = \Theta(Nd + N^2)

When NdN \gg d, the O(N2)O(N^2) term dominates.

Flash Attention Memory

Flash Attention only needs to store inputs, outputs, and auxiliary statistics:

Memory=NdQ+NdK+NdV+NdO+N+Nm=Θ(Nd)\text{Memory} = \underbrace{Nd}_Q + \underbrace{Nd}_K + \underbrace{Nd}_V + \underbrace{Nd}_O + \underbrace{N}_{\ell} + \underbrace{N}_{m} = \Theta(Nd)

No N2N^2 terms at all! The local Br×BcB_r \times B_c score matrix exists only temporarily in SRAM, not in HBM.

Theorem 1 (Dao et al., 2022): The Flash Attention algorithm returns O=softmax(QKT)VO = \text{softmax}(QK^T)V, uses O(N2d)O(N^2 d) FLOPs, and requires only O(N)O(N) additional memory.

IO Complexity Analysis: Why It’s Faster

Flash Attention not only saves memory but also saves time, because the bottleneck of Attention on GPUs is not computation but memory access.

Standard Attention IO Complexity

HBM accesses=Θ(Nd+N2)\text{HBM accesses} = \Theta(Nd + N^2)

Flash Attention IO Complexity

Theorem 2 (Dao et al., 2022): Let NN be the sequence length, dd the head dimension, and MM the SRAM size (dMNdd \leq M \leq Nd). Standard Attention requires Θ(Nd+N2)\Theta(Nd + N^2) HBM accesses; Flash Attention requires Θ(N2d2M1)\Theta(N^2 d^2 M^{-1}).

Intuitive understanding:

  • The outer loop iterates over Tc=N/BcT_c = N/B_c K/V blocks, each loading Θ(Bcd)=Θ(M)\Theta(B_c d) = \Theta(M) data
  • The inner loop iterates over Tr=N/BrT_r = N/B_r Q blocks, each loading and writing back Θ(Brd)\Theta(B_r d) data
  • Total accesses: Tc×(M+Tr×Brd)=NBc×NBr×Brd=N2dBcT_c \times (M + T_r \times B_r d) = \frac{N}{B_c} \times \frac{N}{B_r} \times B_r d = \frac{N^2 d}{B_c}
  • Since Bc=Θ(M/d)B_c = \Theta(M/d), we get N2dBc=Θ(N2d2/M)\frac{N^2 d}{B_c} = \Theta(N^2 d^2 / M)

For typical parameters (d=64-128d = 64\text{-}128, M100KBM \approx 100\text{KB}), d2d^2 is much smaller than MM, so N2d2/MN2N^2 d^2 / M \ll N^2. In experiments, Flash Attention is 2-4x faster than the standard implementation.

IO Complexity Comparison: Standard vs Flash v1 vs v2

2565121K2K4K8K16K32K64KSequence length NHBM access (log scale)
Standard Θ(Nd+N²)
Flash v1 Θ(N²d²/M)
Flash v2 Θ(N²d/M)

Long sequences: standard IO explosion vs Flash Attention sub-quadratic growth

Lower Bound

Proposition 3 (Dao et al., 2022): No exact Attention algorithm can achieve o(N2d2M1)o(N^2 d^2 M^{-1}) HBM accesses for all M[d,Nd]M \in [d, Nd].

This means Flash Attention is asymptotically optimal in terms of IO complexity.

Flash Attention v1 vs v2

In 2023, Tri Dao released Flash Attention v2, which further optimized GPU parallelism on top of v1.

ComparisonFlash Attention v1Flash Attention v2
Outer loopIterates over K/V blocksIterates over Q blocks
Inner loopIterates over Q blocksIterates over K/V blocks
Inter-block parallelismDifferent heads & batch parallelAdditional parallelism on Q block dimension
Non-matmul FLOPsMoreReduced, better Tensor Core utilization
Inter-warp communicationVia shared memoryReduced inter-warp communication
A100 utilization25-40% of theoretical peak50-73% of theoretical peak
Relative speedupBaseline~2x over v1

Key Improvements in v2

1. Swapped Loop Order

v1’s outer loop iterates over K/V blocks, inner loop over Q blocks. v2 reverses this: the outer loop iterates over Q blocks, the inner loop over K/V blocks. This way, each thread block is responsible for only one Q block’s output, reducing synchronization overhead, and allows parallelism across the Q block dimension by distributing to different thread blocks (streaming multiprocessors).

2. Reduced Non-matmul FLOPs

GPU Tensor Cores have extremely high throughput for matrix multiplication, but the rescaling, max, and sum operations in Online Softmax are non-matmul FLOPs. v2 reduces the proportion of these operations through algorithmic adjustments.

3. Better Intra-warp Work Distribution

v2 optimizes task partitioning between warps, reducing the number of synchronizations through shared memory, further improving parallel efficiency.

Summary

Flash Attention solves the memory and speed bottlenecks of standard Attention through three core techniques:

TechniqueProblem SolvedEffect
TilingN×NN \times N matrix not stored in HBMMemory O(N2)O(N)O(N^2) \to O(N)
Online SoftmaxCorrect normalization in tiled computationMathematically exact, zero approximation error
IO-aware designReduces HBM access count2-4x speed improvement

Core formula quick reference:

mnew=max(mold,m~),new=emoldmnewold+em~mnew~m^{\text{new}} = \max(m^{\text{old}}, \tilde{m}), \quad \ell^{\text{new}} = e^{m^{\text{old}} - m^{\text{new}}} \ell^{\text{old}} + e^{\tilde{m} - m^{\text{new}}} \tilde{\ell} Onew=diag(new)1 ⁣(emoldmnewdiag(old)Oold+em~mnewP~V)O^{\text{new}} = \text{diag}(\ell^{\text{new}})^{-1}\!\left(e^{m^{\text{old}} - m^{\text{new}}} \text{diag}(\ell^{\text{old}}) O^{\text{old}} + e^{\tilde{m} - m^{\text{new}}} \tilde{P} V\right)

Flash Attention has become a standard component in modern large model inference and training. Starting from PyTorch 2.0, torch.nn.functional.scaled_dot_product_attention uses the Flash Attention backend by default. Understanding its tiling principles is an important foundation for deeply understanding LLM system optimization.