graph TD
S1["S1: Q[i]"]
S2["S2: K[j]"]
S3["S3: m[i]"]
S4["S4: l[i]"]
S5["S5: S_ij = f(q_i, k_j)"]
S6["S6: m_ij = f(m_i, S_ij)"]
S7["S7: l_ij = f(l_i, m_i, m_ij, S_ij, m_ij)"]
S8["S8: O[i, d]"]
S9["S9: V[j, d]"]
S10["S10: o_ijd = f(o_id, l_i, l_ij, S_ij, m_ij, v_jd, l_ij)"]
S11["S11: O[i, d] = o_ijd"]
S12["S12: m[i] = m_ij"]
S13["S13: l[i] = l_ij"]
%% Intra-iteration dependencies
S1 --> S5
S2 --> S5
S5 --> S6
S3 --> S7
S5 --> S7
S6 --> S7
S4 --> S10
S5 --> S10
S6 --> S10
S7 --> S10
S9 --> S10
%% Force same level for {S11, S12, S13}
%% S11 ~~~ S12
%% S12 ~~~ S13
%% Strongly Connected Components (SCCs)
subgraph SCC1
S3 --> S6
S6 --> S12
S12 -.->|LC-j| S3
end
subgraph SCC2
S4 --> S7
S7 --> S13
S13 -.->|LC-j| S4
end
subgraph SCC3
S8 --> S10
S10 --> S11
S11 -.->|LC-j| S8
end
FlashAttention Kernel: Forward Pass (Parallelism)
Continuing on my previous blog: FlashAttention Kernel: Forward Pass (MATH), here we will explore the possibility of parallelism in the Forward Pass Kernel with step by step code transform, and finally reaching a stage which is much closer to the CUDA programming model.
Flash Attention Forward Pass:
In my last blog we saw how math works in Flash Attention, and this was the final expression that we derived there:
\[\begin{align} m_0 = S^{[m, n=0]} \\ m_{j+1} = \max(m_j, S^{[m, n=j+1]}) \\ l_0 = \exp(S^{[m, n=0]} - m_0) \\ l_{j+1} = \sum_{n\in[0 \dots j, j+1]}\exp(S^{[m, n]} - m_{j+1}) \\ l_{j+1} = \frac{exp(-m_{j})}{\exp(m_{j+1} - m_{j})}\sum_{n\in[0 \dots j]}\exp(S^{[m, n]}) + \exp(S^{[m, n=j+1]} - m_{j+1}) \\ l_{j+1} = l_{j}\exp(m_{j} - m_{j+1}) + \exp(S^{[m, n=j+1]} - m_{j+1}) \\ O^{[m, d]}_{0} = \frac{\exp(S^{[m, n=0]} - m_{0}) \cdot V^{[n=0, d]}}{l_{0}} \\ O^{[m, d]}_{j+1} = \frac{O^{[m, d]}_{j} * l_j}{l_{j+1}} + \frac{\exp(S^{[m, n=j+1]} - m_{j+1}) \cdot V^{[n=j+1, d]}}{l_{j+1}} \\ \end{align}\]
Parallelization Analysis: Forward Pass
Using above mentioned math expressions for Flash Attention forward pass we can derive the following code (partly pseudo):
for i in range(0, M):
for j in range(0, N):
for d in range(0, D):
q_i = Q[i] # S1: No self dependency: [SLoop: i, TLoop: i]
k_j = K[j] # S2: No self dependency: [SLoop: j, TLoop: j]
m_i = m[i] # S3: No self dependency: [SLoop: i, TLoop: i]
l_i = l[i] # S4: No self dependency: [SLoop: i, TLoop: i]
S_ij = q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j
m_ij = max(m_i, S_ij) # S6: No self dependency: [SLoop: (i, (i, j)), TLoop: (i, j)] # RAW on m_i and S_ij
l_ij = l_i * exp(m_i - m_ij) + exp(S_ij - m_ij) # S7: No self dependency: [SLoop: (j, (i, (i, j)), ((i, j), (i, j))), TLoop: (i, j)] # RAW on l_i, m_i, m_ij, S_ij, m_ij
o_id = O[i, d] # S8: No self dependency: [SLoop: (i, d), TLoop: (i, d)]
v_jd = V[j, d] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]
o_ijd = o_id * l_i / l_ij + exp(S_ij - m_ij) * v_jd / l_ij # S10: No self dependency: [SLoop: ((i, d), i, (i, j), (i, j), (i, j), (j, d), (i, j)), TLoop: (i, j, d)] # RAW on o_id, l_i, l_ij, S_ij, m_ij, v_jd, l_ij
### Finally assign the results back to buffers
O[i, d] = o_ijd # S11: Aggregation over j # RAW on o_ijd
m[i] = m_ij # S12: Aggregation over j # RAW on m_ij
l[i] = l_ij # S13: Aggregation over j # RAW on l_ijThe code is annotated with comments that follows a simple information template:
- The starting of the comments starts with letter
Sfollowed by a number, e.g.S2, this indicates an statement along with its given number, soS2stands forstatement 2. - After statement number the comment lists of the statement if dependent on self, i.e. if the variable updated in the statement is further being updated in a future iteration.
- The comment then follows a simple notation for loop iteration order, which consists of two parts
Source Loop (SLoop)andTarget Loop (TLoop). In case where a statement consists of multiple variables this entry will become a list of list, where each entry corresponds to following variable used in order. For a variable being used in the statement, the iteration state is denoted by either the loop variable i.e.i,j, ordin this particular case or if the variable is dependent upon multiple iteration loops then it will be a tuple of those variables e.g.(i, d). - Following this the comment lists the type of dependency (either of
WAW,RAW, andWAR) and corresponding elements to which that dependency applies. - Sometimes a comment can say something like
Aggregation over jwhich means that the reduction is being performed over that particular loop variable. Though this is generally an inplace operation, for better segregation of dependencies update, and reuse of a variable this can be done by computing a source iteration dependent local variable e.g.o_ijdand then updating it to a variable that is independent of the target loop e.g.O[i, d] = o_ijd.
Data Dependency Graph (DDG)
This graph is derived from the previously annotated code, and represent the data flow and dependency across different statements.
Here we have mentioned SCC several times, which stands for Strongly Connected Component, it occurs in a graph if all the nodes of a sub-graph are accessible from every other node within that subgraph. In such case parallelization is not possible.
Note:
SCCsgenerally appears withLoop Carried Dependencywhich is denoted asLC-i/jin the code representing over which loop it appears. E.g. if it’s anLC-itheSCCcould not be parallelized over loopi.
Deductions from DDG:
- There is no loop carries dependencies for
ianddloop, so both are parallelizable. - If we analyze the loop order change all of the loop order
i<->j<->dare valid, because the ony loop carried dependency is injloop, and that has a positive lex i.e. source loop isj-1and target loop isj. - For parallelization over
jloop we need to localize dependent variables and utilizeatomicsto communicate across thread.
Loop Interchange Analysis
Question: If we could interchange the loop \(i \leftrightarrow j\) to improve the locality of k_j and v_j as there are only 3 reads (q_i, K_j, & v_j) from HBM (excluding outputs and m_i & l_i).
The inner loop-j is responsible for loading both k_j and v_j from HBM and if the loops are interchanged a single load of k_j and v_j can be used for all of q_i, which when other way around is : a single load of q_i is being used for all k_j and v_j sequential loads.
For loop interchange one important factor is that “loop iteration dependence vector should not become lexicographically negative”
Example:
for (i=1; i<N; i++) {
for (j=1; j<N; j++) {
A[i][j] = A[i-1][j+1]; // RAW dependencies on i and j
}
}In this loop the the direction vector of iteration for the one dependency i.e. A[i-1][j+1] is (1, -1). Which after switching the loops:
for (j=1; j<N; j++) {
for (i=1; i<N; i++) {
A[i][j] = ...
}
}becomes (-1, 1) which is called lexicographically negative, and thus doesn’t allows the loop interchange as the loop order would change if the interchange happens.
Simply stating: in the original loop A[i-1][j+1] comes before A[i][j] and is updated before it. But after reorder A[i][j] will come before and modified than A[i-1][j+1] thus we can’t interchange the loops.
Modified Code: #1 Improved locality
As noted previously taking out some statements out of independent loop iterations increases variable locality, thus can be reused readily, reducing pressure on memory.
for i in range(0, M):
q_i = Q[i] # S1: No self dependency: [SLoop: i, TLoop: i]
m_i = m[i] # S3: No self dependency: [SLoop: i, TLoop: i]
l_i = l[i] # S4: No self dependency: [SLoop: i, TLoop: i]
for j in range(0, N):
k_j = K[j] # S2: No self dependency: [SLoop: j, TLoop: j]
S_ij = q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j
m_ij = max(m_i, S_ij) # S6: No self dependency: [SLoop: (i, (i, j)), TLoop: (i, j)] # RAW on m_i and S_ij
l_ij = l_i * exp(m_i - m_ij) + exp(S_{ij} - m_ij) # S7: No self dependency: [SLoop: (j, (i, (i, j)), ((i, j), (i, j))), TLoop: (i, j)] # RAW on l_i, m_i, m_ij, S_ij, m_ij
for d in range(0, D):
o_id = O[i, d] # S8: No self dependency: [SLoop: (i, d), TLoop: (i, d)]
v_jd = V[j, d] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]
o_ijd = o_id * l_i / l_ij + exp(S_ij - m_ij) * v_jd / l_ij # S10: No self dependency: [SLoop: ((i, d), i, (i, j), (i, j), (i, j), (j, d), (i, j)), TLoop: (i, j, d)] # RAW on o_id, l_i, l_ij, S_ij, m_ij, v_jd, l_ij
### Finally assign the results back to buffers
O[i, d] = o_ijd # S11: Aggregation over j # RAW on o_ijd
# Independent of `d` loop
m[i] = m_ij # S12: Aggregation over j # RAW on m_ij
l[i] = l_ij # S13: Aggregation over j # RAW on l_ijModified Code: #2 i & d loops full parallelization
Since there are no self-dependencies in loop iterations i and d, these two loops can be fully parallelized.
# Loop Collapsed under parallelization
# for i in range(0, M):
q_i = Q[i, :] # S1: No self dependency: [SLoop: i, TLoop: i]
m_i = m[i] # S3: No self dependency: [SLoop: i, TLoop: i]
l_i = l[i] # S4: No self dependency: [SLoop: i, TLoop: i]
o_i = O[i, :] # S8: No self dependency: [SLoop: (i, d), TLoop: (i, d)]
for j in range(0, N):
k_j = K[j, :] # S2: No self dependency: [SLoop: j, TLoop: j]
v_j = V[j, :] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]
S_ij = q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j
m_ij = max(m_i, S_ij) # S6: No self dependency: [SLoop: (i, (i, j)), TLoop: (i, j)] # RAW on m_i and S_ij
l_ij = l_i * exp(m_i - m_ij) + exp(S_{ij} - m_ij) # S7: No self dependency: [SLoop: (j, (i, (i, j)), ((i, j), (i, j))), TLoop: (i, j)] # RAW on l_i, m_i, m_ij, S_ij, m_ij
# Loop Collapsed under parallelization
# for d in range(0, D):
o_ij = o_i * l_i / l_ij + exp(S_ij - m_ij) * v_j / l_ij # S10: No self dependency: [SLoop: ((i, d), i, (i, j), (i, j), (i, j), (j, d), (i, j)), TLoop: (i, j, d)] # RAW on o_id, l_i, l_ij, S_ij, m_ij, v_jd, l_ij
o_i = o_ij
m_i = m_ij
l_i = l_ij
# Independent of `d` loop
O[i] = o_i # S11: Aggregation over j # RAW on o_i
m[i] = m_i # S12: Aggregation over j # RAW on m_i
l[i] = l_i # S13: Aggregation over j # RAW on l_iModified Code: #3 attempt to parallelize loop j
Though there are multiple SCCs with LC-j, still we can incorporate those in the CUDA programming model by launching all of the corresponding threads in one shot, while enforcing an order in which threads execute.
This is generally done with atomics in CUDA
# assume global variable `counter = 0` which is incremented as per atomics.
# Loop Collapsed under parallelization
# for i in range(0, M):
q_i = Q[i, :] # S1: No self dependency: [SLoop: i, TLoop: i]
# Loop collapsed under parallelization
# for j in range(0, N):
k_j = K[j, :] # S2: No self dependency: [SLoop: j, TLoop: j]
v_j = V[j, :] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]
S_ij = q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j
while True:
if tid == counter:
m_i = m[i]
m_ij = max(m_i, S_ij) # S6: No self dependency: [SLoop: (i, (i, j)), TLoop: (i, j)] # RAW on m_i and S_ij
m[i] = m_ij
l_i = l[i]
l_ij = l_i * exp(m_i - m_ij) + exp(S_ij - m_ij)
l[i] = l_ij
o_i = O[i, :]
o_ij = o_i * l_i / l_ij + exp(S_ij - m_ij) * v_j / l_ij
O[i, :] = o_ij
atomicadd(counter, 1);
breakModified Code: #3 attempt to parallelize loop j this time without enforcing an order
# assume global variable `counter = 0` which is incremented as per atomics.
# Loop Collapsed under parallelization
# for i in range(0, M):
q_i = Q[i, :] # S1: No self dependency: [SLoop: i, TLoop: i]
# Loop collapsed under parallelization
# for j in range(0, N):
k_j = K[j, :] # S2: No self dependency: [SLoop: j, TLoop: j]
v_j = V[j, :] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]
S_ij = q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j
# Compute complete max till end
atomic-max(m[i], S_ij)
grid.sync()
atomic-add(l[i], exp(S_ij - m[i]))
grid.sync()
atomic-add(o[i], exp(S_ij - m[i]) * v_j / l[i])
grid.sync()Modified Code: #4 attempt to parallelize loop j this time without enforcing an order
Another way to achieve this is Cooperative Thread Arrays (CTA) available in CUDA, more on this in next blog when we dive deep into CUDA code itself.
# assume global variable `counter = 0` which is incremented as per atomics.
# Loop Collapsed under parallelization
# for i in range(0, M):
q_i = Q[i, :] # S1: No self dependency: [SLoop: i, TLoop: i]
# Loop collapsed under parallelization
# for j in range(0, N):
k_j = K[j, :] # S2: No self dependency: [SLoop: j, TLoop: j]
v_j = V[j, :] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]
S_ij = q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j
# Compute complete max till end
maxreductionthroughCTA(m[i], S_ij)
grid.sync()
addreductionthroughCTA(l[i], exp(S_ij - m[i]))
grid.sync()
addreductionthroughCTA(o[i], exp(S_ij - m[i]) * v_j / l[i])
grid.sync()Wondering where is the code analysis for backward pass, that’s coming next click here.