graph TD S1["S1: Q[i]"] S2["S2: B[i]"] S3["S3: dB[i]"] S4["S4: K[j]"] S5["S5: V[j]"] S6["S6: S_ij"] S7["S7: P_ij"] S8["S8: dP_ij"] S9["S9: dS_ij"] S10["S10: dQ_i: LC-j @ dQ[i]"] S11["S11: dV_j: LC-i @ dV[j]"] S12["S12: dK_j: LC-i @ dV[j]"] %% Intra-iteration dependencies S1 --> S6 S1 --> S12 S2 --> S9 S3 --> S8 S3 --> S9 S4 --> S6 S4 --> S10 S5 --> S8 S5 --> S11 S6 --> S7 S7 --> S9 S7 --> S11 S8 --> S9 S9 --> S10 S9 --> S12 %% Force same level for S10, S11, S12 S10 ~~~ S11 S11 ~~~ S12 %% Loop-carried dependencies S10 --> S10 S11 --> S11 S12 --> S12 %% Invisible styling linkStyle 15 stroke-width:0px; linkStyle 16 stroke-width:0px; classDef hidden stroke-width:0,color:none,fill:none;
FlashAttention Kernel: Backward Pass (Parallelism)
Continuing on my previous blog: FlashAttention Kernel: Backward Pass (MATH) and FlashAttention Kernel: Forward Pass (Parallelism), here we will explore the possibility of parallelism in the Backward Pass Kernel.
Note: Most of the conceptual details used here were discussed in FlashAttention Kernel: Forward Pass (Parallelism), so I urge you to kindly read that one first.
Flash Attention Backward Pass:
In my previous blogs we saw how math works in Flash Attention backward pass, and this was the final expression that we derived there:
\[\begin{align} dB \in \mathbb{R}^{[M \times D]}, \{Q, dQ\} \in \mathbb{R}^{[M \times D]}, \\ \{K, dK\} \in \mathbb{R}^{[N \times D]}, \{V, dV\} \in \mathbb{R}^{[N \times D]}\\ \end{align}\]
\[\begin{align} dV = dB^T \cdot P \\ dP = dB \cdot V^T \\ dS_{i'j'} = P_{i'j'} \left[dP_{i'j'} - dP_{i':}^T \circ P_{i':} \right] \\ dQ = dS \cdot K \\ dK = dS^T \cdot Q \end{align}\]
Which after a few manipulations yielded this simpler form:
\[\begin{align} S_{ij} = q_i \circ k_j \\ dV_j = \sum_i dB_{i d} \frac{\exp(q_i \circ k_j)}{L_i} \\ dP_{i j} = dB_i \circ V_j \\ dS_{i j} = P_{i j} \left[dP_{i j} - dB_i \circ B_i \right] \\ dQ_i = dS_i \circ K_j \\ dK_j = dS^T_j \circ Q_i \\ \end{align}\]
Note: Here
B
is kind of a notation abuse, you should think this same asO
from the previous blogs.
Parallelization Analysis: Backward Pass
Using above mentioned math expressions for Flash Attention backward pass we can derive the following code (partly pseudo):
for i in range(0, M):
= Q[i] # S1: No self dependency: [SLoop: i, TLoop: i]
q_i = B[i] # S2: No self dependency: [SLoop: i, TLoop: i]
B_i = dB[i] # S3: No self dependency: [SLoop: i, TLoop: i]
dB_i for j in range(0, N):
= K[j] # S4: No self dependency: [SLoop: j, TLoop: j]
k_j = V[j] # S5: No self dependency: [SLoop: j, TLoop: j]
v_j
= q_i @ k_j # S6: S1->S6 RAW q_i: S4->S6 RAW k_j : [SLoop: ij, TLoop: ij]
S_ij = exp(S_ij - m_i)/l_i # S7: S6->S7 RAW S_ij: [SLoop: ij, TLoop: ij]
P_ij
= dB_i @ v_j # S8: S3->S8 RAW dB_i: S5->S8 RAW v_j : [SLoop: ij, TLoop: ij]
dP_ij
= P_ij * (dP_ij - dB_i @ B_i) # S9: S7->S9 RAW P_ij: S8->S9 RAW dP_ij: S3->S9 RAW dB_i: S2->S9 RAW B_i: [SLoop: ij, TLoop: ij]
dS_ij
+= dS_ij * k_j # S10: S9->S10 RAW dS_ij: S4->S10 RAW k_j: [SLoop: ij, TLoop: i(j+1)]
dQ_i += dB_i * P_ij # S11: S3->S11 RAW dB_i: S7->S11 RAW P_ij: [SLoop: ij, TLoop: (i+1)j]
dV_j += dS_ij * q_i # S12: S9->S12 RAW dS_ij: S1->S12 RAW q_i: [SLoop: ij, TLoop: (i+1)j] dK_j
The code is annotated with comments, similar to that of discussed previously in forward pass blog.
Data Dependency Graph (DDG)
In this code, S{10, 11, 12}
individually forms a Strongly Connected Component (SCC) which prevents the parallelization of both other i loop
and inner j loop
.
- Here
LC-<i/j>
represents the loop carried dependency so can’t be parallelized over the loop to which theLC
dependency is called. - One way to deal with such constraint is to perform the computation till
LC
in parallel across threads, and then use methods likeatomics
orwarp-shuffle
to communicate across.
Loop interchange analysis
Question: If we could interchange the loop \(i \leftrightarrow j\) to improve the locality of k_j
and v_j
as there are only 3 reads
(q_i
, K_j
, & v_j
) from HBM
(excluding outputs and m_i & l_i
).
The inner loop-j
is responsible for loading both k_j
and v_j
from HBM
and if the loops are interchanged a single load of k_j
and v_j
can be used for all of q_i
, which when other way around is : a single load of q_i
is being used for all k_j
and v_j
sequentual loads.
For loop interchange one improtant factor is that “loop iteration dependence vector should not become lexicographically negative”
Example:
for (i=1; i<N; i++) {
for (j=1; j<N; j++) {
= A[i-1][j+1]; // RAW dependencies on i and j
A[i][j]
} }
In this loop the the direction vector of iteration for the one depndency i.e. A[i-1][j+1]
is (1, -1)
. Which after switching the loops:
for (j=1; j<N; j++) {
for (i=1; i<N; i++) {
= ...
A[i][j]
} }
becomes (-1, 1)
which is called lexicographically negative, and thus doesn’t allows the loop interchange as the loop order would change if the interchange happens.
Simply stating: in the original loop A[i-1][j+1]
comes before A[i][j]
and is updated before it. But after reorder A[i][j]
will come before and modified than A[i-1][j+1]
thus we can’t interchange the loops.
Applying the concept to the FlashAttention Backward Pass:
In the original loop i.e. \(i \rightarrow j\) let’s list the dependencies:
S7
:P_ij
RAW onS_ij
: loop direction vector(0, 0)
S9
:dS_ij
RAW onP_ij
: loop direction vector(0, 0)
S10
:dQ_i
RAW ondQ_[i-1]
anddS_ij
: loop direction vectors(1, 0)
and(0, 0)
S11
:dV_j
RAW ondV_[j-1]
andP_ij
: loop direction vector(0, 1)
and(0, 0)
S12
:dK_j
RAW ondK_[j-1]
: loop direction vector(0, 1)
Now if we change the order whether some of them becomes lex negative
the answer is no, i.e. none of the loop direction vectors has negative first element after the rearranging, thus the loops can be interchanged as the order is preserved.
Transformed Code:
for j in range(0, N):
= K[j] # S4: No self dependency: [SLoop: j, TLoop: j]
k_j = V[j] # S5: No self dependency: [SLoop: j, TLoop: j]
v_j for i in range(0, M):
= Q[i] # S1: No self dependency: [SLoop: i, TLoop: i]
q_i = B[i] # S2: No self dependency: [SLoop: i, TLoop: i]
B_i = dB[i] # S3: No self dependency: [SLoop: i, TLoop: i]
dB_i
= q_i @ k_j # S6: S1->S6 RAW q_i: S4->S6 RAW k_j : [SLoop: ij, TLoop: ij]
S_ij = exp(S_ij - m_i)/l_i # S7: S6->S7 RAW S_ij: [SLoop: ij, TLoop: ij]
P_ij
= dB_i @ v_j # S8: S3->S8 RAW dB_i: S5->S8 RAW v_j : [SLoop: ij, TLoop: ij]
dP_ij
= P_ij * (dP_ij - dB_i @ B_i) # S9: S7->S9 RAW P_ij: S8->S9 RAW dP_ij: S3->S9 RAW dB_i: S2->S9 RAW B_i: [SLoop: ij, TLoop: ij]
dS_ij
+= dS_ij * k_j # S10: S9->S10 RAW dS_ij: S4->S10 RAW k_j: [SLoop: ij, TLoop: i(j+1)]
dQ_i += dB_i * P_ij # S11: S3->S11 RAW dB_i: S7->S11 RAW P_ij: [SLoop: ij, TLoop: (i+1)j]
dV_j += dS_ij * q_i # S12: S9->S12 RAW dS_ij: S1->S12 RAW q_i: [SLoop: ij, TLoop: (i+1)j] dK_j
Parallelization Analysis:
Since there are three components in this code which forms SCCs.
- Statement
S10
has a RAW ondQ_i
which isLC-j
thusloop-j
can’t be parallelized overloop-j
but since it is independent upon any of the SCCs onLC-i
thus it can be parallelized overloop-i
. S11
has a RAW ondV_j
which isLC-i
thusloop-i
can’t be parallelized overloop-i
but coz it doens’t dependent upon any of the SCCs onLC-j
it can be parallelized overloop-j
.S12
is same asS11
i.e. parallelizable overloop-j
.
Other Statements S1-S9
do not contain any of the loop carried dependencies, thus it is possible to parallelize S1-S9
over both loop-i
& loop-j
.
Ways to parallelize the code [WIP]:
- Loop Order \(j \rightarrow i\): Parallelize the
loop-j
and runloop-i
till statementS9
in parallel, and sync all the threads, followed by aggregation over threads fordQ_i
,dK_j
, anddV_j
. - Loop Order \(i \rightarrow j\): Parallelize the
loop-i
and runloop-j
till statementS9
in parallel, followed by aggregation over threads fordQ_i
,dK_j
, anddV_j
. - Split the Computation for two types of Loop Carried Dependencies i.e.
LC-i
&LC-j
. This means there would be twomicro-kernels
one fordQ_i
and other fordV_j
anddK_j
.
For the 3rd method the two micro-kernels would look like this:
micro-kernel #1: All the communication is constrained to internal loop loop-j
# Parallelize over `loop-i`
# Loop collapsed under the parallelization
# for i in range(0, M):
= Q[i] # S1: No self dependency: [SLoop: i, TLoop: i]
q_i = B[i] # S2: No self dependency: [SLoop: i, TLoop: i]
B_i = dB[i] # S3: No self dependency: [SLoop: i, TLoop: i]
dB_i # Parallelize over `loop-j`
# Loop collapsed under the parallelization
# for j in range(0, N):
= K[j] # S4: No self dependency: [SLoop: j, TLoop: j]
k_j = V[j] # S5: No self dependency: [SLoop: j, TLoop: j]
v_j
#---
#S6 -> S9
#---
# Create a thread local variable to store the per thread result
= dS_ij * k_j
dQ_ij
# Parallel till now
# Sync Thread
# Across the thread accumulation
+= dQ_ij # S10: S9->S10 RAW dS_ij: S4->S10 RAW k_j: [SLoop: ij, TLoop: i(j+1)] dQ_i
micro-kernel #2: All the communication is constrained to internal loop loop-i
# Parallelize over `loop-j`
# Loop collapsed under the parallelization
# for j in range(0, N):
= K[j] # S4: No self dependency: [SLoop: j, TLoop: j]
k_j = V[j] # S5: No self dependency: [SLoop: j, TLoop: j]
v_j # Parallelize over `loop-i`
# Loop collapsed under the parallelization
# for i in range(0, M):
= Q[i] # S1: No self dependency: [SLoop: i, TLoop: i]
q_i = B[i] # S2: No self dependency: [SLoop: i, TLoop: i]
B_i = dB[i] # S3: No self dependency: [SLoop: i, TLoop: i]
dB_i
#---
#S6 -> S9
#---
# Create a thread local variable to store the per thread result
= dB_i * P_ij
dV_ij = dS_ij * q_i
dK_ij
# Parallel till now
# Sync Thread
# Across the thread accumulation
+= dV_ij # S11: S3->S11 RAW dB_i: S7->S11 RAW P_ij: [SLoop: ij, TLoop: (i+1)j]
dV_j += dK_ij # S12: S9->S12 RAW dS_ij: S1->S12 RAW q_i: [SLoop: ij, TLoop: (i+1)j] dK_j
micro-kernel #3: Fused i-j-loop
parallelism
# Parallelize over `loop-j`
# Loop collapsed under the parallelization
# for j in range(0, N):
= K[j] # S4: No self dependency: [SLoop: j, TLoop: j]
k_j = V[j] # S5: No self dependency: [SLoop: j, TLoop: j]
v_j # Parallelize over `loop-i`
# Loop collapsed under the parallelization
# for i in range(0, M):
= Q[i] # S1: No self dependency: [SLoop: i, TLoop: i]
q_i = B[i] # S2: No self dependency: [SLoop: i, TLoop: i]
B_i = dB[i] # S3: No self dependency: [SLoop: i, TLoop: i]
dB_i
#---
#S6 -> S9
#---
# Create a thread local variable to store the per thread result
= dS_ij * k_j
dQ_ij = dB_i * P_ij
dV_ij = dS_ij * q_i
dK_ij
# Parallel till now
# Sync Thread
# Across the thread accumulation
+= dQ_ij # S10: S9->S10 RAW dS_ij: S4->S10 RAW k_j: [SLoop: ij, TLoop: i(j+1)]
dQ_i += dV_ij # S11: S3->S11 RAW dB_i: S7->S11 RAW P_ij: [SLoop: ij, TLoop: (i+1)j]
dV_j += dK_ij # S12: S9->S12 RAW dS_ij: S1->S12 RAW q_i: [SLoop: ij, TLoop: (i+1)j] dK_j
Note: These micro-kernels could be merged in any order to form the parallelization method #1 and #2 but that will put an additional pressure over the synchronization coz now both loops has to be in sync which means the total threads participating in sync are
M+N
. By splitting the kernel here the total threads that has to be in sync formucro-kernel #1
isN
while formucro-kernel #1
isM
which can happen in parallel, thus reducing the pressure over the synchronization system.