Flash Attention (Dao et al., 2022; FA2 2023) solves a key bottleneck: standard attention implementations are IO-bound, spending most time reading/writing the attention matrix to HBM (GPU global memory). Flash Attention tiles the computation to keep data in fast SRAM registers.
Why Standard Attention Is Slow
The naive attention algorithm materialises the full N×N attention matrix in HBM. For a 32K sequence in a 40-layer model, this is 32,000² × 40 × 2 bytes ≈ 80 GB of HBM reads/writes per forward pass — regardless of GPU compute speed. The GPU is mostly waiting for memory transfers.
What Flash Attention Does
It reorders the attention computation into small blocks that fit in the GPU's L2 cache (SRAM), fuses the softmax normalisation across blocks using an online algorithm, and never writes the full N×N matrix to HBM. The result is O(N) memory instead of O(N²) — the same mathematical output, faster and using dramatically less VRAM.
Flash Attention 2 Improvements
- Better parallelism across attention heads (fewer sync barriers)
- Full GQA and MQA support
- ~2× throughput vs FA1 on A100
- Supports causal masking, ALiBi, RoPE without overhead
Flash Attention 3 (2024)
Targets H100 Tensor Cores. Overlaps computation and data transfer (async pipeline), enabling ~75% utilisation of H100 peak FP16 throughput. Available in vLLM ≥0.4 and SGLang.
Why It Matters for On-Premise
Flash Attention is enabled by default in most modern inference stacks (vLLM, Ollama, llama.cpp). But you must ensure your PyTorch build includes it (pip install flash-attn requires GCC and CUDA SDK). Without it, a 32K context can OOM a GPU that would otherwise handle it fine.