graph TD S1["S1: Q[i]"] S2["S2: K[j]"] S3["S3: m[i]"] S4["S4: l[i]"] S5["S5: S_ij = f(q_i, k_j)"] S6["S6: m_ij = f(m_i, S_ij)"] S7["S7: l_ij = f(l_i, m_i, m_ij, S_ij, m_ij)"] S8["S8: O[i, d]"] S9["S9: V[j, d]"] S10["S10: o_ijd = f(o_id, l_i, l_ij, S_ij, m_ij, v_jd, l_ij)"] S11["S11: O[i, d] = o_ijd"] S12["S12: m[i] = m_ij"] S13["S13: l[i] = l_ij"] %% Intra-iteration dependencies S1 --> S5 S2 --> S5 S5 --> S6 S3 --> S7 S5 --> S7 S6 --> S7 S4 --> S10 S5 --> S10 S6 --> S10 S7 --> S10 S9 --> S10 %% Force same level for {S11, S12, S13} %% S11 ~~~ S12 %% S12 ~~~ S13 %% Strongly Connected Components (SCCs) subgraph SCC1 S3 --> S6 S6 --> S12 S12 -.->|LC-j| S3 end subgraph SCC2 S4 --> S7 S7 --> S13 S13 -.->|LC-j| S4 end subgraph SCC3 S8 --> S10 S10 --> S11 S11 -.->|LC-j| S8 end
FlashAttention Kernel: Forward Pass (Parallelism)
Continuing on my previous blog: FlashAttention Kernel: Forward Pass (MATH), here we will explore the possibility of parallelism in the Forward Pass Kernel with step by step code transform, and finally reaching a stage which is much closer to the CUDA programming model.
Flash Attention Forward Pass:
In my last blog we saw how math works in Flash Attention, and this was the final expression that we derived there:
\[\begin{align} m_0 = S^{[m, n=0]} \\ m_{j+1} = \max(m_j, S^{[m, n=j+1]}) \\ l_0 = \exp(S^{[m, n=0]} - m_0) \\ l_{j+1} = \sum_{n\in[0 \dots j, j+1]}\exp(S^{[m, n]} - m_{j+1}) \\ l_{j+1} = \frac{exp(-m_{j})}{\exp(m_{j+1} - m_{j})}\sum_{n\in[0 \dots j]}\exp(S^{[m, n]}) + \exp(S^{[m, n=j+1]} - m_{j+1}) \\ l_{j+1} = l_{j}\exp(m_{j} - m_{j+1}) + \exp(S^{[m, n=j+1]} - m_{j+1}) \\ O^{[m, d]}_{0} = \frac{\exp(S^{[m, n=0]} - m_{0}) \cdot V^{[n=0, d]}}{l_{0}} \\ O^{[m, d]}_{j+1} = \frac{O^{[m, d]}_{j} * l_j}{l_{j+1}} + \frac{\exp(S^{[m, n=j+1]} - m_{j+1}) \cdot V^{[n=j+1, d]}}{l_{j+1}} \\ \end{align}\]
Parallelization Analysis: Forward Pass
Using above mentioned math expressions for Flash Attention forward pass we can derive the following code (partly pseudo):
for i in range(0, M):
for j in range(0, N):
for d in range(0, D):
= Q[i] # S1: No self dependency: [SLoop: i, TLoop: i]
q_i = K[j] # S2: No self dependency: [SLoop: j, TLoop: j]
k_j
= m[i] # S3: No self dependency: [SLoop: i, TLoop: i]
m_i = l[i] # S4: No self dependency: [SLoop: i, TLoop: i]
l_i
= q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j
S_ij
= max(m_i, S_ij) # S6: No self dependency: [SLoop: (i, (i, j)), TLoop: (i, j)] # RAW on m_i and S_ij
m_ij = l_i * exp(m_i - m_ij) + exp(S_ij - m_ij) # S7: No self dependency: [SLoop: (j, (i, (i, j)), ((i, j), (i, j))), TLoop: (i, j)] # RAW on l_i, m_i, m_ij, S_ij, m_ij
l_ij
= O[i, d] # S8: No self dependency: [SLoop: (i, d), TLoop: (i, d)]
o_id = V[j, d] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]
v_jd
= o_id * l_i / l_ij + exp(S_ij - m_ij) * v_jd / l_ij # S10: No self dependency: [SLoop: ((i, d), i, (i, j), (i, j), (i, j), (j, d), (i, j)), TLoop: (i, j, d)] # RAW on o_id, l_i, l_ij, S_ij, m_ij, v_jd, l_ij
o_ijd
### Finally assign the results back to buffers
= o_ijd # S11: Aggregation over j # RAW on o_ijd
O[i, d] = m_ij # S12: Aggregation over j # RAW on m_ij
m[i] = l_ij # S13: Aggregation over j # RAW on l_ij l[i]
The code is annotated with comments that follows a simple information template:
- The starting of the comments starts with letter
S
followed by a number, e.g.S2
, this indicates an statement along with its given number, soS2
stands forstatement 2
. - After statement number the comment lists of the statement if dependent on self, i.e. if the variable updated in the statement is further being updated in a future iteration.
- The comment then follows a simple notation for loop iteration order, which consists of two parts
Source Loop (SLoop)
andTarget Loop (TLoop)
. In case where a statement consists of multiple variables this entry will become a list of list, where each entry corresponds to following variable used in order. For a variable being used in the statement, the iteration state is denoted by either the loop variable i.e.i
,j
, ord
in this particular case or if the variable is dependent upon multiple iteration loops then it will be a tuple of those variables e.g.(i, d)
. - Following this the comment lists the type of dependency (either of
WAW
,RAW
, andWAR
) and corresponding elements to which that dependency applies. - Sometimes a comment can say something like
Aggregation over j
which means that the reduction is being performed over that particular loop variable. Though this is generally an inplace operation, for better segregation of dependencies update, and reuse of a variable this can be done by computing a source iteration dependent local variable e.g.o_ijd
and then updating it to a variable that is independent of the target loop e.g.O[i, d] = o_ijd
.
Data Dependency Graph (DDG)
This graph is derived from the previously annotated code, and represent the data flow and dependency across different statements.
Here we have mentioned SCC
several times, which stands for Strongly Connected Component
, it occurs in a graph if all the nodes of a sub-graph are accessible from every other node within that subgraph. In such case parallelization is not possible.
Note:
SCCs
generally appears withLoop Carried Dependency
which is denoted asLC-i/j
in the code representing over which loop it appears. E.g. if it’s anLC-i
theSCC
could not be parallelized over loopi
.
Deductions from DDG:
- There is no loop carries dependencies for
i
andd
loop, so both are parallelizable. - If we analyze the loop order change all of the loop order
i<->j<->d
are valid, because the ony loop carried dependency is inj
loop, and that has a positive lex i.e. source loop isj-1
and target loop isj
. - For parallelization over
j
loop we need to localize dependent variables and utilizeatomics
to communicate across thread.
Loop Interchange Analysis
Question: If we could interchange the loop \(i \leftrightarrow j\) to improve the locality of k_j
and v_j
as there are only 3 reads
(q_i
, K_j
, & v_j
) from HBM
(excluding outputs and m_i & l_i
).
The inner loop-j
is responsible for loading both k_j
and v_j
from HBM
and if the loops are interchanged a single load of k_j
and v_j
can be used for all of q_i
, which when other way around is : a single load of q_i
is being used for all k_j
and v_j
sequential loads.
For loop interchange one important factor is that “loop iteration dependence vector should not become lexicographically negative”
Example:
for (i=1; i<N; i++) {
for (j=1; j<N; j++) {
= A[i-1][j+1]; // RAW dependencies on i and j
A[i][j]
} }
In this loop the the direction vector of iteration for the one dependency i.e. A[i-1][j+1]
is (1, -1)
. Which after switching the loops:
for (j=1; j<N; j++) {
for (i=1; i<N; i++) {
= ...
A[i][j]
} }
becomes (-1, 1)
which is called lexicographically negative, and thus doesn’t allows the loop interchange as the loop order would change if the interchange happens.
Simply stating: in the original loop A[i-1][j+1]
comes before A[i][j]
and is updated before it. But after reorder A[i][j]
will come before and modified than A[i-1][j+1]
thus we can’t interchange the loops.
Modified Code: #1 Improved locality
As noted previously taking out some statements out of independent loop iterations increases variable locality, thus can be reused readily, reducing pressure on memory.
for i in range(0, M):
= Q[i] # S1: No self dependency: [SLoop: i, TLoop: i]
q_i = m[i] # S3: No self dependency: [SLoop: i, TLoop: i]
m_i = l[i] # S4: No self dependency: [SLoop: i, TLoop: i]
l_i for j in range(0, N):
= K[j] # S2: No self dependency: [SLoop: j, TLoop: j]
k_j
= q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j
S_ij
= max(m_i, S_ij) # S6: No self dependency: [SLoop: (i, (i, j)), TLoop: (i, j)] # RAW on m_i and S_ij
m_ij
= l_i * exp(m_i - m_ij) + exp(S_{ij} - m_ij) # S7: No self dependency: [SLoop: (j, (i, (i, j)), ((i, j), (i, j))), TLoop: (i, j)] # RAW on l_i, m_i, m_ij, S_ij, m_ij
l_ij
for d in range(0, D):
= O[i, d] # S8: No self dependency: [SLoop: (i, d), TLoop: (i, d)]
o_id = V[j, d] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]
v_jd
= o_id * l_i / l_ij + exp(S_ij - m_ij) * v_jd / l_ij # S10: No self dependency: [SLoop: ((i, d), i, (i, j), (i, j), (i, j), (j, d), (i, j)), TLoop: (i, j, d)] # RAW on o_id, l_i, l_ij, S_ij, m_ij, v_jd, l_ij
o_ijd
### Finally assign the results back to buffers
= o_ijd # S11: Aggregation over j # RAW on o_ijd
O[i, d]
# Independent of `d` loop
= m_ij # S12: Aggregation over j # RAW on m_ij
m[i] = l_ij # S13: Aggregation over j # RAW on l_ij l[i]
Modified Code: #2 i
& d
loops full parallelization
Since there are no self-dependencies in loop iterations i
and d
, these two loops can be fully parallelized.
# Loop Collapsed under parallelization
# for i in range(0, M):
= Q[i, :] # S1: No self dependency: [SLoop: i, TLoop: i]
q_i = m[i] # S3: No self dependency: [SLoop: i, TLoop: i]
m_i = l[i] # S4: No self dependency: [SLoop: i, TLoop: i]
l_i = O[i, :] # S8: No self dependency: [SLoop: (i, d), TLoop: (i, d)]
o_i for j in range(0, N):
= K[j, :] # S2: No self dependency: [SLoop: j, TLoop: j]
k_j = V[j, :] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]
v_j
= q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j
S_ij
= max(m_i, S_ij) # S6: No self dependency: [SLoop: (i, (i, j)), TLoop: (i, j)] # RAW on m_i and S_ij
m_ij
= l_i * exp(m_i - m_ij) + exp(S_{ij} - m_ij) # S7: No self dependency: [SLoop: (j, (i, (i, j)), ((i, j), (i, j))), TLoop: (i, j)] # RAW on l_i, m_i, m_ij, S_ij, m_ij
l_ij
# Loop Collapsed under parallelization
# for d in range(0, D):
= o_i * l_i / l_ij + exp(S_ij - m_ij) * v_j / l_ij # S10: No self dependency: [SLoop: ((i, d), i, (i, j), (i, j), (i, j), (j, d), (i, j)), TLoop: (i, j, d)] # RAW on o_id, l_i, l_ij, S_ij, m_ij, v_jd, l_ij
o_ij
= o_ij
o_i = m_ij
m_i = l_ij
l_i
# Independent of `d` loop
= o_i # S11: Aggregation over j # RAW on o_i
O[i] = m_i # S12: Aggregation over j # RAW on m_i
m[i] = l_i # S13: Aggregation over j # RAW on l_i l[i]
Modified Code: #3 attempt to parallelize loop j
Though there are multiple SCCs
with LC-j
, still we can incorporate those in the CUDA programming model by launching all of the corresponding threads in one shot, while enforcing an order in which threads execute.
This is generally done with atomics
in CUDA
# assume global variable `counter = 0` which is incremented as per atomics.
# Loop Collapsed under parallelization
# for i in range(0, M):
= Q[i, :] # S1: No self dependency: [SLoop: i, TLoop: i]
q_i
# Loop collapsed under parallelization
# for j in range(0, N):
= K[j, :] # S2: No self dependency: [SLoop: j, TLoop: j]
k_j = V[j, :] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]
v_j
= q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j
S_ij
while True:
if tid == counter:
= m[i]
m_i = max(m_i, S_ij) # S6: No self dependency: [SLoop: (i, (i, j)), TLoop: (i, j)] # RAW on m_i and S_ij
m_ij = m_ij
m[i] = l[i]
l_i = l_i * exp(m_i - m_ij) + exp(S_ij - m_ij)
l_ij = l_ij
l[i] = O[i, :]
o_i = o_i * l_i / l_ij + exp(S_ij - m_ij) * v_j / l_ij
o_ij = o_ij
O[i, :] 1);
atomicadd(counter, break
Modified Code: #3 attempt to parallelize loop j
this time without enforcing an order
# assume global variable `counter = 0` which is incremented as per atomics.
# Loop Collapsed under parallelization
# for i in range(0, M):
= Q[i, :] # S1: No self dependency: [SLoop: i, TLoop: i]
q_i
# Loop collapsed under parallelization
# for j in range(0, N):
= K[j, :] # S2: No self dependency: [SLoop: j, TLoop: j]
k_j = V[j, :] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]
v_j
= q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j
S_ij
# Compute complete max till end
-max(m[i], S_ij)
atomic
grid.sync()
-add(l[i], exp(S_ij - m[i]))
atomic
grid.sync()
-add(o[i], exp(S_ij - m[i]) * v_j / l[i])
atomic
grid.sync()
Modified Code: #4 attempt to parallelize loop j
this time without enforcing an order
Another way to achieve this is Cooperative Thread Arrays (CTA)
available in CUDA, more on this in next blog when we dive deep into CUDA code itself.
# assume global variable `counter = 0` which is incremented as per atomics.
# Loop Collapsed under parallelization
# for i in range(0, M):
= Q[i, :] # S1: No self dependency: [SLoop: i, TLoop: i]
q_i
# Loop collapsed under parallelization
# for j in range(0, N):
= K[j, :] # S2: No self dependency: [SLoop: j, TLoop: j]
k_j = V[j, :] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]
v_j
= q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j
S_ij
# Compute complete max till end
maxreductionthroughCTA(m[i], S_ij)
grid.sync()
- m[i]))
addreductionthroughCTA(l[i], exp(S_ij
grid.sync()
- m[i]) * v_j / l[i])
addreductionthroughCTA(o[i], exp(S_ij
grid.sync()
Wondering where is the code analysis for backward pass, that’s coming next click here.