Computing Attention via Monoid Aggregates


The softmax operation is at the core of the attention mechanism within the transformer architecture, from its inception to the state of the art LLMs today. Given a vector xRdx \in \mathbb{R}^d, the softmax is defined by

softmax(x)i:=exij=1dexj.\operatorname{softmax}(x)_i := \frac{e^{x_i}}{\sum_{j=1}^d e^{x_j}}.

While this formula is deceptively simple, using it efficiently in modern LLMs is not. Implemented naively, softmax struggles with numerical stability. This is especially a problem for modern LLMs, when we do not have much precision in our floating point numbers with which to work. The issue arises when the sum of the exponentials in the denominator becomes huge, reaching the upper limit of what the floating point format can represent.

The safe softmax algorithm fixes the numerical stability problems using the following trick. By applying some basic algebra, we see that we can compute the same value as the softmax when we add an arbitrary constant cc to every element:

exi+cj=1dexj+c=exiecj=1dexjec=ececexij=1dexj=exij=1dexj.\frac{e^{x_i + c}}{\sum_{j=1}^d e^{x_j + c}} = \frac{e^{x_i} e^c}{\sum_{j=1}^d e^{x_j} e^c} = \frac{e^c}{e^c} \frac{e^{x_i}}{\sum_{j=1}^d e^{x_j}} = \frac{e^{x_i}}{\sum_{j=1}^d e^{x_j}}.

By choosing c:=max(x)c := -\max(x) and subtracting the maximum of the data from each data point, we ensure that xi+c0x_i + c \le 0 and therefore 0<exj+c10 < e^{x_j + c} \le 1. This fixes the issue of the exploding sum in the denominator, as it is now bounded by the number of data points:

0<j=1dexj+cd.0 < \sum_{j=1}^d e^{x_j + c} \le d.

But now we have introduced a different problem: naive softmax takes just one traversal of the data, to compute the sum of the exponentials. Safe softmax appears to require two traversals: one to compute the maximum, and one to sum the adjusted exponentials.

LLMs are limited by the speed by which data can be transferred within the memory hierarchy. GPUs have a big and slow pool of memory, global memory, together with small but fast memory, shared memory, for every shader module. To compute anything on a set of data, the GPU has to transfer the data from global memory into the shared memory of the SM that performs the computation. Because the shared memory is not large enough to fit all of the data required by attention at once, the data needs to be streamed in chunks. But that means that, if we implemented safe softmax with two traversals, we would have to stream the data from global to shared memory twice. To make matters worse, once we have computed the softmax, we use it to weight the values in the attention mechanism, thus incurring a third traversal of the data.

FlashAttention 1 is an implementation of softmax attention that employs a clever trick to avoid these additional traversals. I found that this trick wasn’t immediately obvious to me when I first read that paper, so in this post I will discuss an alternative presentation that made it click for me. The paper describes how to compute both the forward and backward passes; here we will concentrate on the forward pass.

To get started, let us first think about how we would go about calculating a sum

i=1dxi\sum_{i=1}^d x_i

for a vector xRdx \in \mathbb{R}^d whose dimension dd is sufficiently large that xx does not fit into shared memory all at once. As a first attempt, we can stream the elements of xx one by one while adding them to a running tally. That way, we only hold a small and constant amount of data in shared memory, regardless of how large the dimension dd gets. If d=10d = 10 we therefore evaluate the sum like this:

i=110xi=((((((((x1+x2)+x3)+x4)+x5)+x6)+x7)+x8)+x9)+x10.\sum_{i=1}^{10} x_i = ((((((((x_1 + x_2) + x_3) + x_4) + x_5) + x_6) + x_7) + x_8) + x_9) + x_{10}.

Streaming elements individually makes suboptimal use of the hardware. Memory is typically transferred in chunks at a time and shared memory will be large enough to fit multiple elements. Moreover, modern compute hardware has capabilities such as SIMD instructions and tensor cores that can perform an operation on multiple data points at once. For d=10d = 10 and a chunk size of 44 elements, we would evaluate the sum as follows:

i=110xi=((x1+x2+x3+x4)+(x5+x6+x7+x8))+(x9+x10+0+0).\sum_{i=1}^{10} x_i = \bigl((x_1 + x_2 + x_3 + x_4) + (x_5 + x_6 + x_7 + x_8)\bigr) + (x_9 + x_{10} + 0 + 0).

The reason this works is that addition is associative and unital. Associativity guarantees that no matter where we put the parentheses, we compute the same result, terms and conditions apply regarding floating point math. Therefore computing the sum in chunks produces the same outcome as summing the elements one by one. Unitality provides a unit, also called a neutral element. For addition this is 00 since x+0=x=0+xx + 0 = x = 0 + x for all xRx \in \mathbb{R}. The unit allows us to pad our collections so that they are evenly divisible by the chunk size. This padding can be physical in that we add additional elements to the collection in memory. More often, the padding is performed through a masking step that avoids reading past the bounds of a buffer in memory and instead produces the neutral element for those locations.

These tricks aren’t restricted to addition of real numbers but work for any monoid. A monoid is a set MM together with a binary operation :M×MM\cdot : M \times M \to M that is both associative and unital. Any computation that involves repeatedly applying the operation of a monoid to a collection of elements immediately inherits a bag of tricks.

  • Split an aggregate into chunks.
  • Compute multiple chunks in parallel or stream them one by one.
  • Pad chunks to a regular size.
  • Efficiently compute prefix sums.
  • Incrementally update aggregates when inputs change.

The monoid that is relevant to FlashAttention allows us to compute a sum of vectors with exponential weights in a single traversal while retaining numerical stability. It uses the same idea of subtracting the maximum exponent as we explored for safe softmax earlier. However it avoids having to compute the maximum of the entire collection first via a clever trick that reweights the vectors when a new maximum element is discovered.

Definition

For any vector space VV we can equip the set

SE(V):=V×(R{})\operatorname{SE}(V) := V \times \bigl(\mathbb{R} \cup \{-\infty\}\bigr)

with a composition operation \cdot defined by

(v1,m1)(v2,m2)=(em1mv1+em2mv2,  m)(v_1, m_1) \cdot (v_2, m_2) = \bigl(e^{m_1 - m} v_1 + e^{m_2 - m} v_2,\; m\bigr)

where m=max(m1,m2)m = \max(m_1, m_2). We write

π:SE(V)V\pi : \operatorname{SE}(V) \to V

for the projection map π(v,m)=v\pi(v, m) = v.

Lemma

Let VV be a vector space. Then \cdot is an associative operation on SE(V)\operatorname{SE}(V) and (0,)(0, -\infty) is the neutral element for \cdot. In particular (SE(V),,(0,))\bigl(\operatorname{SE}(V), \cdot, (0, -\infty)\bigr) is a monoid.

Theorem

Let VV be a vector space, v1,,vdVv_1, \dots, v_d \in V, and x1,,xdRx_1, \dots, x_d \in \mathbb{R}. Then

i=1d(vi,xi)=(emax(x)i=1dexivi,  max(x)),\bigodot_{i=1}^d (v_i, x_i) = \left( e^{-\max(x)} \sum_{i=1}^d e^{x_i} v_i,\; \max(x) \right),

where max(x):=max(x1,,xd)\max(x) := \max(x_1, \dots, x_d).

Proof

By induction on d0d \ge 0. When d=0d = 0 both sides are (0,)(0, -\infty). For the induction step, suppose the claim holds for some d0d \ge 0 and let vd+1Vv_{d+1} \in V, xd+1Rx_{d+1} \in \mathbb{R}. Using the induction hypothesis we can then calculate:

i=1d+1(vi,xi)=(i=1d(vi,xi))(vd+1,xd+1)=(emax(x1:d)i=1dexivi,  max(x1:d))(vd+1,xd+1)=(emax(x1:d)max(x)max(x1:d)i=1dexivi+exd+1max(x)vd+1,  max(x))=(emax(x)i=1d+1exivi,  max(x)).\begin{aligned} \bigodot_{i=1}^{d+1} (v_i, x_i) &= \left(\bigodot_{i=1}^{d} (v_i, x_i)\right) \cdot (v_{d+1}, x_{d+1}) \\ &= \left(e^{-\max(x_{1:d})} \sum_{i=1}^{d} e^{x_i} v_i,\; \max(x_{1:d})\right) \cdot (v_{d+1}, x_{d+1}) \\ &= \left( e^{\max(x_{1:d}) - \max(x) - \max(x_{1:d})} \sum_{i=1}^{d} e^{x_i} v_i + e^{x_{d+1} - \max(x)} v_{d+1}, \; \max(x) \right) \\ &= \left( e^{-\max(x)} \sum_{i=1}^{d+1} e^{x_i} v_i, \; \max(x) \right). \end{aligned}

We can now apply this directly to the softmax of some vector xRdx \in \mathbb{R}^d. The numerator of the fraction defining the softmax is again a vector in Rd\mathbb{R}^d, obtained as the sum of the one-hot vectors uiRdu_i \in \mathbb{R}^d weighted by exie^{x_i}. It therefore can be calculated via an aggregate in the monoid SE(Rd)\operatorname{SE}(\mathbb{R}^d):

i=1dexiui=emax(x)  π(i=1d(ui,xi)).\sum_{i=1}^d e^{x_i} u_i = e^{\max(x)} \; \pi\left(\bigodot_{i=1}^d (u_i, x_i)\right).

The denominator is a scalar and therefore computed via SE(R)\operatorname{SE}(\mathbb{R}):

i=1dexi=emax(x)  π(i=1d(1,xi)).\sum_{i=1}^d e^{x_i} = e^{\max(x)} \; \pi\left(\bigodot_{i=1}^d (1, x_i)\right).

Putting these together we can cancel the emax(x)e^{\max(x)} factor from the fraction:

softmax(x)=i=1dexiuii=1dexi=emax(x)  π(i=1d(ui,xi))emax(x)  π(i=1d(1,xi))=π(i=1d(ui,xi))π(i=1d(1,xi)).\operatorname{softmax}(x) = \frac{\sum_{i=1}^d e^{x_i} u_i}{\sum_{i=1}^d e^{x_i}} = \frac{ e^{\max(x)} \; \pi\left(\bigodot_{i=1}^d (u_i, x_i)\right) }{ e^{\max(x)} \; \pi\left(\bigodot_{i=1}^d (1, x_i)\right) } = \frac{ \pi\left(\bigodot_{i=1}^d (u_i, x_i)\right) }{ \pi\left(\bigodot_{i=1}^d (1, x_i)\right) }.

This trick extends to the fused attention kernel in which the weights for the softmax are derived via scalar products between keys and queries, and where the outputs of the softmax are used as the coefficients for a convex combination of values. Suppose QRL×dkQ \in \mathbb{R}^{L \times d_k}, KRL×dkK \in \mathbb{R}^{L \times d_k}, and VRL×dvV \in \mathbb{R}^{L \times d_v} are the query, key, and value matrices. Then for any given position 1qL1 \le q \le L, unmasked attention can be calculated via

attention(Q,K,V)(q,:)=π(i=1L(V(i,:),Q(q,:)K(i,:)dk))π(i=1L(1,Q(q,:)K(i,:)dk)).\operatorname{attention}(Q, K, V)_{(q, :)} = \frac{ \pi\left( \bigodot_{i=1}^{L} \left( V_{(i, :)}, \frac{Q_{(q, :)}^\top K_{(i, :)}}{\sqrt{d_k}} \right) \right) }{ \pi\left( \bigodot_{i=1}^{L} \left( 1, \frac{Q_{(q, :)}^\top K_{(i, :)}}{\sqrt{d_k}} \right) \right) }.

While this still looks like two traversals since we have aggregates in both the SE(Rdv)\operatorname{SE}(\mathbb{R}^{d_v}) and SE(R)\operatorname{SE}(\mathbb{R}) monoids in the numerator and denominator, respectively, we can calculate this in one traversal by aggregating over the product monoid SE(Rdv)×SE(R)\operatorname{SE}(\mathbb{R}^{d_v}) \times \operatorname{SE}(\mathbb{R}). By loading the data in chunks, we then have effectively rederived the forward pass of FlashAttention.

Footnotes

  1. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”, Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré