Max Kernel: Forward and Backward Pass (MATH)
\(A \in \mathbb{R}^{s_2 s_1}\) where \(s_2\) and \(s_1\) are the index set, for example in a 5D tensor \(A \in \mathbb{R}^{ijklm}\) a possible index set could be \(s_2 = \{ i, j \}\) and \(s_1 = \{k, l, m \}\).
\(B_{s_2} = max_{s_1}(A_{s_2 s_1})\) is a max operation over s1 index set of A, and the resulting index set is the index set of \(A\) reduced over \(s_1\) s.t. \(B \in \mathbb{R}^{s_2}\)
Max (Reduction) operation
\[\begin{align} B_{s_2} = \max_{s_1}(A_{s_2 s_1}) = A_{s_2 s^m_1} \big|_{s^m_1 = argmax_{s_1}(A_{s_2 s_1})} \end{align}\]
Backward Pass:
Here we will deduce the pullback of \(A\) under the \(\max\) operation, w.r.t. \(B\).
\[\begin{align} \frac{\partial B_{s_2}}{\partial A_{s'_2 s'_1}} = \frac{\partial A_{s_2 s^m_1}}{\partial A_{s'_2 s'_1}} = \mathbb{1}_{s_2 = s'_2} \cdot \mathbb{1}_{s^m_1 = s'_1} \end{align}\]
Full reduction: With Loss Drivative
Let’s assume the output \(B_{s_2}\) being used by some frisky function to generate loss \(O_{s_3}\) and somehow we have the pullback of \(B\) w.r.t. \(O\) as \(dB_{s_3 s_2} = \frac{\partial O_{s_3}}{\partial B_{s_2}}\), and now we are interested in find out what’s the pullback of \(A_{s'_2 s'_1}\) w.r.t. \(O_{s_3}\) i.e. \(\frac{\partial O_{s_3}}{\partial A_{s'_2 s'_1}}\).
\[\begin{align} \frac{\partial O_{s_3}}{\partial A_{s'_2 s'_1}} = \sum_{s_2} \frac{\partial O_{s_3}}{\partial B_{s_2}} \frac{\partial B_{s_2}}{\partial A_{s'_2 s'_1}} = \sum_{s_2} dB_{s_3 s_2} \mathbb{1}_{s_2 = s'_2} \cdot \mathbb{1}_{s^m_1 = s'_1} = dB_{s_3 s'_2} \mathbb{1}_{s^m_1 = s'_1} \big|_{s^m_1 = argmax_{s_1}(A_{s'_2 s_1})} \end{align}\]
\[\begin{align} \frac{\partial O_{s_3}}{\partial A_{s'_2 s'_1}} = dB_{s_3 s'_2} \mathbb{1}_{s^m_1 = s'_1} \big|_{s^m_1 = argmax_{s_1}(A_{s'_2 s_1})} \end{align}\]
Application in Attention: Specialize the Expressions
Setup:
\[\begin{align} S_{ij} = S^p_{ij} - \max_j(S^p_{ij}) \\ P_{ij} = softmax(S_{ij}) \\ S_{ij} \in \mathbb{R}^{M \times N} \\ P_{ij} \in \mathbb{R}^{M \times N} \\ O_{\phi} \in \mathbb{R} \\ dP_{\phi ij} = \frac{\partial O_{\phi}}{\partial P_{ij}} \in \mathbb{R}^{\phi \times M \times N} \implies \text{Known} \\ dS_{\phi ij} = \frac{\partial O_{\phi}}{\partial S_{ij}} \in \mathbb{R}^{\phi \times M \times N} = P_{ij} \left[dP_{ij} - dP_{i:}^T \circ P_{i:} \right] \end{align}\]
Differentiation
\[\begin{align} \frac{\partial O_{\phi}}{\partial S^p_{i'j'}} = \sum_{ij} \frac{\partial O_{\phi}}{\partial S_{ij}} \frac{\partial S_{ij}}{\partial S^p_{i'j'}} = \sum_{ij} dS_{\phi ij} (\mathbb{1}_{ij = i'j'} - \mathbb{1}_{i = i'} \mathbb{1}_{j^m = j'} \big|_{argmax_{j}(S^p_{i'j})}) \\ = dS_{i'j'} - \sum_{j} dS_{i'j} \mathbb{1}_{j^m = j'} \big|_{j^m={argmax_{j}(S^p_{i'j})}} = dS_{i'j'} - \mathbb{1}_{j^m = j'} \big|_{j^m={argmax_{j}(S^p_{i'j})}} \sum_{j} dS_{i'j} \end{align}\]
Replacing \(dS_{ij}\):
\[\begin{align} \frac{\partial O_{\phi}}{\partial S^p_{i'j'}} = dS_{i'j'} - \mathbb{1}_{j^m = j'} \big|_{j^m={argmax_{j}(S^p_{i'j})}} \sum_{j} dS_{i'j} \\ \end{align}\]