FlashAttention Kernel: Forward Pass (MATH)
\(A \in \mathbb{R}^{s_2 s_1}\) where \(s_2\) and \(s_1\) are the index set, for example in a 5D tensor \(A \in \mathbb{R}^{ijklm}\) a possible index set could be \(s_2 = \{ i, j \}\) and \(s_1 = \{k, l, m \}\).
\(B_{s_2 s_1} = flash_{s_1}(Q_{s_2 s_1}, K_{s_3 s_1}, V_{s_3 s_1})\) is a flash attention operation over s1 index set of tensor set \(<Q, K, V>\), and the resulting index set is the index set of \(Q\) s.t. \(B \in \mathbb{R}^{s_2 s_1}\)
Attention Operation
For this exercise we will simplify the target index sets to match the most common setup, i.e. \(Q_{s_2 s_1} \in \mathbb{R}^{[M \times d]}\), \(K_{s_3 s_1} \in \mathbb{R}^{[N \times d]}\), \(V_{s_3 s_1} \in \mathbb{R}^{[N \times d]}\)
Note: In this doc, whenever we have to denote the exact dimensions instead of index set it will be denoted as \(.^{[... \times ... \times \dots]}\) where \(\times\) symbol separates across different index sets.
Thus our operation becomes:
\[\begin{align} B_{s_2 s_1} = attention_{s_1}(Q_{s_2 s_1}, K_{s_3 s_1}, V_{s_3 s_1}) \\ B^{[M \times d]} = attention_{d}(Q^{[M \times d]}, K^{[N \times d]}, V^{[N \times d]}) \end{align}\]
\[\begin{align} S^{[M \times N]} = Q @ K^T = Q *_{(M \times d, N \times d, M \times N)} K \in \mathbb{R}^{[M \times N]} \\ S_{rmax}^{M} = \max_{N} S^{[M \times N]} \in \mathbb{R}^{M} \\ S_{rm}^{[m, n]} = S^{[m, n]} - S_{rmax}^{m} \forall [m, n] \in [M \times N] \\ P^{[M \times N]} = softmax_{N}(S_{rm}^{[M \times N]}) \\ O^{[M \times d]} = P @ V = P *_{(M \times N, N \times d, M \times d)} V \in \mathbb{R}^{[M \times d]} \end{align}\]
Note: Notation abuse -> \(O \iff B\)
Simplifying Forward Pass
\[\begin{align} P^{[m, n]} = \frac{\exp(S^{[m, n]} - \max_{n}(S^{[m, n]}))}{\sum_n \exp(S^{[m, n]} - \max_{n}(S^{[m, n]}))} = \frac{\exp(S^{[m, n]})}{\sum_n \exp(S^{[m, n]})} \end{align}\]
Note: The independence over M in softmax the only aggregation is required over N dimension
Computation in chunk
Here we will first look at what is reuired to generate the output for a single query i.e. \(O^{[m, d]}\)
\[\begin{align} O^{[m, d]} = P_m *_{(N, N \times d, d)} V = \sum_n P^{[m, n]} \cdot V^{[n, d]} \\ P^{[m, n]} = softmax_{n}(S^{[m, n]}) = \frac{\exp(S^{[m, n]})}{\sum_n \exp(S^{[m, n]})} \\ O^{[m, d]} = \sum_n P^{[m, n]} \cdot V^{[n, d]} = \sum_n \frac{\exp(S^{[m, n]})}{\sum_n \exp(S^{[m, n]})} \cdot V^{[n, d]} \\ = \frac{1}{\sum_n \exp(S^{[m, n]})}\sum_n \exp(S^{[m, n]}) \cdot V^{[n, d]} \end{align}\]
We want to process \(O^{[m, d]} = \sum_N \dots\) over \(n\) sequentially to avoid whole sequence loading.
for the sequence just processed till \(n = j\) we can write: \[\begin{align} O^{[m, d]}_j = \frac{1}{\sum_{n\in[0 \dots j]} \exp(S^{[m, n]})}\sum_{n\in[0 \dots j]} \exp(S^{[m, n]}) \cdot V^{[n, d]} = \frac{1}{l_j} u_j \end{align}\]
Let’s say we proceed by a single setp \(n = j+1\): \[\begin{align} O^{[m, d]}_{j+1} = \frac{1}{\sum_{n\in[0 \dots j, j+1]} \exp(S^{[m, n]})}\sum_{n\in[0 \dots j, j+1]} \exp(S^{[m, n]}) \cdot V^{[n, d]}\\ = \frac{1}{l_j + \exp(S^{[m, n=j+1]})} (u_j + \exp(S^{[m, n=j+1]}) \cdot V^{[n=j+1, d]}) \\ = \frac{O^{[m, d]}_{j} * l_j}{l_{j+1}} + \frac{\exp(S^{[m, n=j+1]}) \cdot V^{[n=j+1, d]}}{l_{j+1}} \end{align}\]
Thus we can compute the output simply by iterating over the \(N\) dimension for \(O^{m, d}\) the final expression
\[\begin{align} O^{[m, d]}_{0} = \frac{\exp(S^{[m, n=0]}) \cdot V^{[n=0, d]}}{l_{0}}\\ O^{[m, d]}_{j+1} = \frac{O^{[m, d]}_{j} * l_j}{l_{j+1}} + \frac{\exp(S^{[m, n=j+1]}) \cdot V^{[n=j+1, d]}}{l_{j+1}} \\ l_0 = exp(S^{[m, n=0]}) \\ l_{j+1} = l_j + exp(S^{[m, n=j+1]}) \end{align}\]
WTF: \(\exp\) can explode coz of high multiplication values
max operation is used for numerical stability of the softmax especially keeping exp from exploding.
Here next we will try to encorporate this stabilization technique in the above derived framework.
\[\begin{align} m_0 = S^{[m, n=0]} \\ m_{j+1} = \max(m_j, S^{[m, n=j+1]}) \\ l_0 = \exp(S^{[m, n=0]} - m_0) \\ l_{j+1} = \sum_{n\in[0 \dots j, j+1]}\exp(S^{[m, n]} - m_{j+1}) \\ l_{j+1} = \frac{exp(-m_{j})}{\exp(m_{j+1} - m_{j})}\sum_{n\in[0 \dots j]}\exp(S^{[m, n]}) + \exp(S^{[m, n=j+1]} - m_{j+1}) \\ l_{j+1} = l_{j}\exp(m_{j} - m_{j+1}) + \exp(S^{[m, n=j+1]} - m_{j+1}) \\ O^{[m, d]}_{0} = \frac{\exp(S^{[m, n=0]} - m_{0}) \cdot V^{[n=0, d]}}{l_{0}} \\ O^{[m, d]}_{j+1} = \frac{O^{[m, d]}_{j} * l_j}{l_{j+1}} + \frac{\exp(S^{[m, n=j+1]} - m_{j+1}) \cdot V^{[n=j+1, d]}}{l_{j+1}} \\ \end{align}\]