SoftMax Kernel: Forward and Backward Pass (MATH)
\(A \in \mathbb{R}^{s_2 s_1}\) where \(s_2\) and \(s_1\) are the index set, for example in a 5D tensor \(A \in \mathbb{R}^{ijklm}\) a possible index set could be \(s_2 = \{ i, j \}\) and \(s_1 = \{k, l, m \}\).
\(B_{s_2 s_1} = softmax_{s_1}(A_{s_2 s_1})\) is a softmax operation over s1 index set of A, and the resulting index set still remains same as A s.t. \(B \in \mathbb{R}^{s2s1}\)
Softmax Operation
\[\begin{equation} B_{s_2 s_1} = softmax_{s_1}(A_{s_2 s_1}) \end{equation}\]
Intermediate Result: #1
\[\begin{equation} I^1_{s_2 s_1} = \exp{(A_{s_2 s_1})} \end{equation}\]
Intermediate Result: #2
\[\begin{equation} I^2_{s_2} = \sum_{s1}I^1_{s_2 s_1} \end{equation}\]
Softmax:
\[\begin{equation} B_{s_2 s_1} = \frac{I^1_{s_2 s_1}}{I^2_{s_2}} = \frac{\exp{(A_{s_2 s_1})}}{\sum_{s1}\exp{(A_{s_2 s_1})}} \end{equation}\]
Backward Pass:
Here we will just look at the backward pass of the softmax kernel alone, as it will help us understand a much wider concept of having a multidimensional loss function instead of a scalar loss.
Fortunately this also simplifies the problem for us as we won’t have to account for any pullbacks for the output \(B_{s_2 s_1}\) itself.
\[\begin{equation} \frac{\partial B_{s_2 s_1}}{\partial A_{s'_2 s'_1}} \in \mathbb{R}^{s_2 s_1 s'_2 s'_1} \end{equation}\]
Derivation:
\[\begin{align} \frac{\partial B_{s_2 s_1}}{\partial A_{s'_2 s'_1}} = \frac{1}{\sum_{s_1}\exp{(A_{s_2 s_1})}} \frac{\partial \exp{(A_{s_2 s_1})}}{\partial A_{s'_2 s'_1}} + \exp(A_{s_2 s_1})\frac{\partial \sum_{s_1} 1/exp(A_{s_2 s_1})}{\partial A_{s'_2 s'_1}} \end{align}\]
\[\begin{align} \frac{1}{\sum_{s_1}\exp{(A_{s_2 s_1})}} \frac{\partial \exp{(A_{s_2 s_1})}}{\partial A_{s'_2 s'_1}} = \frac{\exp(A_{s_2 s_1})}{\sum_{s_1}\exp{(A_{s_2 s_1})}} \frac{A_{s_2 s_1}}{A_{s'_2 s'_1}} = \frac{\exp(A_{s_2 s_1})}{\sum_{s_1}\exp{(A_{s_2 s_1})}} \mathbb{1}_{(s_2 s_1) = (s'_2 s'_1)} \end{align}\]
\[\begin{equation} \exp(A_{s_2 s_1})\frac{\partial \sum_{s_1} 1/exp(A_{s_2 s_1})}{\partial A_{s'_2 s'_1}} = -\frac{\exp(A_{s_2 s_1})}{[\sum_{s_1} exp(A_{s_2 s_1})]^2} \frac{\partial \sum_{s_1} \exp(A_{s_2 s_1})}{\partial A_{s'_2 s'_1}} \end{equation}\]
\[\begin{align} \frac{\partial \sum_{s_1} \exp(A_{s_2 s_1})}{\partial A_{s'_2 s'_1}} = \sum_{s_1} \frac{\partial \exp(A_{s_2 s_1})}{\partial A_{s'_2 s'_1}} = \sum_{s_1} \exp(A_{s_2 s_1}) \frac{\partial A_{s_2 s_1}}{\partial A_{s'_2 s'_1}} \\ = \sum_{s_1} \exp(A_{s_2 s_1}) \mathbb{1}_{(s_2 s_1) = (s'_2 s'_1)} = \exp(A_{s_2 s'_1}) \mathbb{1}_{s_2 = s'_2} \end{align}\]
Final Derivative Simplification:
\[\begin{align} \frac{\partial B_{s_2 s_1}}{\partial A_{s'_2 s'_1}} = \frac{\exp(A_{s_2 s_1})}{\sum_{s_1}\exp{(A_{s_2 s_1})}} \mathbb{1}_{(s_2 s_1) = (s'_2 s'_1)} - \frac{\exp(A_{s_2 s_1})}{[\sum_{s_1} exp(A_{s_2 s_1})]^2} \exp(A_{s_2 s'_1}) \mathbb{1}_{s_2 = s'_2} \\ = \frac{\exp(A_{s_2 s_1})}{\sum_{s_1}\exp{(A_{s_2 s_1})}} \mathbb{1}_{(s_2 s_1) = (s'_2 s'_1)} - \frac{\exp(A_{s_2 s_1}) \exp(A_{s_2 s'_1})}{[\sum_{s_1} exp(A_{s_2 s_1})]^2} \mathbb{1}_{s_2 = s'_2} \\ = B_{s_2 s_1} \mathbb{1}_{(s_2 s_1) = (s'_2 s'_1)} - B_{s_2 s_1}B_{s_2 s'_1} \mathbb{1}_{s_2 = s'_2} \end{align}\]
\[\begin{align} \frac{\partial B_{s_2 s_1}}{\partial A_{s'_2 s'_1}} = B_{s_2 s_1} \mathbb{1}_{(s_2 s_1) = (s'_2 s'_1)} - B_{s_2 s_1}B_{s_2 s'_1} \mathbb{1}_{s_2 = s'_2} \end{align}\]
Full reduction: With Loss Drivative
Let’s assume the output \(B_{s_2 s_1}\) being used by some frisky function to generate loss \(O_{s_3}\) and somehow we have the pullback of \(B\) w.r.t. \(O\) as \(dB_{s_3 s_2 s_1} = \frac{\partial O_{s_3}}{\partial B_{s_2 s_1}}\), and now we are interested in find out what’s the pullback of \(A_{s'_2 s'_1}\) w.r.t. \(O_{s_3}\) i.e. \(\frac{\partial O_{s_3}}{\partial A_{s'_2 s'_1}}\).
\[\begin{align} \frac{\partial O_{s_3}}{\partial A_{s'_2 s'_1}} = \sum_{s_2 s_1} \frac{\partial O_{s_3}}{\partial B_{s_2 s_1}}\frac{\partial B_{s_2 s_1}}{\partial A_{s'_2 s'_1}} \\ = \sum_{s_2 s_1} dB_{s_3 s_2 s_1} \left[ \frac{\exp(A_{s_2 s_1})}{\sum_{s_1}\exp{(A_{s_2 s_1})}} \mathbb{1}_{(s_2 s_1) = (s'_2 s'_1)} - \frac{\exp(A_{s_2 s_1}) \exp(A_{s_2 s'_1})}{[\sum_{s_1} exp(A_{s_2 s_1})]^2} \mathbb{1}_{s_2 = s'_2} \right] \\ = dB_{s_3 s'_2 s'_1} \frac{\exp(A_{s'_2 s'_1})}{\sum_{s_1}\exp{(A_{s'_2 s_1})}} - \sum_{s_2 s_1} dB_{s_3 s_2 s_1} \frac{\exp(A_{s_2 s_1}) \exp(A_{s_2 s'_1})}{[\sum_{s_1} exp(A_{s_2 s_1})]^2} \mathbb{1}_{s_2 = s'_2} \end{align}\]
\[\begin{align} \sum_{s_2 s_1} dB_{s_3 s_2 s_1} \frac{\exp(A_{s_2 s_1}) \exp(A_{s_2 s'_1})}{[\sum_{s_1} exp(A_{s_2 s_1})]^2} \mathbb{1}_{s_2 = s'_2} = \sum_{s_1} dB_{s_3 s'_2 s_1} \frac{\exp(A_{s'_2 s_1}) \exp(A_{s'_2 s'_1})}{[\sum_{s_1} exp(A_{s'_2 s_1})]^2} \\ = \frac{\exp(A_{s'_2 s'_1})}{[\sum_{s_1} exp(A_{s'_2 s_1})]^2} \sum_{s_1} [dB_{s_3 s'_2 s_1} \exp(A_{s'_2 s_1})] \end{align}\]
\[\begin{align} \frac{\partial O_{s_3}}{\partial A_{s'_2 s'_1}} = dB_{s_3 s'_2 s'_1} \frac{\exp(A_{s'_2 s'_1})}{\sum_{s_1}\exp{(A_{s'_2 s_1})}} - \frac{\exp(A_{s'_2 s'_1})}{[\sum_{s_1} exp(A_{s'_2 s_1})]^2} \sum_{s_1} [dB_{s_3 s'_2 s_1} \exp(A_{s'_2 s_1})] \\ = dB_{s_3 s'_2 s'_1} B_{s'_2 s'_1} - B_{s'_2 s'_1}\frac{1}{\sum_{s_1} exp(A_{s'_2 s_1})} \sum_{s_1} [dB_{s_3 s'_2 s_1} \exp(A_{s'_2 s_1})] \\ = dB_{s_3 s'_2 s'_1} B_{s'_2 s'_1} - B_{s'_2 s'_1} \sum_{s_1} [dB_{s_3 s'_2 s_1} B_{s'_2 s_1}] = B_{s'_2 s'_1} [dB_{s_3 s'_2 s'_1} - \sum_{s_1} [dB_{s_3 s'_2 s_1} B_{s'_2 s_1}]] \end{align}\]
\[\begin{align} \frac{\partial O_{s_3}}{\partial A_{s'_2 s'_1}} = B_{s'_2 s'_1} \left[dB_{s_3 s'_2 s'_1} - \sum_{s_1} dB_{s_3 s'_2 s_1} B_{s'_2 s_1}\right] \end{align}\]
Application in Attention: Specialize the Expressions
Setup:
\[\begin{align} P_{ij} = softmax(S_{ij}) \\ S_{ij} \in \mathbb{R}^{M \times N} \\ P_{ij} \in \mathbb{R}^{M \times N} \\ O_{\phi} \in \mathbb{R} \\ dP_{\phi ij} = \frac{\partial O_{\phi}}{\partial P_{ij}} \in \mathbb{R}^{\phi \times M \times N} \implies \text{Known} \end{align}\]
Differentiation:
\[\begin{align} \frac{\partial O_{\phi}}{\partial S_{i'j'}} = \sum_{ij} \frac{\partial O_{\phi}}{\partial P_{ij}}\frac{\partial P_{ij}}{\partial S_{i'j'}} = P_{i'j'} \left[dP_{\phi i'j'} - \sum_{j} dP_{\phi i'j} P_{i'j} \right] \\ = P_{i'j'} \left[dP_{i'j'} - \sum_{j} dP_{i'j} P_{i'j} \right] = P_{i'j'} \left[dP_{i'j'} - dP_{i':}^T \circ P_{i':} \right] \\ = P \times \left[dP - BMM(dP_{i'1j, P_{i'j1}}) \right] \end{align}\]
\[\begin{align} \frac{\partial O_{\phi}}{\partial S_{i'j'}} = P_{i'j'} \left[dP_{i'j'} - dP_{i':}^T \circ P_{i':} \right] \end{align}\]