/ inference
04 Kernels for Inference
Where the GEMM skills meet real LLMs. Fusion, softmax, attention, FlashAttention, the KV cache and the quantized kernels that serve tokens at scale.
→
Prefill vs decode: two different machines
Why the compute-bound prefill (GEMM) and the memory-bound decode (GEMV) demand completely different kernels.
→
Operator fusion
The highest-leverage inference optimization: stop round-tripping tensors through HBM between elementwise ops.
→
Softmax from scratch (and online)
The numerically-stable, single-pass, online softmax — the trick that makes FlashAttention possible.
→
RMSNorm & LayerNorm kernels
Reduction + normalize + affine, fused into one pass — the workhorse normalization of modern LLMs.
→
Attention, the naive way
QKᵀ → softmax → V as three kernels, its O(N²) HBM traffic, and why that is the problem to solve.
→
FlashAttention I: tiling attentionFA
Fusing the whole attention into one kernel with online softmax so the N×N scores never touch HBM.
→
FlashAttention II: better work partitioningFA2
Rebalancing work across warps and the sequence dimension for a ~2× step over FA1.
→
FlashAttention III: Hopper & asyncFA3
Warp specialization, TMA and FP8 on Hopper — attention at the hardware's speed of light.
→
The KV cache & PagedAttention
Why decode is memory-bound, how paging fixes fragmentation, and the kernel that reads a paged cache.
→
Quantization kernels: FP8, INT4, W4A16FP8
Dequantize-in-the-kernel, scale handling, and the low-precision matmuls that halve or quarter memory traffic.
→
The SwiGLU kernel
Gate × up × SiLU fused — a small kernel that shows up in every transformer MLP and every kernel benchmark.
→
Batched decode: the GEMV problem
Serving many sequences one token at a time: skinny matmuls, memory-bound reality, and how batching helps.
