Computing Attention via Monoid Aggregates
The softmax operation is at the core of the attention mechanism within the transformer architecture, from its inception to the state of the art LLMs today. Given a vector the softmax is defined by
While this formula is deceptively simple, using it efficiently in modern LLMs is not. Implemented naively, softmax struggles with numerical stability. This is especially a problem for modern LLMs, when we do not have much precision in our floating point numbers with which to work. The issue arises when the sum of the exponentials in the denominator becomes huge, reaching the upper limit of what the floating point format can represent.
The safe softmax algorithm fixes the numerical stability problems using the following trick. By applying some basic algebra, we see that we can compute the same value as the softmax when we add an arbitrary constant to every element:
By choosing and subtracting the maximum of the data from each data point, we ensure that and therefore . This fixes the issue of the exploding sum in the denominator, as it is now bounded by the number of data points:
But now we have introduced a different problem: Naive softmax takes just one traversal of the data, to compute the sum of the exponentials. Safe softmax appears to require two traversals: one to compute the maximum, and one to sum the adjusted exponentials.
LLMs are limited by the speed by which data can be transferred within the memory hierarchy. GPUs have a big and slow pool of memory (global memory) together with small but fast memory (shared memory) for every shader module. To compute anything on a set of data, the GPU has to transfer the data from global memory into the shared memory of the SM that performs the computation. Because the shared memory is not large enough to fit all of the data required by attention at once, the data needs to be streamed in chunks. But that means that, if we implemented safe softmax with two traversals, we would have to stream the data from global to shared memory twice. To make matters worse, once we have computed the softmax, we use it to weight the values in the attention mechanism, thus incurring a third traversal of the data.
FlashAttention [1] is an implementation of softmax attention that employs a clever trick to avoid these additional traversals. I found that this trick wasn’t immediately obvious to me when I first read that paper, so in this post I will discuss an alternative presentation that made it click for me. The paper describes how to compute both the forward and backward passes; here we will concentrate on the forward pass.
To get started, let us first think about how we would go about calculating a sum
for a vector whose dimension is sufficiently large that does not fit into shared memory all at once. As a first attempt, we can stream the elements of one by one while adding them to a running tally. That way, we only hold a small and constant amount of data in shared memory, regardless of how large the dimension gets. If we therefore evaluate the sum like this:
Streaming elements individually makes suboptimal use of the hardware. Memory is typically transferred in chunks at a time and shared memory will be large enough to fit multiple elements. Moreover, modern compute hardware has capabilities such as SIMD instructions and tensor cores that can perform an operation on multiple data points at once. For and a chunk size of elements, we would evaluate the sum as follows:
The reason this works is that addition is associative and unital. Associativity guarantees that no matter where we put the parentheses, we compute the same result (terms and conditions apply regarding to floating point math). Therefore computing the sum in chunks produces the same outcome as summing the elements one by one. Unitality provides a unit, also called a neutral element. For addition this is since for all . The unit allows us to pad our collections so that they are evenly divisible by the chunk size. This padding can be physical in that we add additional elements to the collection in memory. More often, the padding is performed through a masking step that avoids reading past the bounds of a buffer in memory and instead produces the neutral element for those locations.
These tricks aren’t restricted to addition of real numbers but work for any monoid: A monoid is a set together with a binary operation that is both associative and unital. Any computation that involves repeatedly applying the operation of a monoid to a collection of elements immediately inherits a bag of tricks.
- Split an aggregate into chunks.
- Compute multiple chunks in parallel or stream them one by one.
- Pad chunks to a regular size.
- Efficiently compute prefix sums.
- Incrementally update aggregates when inputs change.
The monoid that is relevant to FlashAttention allows us to compute a sum of vectors with exponential weights in a single traversal while retaining numerical stability. It uses the same idea of subtracting the maximum exponent as we explored for safe softmax earlier. However it avoids having to compute the maximum of the entire collection first via a clever trick that reweights the vectors when a new maximum element is discovered.
Definition: For any vector space we can equip the set
with a composition operation defined by
where . We write for the projection map .
Lemma: Let be a vector space. Then is an associative operation on and is the neutral element for . In particular is a monoid.
Theorem: Let be a vector space, , and . Then
Proof: By induction on . When both sides are . For the induction step, suppose the claim holds for some and let , . Using the induction hypothesis we can then calculate:
We can now apply this directly to the softmax of some vector . The numerator of the fraction defining the softmax is again vector in , obtained as the sum of the one-hot vectors weighted by . It therefore can be calculated via an aggregate in the monoid :
The denominator is a scalar and therefore computed via :
Putting these together we can cancel the factor from the fraction:
This trick extends to the fused attention kernel in which the weights for the softmax are derived via scalar products between keys and queries, and where the outputs of the softmax are used as the coefficients for a convex combination of values. Suppose , , and are the query, key, and value matrices. Then for any given position , (unmasked) attention can be calculated via
While this still looks like two traversals since we have aggregates in both the and monoids in the numerator and denominator, respectively, we can calculate this in one traversal by aggregating over the product monoid . By loading the data in chunks, we then have effectively rederived the forwards pass of FlashAttention.
Bibliography
- [1] T. Dao, D. Y. Fu, S. Ermon, A. Rudra, and C. Ré, “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness,” no. arXiv:2205.14135. arXiv, June 2022. doi: 10.48550/arXiv.2205.14135.