Max Kernel: Forward and Backward Pass (MATH)

AI
Compute
Autograd
MATH
Author

Shivam Pandey

Published

March 30, 2025

Max (Reduction) operation

\[\begin{align} B_{s_2} = \max_{s_1}(A_{s_2 s_1}) = A_{s_2 s^m_1} \big|_{s^m_1 = argmax_{s_1}(A_{s_2 s_1})} \end{align}\]

Backward Pass:

Here we will deduce the pullback of \(A\) under the \(\max\) operation, w.r.t. \(B\).

\[\begin{align} \frac{\partial B_{s_2}}{\partial A_{s'_2 s'_1}} = \frac{\partial A_{s_2 s^m_1}}{\partial A_{s'_2 s'_1}} = \mathbb{1}_{s_2 = s'_2} \cdot \mathbb{1}_{s^m_1 = s'_1} \end{align}\]

Full reduction: With Loss Drivative

Let’s assume the output \(B_{s_2}\) being used by some frisky function to generate loss \(O_{s_3}\) and somehow we have the pullback of \(B\) w.r.t. \(O\) as \(dB_{s_3 s_2} = \frac{\partial O_{s_3}}{\partial B_{s_2}}\), and now we are interested in find out what’s the pullback of \(A_{s'_2 s'_1}\) w.r.t. \(O_{s_3}\) i.e. \(\frac{\partial O_{s_3}}{\partial A_{s'_2 s'_1}}\).

\[\begin{align} \frac{\partial O_{s_3}}{\partial A_{s'_2 s'_1}} = \sum_{s_2} \frac{\partial O_{s_3}}{\partial B_{s_2}} \frac{\partial B_{s_2}}{\partial A_{s'_2 s'_1}} = \sum_{s_2} dB_{s_3 s_2} \mathbb{1}_{s_2 = s'_2} \cdot \mathbb{1}_{s^m_1 = s'_1} = dB_{s_3 s'_2} \mathbb{1}_{s^m_1 = s'_1} \big|_{s^m_1 = argmax_{s_1}(A_{s'_2 s_1})} \end{align}\]

\[\begin{align} \frac{\partial O_{s_3}}{\partial A_{s'_2 s'_1}} = dB_{s_3 s'_2} \mathbb{1}_{s^m_1 = s'_1} \big|_{s^m_1 = argmax_{s_1}(A_{s'_2 s_1})} \end{align}\]

Application in Attention: Specialize the Expressions

Setup:

\[\begin{align} S_{ij} = S^p_{ij} - \max_j(S^p_{ij}) \\ P_{ij} = softmax(S_{ij}) \\ S_{ij} \in \mathbb{R}^{M \times N} \\ P_{ij} \in \mathbb{R}^{M \times N} \\ O_{\phi} \in \mathbb{R} \\ dP_{\phi ij} = \frac{\partial O_{\phi}}{\partial P_{ij}} \in \mathbb{R}^{\phi \times M \times N} \implies \text{Known} \\ dS_{\phi ij} = \frac{\partial O_{\phi}}{\partial S_{ij}} \in \mathbb{R}^{\phi \times M \times N} = P_{ij} \left[dP_{ij} - dP_{i:}^T \circ P_{i:} \right] \end{align}\]

Differentiation

\[\begin{align} \frac{\partial O_{\phi}}{\partial S^p_{i'j'}} = \sum_{ij} \frac{\partial O_{\phi}}{\partial S_{ij}} \frac{\partial S_{ij}}{\partial S^p_{i'j'}} = \sum_{ij} dS_{\phi ij} (\mathbb{1}_{ij = i'j'} - \mathbb{1}_{i = i'} \mathbb{1}_{j^m = j'} \big|_{argmax_{j}(S^p_{i'j})}) \\ = dS_{i'j'} - \sum_{j} dS_{i'j} \mathbb{1}_{j^m = j'} \big|_{j^m={argmax_{j}(S^p_{i'j})}} = dS_{i'j'} - \mathbb{1}_{j^m = j'} \big|_{j^m={argmax_{j}(S^p_{i'j})}} \sum_{j} dS_{i'j} \end{align}\]

Replacing \(dS_{ij}\):

\[\begin{align} \frac{\partial O_{\phi}}{\partial S^p_{i'j'}} = dS_{i'j'} - \mathbb{1}_{j^m = j'} \big|_{j^m={argmax_{j}(S^p_{i'j})}} \sum_{j} dS_{i'j} \\ \end{align}\]

Back to top