FlashAttention Kernel: Forward Pass (Parallelism)

FlashAttention
Transformers
Attention
Compute
Autograd
Parallelism
CUDA
Author

Shivam Pandey

Published

April 14, 2025

Continuing on my previous blog: FlashAttention Kernel: Forward Pass (MATH), here we will explore the possibility of parallelism in the Forward Pass Kernel with step by step code transform, and finally reaching a stage which is much closer to the CUDA programming model.

Flash Attention Forward Pass:

In my last blog we saw how math works in Flash Attention, and this was the final expression that we derived there:

\[\begin{align} m_0 = S^{[m, n=0]} \\ m_{j+1} = \max(m_j, S^{[m, n=j+1]}) \\ l_0 = \exp(S^{[m, n=0]} - m_0) \\ l_{j+1} = \sum_{n\in[0 \dots j, j+1]}\exp(S^{[m, n]} - m_{j+1}) \\ l_{j+1} = \frac{exp(-m_{j})}{\exp(m_{j+1} - m_{j})}\sum_{n\in[0 \dots j]}\exp(S^{[m, n]}) + \exp(S^{[m, n=j+1]} - m_{j+1}) \\ l_{j+1} = l_{j}\exp(m_{j} - m_{j+1}) + \exp(S^{[m, n=j+1]} - m_{j+1}) \\ O^{[m, d]}_{0} = \frac{\exp(S^{[m, n=0]} - m_{0}) \cdot V^{[n=0, d]}}{l_{0}} \\ O^{[m, d]}_{j+1} = \frac{O^{[m, d]}_{j} * l_j}{l_{j+1}} + \frac{\exp(S^{[m, n=j+1]} - m_{j+1}) \cdot V^{[n=j+1, d]}}{l_{j+1}} \\ \end{align}\]

Parallelization Analysis: Forward Pass

Using above mentioned math expressions for Flash Attention forward pass we can derive the following code (partly pseudo):

for i in range(0, M):
    for j in range(0, N):
        for d in range(0, D):
            q_i = Q[i] # S1: No self dependency: [SLoop: i, TLoop: i]
            k_j = K[j] # S2: No self dependency: [SLoop: j, TLoop: j]

            m_i = m[i] # S3: No self dependency: [SLoop: i, TLoop: i]
            l_i = l[i] # S4: No self dependency: [SLoop: i, TLoop: i]

            S_ij = q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j

            m_ij = max(m_i, S_ij) # S6: No self dependency: [SLoop: (i, (i, j)), TLoop: (i, j)] # RAW on m_i and S_ij
            l_ij = l_i * exp(m_i - m_ij) + exp(S_ij - m_ij) # S7: No self dependency: [SLoop: (j, (i, (i, j)), ((i, j), (i, j))), TLoop: (i, j)] # RAW on l_i, m_i, m_ij, S_ij, m_ij

            o_id = O[i, d] # S8: No self dependency: [SLoop: (i, d), TLoop: (i, d)]
            v_jd = V[j, d] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]

            o_ijd = o_id * l_i / l_ij + exp(S_ij - m_ij) * v_jd / l_ij # S10: No self dependency: [SLoop: ((i, d), i, (i, j), (i, j), (i, j), (j, d), (i, j)), TLoop: (i, j, d)] # RAW on o_id, l_i, l_ij, S_ij, m_ij, v_jd, l_ij

            ### Finally assign the results back to buffers
            O[i, d] = o_ijd # S11: Aggregation over j # RAW on o_ijd
            m[i] = m_ij # S12: Aggregation over j # RAW on m_ij
            l[i] = l_ij # S13: Aggregation over j # RAW on l_ij

The code is annotated with comments that follows a simple information template:

  1. The starting of the comments starts with letter S followed by a number, e.g. S2, this indicates an statement along with its given number, so S2 stands for statement 2.
  2. After statement number the comment lists of the statement if dependent on self, i.e. if the variable updated in the statement is further being updated in a future iteration.
  3. The comment then follows a simple notation for loop iteration order, which consists of two parts Source Loop (SLoop) and Target Loop (TLoop). In case where a statement consists of multiple variables this entry will become a list of list, where each entry corresponds to following variable used in order. For a variable being used in the statement, the iteration state is denoted by either the loop variable i.e. i, j, or d in this particular case or if the variable is dependent upon multiple iteration loops then it will be a tuple of those variables e.g. (i, d).
  4. Following this the comment lists the type of dependency (either of WAW, RAW, and WAR) and corresponding elements to which that dependency applies.
  5. Sometimes a comment can say something like Aggregation over j which means that the reduction is being performed over that particular loop variable. Though this is generally an inplace operation, for better segregation of dependencies update, and reuse of a variable this can be done by computing a source iteration dependent local variable e.g. o_ijd and then updating it to a variable that is independent of the target loop e.g. O[i, d] = o_ijd.

Data Dependency Graph (DDG)

This graph is derived from the previously annotated code, and represent the data flow and dependency across different statements.

Here we have mentioned SCC several times, which stands for Strongly Connected Component, it occurs in a graph if all the nodes of a sub-graph are accessible from every other node within that subgraph. In such case parallelization is not possible.

Note: SCCs generally appears with Loop Carried Dependency which is denoted as LC-i/j in the code representing over which loop it appears. E.g. if it’s an LC-i the SCC could not be parallelized over loop i.

graph TD
    S1["S1: Q[i]"]
    S2["S2: K[j]"]
    S3["S3: m[i]"]
    S4["S4: l[i]"]
    S5["S5: S_ij = f(q_i, k_j)"]
    S6["S6: m_ij = f(m_i, S_ij)"]
    S7["S7: l_ij = f(l_i, m_i, m_ij, S_ij, m_ij)"]
    S8["S8: O[i, d]"]
    S9["S9: V[j, d]"]
    S10["S10: o_ijd = f(o_id, l_i, l_ij, S_ij, m_ij, v_jd, l_ij)"]
    S11["S11: O[i, d] = o_ijd"]
    S12["S12: m[i] = m_ij"]
    S13["S13: l[i] = l_ij"]

    %% Intra-iteration dependencies
    S1 --> S5
    S2 --> S5
    S5 --> S6
    S3 --> S7
    S5 --> S7
    S6 --> S7
    S4 --> S10
    S5 --> S10
    S6 --> S10
    S7 --> S10
    S9 --> S10

    %%  Force same level for {S11, S12, S13}
    %% S11 ~~~ S12
    %% S12 ~~~ S13

    %% Strongly Connected Components (SCCs)
    subgraph SCC1
        S3 --> S6
        S6 --> S12
        S12 -.->|LC-j| S3
    end

    subgraph SCC2
        S4 --> S7
        S7 --> S13
        S13 -.->|LC-j| S4
    end

    subgraph SCC3
        S8 --> S10
        S10 --> S11
        S11 -.->|LC-j| S8
    end

Deductions from DDG:

  1. There is no loop carries dependencies for i and d loop, so both are parallelizable.
  2. If we analyze the loop order change all of the loop order i<->j<->d are valid, because the ony loop carried dependency is in j loop, and that has a positive lex i.e. source loop is j-1 and target loop is j.
  3. For parallelization over j loop we need to localize dependent variables and utilize atomics to communicate across thread.

Loop Interchange Analysis

Question: If we could interchange the loop \(i \leftrightarrow j\) to improve the locality of k_j and v_j as there are only 3 reads (q_i, K_j, & v_j) from HBM (excluding outputs and m_i & l_i).

The inner loop-j is responsible for loading both k_j and v_j from HBM and if the loops are interchanged a single load of k_j and v_j can be used for all of q_i, which when other way around is : a single load of q_i is being used for all k_j and v_j sequential loads.

For loop interchange one important factor is that “loop iteration dependence vector should not become lexicographically negative”

Example:

for (i=1; i<N; i++) {
  for (j=1; j<N; j++) {
    A[i][j] = A[i-1][j+1]; // RAW dependencies on i and j
  }
}

In this loop the the direction vector of iteration for the one dependency i.e. A[i-1][j+1] is (1, -1). Which after switching the loops:

for (j=1; j<N; j++) {
  for (i=1; i<N; i++) {
    A[i][j] = ...
  }
}

becomes (-1, 1) which is called lexicographically negative, and thus doesn’t allows the loop interchange as the loop order would change if the interchange happens.

Simply stating: in the original loop A[i-1][j+1] comes before A[i][j] and is updated before it. But after reorder A[i][j] will come before and modified than A[i-1][j+1] thus we can’t interchange the loops.

Modified Code: #1 Improved locality

As noted previously taking out some statements out of independent loop iterations increases variable locality, thus can be reused readily, reducing pressure on memory.

for i in range(0, M):
    q_i = Q[i] # S1: No self dependency: [SLoop: i, TLoop: i]
    m_i = m[i] # S3: No self dependency: [SLoop: i, TLoop: i]
    l_i = l[i] # S4: No self dependency: [SLoop: i, TLoop: i]
    for j in range(0, N):
        k_j = K[j] # S2: No self dependency: [SLoop: j, TLoop: j]

        S_ij = q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j

        m_ij = max(m_i, S_ij) # S6: No self dependency: [SLoop: (i, (i, j)), TLoop: (i, j)] # RAW on m_i and S_ij

        l_ij = l_i * exp(m_i - m_ij) + exp(S_{ij} - m_ij) # S7: No self dependency: [SLoop: (j, (i, (i, j)), ((i, j), (i, j))), TLoop: (i, j)] # RAW on l_i, m_i, m_ij, S_ij, m_ij

        for d in range(0, D):
            o_id = O[i, d] # S8: No self dependency: [SLoop: (i, d), TLoop: (i, d)]
            v_jd = V[j, d] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]

            o_ijd = o_id * l_i / l_ij + exp(S_ij - m_ij) * v_jd / l_ij # S10: No self dependency: [SLoop: ((i, d), i, (i, j), (i, j), (i, j), (j, d), (i, j)), TLoop: (i, j, d)] # RAW on o_id, l_i, l_ij, S_ij, m_ij, v_jd, l_ij

            ### Finally assign the results back to buffers
            O[i, d] = o_ijd # S11: Aggregation over j # RAW on o_ijd

    # Independent of `d` loop
    m[i] = m_ij # S12: Aggregation over j # RAW on m_ij
    l[i] = l_ij # S13: Aggregation over j # RAW on l_ij

Modified Code: #2 i & d loops full parallelization

Since there are no self-dependencies in loop iterations i and d, these two loops can be fully parallelized.

# Loop Collapsed under parallelization
# for i in range(0, M):
q_i = Q[i, :] # S1: No self dependency: [SLoop: i, TLoop: i]
m_i = m[i] # S3: No self dependency: [SLoop: i, TLoop: i]
l_i = l[i] # S4: No self dependency: [SLoop: i, TLoop: i]
o_i = O[i, :] # S8: No self dependency: [SLoop: (i, d), TLoop: (i, d)]
for j in range(0, N):
    k_j = K[j, :] # S2: No self dependency: [SLoop: j, TLoop: j]
    v_j = V[j, :] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]

    S_ij = q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j

    m_ij = max(m_i, S_ij) # S6: No self dependency: [SLoop: (i, (i, j)), TLoop: (i, j)] # RAW on m_i and S_ij

    l_ij = l_i * exp(m_i - m_ij) + exp(S_{ij} - m_ij) # S7: No self dependency: [SLoop: (j, (i, (i, j)), ((i, j), (i, j))), TLoop: (i, j)] # RAW on l_i, m_i, m_ij, S_ij, m_ij

    # Loop Collapsed under parallelization
    # for d in range(0, D):
    o_ij = o_i * l_i / l_ij + exp(S_ij - m_ij) * v_j / l_ij # S10: No self dependency: [SLoop: ((i, d), i, (i, j), (i, j), (i, j), (j, d), (i, j)), TLoop: (i, j, d)] # RAW on o_id, l_i, l_ij, S_ij, m_ij, v_jd, l_ij

    o_i = o_ij
    m_i = m_ij
    l_i = l_ij

# Independent of `d` loop
O[i] = o_i # S11: Aggregation over j # RAW on o_i
m[i] = m_i # S12: Aggregation over j # RAW on m_i
l[i] = l_i # S13: Aggregation over j # RAW on l_i

Modified Code: #3 attempt to parallelize loop j

Though there are multiple SCCs with LC-j, still we can incorporate those in the CUDA programming model by launching all of the corresponding threads in one shot, while enforcing an order in which threads execute.

This is generally done with atomics in CUDA

# assume global variable `counter = 0` which is incremented as per atomics.
# Loop Collapsed under parallelization
# for i in range(0, M):
q_i = Q[i, :] # S1: No self dependency: [SLoop: i, TLoop: i]

# Loop collapsed under parallelization
# for j in range(0, N):
k_j = K[j, :] # S2: No self dependency: [SLoop: j, TLoop: j]
v_j = V[j, :] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]

S_ij = q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j

while True:
    if tid == counter:
        m_i = m[i]
        m_ij = max(m_i, S_ij) # S6: No self dependency: [SLoop: (i, (i, j)), TLoop: (i, j)] # RAW on m_i and S_ij
        m[i] = m_ij
        l_i = l[i]
        l_ij = l_i * exp(m_i - m_ij) + exp(S_ij - m_ij)
        l[i] = l_ij
        o_i = O[i, :]
        o_ij = o_i * l_i / l_ij + exp(S_ij - m_ij) * v_j / l_ij
        O[i, :] = o_ij
        atomicadd(counter, 1);
        break

Modified Code: #3 attempt to parallelize loop j this time without enforcing an order

# assume global variable `counter = 0` which is incremented as per atomics.
# Loop Collapsed under parallelization
# for i in range(0, M):
q_i = Q[i, :] # S1: No self dependency: [SLoop: i, TLoop: i]

# Loop collapsed under parallelization
# for j in range(0, N):
k_j = K[j, :] # S2: No self dependency: [SLoop: j, TLoop: j]
v_j = V[j, :] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]

S_ij = q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j

# Compute complete max till end
atomic-max(m[i], S_ij)

grid.sync()

atomic-add(l[i], exp(S_ij - m[i]))

grid.sync()

atomic-add(o[i], exp(S_ij - m[i]) * v_j / l[i])

grid.sync()

Modified Code: #4 attempt to parallelize loop j this time without enforcing an order

Another way to achieve this is Cooperative Thread Arrays (CTA) available in CUDA, more on this in next blog when we dive deep into CUDA code itself.

# assume global variable `counter = 0` which is incremented as per atomics.
# Loop Collapsed under parallelization
# for i in range(0, M):
q_i = Q[i, :] # S1: No self dependency: [SLoop: i, TLoop: i]

# Loop collapsed under parallelization
# for j in range(0, N):
k_j = K[j, :] # S2: No self dependency: [SLoop: j, TLoop: j]
v_j = V[j, :] # S9: No self dependency: [SLoop: (j, d), TLoop: (j, d)]

S_ij = q_i @ k_j # S5: No self dependency: [SLoop: (i, j), TLoop: (i, j)] # Loop interchange possible # RAW on q_i, and k_j

# Compute complete max till end
maxreductionthroughCTA(m[i], S_ij)

grid.sync()

addreductionthroughCTA(l[i], exp(S_ij - m[i]))

grid.sync()

addreductionthroughCTA(o[i], exp(S_ij - m[i]) * v_j / l[i])

grid.sync()

Wondering where is the code analysis for backward pass, that’s coming next click here.

Back to top