Vizuara Kernel Engineering
Mentor Handbook · 04 Teaching Inference Kernels

Teaching fusion: stop driving to the warehouse

By the end of this chapter you'll be able to stand at a whiteboard and teach operator fusion — the single highest-leverage optimization in AI inference — so plainly that a student who has never heard the word "kernel" will understand why merging a few tiny operations together can make a model run twice as fast, without changing a single number the model computes. You don't need to know any CUDA to teach this well. You need one warehouse, one honest count of trips, and the discipline to keep saying: we didn't do less math — we drove to the warehouse fewer times.

This is the trick FlashAttention is built on. Own it.

The one-sentence answer

Between the big, expensive matrix multiplies in a neural network, there is always a chain of tiny, cheap operations — add a bias, apply an activation like GELU, add the residual, normalize. Each of these is trivial arithmetic. The naive way to run them is one at a time, and each one reads the entire data tensor out of the GPU's slow main memory and writes it straight back. That back-and-forth is the whole cost. Fusion means: do the whole chain of tiny operations in one go, while the data is sitting in the fast on-chip registers, and only touch the slow memory at the very start and the very end.

Same math. Far fewer trips. Often a clean 2× speedup for free.

🧠 Metaphor
The warehouse. Imagine you're assembling a gift box, and you have four tiny jobs to do to each item: add a ribbon, add a label, add tissue paper, add a bow. The item lives in a giant warehouse across town (that's HBM — the GPU's big, slow main memory). The naive worker drives to the warehouse, picks up the item, ties the ribbon, drives it back to the warehouse and drops it off — then immediately drives back out, picks up the same item, adds the label, drops it off again... four separate round-trips for four tiny jobs. The fused worker drives out once, does all four jobs on the loading dock while holding the item, and drives it back once. The ribbon-tying didn't get faster. The driving got 4× cheaper. On a GPU, the driving is the whole bill.
The core metaphor: unfused ops make a full round-trip to slow memory for each tiny job; fusion does them all in one tripfigure rendering · The core metaphor: unfused ops make a full round-trip to slow memory f
The core metaphor: unfused ops make a full round-trip to slow memory for each tiny job; fusion does them all in one trip.
🎓 Teaching note
Draw the warehouse and the winding road first, and act out the driving. Physically walk back and forth across the front of the room four times, sighing, for the unfused version — then walk across once, briskly, for the fused version. The comedy of all that pointless walking is what makes it land. Only after the picture is drawn do you write the word "HBM" and say it means "the GPU's big main memory — huge, but slow and far from the math." Never lead with the acronym.

Where the tiny operations come from

Students need to believe these little ops are everywhere, not a contrived example. A single transformer layer isn't just matmuls — it's matmuls wrapped in small element-wise operations. After a linear layer you add a bias. Then an activation function (GELU, SiLU). Then the residual connection. Before the next matmul you normalize (RMSNorm). Each of those touches every single number in the activation tensor, does almost no arithmetic to it, and — done naively — drags the whole tensor through slow memory to do it.

🔢 By hand
Say it as a chain out loud: "y = gelu(bias + x). That's two tiny ops — an add and a GELU — sitting between two big matmuls. Looks innocent. Written as two separate library calls, watch what the GPU actually does." Then walk the four trips on the board (next section). The gap between how tiny the math looks and how expensive the movement is — that's the whole lesson.

Count the trips — the by-hand number

This is the heart of the chapter, and it's just counting. Take y = gelu(bias + x) written as two separate operations, one kernel each. (A kernel is one launch of work on the GPU — one "trip.") Track how many times the full tensor crosses the slow memory.

  • Kernel 1 (bias add): reads the whole tensor from memory (trip 1), adds the bias, writes the whole tensor back (trip 2).
  • Kernel 2 (GELU): reads that same tensor back from memory (trip 3), applies GELU, writes it out again (trip 4).

That's four full passes over the tensor. Now look hard at trips 2 and 3 — the write at the end of kernel 1 and the read at the start of kernel 2. The intermediate result bias + x never needed to be seen by anyone except the very next op. We paid to store it in the slowest memory on the chip and then immediately paid again to fetch it right back. Those two trips exist for exactly one reason: the two ops live in separate kernels.

Fuse them into one kernel and it becomes: read once (trip 1), do the add and the GELU while the data sits in registers, write once (trip 2). Two passes instead of four. Half the memory traffic. About 2× faster.

The click
Here's the number that makes the room go quiet: x.cos().cos() takes almost exactly the same wall-clock time as x.cos() alone. Two cosines, same time as one. Why? Both read the tensor once and write it once — and the second cosine happens for free on data that's already in a register. Then push it further: "This is why every activation function costs about the same. GELU does far more arithmetic than ReLU — and they benchmark identically. Neither is limited by its math. Both are limited by the two trips to slow memory that bracket them." Students never forget this one.
The unfused chain pays for the intermediate tensor twice; fusion deletes both of those trips.figure rendering · The unfused chain pays for the intermediate tensor twice; fusion delet
The unfused chain pays for the intermediate tensor twice; fusion deletes both of those trips.

Why fusion saves movement, not math

This is the subtle point that separates a mentor who really gets it from one who's reciting. Fusion does not do less arithmetic. The fused kernel runs the exact same add and the exact same GELU as the two unfused kernels. What it removes is bytes moved through slow memory. And for these tiny element-wise ops, bytes moved is the only thing that ever mattered.

Here's the honest reason, and it's worth putting one real number on the board. On an NVIDIA H100 GPU, the math units can chew through roughly 989 trillion operations per second, but the pipe from slow memory only delivers about 3.35 trillion bytes per second. So the chip is starving for data: it can do about 295 math operations for every single byte it manages to fetch before the math even becomes the bottleneck. A bias-add does one operation per number while moving 8 bytes (a read and a write). That's an intensity of about 0.1 — roughly three thousand times below what the chip wants. These ops are pure memory movement with a speck of math stapled on.

⚠️ Where students trip
The number-one confusion: a student thinks "if fusion is faster, we must be doing less work / less math." Gently correct it every time: "We do the identical math — every add, every GELU. We just stopped shuttling the data to the slow warehouse and back between them. Fusion is a logistics win, not an arithmetic win." Tie it straight back to the CPU-vs-GPU chapter: the cooks were never the bottleneck; feeding them was. Fusion is one of the best ways to keep them fed.
🏭 In production today
This isn't academic. When DeepSeek, Meta, or OpenAI serve a model to millions of users, fusion is one of the first things their kernels do, on every H100 and B200 in the rack. The reason torch.compile exists in PyTorch is largely to find these chains automatically and emit one fused kernel instead of ten. Every hardware generation, compute grows faster than memory bandwidth — so the memory pipe gets relatively slower, and more of the network falls into the region where fusion is the win. Fusion is the optimization that keeps paying as the hardware improves.
Fusing pointwise ops keeps you memory-bound; the win is halving the byte traffic, not crossing into the math-bound zone.figure rendering · Fusing pointwise ops keeps you memory-bound; the win is halving the by
Fusing pointwise ops keeps you memory-bound; the win is halving the byte traffic, not crossing into the math-bound zone.

The best fusion of all: glue the tiny ops onto the matmul

Now the move that matters most in a real transformer. The biggest win isn't fusing two tiny ops together — it's welding the tiny ops onto the big matmul that produced the data in the first place. This is called a fused epilogue, and it's why serious GPU math libraries expose an "epilogue" hook at all.

Think about what a matmul already does. It computes its answer one tile at a time, and it builds up each output tile in the fast on-chip registers. The very last thing it does is write that finished tile from registers out to slow memory. That final write is unavoidable — the answer has to land somewhere. But the naive linear layer then launches a whole separate kernel that reads the answer back, adds the bias, and writes it again. You just paid two extra trips to add a bias — when the matmul had the answer sitting right there in registers, one instruction away from adding the bias for free.

🎤 Say this at the board
"The matmul is holding your answer in its hands, in the fast registers, about to set it down in the warehouse. The unfused version sets it down, drives away, drives back, picks it up, adds the bias, sets it down again. The fused epilogue says: before you set it down, just add the bias while it's in your hands. One extra instruction, zero extra trips. The tile touches slow memory exactly once, already finished."
A fused epilogue does the bias and activation while the output tile is still in registers, so the result touches slow mefigure rendering · A fused epilogue does the bias and activation while the output tile is
A fused epilogue does the bias and activation while the output tile is still in registers, so the result touches slow memory exactly once.

The code change is almost nothing — after the matmul finishes a tile, instead of just writing it, you transform it first:

// acc[i][j] holds the finished output tile, in registers.
float v = acc[i][j] + bias[col + j];   // fused bias — free, data already here
v = gelu(v);                           // fused activation — also free
C[(row + i) * N + col + j] = v;        // the ONE write we were always going to do

No extra kernel launch, no intermediate tensor, no second read. And there's a mirror-image trick on the input side: an RMSNorm that would normally run as its own kernel before a matmul can be folded into the matmul's loading stage, so the normalized activation never gets written to slow memory at all. Between the fused-in norm on the read side and the fused-out bias-and-activation on the write side, a whole transformer sub-block can collapse from five or six kernels down to essentially "one matmul with decorations."

🏭 In production today
This is exactly the idea behind FlashAttention, the most famous kernel in modern AI. Attention computes softmax(Q·Kᵀ)·V, and the naive version writes a giant N × N attention matrix out to slow memory and reads it back. FlashAttention refuses to ever write that intermediate — it fuses the whole chain and keeps the running result on-chip. Same principle as our warehouse, applied to the most expensive intermediate in the transformer. The entire industry adopted it within months. When your students understand fusion, they understand the beating heart of FlashAttention.

How to see the win before you write a line of code

Teach students the discipline, not just the trick: predict, then measure. Before fusing anything, count the trips. Write down how many times each byte of the activation crosses slow memory in the unfused version, then how many in the fused version. The ratio of those two counts is your predicted speedup — because these kernels are memory-bound, wall-clock time is very nearly proportional to bytes moved.

▶️ Live demo
The one live demo for this block. In PyTorch, time y = x.cos() and then y = x.cos().cos() on a large tensor. They come back nearly identical — the second cosine is free. Then time a linear → bias → gelu chain written as three separate operations, versus the same thing under torch.compile (which fuses it). Predict the speedup by counting passes first (roughly six passes down to two → about 3×), write your prediction on the board, then run it. When the measured number matches your predicted trip-ratio, the room believes you. When it falls a little short, that's a teaching gift — "the compiler probably fused fewer ops than we hoped; let's go find out which."

When the measured speedup matches your trip-count ratio, the student understands the kernel. When it doesn't, they've found something worth knowing.

You can now teach

  • Operator fusion as the warehouse metaphor: tiny ops between the matmuls each make a pointless round-trip to slow memory, and fusion does them all in one trip.
  • Counting the trips on gelu(bias + x) by hand — four passes unfused, two passes fused — and which two trips fusion deletes and why they existed.
  • Why fusion saves movement, not math ("same math, half the trips"), grounded in the H100's ~295-ops-per-byte imbalance and the x.cos().cos() jaw-dropper.
  • The fused epilogue: gluing the bias and activation onto the matmul while the output tile is still in registers, so it touches slow memory exactly once — plus the read-side RMSNorm mirror.
  • The production link: torch.compile fuses these automatically, and FlashAttention is this exact idea applied to attention's giant intermediate.
  • The predict-then-measure discipline: count passes, predict the speedup from the ratio, then confirm it — and treat any gap as the interesting part.