TLDW logo

How FlashAttention Accelerates Generative AI Revolution

By Jia-Bin Huang

Summary

## Key takeaways - **Attention's Fatal Memory Bottleneck**: Data frequently travels back and forth between HBM and GPU core: load Q,K from HBM, compute S and write back; load S, compute softmax, write A; load A,V, compute O and write output. This back-and-forth global memory access adds significant latencies. [02:04], [02:39] - **Tiling Slashes Memory Access**: Tiling partitions matrices into small blocks that fit in fast on-chip SRAM, reducing global memory accesses by a factor of sqrt(N) for matrix multiplication. Instead of 32 accesses for 4 values, 2x2 blocks need just 16. [03:34], [04:49] - **Safe Softmax Needs 3 Passes**: Safe softmax finds max M in first pass, computes sum of exp(s_i - M) in second, normalizes in third—iterating the sequence three times. This is inefficient due to multiple global memory accesses. [06:30], [06:53] - **Online Softmax Fuses to 2 Passes**: Replace global max M with running max m_i and use recurrence d_i = exp(s_i - m_i) + d_{i-1} * exp(m_{i-1} - m_i), merging loops into one pass over the sequence. [07:23], [08:07] - **FlashAttention: Single-Pass Fusion**: Extend online softmax to attention with running max m_i, scale d_i, output o_i = (o_{i-1} * d_{i-1} * exp(m_{i-1}-m_i) + v_i * exp(s_i - m_i)) / d_i, fusing QK, softmax, and Ov into one SRAM loop. Never materializes N x N attention matrix, achieving exact, fast, memory-efficient computation. [08:24], [09:40]

Topics Covered

  • Attention Slows from HBM Bottleneck
  • Tiling Cuts Memory Access by n
  • Online Softmax Fuses Three Passes
  • FlashAttention Fuses Single Loop

Full Transcript

Transformers are the key driving force behind today's AI boom however the attention mechanism is low and memory hungry many have tried using

approximations to speed it up making the attention Matrix far overall R these techniques are very interesting but none of them fries these approximate

attention Masters often sacrifice accuracy while not achieving actual wall clock speed up that change with flash attention it's a clever algorithm that

is fast memory efficient and exact in this video I will first explain why self attention is low and talk about how one could speed up a computation using I

aware algorithms like tiling I'll use matrix multiplication as an example but it's not immediately clear how we can use an i aware algorithm to speed up

attention to do this we need to First understand a clever trick called online softmax I will then show how extending online softmax leads to fracture

attension let's start by looking at self attention the self attention module first computes the query and key vectors and then computes the vector. product

with every pair to capture the introd correlation between tokens this creates an A by A attention Matrix this Matrix isn't too large here but it becomes

memory heavy with sequence of tens of thousands of tokens we can express the computation more concisely with matrix multiplication by stacking all the query vectors along

the rows and all the key vectors along the columns next we apply the softmax function for each row to get the attention weights and compu the output

as the weighted average of the value vectors so the whole process boils down to these three simple equations for Simplicity I'm leaving out some details

like do product scaling Dropout coao masking and multi-ad attention to understand why it is slow let's look at how this computation happen on a

GPU the query key and value metrix are store in a high bandwidth memory hbm outside the GPU course we load the query

and key metrix from hbm computer.

product and save the result s back to hbm next we load the Matrix S from hbm compute the rwise softmax and write the

N byn attention Matrix a back to it hbm finally we load attention Matrix a and the value Matrix V from hbm computer weighted average and ride the output o

to hbm as we can see data frequently travel back and forth between the hbm and the GPU course this back and forth Global memory access adds significant

latencies so how can we speed this up let's exam the memory hierarchy focusing on the bandwidth and the memory size all computers have the Ram it's large but

slow in bandwidth on a GPU there's hbm with smaller size about 40 GB but much faster bandwidth than D RAM on chip SRAM

is an older of Mag 2 faster than hbm but much smaller the core idea is to leverage on cheap SRAM to prevent the cost of loading and writing large n

byend attention matrics on slow hbm but the SRAM is much smaller clearly we cannot the full attention Matrix in SRAM this is where the Tiding technique

comes in let's use matrix multiplication as an example here we want to multiply two 4x4 matrixes A and B to produce

output Matrix C the value of C11 is the dot product between the first row Vector of Matrix a and the First Column Vector

of Matrix B similarly the value of C12 is a thought product between the first row Vector of Matrix a and the second column Vector of Matrix B following this

rule we compute the value of c21 and c22 now let's check how many Global memory assets this requires to compute each value we need to low one row Vector

from Matrix a and one column Vector from Matrix B so a accs in total this results in 32 memory accesses can we do

better we see that C11 one can be computed with two dot products the same applies to other values now let's group then along the column and roll

Dimensions we now see that the 2x two Block in the output Matrix C can be computed by adding two 2x two Matrix multiplications each of these

multiplications require a AIS with stying we just need 16 memory access for computing the same four values using M

byn blocks we can cut grower memory access by a factor of n to complete the matrix multiplication we partition The

Matrix a b and c into 2x two blocks now these 2x two blocks can be moved to onchip SRAM for faster processing we then combine these partial tile based

Matrix mplications to get the final results that's great progress now how can we apply the tiling technique to speed up attention hm we already know

how to apply tiling to perform Matrix mplications but we still have this softmax operation in between let's focus on this and see how we can break down

the computation using tiling for Simplicity we only consider one rle of the attention Matrix at a time this

gives us a sequence of UN number X1 to xn to compute softmax we first applied an exponential to make all the numbers

positive then normalize them to sum up to one one issue with this approach is that it's not numerically stable with half Precision overflow occurs in the

exponential function for input values over 11 luckily we have a simple fix before the exponential step we first subtract M the maximum value of the

sequence from each input this prevents the overflowing issue because all the input values to the exponential function are now less than or or equal to

zero this is known as safe softmax so how do we Implement safe softmax essentially it requires three passes of the sequence s the first pass

finds the maximum value in the sequence we call it mm here in the second pass we compute the sum of the exponential values finally we normalize the

exponential value to obtain the attention weights summing up to one this is very in efficient because the algorithm requires iterating the sequence three

times let's explore how to reduce Global memory access by refusing computations we cannot fuse the second PA with the first one due to the

dependency of MN which is the maximum value of the sequence found in the first pass we can remove this dependency by

replacing the value MN with Mi I where Mi is the maximal value of the partial sequence from X1 to x i at the end note

that the DN will equal DN Plum but how do we compute the sequence di Plum iteratively we first separate the I item from the summation we then multiply a d

value of one here since multiplication is associative we can swap these two terms now we recognize that the sumission in the parenthesis is just D

IUS one plump this gives us the recurrence relation between di Plum and dius one Plum using this recurrence we can implement the softmax function by

just two passes over the sequence we merge the two for Loops into one because the recurrence only depends

on Mi i m m IUS one not MN this is the idea of online soft Max now let's see if we can apply the same idea for self

attention in self attention the XI are the pre softmax Logics computed by the do product between the query and key vectors this can be handled in the first

RP we computer output o using the weighted average of value vectors according to the attention weight AI so this requires two passes can we

fuse everything into one Loop here we Den know oi as the weighted average of the value vectors up to the I token we cannot fuse the comp utation in one

single Loop due to the dependency on MM and DM plant let's try the same trick as before by replacing mm and DM PL with Mi

and Di Plum this gives us a new sequence o i Plum first we move the I term out of the

summation and multiply some D values of one now we can swap these terms and regroup the equations like like

this we recognize the summation in the parenthesis is simply o IUS one p this forms a recurrence relation between o i

Plum and o i minus one Plum by removing the dependency of MN and DN pl we can now fuse all the computation into one single Loop this

approach forms frash attenion pressure attention avoids materializing the large attention matrix by fusing all the computations together we still get exact results after

iterating through all the elements in the sequence let's visualize these steps here we have query key and value matrices we first partition these

matrices into tiles we load the first query tile from hbm to on trip SRAM similarly we load

the first key and value tiles to SRAM with all these values we can perform the attention computation and save the partial result to hbm now we have the

partial result labeled as 01 plump then we load the next tile perform attention computation in SRAM and update partial results as O2

Plum we repeat this for all tiles updating the output each time at the end of the ROP we have the complete attention results as 05 Plum

similarly we repeat this process for the next query tile loading the corresponding key and value blocks Computing the attention and updating the

results using tiling we do not materialize the full attention Matrix at any time and this significantly reduce the global memory access this achieves

fast memory efficient and exact computation of the attention mechanism and that's the core of fresh attention this already followup work to improve

the efficiency of flash tension further like frash attention 2 and fraction attention three it's amazing how Hardware aware algorithms can make attention so much faster and memory

efficient in summary we cover why attention is slow how an IO algorithm like tiling can speed things up onl softmax and its extension to thrash

attention thanks for watching and I'll see you next time [Music]

Loading...

Loading video analysis...