FlashAttention Kernel: Backward Pass (Parallelism)

FlashAttention
Transformers
Attention
Compute
Autograd
Parallelism
CUDA
Author

Shivam Pandey

Published

April 14, 2025

Continuing on my previous blog: FlashAttention Kernel: Backward Pass (MATH) and FlashAttention Kernel: Forward Pass (Parallelism), here we will explore the possibility of parallelism in the Backward Pass Kernel.

Note: Most of the conceptual details used here were discussed in FlashAttention Kernel: Forward Pass (Parallelism), so I urge you to kindly read that one first.

Flash Attention Backward Pass:

In my previous blogs we saw how math works in Flash Attention backward pass, and this was the final expression that we derived there:

\[\begin{align} dB \in \mathbb{R}^{[M \times D]}, \{Q, dQ\} \in \mathbb{R}^{[M \times D]}, \\ \{K, dK\} \in \mathbb{R}^{[N \times D]}, \{V, dV\} \in \mathbb{R}^{[N \times D]}\\ \end{align}\]


\[\begin{align} dV = dB^T \cdot P \\ dP = dB \cdot V^T \\ dS_{i'j'} = P_{i'j'} \left[dP_{i'j'} - dP_{i':}^T \circ P_{i':} \right] \\ dQ = dS \cdot K \\ dK = dS^T \cdot Q \end{align}\]

Which after a few manipulations yielded this simpler form:

\[\begin{align} S_{ij} = q_i \circ k_j \\ dV_j = \sum_i dB_{i d} \frac{\exp(q_i \circ k_j)}{L_i} \\ dP_{i j} = dB_i \circ V_j \\ dS_{i j} = P_{i j} \left[dP_{i j} - dB_i \circ B_i \right] \\ dQ_i = dS_i \circ K_j \\ dK_j = dS^T_j \circ Q_i \\ \end{align}\]

Note: Here B is kind of a notation abuse, you should think this same as O from the previous blogs.

Parallelization Analysis: Backward Pass

Using above mentioned math expressions for Flash Attention backward pass we can derive the following code (partly pseudo):

for i in range(0, M):
    q_i = Q[i]         # S1: No self dependency: [SLoop: i, TLoop: i]
    B_i = B[i]         # S2: No self dependency: [SLoop: i, TLoop: i]
    dB_i = dB[i]       # S3: No self dependency: [SLoop: i, TLoop: i]
    for j in range(0, N):
        k_j = K[j]     # S4: No self dependency: [SLoop: j, TLoop: j]
        v_j = V[j]     # S5: No self dependency: [SLoop: j, TLoop: j]

        S_ij = q_i @ k_j # S6: S1->S6 RAW q_i: S4->S6 RAW k_j : [SLoop: ij, TLoop: ij]
        P_ij = exp(S_ij - m_i)/l_i # S7: S6->S7 RAW S_ij: [SLoop: ij, TLoop: ij]

        dP_ij = dB_i @ v_j # S8: S3->S8 RAW dB_i: S5->S8 RAW v_j : [SLoop: ij, TLoop: ij]

        dS_ij = P_ij * (dP_ij - dB_i @ B_i) # S9: S7->S9 RAW P_ij: S8->S9 RAW dP_ij: S3->S9 RAW dB_i: S2->S9 RAW B_i: [SLoop: ij, TLoop: ij]

        dQ_i += dS_ij * k_j # S10: S9->S10 RAW dS_ij: S4->S10 RAW k_j: [SLoop: ij, TLoop: i(j+1)]
        dV_j += dB_i * P_ij # S11: S3->S11 RAW dB_i: S7->S11 RAW P_ij: [SLoop: ij, TLoop: (i+1)j]
        dK_j += dS_ij * q_i # S12: S9->S12 RAW dS_ij: S1->S12 RAW q_i: [SLoop: ij, TLoop: (i+1)j]

The code is annotated with comments, similar to that of discussed previously in forward pass blog.

Data Dependency Graph (DDG)

graph TD
    S1["S1: Q[i]"]
    S2["S2: B[i]"]
    S3["S3: dB[i]"]
    S4["S4: K[j]"]
    S5["S5: V[j]"]
    S6["S6: S_ij"]
    S7["S7: P_ij"]
    S8["S8: dP_ij"]
    S9["S9: dS_ij"]
    S10["S10: dQ_i: LC-j @ dQ[i]"]
    S11["S11: dV_j: LC-i @ dV[j]"]
    S12["S12: dK_j: LC-i @ dV[j]"]

    %% Intra-iteration dependencies
    S1 --> S6
    S1 --> S12
    S2 --> S9
    S3 --> S8
    S3 --> S9
    S4 --> S6
    S4 --> S10
    S5 --> S8
    S5 --> S11
    S6 --> S7
    S7 --> S9
    S7 --> S11
    S8 --> S9
    S9 --> S10
    S9 --> S12

    %% Force same level for S10, S11, S12
    S10 ~~~ S11
    S11 ~~~ S12

    %% Loop-carried dependencies
    S10 --> S10
    S11 --> S11
    S12 --> S12

    %% Invisible styling
    linkStyle 15 stroke-width:0px;
    linkStyle 16 stroke-width:0px;
    classDef hidden stroke-width:0,color:none,fill:none;

In this code, S{10, 11, 12} individually forms a Strongly Connected Component (SCC) which prevents the parallelization of both other i loop and inner j loop.

  • Here LC-<i/j> represents the loop carried dependency so can’t be parallelized over the loop to which the LC dependency is called.
  • One way to deal with such constraint is to perform the computation till LC in parallel across threads, and then use methods like atomics or warp-shuffle to communicate across.

Loop interchange analysis

Question: If we could interchange the loop \(i \leftrightarrow j\) to improve the locality of k_j and v_j as there are only 3 reads (q_i, K_j, & v_j) from HBM (excluding outputs and m_i & l_i).

The inner loop-j is responsible for loading both k_j and v_j from HBM and if the loops are interchanged a single load of k_j and v_j can be used for all of q_i, which when other way around is : a single load of q_i is being used for all k_j and v_j sequentual loads.

For loop interchange one improtant factor is that “loop iteration dependence vector should not become lexicographically negative”

Example:

for (i=1; i<N; i++) {
  for (j=1; j<N; j++) {
    A[i][j] = A[i-1][j+1]; // RAW dependencies on i and j
  }
}

In this loop the the direction vector of iteration for the one depndency i.e. A[i-1][j+1] is (1, -1). Which after switching the loops:

for (j=1; j<N; j++) {
  for (i=1; i<N; i++) {
    A[i][j] = ...
  }
}

becomes (-1, 1) which is called lexicographically negative, and thus doesn’t allows the loop interchange as the loop order would change if the interchange happens.

Simply stating: in the original loop A[i-1][j+1] comes before A[i][j] and is updated before it. But after reorder A[i][j] will come before and modified than A[i-1][j+1] thus we can’t interchange the loops.

Applying the concept to the FlashAttention Backward Pass:

In the original loop i.e. \(i \rightarrow j\) let’s list the dependencies:

  1. S7: P_ij RAW on S_ij : loop direction vector (0, 0)
  2. S9: dS_ij RAW on P_ij : loop direction vector (0, 0)
  3. S10: dQ_i RAW on dQ_[i-1] and dS_ij : loop direction vectors (1, 0) and (0, 0)
  4. S11: dV_j RAW on dV_[j-1] and P_ij : loop direction vector (0, 1) and (0, 0)
  5. S12: dK_j RAW on dK_[j-1] : loop direction vector (0, 1)

Now if we change the order whether some of them becomes lex negative the answer is no, i.e. none of the loop direction vectors has negative first element after the rearranging, thus the loops can be interchanged as the order is preserved.

Transformed Code:

for j in range(0, N):
    k_j = K[j]     # S4: No self dependency: [SLoop: j, TLoop: j]
    v_j = V[j]     # S5: No self dependency: [SLoop: j, TLoop: j]
    for i in range(0, M):
        q_i = Q[i]         # S1: No self dependency: [SLoop: i, TLoop: i]
        B_i = B[i]         # S2: No self dependency: [SLoop: i, TLoop: i]
        dB_i = dB[i]       # S3: No self dependency: [SLoop: i, TLoop: i]

        S_ij = q_i @ k_j # S6: S1->S6 RAW q_i: S4->S6 RAW k_j : [SLoop: ij, TLoop: ij]
        P_ij = exp(S_ij - m_i)/l_i # S7: S6->S7 RAW S_ij: [SLoop: ij, TLoop: ij]

        dP_ij = dB_i @ v_j # S8: S3->S8 RAW dB_i: S5->S8 RAW v_j : [SLoop: ij, TLoop: ij]

        dS_ij = P_ij * (dP_ij - dB_i @ B_i) # S9: S7->S9 RAW P_ij: S8->S9 RAW dP_ij: S3->S9 RAW dB_i: S2->S9 RAW B_i: [SLoop: ij, TLoop: ij]

        dQ_i += dS_ij * k_j # S10: S9->S10 RAW dS_ij: S4->S10 RAW k_j: [SLoop: ij, TLoop: i(j+1)]
        dV_j += dB_i * P_ij # S11: S3->S11 RAW dB_i: S7->S11 RAW P_ij: [SLoop: ij, TLoop: (i+1)j]
        dK_j += dS_ij * q_i # S12: S9->S12 RAW dS_ij: S1->S12 RAW q_i: [SLoop: ij, TLoop: (i+1)j]

Parallelization Analysis:

Since there are three components in this code which forms SCCs.

  • Statement S10 has a RAW on dQ_i which is LC-j thus loop-j can’t be parallelized over loop-j but since it is independent upon any of the SCCs on LC-i thus it can be parallelized over loop-i.
  • S11 has a RAW on dV_j which is LC-i thus loop-i can’t be parallelized over loop-i but coz it doens’t dependent upon any of the SCCs on LC-j it can be parallelized over loop-j.
  • S12 is same as S11 i.e. parallelizable over loop-j.

Other Statements S1-S9 do not contain any of the loop carried dependencies, thus it is possible to parallelize S1-S9 over both loop-i & loop-j.

Ways to parallelize the code [WIP]:

  1. Loop Order \(j \rightarrow i\): Parallelize the loop-j and run loop-i till statement S9 in parallel, and sync all the threads, followed by aggregation over threads for dQ_i, dK_j, and dV_j.
  2. Loop Order \(i \rightarrow j\): Parallelize the loop-i and run loop-j till statement S9 in parallel, followed by aggregation over threads for dQ_i, dK_j, and dV_j.
  3. Split the Computation for two types of Loop Carried Dependencies i.e. LC-i & LC-j. This means there would be two micro-kernels one for dQ_i and other for dV_j and dK_j.

For the 3rd method the two micro-kernels would look like this:

micro-kernel #1: All the communication is constrained to internal loop loop-j
# Parallelize over `loop-i`
# Loop collapsed under the parallelization
# for i in range(0, M):
q_i = Q[i]         # S1: No self dependency: [SLoop: i, TLoop: i]
B_i = B[i]         # S2: No self dependency: [SLoop: i, TLoop: i]
dB_i = dB[i]       # S3: No self dependency: [SLoop: i, TLoop: i]
# Parallelize over `loop-j`
# Loop collapsed under the parallelization
# for j in range(0, N):
k_j = K[j]     # S4: No self dependency: [SLoop: j, TLoop: j]
v_j = V[j]     # S5: No self dependency: [SLoop: j, TLoop: j]

#---
#S6 -> S9
#---

# Create a thread local variable to store the per thread result
dQ_ij = dS_ij * k_j

# Parallel till now

# Sync Thread
# Across the thread accumulation
dQ_i += dQ_ij # S10: S9->S10 RAW dS_ij: S4->S10 RAW k_j: [SLoop: ij, TLoop: i(j+1)]
micro-kernel #2: All the communication is constrained to internal loop loop-i
# Parallelize over `loop-j`
# Loop collapsed under the parallelization
# for j in range(0, N):
k_j = K[j]     # S4: No self dependency: [SLoop: j, TLoop: j]
v_j = V[j]     # S5: No self dependency: [SLoop: j, TLoop: j]
# Parallelize over `loop-i`
# Loop collapsed under the parallelization
# for i in range(0, M):
q_i = Q[i]         # S1: No self dependency: [SLoop: i, TLoop: i]
B_i = B[i]         # S2: No self dependency: [SLoop: i, TLoop: i]
dB_i = dB[i]       # S3: No self dependency: [SLoop: i, TLoop: i]

#---
#S6 -> S9
#---

# Create a thread local variable to store the per thread result
dV_ij = dB_i * P_ij
dK_ij = dS_ij * q_i

# Parallel till now

# Sync Thread
# Across the thread accumulation
dV_j += dV_ij # S11: S3->S11 RAW dB_i: S7->S11 RAW P_ij: [SLoop: ij, TLoop: (i+1)j]
dK_j += dK_ij # S12: S9->S12 RAW dS_ij: S1->S12 RAW q_i: [SLoop: ij, TLoop: (i+1)j]
micro-kernel #3: Fused i-j-loop parallelism
# Parallelize over `loop-j`
# Loop collapsed under the parallelization
# for j in range(0, N):
k_j = K[j]     # S4: No self dependency: [SLoop: j, TLoop: j]
v_j = V[j]     # S5: No self dependency: [SLoop: j, TLoop: j]
# Parallelize over `loop-i`
# Loop collapsed under the parallelization
# for i in range(0, M):
q_i = Q[i]         # S1: No self dependency: [SLoop: i, TLoop: i]
B_i = B[i]         # S2: No self dependency: [SLoop: i, TLoop: i]
dB_i = dB[i]       # S3: No self dependency: [SLoop: i, TLoop: i]

#---
#S6 -> S9
#---

# Create a thread local variable to store the per thread result
dQ_ij = dS_ij * k_j
dV_ij = dB_i * P_ij
dK_ij = dS_ij * q_i

# Parallel till now

# Sync Thread
# Across the thread accumulation
dQ_i += dQ_ij # S10: S9->S10 RAW dS_ij: S4->S10 RAW k_j: [SLoop: ij, TLoop: i(j+1)]
dV_j += dV_ij # S11: S3->S11 RAW dB_i: S7->S11 RAW P_ij: [SLoop: ij, TLoop: (i+1)j]
dK_j += dK_ij # S12: S9->S12 RAW dS_ij: S1->S12 RAW q_i: [SLoop: ij, TLoop: (i+1)j]

Note: These micro-kernels could be merged in any order to form the parallelization method #1 and #2 but that will put an additional pressure over the synchronization coz now both loops has to be in sync which means the total threads participating in sync are M+N. By splitting the kernel here the total threads that has to be in sync for mucro-kernel #1 is N while for mucro-kernel #1 is M which can happen in parallel, thus reducing the pressure over the synchronization system.

Back to top