Vizuara Kernel Engineering
Mentor Handbook · 04 Teaching Inference Kernels

Teaching FlashAttention: never write the scores down

By the end of this chapter you'll be able to stand at a whiteboard and teach the single most famous kernel in modern AI — FlashAttention — as one clean idea a student can hold in their head: never write the big scores matrix down. No CUDA required. You need one honest picture of why attention wastes memory, one very good "grading papers" metaphor for streaming, and the courage to do a tiny running-max by hand. Let's build it so it's yours.

First, what is attention doing? (the two-minute recap)

You don't need to re-teach the whole transformer here — the earlier chapters did that. You only need the shape of one attention head. Say it plainly. Attention takes three grids of numbers, all the same size: Q (the questions), K (the keys), and V (the values). Each is N × dN rows, one per word in the sentence, and d columns (the "head dimension," usually 64 or 128).

The math is three steps:

S = Q @ K.T            # (N, N)  every word scores every other word
P = softmax(S, dim=-1) # (N, N)  turn scores into weights that sum to 1
O = P @ V              # (N, d)  each word = a weighted blend of the values
🎤 Say this at the board
"Every word looks at every other word and asks 'how much should I pay attention to you?' That produces a big square grid of scores — N words by N words. Softmax turns each row of scores into a set of weights that add up to 1. Then each word becomes a weighted blend of everybody's values. That middle grid — N by N — is the villain of this whole story."

The thing to burn into the room is the size of that middle grid, S. It is N × N. For a sequence of 8,192 words, that is 67 million numbers — 256 MiB — for a single head in a single layer. And a model has dozens of heads and dozens of layers.

🧠 Metaphor
The scores grid is a giant seating chart for a party of N guests: for every pair of guests, one number saying how interested guest A is in guest B. If N doubles, the chart doesn't double — it quadruples. That quadratic blow-up is why long context is so painful, and why this one grid is worth building a whole kernel to avoid.

Why naive attention is a memory disaster

Now the part students always get wrong, so you must get it right. Everyone assumes the two matrix multiplies are the expensive part. They are not. The expensive part is the little softmax in the middle — not because it does much math, but because of where it forces the data to travel.

Walk through the trips the naive version makes:

  1. Compute S (N × N) and write all of it out to the GPU's main memory (HBM).
  2. Read all of S back to find each row's max and sum, so softmax can normalize it.
  3. Write P (N × N) back out.
  4. Read P back a third time to multiply by V.
The click
Four full trips across the slowest road on the chip, dragging a quarter-gigabyte matrix each way — and for what? The softmax itself does almost no arithmetic: one exponent and a couple of adds per number. So this is nearly pure waiting on memory, with no real math to hide the wait behind. Say the punchline: "The GPU's fast math units sit idle while a matrix that we didn't even want to keep gets shoved back and forth four times." That is the disaster, in one sentence.

This connects straight to the course's spine (the "feeding the cooks" chapter): the score matrix is a memory-bound intermediate. FlashAttention's entire pitch is one word — fusion — glue the two matmuls and the softmax into a single kernel so that S is computed, used, and thrown away without ever touching main memory.

Naive attention as a truck making four pointless warehouse trips with a giant crate it discards at the end. The driving figure rendering · Naive attention as a truck making four pointless warehouse trips with
Naive attention as a truck making four pointless warehouse trips with a giant crate it discards at the end. The driving is the cost.
Naive attention makes four trips over an N-by-N matrix. FlashAttention makes none — the scores are born and die on-chip.figure rendering · Naive attention makes four trips over an N-by-N matrix. FlashAttention
Naive attention makes four trips over an N-by-N matrix. FlashAttention makes none — the scores are born and die on-chip.

The obstacle: softmax wants to see the whole row

Here is why nobody had just done this obvious fusion for years, and it's worth pausing on because it's the intellectual heart of the chapter.

Softmax is global. To normalize one row of scores you need two things that depend on every number in that row: the row's maximum (you subtract it before exponentiating, so the numbers don't overflow) and the row's sum of exponentials (the denominator). You cannot divide by a total you haven't finished adding up. You cannot subtract a max you haven't finished searching for.

So streaming seems impossible. If you only look at one tile of the row at a time, you don't yet know the max or the sum for the whole row.

⚠️ Where students trip
This is the moment students seize up: "You can't do softmax without the whole row — so how can you tile it?" Don't wave it away. Name their objection out loud and promise the resolution: "You're right that you need the whole row eventually. The trick is you can carry a running guess and correct it as new numbers arrive — like updating a running average." Then show them the metaphor before any algebra.

The metaphor: grading a stack of exam papers

This is the metaphor to draw and act out. It makes online softmax feel obvious.

Imagine you're grading a huge stack of exams and you want two things at the end: the highest score in the stack, and a running tally that lets you compute everyone's grade relative to that highest score. But the stack is too tall to lay out on your desk at once. So you grade one small pile at a time.

You keep three sticky notes on your desk:

  • m — the highest score you've seen so far.
  • — a running total (of scores measured relative to that highest-so-far).
  • O — a running blended answer you're building up.

You grade a pile. If the top score in this new pile is higher than your old m, then your old running total was measured against a max that was too low — every number in it is now a little too big. So you rescale: shrink the old total by exactly the right factor, then add in the new pile. Update your sticky note m to the new high. Move to the next pile.

At the very end — and only at the end — you divide by the final total to get everyone's proper grade.

🧠 Metaphor
The whole of FlashAttention's softmax is: "grade the exams one pile at a time, keep three sticky notes (highest-so-far, running total, running blend), and whenever a new pile beats your old high score, gently shrink your old totals to match before adding the new pile in." That "gently shrink the past to match the present" step is the one and only clever line in the algorithm.
The core trick as a human task: you never lay out the whole stack; you keep three sticky notes and fix them up pile by pfigure rendering · The core trick as a human task: you never lay out the whole stack; you
The core trick as a human task: you never lay out the whole stack; you keep three sticky notes and fix them up pile by pile.

Do it by hand: a tiny running softmax

Numbers make it real. Do this slowly on the board. Take one row of scores — just four numbers — and pretend they arrive in two piles of two.

Row of scores: [1, 3, 2, 5]. Piles: [1, 3] then [2, 5].

Pile 1 = [1, 3].

  • New max m = 3.
  • Exponentials relative to the max: exp(1-3)=0.135, exp(3-3)=1.
  • Running sum ℓ = 0.135 + 1 = 1.135.

Pile 2 = [2, 5]. The top of this pile is 5 — bigger than our old max of 3!

  • New max m_new = 5.
  • Shrink factor α = exp(m_old − m_new) = exp(3 − 5) = exp(−2) = 0.135.
  • Rescale the old sum: 0.135 × 1.135 = 0.153.
  • New pile's exponentials: exp(2-5)=0.050, exp(5-5)=1.
  • Running sum ℓ = 0.153 + 0.050 + 1 = 1.203.

Now check it against the honest, all-at-once answer. The true denominator is exp(1-5)+exp(3-5)+exp(2-5)+exp(5-5) = 0.018 + 0.135 + 0.050 + 1 = 1.203. Identical. Not an approximation — exactly the same number, computed without ever holding all four scores at once.

🔢 By hand
Put both totals on the board side by side: the streamed ℓ = 1.203 and the all-at-once ℓ = 1.203. Circle them. Say: "Same answer. We never had the whole row on the desk, and softmax came out exact. That is online softmax — and it's the entire reason we can throw the big grid away." This is the jaw-drop moment; let it land before moving on.
The by-hand proof: streaming the softmax in two piles gives the exact same denominator as computing it all at once.figure rendering · The by-hand proof: streaming the softmax in two piles gives the exact
The by-hand proof: streaming the softmax in two piles gives the exact same denominator as computing it all at once.

The real recurrence, built from the sticky notes

Now generalize the by-hand example into the actual loop. One block of the GPU owns one block of query rows (Q_i, say 64 of them) and keeps its three sticky notes on-chip the whole time. It walks across the key/value blocks one at a time. Here is the whole kernel, in plain form:

# One block owns query rows Q_i (B_r x d), resident on-chip.
m = -inf            # running max   (sticky note 1)
l = 0               # running sum   (sticky note 2)
O = 0               # running blend (sticky note 3), shape (B_r, d)

for j in range(num_k_blocks):        # stream the key/value blocks
    K_j, V_j = load(K, j), load(V, j)     # small tiles -> on-chip
    S = Q_i @ K_j.T                       # (B_r, B_c) scores, stays on-chip
    m_new = maximum(m, rowmax(S))         # did this tile raise the max?
    P = exp(S - m_new)                    # exponentials vs the new max
    alpha = exp(m - m_new)                # the shrink factor
    l = alpha * l + rowsum(P)             # shrink old sum, add new
    O = alpha * O + P @ V_j               # shrink old blend, add new
    m = m_new

O = O / l                                 # divide ONCE, at the very end
store(O_i, O)                             # write O (N x d) -- the only write!

Point at each line and match it to a sticky note. The S tile is small (B_r × B_c, say 64×64), lives on-chip, and is consumed instantly — it is never written to main memory. The alpha line is the "shrink the past to match the present" move — the only line that would look strange to someone who's only written a normal softmax. And the division by happens exactly once, at the end, because only then do we know the true total.

The recurrence in three panels: scores are tiled, each new key block may raise the max, and we rescale the running sum afigure rendering · The recurrence in three panels: scores are tiled, each new key block m
The recurrence in three panels: scores are tiled, each new key block may raise the max, and we rescale the running sum and output before folding it in.

What we actually bought

Be honest with students about what the win is and isn't. FlashAttention does the same matrix math as naive attention — in fact a few extra multiplies for the rescales. It does not save FLOPs. The entire win is in memory traffic.

Naive attention moves the N × N matrix across main memory four times: traffic grows like . FlashAttention moves only Q, K, V, and O — each N × d — once: traffic grows like N × d. The ratio is roughly N / d. At N = 8192 and d = 128, that is about 64× less traffic across the slowest road on the chip.

The click
The number to write huge on the board: at 8k context, FlashAttention moves roughly 64× fewer bytes across main memory than naive attention — for the exact same answer. And because the memory bottleneck is gone, the fast math units finally run near their ceiling (an H100's tensor cores can do ~989 trillion operations per second) instead of idling. A memory-bound layer becomes a compute-bound one. That flip is the whole prize.
The ledger: traffic drops from roughly 4N-squared to 4Nd, dozens of times fewer bytes for long sequences.figure rendering · The ledger: traffic drops from roughly 4N-squared to 4Nd, dozens of ti
The ledger: traffic drops from roughly 4N-squared to 4Nd, dozens of times fewer bytes for long sequences.

1 The multiplier depends heavily on N. At short sequence lengths the term is small and fusion barely helps; the gain grows with context length. That's exactly why FlashAttention arrived the same moment models started chasing long context — the two needs met.

FlashAttention-2, in one breath

You don't need to teach FA2 in depth to a first-time audience, but you should be able to say what it added, because someone will ask. FA1 answered where the bytes live. FA2 answered how the work is split up — and it roughly doubled the speed again, taking the forward pass from around 35% of peak to about 70%.

The one-liner for each of its three moves:

  • Do the slow softmax work less often. The exp and rescale run on the GPU's slow math units (~16× slower per operation than the tensor cores). FA2 keeps the output un-normalized and divides by just once at the end instead of every inner step. Fewer slow instructions blocking the fast matmuls.
  • Parallelize over sequence length. FA1 gave one block to each (batch, head) pair. With batch=1 long-context inference — 32 heads on a 132-SM machine — three-quarters of the GPU sits idle. FA2 hands each query block its own SM, filling the machine.
  • Skip the masked upper triangle. In causal attention each word only looks backward, so half the score grid is thrown away anyway. FA2, with tiles as explicit loop indices, simply never computes it — roughly 2× less matmul work for free.
🏭 In production today
This is not a paper on a shelf. FlashAttention is in essentially every serving stack shipping today — vLLM, the reference transformer libraries, every long-context model you've used. When DeepSeek or Meta serve a model to millions, the fused-attention kernel is a direct line item on the electricity bill and the GPU count. Your students are learning the exact kernel that made long context affordable — and the Hopper rewrite (with TMA bulk copies and wgmma async tensor-core instructions) is where the frontier keeps pushing on H100 and B200 today.

Teaching notes: the board plan

Here's the order that works. Don't deviate — each step sets up the next.

🎓 Teaching note
Board sequence: (1) Draw naive attention's three lines and the big N x N S in the middle. (2) Draw the four round-trips to main memory and say "four trips, a quarter-gig each, for a matrix we don't even keep." That's the problem. (3) State the obstacle: "softmax needs the whole row" — let it feel impossible for a beat. (4) Introduce the grading-exams metaphor with three sticky notes and act it out — pull piles off a stack. (5) ONLY THEN do the four-number by-hand example and show the streamed total equals the all-at-once total. That equality is the payoff; pause on it. (6) Generalize to the recurrence. (7) Write the 64x traffic number huge and tie to production. Metaphor and by-hand number come BEFORE the algebra, always.
▶️ Live demo
The one live demo: in a notebook, run naive attention and FlashAttention (torch's scaled_dot_product_attention picks the fused kernel automatically) on a long sequence, and print peak GPU memory for each. Naive will allocate the giant N × N buffer; fused won't. Watch the peak-memory number drop by orders of magnitude while the outputs match to floating-point noise. Same answer, a fraction of the memory — that's the whole chapter in one cell.
⚠️ Where students trip
Two confusions to pre-empt. First: "Isn't the streamed softmax just an approximation?" No — walk them back to the by-hand example where both totals were 1.203 exactly. It is algebraically identical. Second: "So we made the math faster?" No — same math, even a few extra multiplies. We made the memory movement smaller. Keep hammering: FlashAttention is a logistics win, not an arithmetic one. That distinction is the mark of a student who truly gets it.

You can now teach

  • Why naive attention is a memory disaster — the N × N score grid written and re-read four times, quadratic in N, while the softmax does almost no real math.
  • The obstacle — softmax needs the whole row (its max and its sum) — and why that makes fusion look impossible.
  • Online softmax as grading exams pile by pile — three sticky notes (running max, running sum, running blend) and the one clever "shrink the past to match the present" rescale.
  • The by-hand proof — a four-number, two-pile softmax that comes out exactly equal to the all-at-once answer without ever holding the whole row.
  • What was actually bought — same FLOPs, roughly N/d (~64×) less main-memory traffic, flipping a memory-bound layer into a compute-bound one.
  • The production stakes and FA2 in one breath — where FlashAttention runs today, and the three FA2 moves (defer the divide, parallelize over sequence length, skip the causal upper triangle).