FlashAttention Kernel: Forward Pass (MATH)

FlashAttention
Transformers
Attention
Compute
Autograd
MATH
Author

Shivam Pandey

Published

March 30, 2025

Attention Operation

For this exercise we will simplify the target index sets to match the most common setup, i.e. \(Q_{s_2 s_1} \in \mathbb{R}^{[M \times d]}\), \(K_{s_3 s_1} \in \mathbb{R}^{[N \times d]}\), \(V_{s_3 s_1} \in \mathbb{R}^{[N \times d]}\)

Note: In this doc, whenever we have to denote the exact dimensions instead of index set it will be denoted as \(.^{[... \times ... \times \dots]}\) where \(\times\) symbol separates across different index sets.

Thus our operation becomes:

\[\begin{align} B_{s_2 s_1} = attention_{s_1}(Q_{s_2 s_1}, K_{s_3 s_1}, V_{s_3 s_1}) \\ B^{[M \times d]} = attention_{d}(Q^{[M \times d]}, K^{[N \times d]}, V^{[N \times d]}) \end{align}\]


\[\begin{align} S^{[M \times N]} = Q @ K^T = Q *_{(M \times d, N \times d, M \times N)} K \in \mathbb{R}^{[M \times N]} \\ S_{rmax}^{M} = \max_{N} S^{[M \times N]} \in \mathbb{R}^{M} \\ S_{rm}^{[m, n]} = S^{[m, n]} - S_{rmax}^{m} \forall [m, n] \in [M \times N] \\ P^{[M \times N]} = softmax_{N}(S_{rm}^{[M \times N]}) \\ O^{[M \times d]} = P @ V = P *_{(M \times N, N \times d, M \times d)} V \in \mathbb{R}^{[M \times d]} \end{align}\]

Note: Notation abuse -> \(O \iff B\)

Simplifying Forward Pass

\[\begin{align} P^{[m, n]} = \frac{\exp(S^{[m, n]} - \max_{n}(S^{[m, n]}))}{\sum_n \exp(S^{[m, n]} - \max_{n}(S^{[m, n]}))} = \frac{\exp(S^{[m, n]})}{\sum_n \exp(S^{[m, n]})} \end{align}\]

Note: The independence over M in softmax the only aggregation is required over N dimension

Computation in chunk

Here we will first look at what is reuired to generate the output for a single query i.e. \(O^{[m, d]}\)

\[\begin{align} O^{[m, d]} = P_m *_{(N, N \times d, d)} V = \sum_n P^{[m, n]} \cdot V^{[n, d]} \\ P^{[m, n]} = softmax_{n}(S^{[m, n]}) = \frac{\exp(S^{[m, n]})}{\sum_n \exp(S^{[m, n]})} \\ O^{[m, d]} = \sum_n P^{[m, n]} \cdot V^{[n, d]} = \sum_n \frac{\exp(S^{[m, n]})}{\sum_n \exp(S^{[m, n]})} \cdot V^{[n, d]} \\ = \frac{1}{\sum_n \exp(S^{[m, n]})}\sum_n \exp(S^{[m, n]}) \cdot V^{[n, d]} \end{align}\]

We want to process \(O^{[m, d]} = \sum_N \dots\) over \(n\) sequentially to avoid whole sequence loading.

for the sequence just processed till \(n = j\) we can write: \[\begin{align} O^{[m, d]}_j = \frac{1}{\sum_{n\in[0 \dots j]} \exp(S^{[m, n]})}\sum_{n\in[0 \dots j]} \exp(S^{[m, n]}) \cdot V^{[n, d]} = \frac{1}{l_j} u_j \end{align}\]

Let’s say we proceed by a single setp \(n = j+1\): \[\begin{align} O^{[m, d]}_{j+1} = \frac{1}{\sum_{n\in[0 \dots j, j+1]} \exp(S^{[m, n]})}\sum_{n\in[0 \dots j, j+1]} \exp(S^{[m, n]}) \cdot V^{[n, d]}\\ = \frac{1}{l_j + \exp(S^{[m, n=j+1]})} (u_j + \exp(S^{[m, n=j+1]}) \cdot V^{[n=j+1, d]}) \\ = \frac{O^{[m, d]}_{j} * l_j}{l_{j+1}} + \frac{\exp(S^{[m, n=j+1]}) \cdot V^{[n=j+1, d]}}{l_{j+1}} \end{align}\]

Thus we can compute the output simply by iterating over the \(N\) dimension for \(O^{m, d}\) the final expression

\[\begin{align} O^{[m, d]}_{0} = \frac{\exp(S^{[m, n=0]}) \cdot V^{[n=0, d]}}{l_{0}}\\ O^{[m, d]}_{j+1} = \frac{O^{[m, d]}_{j} * l_j}{l_{j+1}} + \frac{\exp(S^{[m, n=j+1]}) \cdot V^{[n=j+1, d]}}{l_{j+1}} \\ l_0 = exp(S^{[m, n=0]}) \\ l_{j+1} = l_j + exp(S^{[m, n=j+1]}) \end{align}\]

WTF: \(\exp\) can explode coz of high multiplication values

max operation is used for numerical stability of the softmax especially keeping exp from exploding.

Here next we will try to encorporate this stabilization technique in the above derived framework.

\[\begin{align} m_0 = S^{[m, n=0]} \\ m_{j+1} = \max(m_j, S^{[m, n=j+1]}) \\ l_0 = \exp(S^{[m, n=0]} - m_0) \\ l_{j+1} = \sum_{n\in[0 \dots j, j+1]}\exp(S^{[m, n]} - m_{j+1}) \\ l_{j+1} = \frac{exp(-m_{j})}{\exp(m_{j+1} - m_{j})}\sum_{n\in[0 \dots j]}\exp(S^{[m, n]}) + \exp(S^{[m, n=j+1]} - m_{j+1}) \\ l_{j+1} = l_{j}\exp(m_{j} - m_{j+1}) + \exp(S^{[m, n=j+1]} - m_{j+1}) \\ O^{[m, d]}_{0} = \frac{\exp(S^{[m, n=0]} - m_{0}) \cdot V^{[n=0, d]}}{l_{0}} \\ O^{[m, d]}_{j+1} = \frac{O^{[m, d]}_{j} * l_j}{l_{j+1}} + \frac{\exp(S^{[m, n=j+1]} - m_{j+1}) \cdot V^{[n=j+1, d]}}{l_{j+1}} \\ \end{align}\]

Back to top