FlashAttention Kernel: Backward Pass (MATH)

FlashAttention
Transformers
Attention
Compute
Autograd
MATH
Author

Shivam Pandey

Published

March 30, 2025

Preliminary

  • \(A \in \mathbb{R}^{s_2 s_1}\) where \(s_2\) and \(s_1\) are the index set, for example in a 5D tensor \(A \in \mathbb{R}^{ijklm}\) a possible index set could be \(s_2 = \{ i, j \}\) and \(s_1 = \{k, l, m \}\).

  • \(B_{s_2 s_1} = flash_{s_1}(Q_{s_2 s_1}, K_{s_3 s_1}, V_{s_3 s_1})\) is a flash attention operation over s1 index set of tensor set \(<Q, K, V>\), and the resulting index set is the index set of \(Q\) s.t. \(B \in \mathbb{R}^{s_2 s_1}\)

Attention Operation

For this exercise we will simplify the target index sets to match the most common setup, i.e. \(Q_{s_2 s_1} \in \mathbb{R}^{[M \times d]}\), \(K_{s_3 s_1} \in \mathbb{R}^{[N \times d]}\), \(V_{s_3 s_1} \in \mathbb{R}^{[N \times d]}\)

Note: In this doc, whenever we have to denote the exact dimensions instead of index set it will be denoted as \(.^{[... \times ... \times \dots]}\) where \(\times\) symbol separates across different index sets.

Thus our operation becomes:

\[\begin{align} B_{s_2 s_1} = attention_{s_1}(Q_{s_2 s_1}, K_{s_3 s_1}, V_{s_3 s_1}) \\ B^{[M \times d]} = attention_{d}(Q^{[M \times d]}, K^{[N \times d]}, V^{[N \times d]}) \end{align}\]


\[\begin{align} S^{[M \times N]} = Q @ K^T = Q *_{(M \times d, N \times d, M \times N)} K \in \mathbb{R}^{[M \times N]} \\ S_{rmax}^{M} = \max_{N} S^{[M \times N]} \in \mathbb{R}^{M} \\ S_{rm}^{[m, n]} = S^{[m, n]} - S_{rmax}^{m} \forall [m, n] \in [M \times N] \\ P^{[M \times N]} = softmax_{N}(S_{rm}^{[M \times N]}) \\ O^{[M \times d]} = P @ V = P *_{(M \times N, N \times d, M \times d)} V \in \mathbb{R}^{[M \times d]} \end{align}\]

Note: Notation abuse -> \(O \iff B\)

For detailed Forward pass derivation please refer to my previous blog: FlashAttention Kernel: Forward Pass (MATH)

Backward (mode autodiff) Pass:

\[\begin{align} B_{s_2 s_1} \iff B^{[M \times d]} = attention_{d}(Q^{[M \times d]}, K^{[N \times d]}, V^{[N \times d]}) \end{align}\]

For a given loss value \(O_{s_3 = \phi}\) and known \(dB_{\phi s_2 s_1} = \frac{dB_{s_2 s_1}}{dO_{s_3 = \phi}}\) We need to find out \(dQ_{\phi s_2 s_1}\), \(dK_{\phi s_3 s_1}\), and \(dV_{\phi s_3 s_1}\).

Here we will directly differentiate the core attention operation without adjusting for numerical stability of exponent (we did so in forward pass to just make computation stable). Here we will first derive the core backward operations and then change it to computation, followed by mitigating any source of numerical instability.

\(dV_{\phi s_3 s_1}\):

Consider following op: \(B_{s_2 s_1} = \sum_{s_3} P_{s_2 s_3} \cdot V_{s_3 s_1} = P^{[M \times N]} @ V^{[N \times d]} \in \mathbb{R}^{[M \times d]}_{s_2 s_1}\)

\[\begin{align} dV_{\phi s'_3 s'_1} = \frac{\partial O_{\phi}}{\partial V_{s'_3 s'_1}} = \sum_{s_2 s_1} \frac{\partial O_{\phi}}{\partial B_{s_2 s_1}} \frac{\partial B_{s_2 s_1}}{\partial V_{s'_3 s'_1}} = \sum_{s_2 s_1} dB_{s_2 s_1} \frac{\partial B_{s_2 s_1}}{\partial V_{s'_3 s'_1}} \\ \frac{\partial B_{s_2 s_1}}{\partial V_{s'_3 s'_1}} = \sum_{s_3} P_{s_2 s_3} \mathbb{1}_{(s_3 s_1) = (s'_3 s'_1)} = P_{s_2 s'_3} \mathbb{1}_{s_1 = s'_1} \\ \sum_{s_2 s_1} dB_{s_2 s_1} \frac{\partial B_{s_2 s_1}}{\partial V_{s'_3 s'_1}} = \sum_{s_2 s_1} dB_{s_2 s_1} P_{s_2 s'_3} \mathbb{1}_{s_1 = s'_1} = \sum_{s_2} dB_{s_2 s'_1} P_{s_2 s'_3} \\ dV_{\phi s'_3 s'_1} = \sum_{s_2 s_1} dB_{s_2 s_1} \frac{\partial B_{s_2 s_1}}{\partial V_{s'_3 s'_1}} = \sum_{s_2} dB_{s_2 s'_1} P_{s_2 s'_3} = dB^T \cdot P \end{align}\]

Note: This also provides a crucial propetry of tensor differentiation i.e. for tensor product operation \(C_{s_2 s_3} = \sum_{s_1} A_{s_2 s_1} B_{s_1 s_3} = A \cdot B\) then for a given \(dC_{s_o s_2 s_3}\) the derivative \(dA_{s_o s'_2 s'_1} = \sum_{s_3} dC_{s_o s'_2 s_3} B_{s'_1 s_3}\) for a simple matmul i.e. \(s_o = \phi, s_3 \in \mathbb{R}, s_2 \in \mathbb{R}, s_1 \in \mathbb{R}\) this operation shrinks to simply \(dA = dC \cdot B^T\). Similarly for \(dB_{s_o s'_1 s'_3} = \sum_{s_2} dC_{s_o s_2 s'_3} A_{s_2 s'_1}\) for a simple matrix multiplication this would reduce to \(dB = dC^T \cdot A\)

\(dP_{\phi s_2 s_3}\):

From the formula derived previously \(dP_{\phi s'_2 s'_3} = \sum_{s_1} dB_{s'_2 s_1} V_{s'_3 s_1} = dB \cdot V^T\)

\(dQK^T_{\phi s_2 s_3}\)

Here we have encountered softmax operation as \(P_{s_2 s_3} = softmax_{s_3}(S_{s_2 s_3} = QK^T_{s_2 s_3})\) from the softmax blog we can

Direct Operation:

\[\begin{equation} B_{s_2 s_1} = softmax_{s_1}(A_{s_2 s_1}) \end{equation}\]

Here \(O_{s_3}\) is the final loss value for which we need to extract the derivatives.

\[\begin{align} \frac{\partial O_{s_3}}{\partial A_{s'_2 s'_1}} = B_{s'_2 s'_1} \left[dB_{s_3 s'_2 s'_1} - \sum_{s_1} dB_{s_3 s'_2 s_1} B_{s'_2 s_1}\right] \end{align}\]

From the formula, we can say:

\[\begin{align} \frac{\partial O_{\phi}}{\partial QK^T_{s'_2 s'_3}} = P_{s'_2 s'_3} \left[dP_{\phi s'_2 s'_3} - \sum_{s_3} dP_{\phi s'_2 s_3} P_{s'_2 s_3}\right] \end{align}\]

for a simple matmul:

\[\begin{align} \frac{\partial O_{\phi}}{\partial S_{i'j'}} = P_{i'j'} \left[dP_{i'j'} - dP_{i':}^T \circ P_{i':} \right] = P \times \left[dP - BMM(dP_{i'1j, P_{i'j1}}) \right] \end{align}\]

\(dQ_{\phi s_2 s_1}\) & \(dK_{\phi s_2 s_1}\):

\(S = QK^T\) and we know \(dS\) thus we can directly write the derivatives of both \(Q\) and \(K\).

\[\begin{align} dQ = dS \cdot K \\ dK = dS^T \cdot Q \end{align}\]

Final Backward Pass Equations:

\[\begin{align} dB \in \mathbb{R}^{[M \times D]}, \{Q, dQ\} \in \mathbb{R}^{[M \times D]}, \\ \{K, dK\} \in \mathbb{R}^{[N \times D]}, \{V, dV\} \in \mathbb{R}^{[N \times D]}\\ \end{align}\]


\[\begin{align} dV = dB^T \cdot P \\ dP = dB \cdot V^T \\ dS_{i'j'} = P_{i'j'} \left[dP_{i'j'} - dP_{i':}^T \circ P_{i':} \right] \\ dQ = dS \cdot K \\ dK = dS^T \cdot Q \end{align}\]

Expansion in dimensions:

\[\begin{align} S_{ij} = \sum_d q_{i d} k_{j d} \\ dV_{j d} = dB^T \cdot P = \sum_i dB_{i d} P_{i j} = \sum_i dB_{i d} \frac{\exp(S_{i j})}{L_i} \\ dP_{i j} = dB \cdot V^T = \sum_d dB_{i d} V_{j d} \\ dS_{i j} = P_{i j} \left[dP_{i j} - \sum_j dP_{i j} P_{i j} \right] \\ dQ_{i d} = dS \cdot K = \sum_{j} dS_{i j} K_{j d}\\ dK_{j d} = dS^T \cdot Q = \sum_{i} dS_{i j} Q_{i d} \end{align}\]

Abstract away D dimension:

In a future blog we will see that this is natural to do computation along D (embedding) dimension as all of the computations are independent of each other in this dimension.

\[\begin{align} S_{ij} = q_i \circ k_j \\ dV_j = \sum_i dB_{i d} \frac{\exp(q_i \circ k_j)}{L_i} \\ dP_{i j} = dB_i \circ V_j \\ dS_{i j} = P_{i j} \left[dP_{i j} - \sum_j dP_{i j} P_{i j} \right] \\ dQ_i = dS_i \circ K_j \\ dK_j = dS^T_j \circ Q_i \\ \end{align}\]


\[\begin{align} \sum_j dP_{i j} P_{i j} = \sum_j \big(\sum_d dB_{i d} V_{j d}\big) P_{i j} = \sum_j \sum_d dB_{i d} V_{j d} P_{i j} \\ = \sum_d \sum_j dB_{i d} V_{j d} P_{i j} = \sum_d dB_{i d} \sum_j V_{j d} P_{i j} = \sum_d dB_{i d} B_{i d} \end{align}\]

\[\begin{align} S_{ij} = q_i \circ k_j \\ dV_j = \sum_i dB_{i d} \frac{\exp(q_i \circ k_j)}{L_i} \\ dP_{i j} = dB_i \circ V_j \\ dS_{i j} = P_{i j} \left[dP_{i j} - dB_i \circ B_i \right] \\ dQ_i = dS_i \circ K_j \\ dK_j = dS^T_j \circ Q_i \\ \end{align}\]

Back to top