Let’s do a small exercise. Try to memorize this sequence of random numbers: 3 9 14 56 89 1 20 43 81 23. Do you feel a mental strain somewhere around the 7th number? That’s the empirical capacity for processing information. Us humans can reliably hold 7 ± 2 discrete items in working memory, in some ways the span of our immediate memory. This was quantified by George Miller in the Psychological Review in “The Magical Number Seven, Plus or Minus Two”1. But our brain doesn’t hit this limit because it’s poorly designed. It hits the limit because it’s optimally designed for the various constraints it is under. Traditional attention mechanisms haven't learned this lesson.
In my previous blog, I introduced attention as the brain's ability to selectively focus. But what if I told you that the way we've been computing attention is like forcing your brain to consciously track every possible connection between every thought you've ever had? Flash Attention (Dao et al.)2 changes this. It doesn't just make attention faster, it makes it more like our brain.
A Tale of Two Memories
Modern GPUs have two types of memory, and understanding the difference let’s you appreciate Flash Attention's brilliance. Think of GPU memory like a city:
HBM (High Bandwidth Memory) is like a massive warehouse on the outskirts. It can store enormous amounts of data but it's far away. Every trip there takes time, and moving large shipments is expensive in terms of energy and bandwidth.
SRAM (On-chip Memory) is like a small office right in the city center. It's tiny but lightning fast. If you can fit your work here, everything happens instantly.
Traditional attention is like a company that stores everything in the distant warehouse and makes constant round trips. Flash Attention is like a smart company that brings just what it needs to the city center office and does all the work there.
The Quadratic Memory Explosion
Imagine if your brain tried to compute attention by creating a giant matrix showing how every thought relates to every other thought you've ever had. For just 2048 thoughts, that's 2048 × 2048 = 4 million relationships to store and compute. Traditional attention does exactly this. For a sequence length of L tokens, it creates an L × L attention matrix. That's L² memory usage. This massive matrix has to be stored somewhere, usually in that distant warehouse memory, creating a bottleneck that slows down performance.
Flash Attention Intuition
The core inefficiency in the traditional attention mechanism is the ordering of operations. Each input sequence requires computing a full attention matrix of shape L×L, storing it in memory, applying softmax, and then multiplying by the value matrix. Every step materializes intermediate states that consume bandwidth and slow down computation. FlashAttention reorders this process. Instead of generating and storing the full attention matrix, it computes the output directly by streaming small blocks of queries, keys, and values through fast on-chip memory. This is called tiling. It processes each tile one at a time, using just the pieces of Q, K, and V it needs. And instead of remembering every attention weight, it selectively recomputes sets of attention scores for each tile when needed. By fusing the softmax and matrix multiplications into a single kernel, it eliminates the need to write or read intermediate results. This way, it removes redundant HBM read/writes. In other words, this is what we call kernel fusion.
Online Softmax
Applying softmax to the attention scores traditionally requires all logits (attention scores before normalization) for each query to be available simultaneously. FlashAttention circumvents this by computing softmax incrementally, using an online softmax algorithm. This is similar to how we compute a running average. You don’t need to store every number and instead you just update your total and divide at the end. This technique lets FlashAttention avoid storing the attention matrix entirely, while still giving exactly the same result.
Flash Attention v1 vs v2
FlashAttention v1 introduced blockwise computation and fused kernel operations. However, it was still sequential along the query sequence: for each block of queries, it iterated through all blocks of keys and values.
FlashAttention v2 addressed this bottleneck by using parallelization across query tiles. Rather than handling each query tile independently and sequentially, v2 simultaneously processes multiple query tiles. Each of these parallel tasks accesses the same set of key-value tiles loaded into SRAM exactly once. This parallelization requires careful synchronization to prevent redundant memory operations. FlashAttention v2 achieves this through coordinated scheduling of tiles and precise memory access patterns. As a result, it significantly increases throughput by making better use of GPU parallelism, reducing the overall runtime without sacrificing the efficiency gains of tiling and recomputation. Instead of requiring enterprise-grade GPUs with hundreds of gigabytes of memory, sophisticated language models can run on more accessible hardware configs.
By breaking the quadratic memory wall, it makes possible applications that were previously impossible or impractical. Longer context windows now become feasible, allowing transformers to process entire documents, complete codebases, or long conversations without hitting memory limits.
Implementation Deep Dive
I started by implementing the scaled dot product attention mechanism in PyTorch along with Multi Headed Attention for the purpose of benchmarking speed and memory consumption. I then implemented both versions of FlashAttention from scratch in Triton3, a GPU programming language designed by OpenAI for block-level abstraction and performance. Unlike CUDA, which requires manual management of individual threads, Triton operates on a higher-level abstraction of data blocks, allowing the expression of the algorithm in a way that maps cleanly to the GPU's architecture.
The core of the implementation is the Triton kernel, which is launched on a 2D or 3D grid of program instances. In my implementation, the grid is defined to parallelize across the batch, head, and sequence dimensions. For example, the grid
triton.cdiv(seq_len, BLOCK_SIZE_M), batch_size * heads)
assigns a unique program_id to each instance. This id is used to calculate memory offsets, ensuring each kernel instance operates on a distinct tile of the Q matrix while iterating through all necessary tiles of K and V.
Inside the kernel, the implementation hinges on two main optimizations that are fused into the forward pass.
Memory Hierarchy Management: A block of the Q matrix is first loaded from the high-latency HBM into the fast SRAM. The kernel then enters a loop over the sequence dimension, loading blocks of K and V in each iteration. This tiling strategy ensures the inner loop that performs the dot product of Q and K blocks operates almost entirely out of SRAM, maximizing memory throughput.
Online Softmax: Storing the full N×N attention matrix is avoided by implementing an online softmax. This is not just a simple running average but a numerically stable update rule derived from the log-sum-exp trick. For each new block, we compute the local attention scores and then find the new maximum logit and use it to rescale the previous running sum and accumulator before adding the new block. This update, when implemented elemnt-wise, helps maintain numerical stability and produces an output similar to a standard softmax without having to materialize teh intermediate attention matrix.
The final, normalized output block is only written back to HBM once the loop over K and V tiles is complete, drastically reducing the number of costly HBM I/O operations from O(N2) to O(N).
Benchmarking
To quantify the prowess of Flash Attention, I benchmarked its performance against the scaled dot product attention and multi headed attention mechanisms as a baseline. The goal was to measure the forward pass performance and observe how each of the 4 implementations scale with respect to increasing sequence lengths. We expect flash attention to outperform for longer sequence lengths. I ran the benchmark a single GPU across a range of sequence lengths from 1,024 to 16,384, while keeping the batch size, head count, and head dimension constant.
Batch: 1, Seq: 1024, Heads: 8, Head Dim: 64
naive_attention: 0.88ms
multi_head : 1.07ms
flash_v1 : 0.37ms
flash_v2 : 0.40ms
Batch: 1, Seq: 4096, Heads: 8, Head Dim: 64
naive_attention: 13.71ms
multi_head : 10.81ms
flash_v1 : 1.70ms
flash_v2 : 1.97ms
Batch: 1, Seq: 8192, Heads: 8, Head Dim: 64
naive_attention: 32.78ms
multi_head : 34.90ms
flash_v1 : 8.50ms
flash_v2 : 6.71ms
Batch: 1, Seq: 16384, Heads: 8, Head Dim: 64
naive_attention: 147.47ms
multi_head : 157.54ms
flash_v1 : 33.98ms
flash_v2 : 23.35ms
Fastest Implementation: flash_v2
Speedup vs Naive Attention:
flash_v2 : 6.01x faster
flash_v1 : 4.37x faster
multi_head : 0.95x faster
The results validate our intuition. We can evidently see the quadratic scaling in the naive attention implementations. As the sequence length doubles from 8,192 to 16,384, the runtime quadruples (32.78ms to 147.47ms), which is a symptom of an O(N2) operation bottlenecked by memory I/O. It's making too many trips to that distant HBM warehouse. On the other hand, both flash attention implementations show near-linear runtimes. By avoiding the materialization of the full attention matrix and keeping the computation in fast SRAM, my implementations are no longer bound by memory bandwidth in the same way and are instead compute-bound. On average, my flashv2 implementation was 6x faster than the naive baseline. At shorter sequence lengths, flashv1 is slightly faster than flashv2. This is likely due to the overhead of v2's more complex scheduling. However, as the sequence length grows past 4,096, flashv2 pulls ahead. This is empirical proof of v2's improved parallelization strategy. By allowing different query blocks to be processed in parallel, it achieves better GPU utilization on longer sequences, overcoming the sequential bottleneck of v1 and confirming its architectural improvements.
Code
The full implementation is on GitHub at https://github.com/arnav0811/flash-attention
Conclusion
Flash Attention represents a fundamental shift in how we think about computation under constraints. Instead of treating memory limitations as obstacles to overcome with brute force, it shows us how to work with those constraints to achieve better results. The transition from standard attention to FlashAttention mirrors the lesson from "The Magical Number Seven". The limitations of our own minds aren't bugs, they are features that force us to be efficient. FlashAttention applies this same philosophy to silicon, teaching our models to selectively focus and process information in a way that is more like the human brain.