Attention Guide
This is compact guide to the main attention forms, the kernel families that power them, and benchmarking for how they behave across prefill / append / decode workloads.
Numbers and code referenced come from the accompanying repo.
History (Algorithms & Variants)
Core Attention Forms
Attention started with the multi-head formulation and then subsequently evolved to adapt to enable more efficient work, mostly around long context and throughout / latency. The lineage of the most important forms is as follows:
Notation: \( X\in\mathbb{R}^{L\times d_{\text{model}}} \) (sequence length \( L \), model dim \( d_{\text{model}} \)); heads \( h \); per-head dims \( d_k = d_v = d_{\text{model}}/h \).
$$ \mathrm{Attn}(Q,K,V)=\operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V $$
- MHA (Multi-Head Attention)
This is the original form, introduced in the Attention Is All You Need paper and later refined for decoder only. It contains a standard single projection for each Q/K/V head.
Computes:
$$ Q_i = X W_i^{Q}, \quad K_i = X W_i^{K},\quad V_i = X W_i^{V},\quad i=1,\dots,h $$
$$ H_i = \mathrm{Attn}(Q_i,K_i,V_i) $$
$$ Y = \mathrm{Concat}(H_1,\dots,H_h)\,W^{O} $$
KV-cache: is \( h \times T \times d_k \).
- MQA (Multi-Query Attention)
This introduces shared K/V projections for all heads. The benefit is cutting KV cache size significantly, which is great for memory-bandwidth-bound decode at batch-1 and longer context workloads. It computes:
$$ Q_i = X W_i^{Q},\quad K_\star = X W^{K},\quad V_\star = X W^{V} $$
$$ H_i = \mathrm{Attn}(Q_i,K_\star,V_\star) $$
$$ Y = \mathrm{Concat}(H_1,\dots,H_h),W^{O} $$
KV-cache: is \( T \times d_k \).
- GQA (Group-Query Attention)
Middle ground: K/V shared per group of Q heads (group size ( g )). The Trade-off between MHA quality and MQA efficiency. It Computes:
$$ Q_i = X W_i^{Q},\quad K_{(j)} = X W^{K}_{(j)},\quad V_{(j)} = X W^{V}_{(j)},\quad j=1,\dots,G $$
$$ H_i = \mathrm{Attn}\big(Q_i,,K_{(j(i))},,V_{(j(i))}\big) $$
$$ Y = \mathrm{Concat}(H_1,\dots,H_h),W^{O} $$
KV-cache: is \( \frac{h}{g} \times T \times d_k \).
- MLA (Multi‑Latent Attention) This form was introduced by DeepSeek for their V2 architecture. Its benefit is that it reduces KV-cache size even further while preserving query expressivity. The assumption that it rests upon is that the key-value heads are redundant in the sense that all the semantic information can be stored in lower dimensionality. It computes:
$$ Q_i = X W_i^{Q} U_{Q} \in \mathbb{R}^{T\times r},\qquad K_\ell = X U_{K} \in \mathbb{R}^{T\times r},\qquad V_\ell = X U_{V} \in \mathbb{R}^{T\times r} $$
$$ H_i = \operatorname{Attn}(Q_i, K_\ell, V_\ell) \in \mathbb{R}^{T\times r} $$
$$ Y = \operatorname{Concat}(H_1,\dots,H_h),W^{O} $$
Where: \( W_i^{Q}\in\mathbb{R}^{d_{\text{model}}\times d_k} \), \( U_Q\in\mathbb{R}^{d_k\times r} \) (reduces queries to latent dim), \( U_K,\ U_V\in\mathbb{R}^{d_{\text{model}}\times r} \) (project keys/values to latent dim).
KV-cache: \( T \times r \) (or \( \frac{h}{g} \times T \times r \) if latents are shared per group of \( g \) heads).
Here are tightened, more substantial versions of the two sections:
Kernel Variants
FlashAttention (v1→v3)
- Flash Attention v1 introduced a new approach of doing an exact attention computation without materialising an entire $L \times L$ matrix in memory. At its core it defines smaller I/O aware tiles for materialisation and does the computation on the tiles online, using up roughly linear memory.
- v2 of Flash Attention introduces mechanical improvements to the original algorithm, such as removing repeated work for a considerable speedup
- v3 of FlashAttention makes use of Hopper-specific features for better distribution of work — it achieved ~75% of max FLOPs utilisation for Hopper (compared to 35% for Flash Attention v2)
Flash Decoding
- Flash Decoding optimizes attention for the Q≪K regime (decode / small-q_chunk append) by splitting work along KV length, maintaining online (log-sum-exp) softmax across splits, and reducing per-token kernel-launch overhead with fused, memory-efficient reads; designed for memory-bound decoding to boost batch-1/small-batch throughput.
Paged Attention
- Paged Attention first introduced in vLLM is the de-facto standard for inference engines with high throughput requirements
- It's a computational technique which stores and allocates KV cache memory for each request in a set of fixed size pages similar to how the operating system in a computer allocates memory for processes. This means that each request is dynamically allocated memory as it requires it
- Before paged attention, the standard used to be that you allocate memory equal to the maximum amount (context length) that could be needed and each request may or may not fill that mount
- This leads to external fragmentation because there is no guarantee for where the memory region for each request will be and also leads to internal fragmentation because each request is likely to not use up all of the context
- Paged attention eliminates external fragmentation entirely because memory regions are allocated as fixed sized pages where each page normally is enough to hold several slots (i.e. the KV-cache requirements for a certain number of tokens) and minimises internal fragmentation, upper bounding it to the product of KV_cache size per token times the number of slots in a page times the number of active requests
Other Variants
Sliding-Window Attention (SWA) — Restricts each query to attend only the most recent W tokens (banded causal mask), dropping cost from O(L²) to O(L·W) for long contexts while keeping KV writes linear in L — proposed for better long context capabilities
Ring Attention — Shards sequence/KV across devices and circulates partial results around a device ring/mesh, trading HBM for interconnect bandwidth to extend effective context or fit larger models.
Workload Phases
Usually there are 3 workloads types when discussing LLM inference, each one having slightly different computational requitemens:
Prefill
- initial phase, create KV cache for a prompt (Q=KV=L)
- High arithmetic intensity, so compute bound
Append
- Forward pass for q_chunk tokens at a time (1≪q≪L).
Decode
- One token at a time autoregressive generation (Q=1)
- Memory-bandwidth bound, KV reads dominate
This flash-infer blogpost does a great job of analysing the requirents of different workloads.
Frameworks
There's quite a lot of ways to compute attention within a single framework like PyTorch, and there are many dedicated framework which evolve quite rapidly.
In our benchmarking code we test the following approaches:
- Barebones PyTorch (naïve) – use non-fused torch operators and compute the attention op by op
- torch.compile – uses the torch.compile feature which in this case helps with producing fused eager code and better underlying kernels
- PyTorch SDPA – dedicated torch scaled-dot-product attention op
- PyTorch FlexAttention – pattern‑based attention API
- Transformer Engine (TE) – NVIDIA kernels which were introduced to leverage the Hopper architecture to the maximum (e.g. efficient use of tensor cores)
- flash‑attn – the og FlashAttention library
- flash‑infer – a really cool optimised attention library
These frameworks evolve very quickly as workloads and state of the art techniques change so it's not trivial to conclude which one is best for a particular use-case and the respective performance ordering might change rapidly when a new data type, library version, or workload is introduced.
Some frameworks like torch SDPA even have a routing mechanism where they try to dispatch operations to the most efficient available backend.
Framework modules can be found in the repo here.
Testing Framework (what we measure)
Right now we test on bfloat16 and we use GQA variants.
We use the following architectures as testing templates: Llama‑3.1‑8B, Qwen3‑30B‑A3B, and Qwen3‑235B‑A3B.
Each experiment is defined with its batch size, sequence lengths, and architecture shapes like this:
@dataclass
class ExperimentSpec:
workload_type: WorkloadType
batch_size: int
q_seq_len: int
kv_seq_len: int
num_q_heads: int
num_kv_heads: int
head_dim: int
The repository contains a top level run.py
script which measures the performance of a given framework on a set of Experiments.
We report latency, throughput, and TFLOPs/s
Results (high‑level takeaways)
For all hardware we use bfloat16
Here are the results for batch_size=256, seq_len=4096, q_chunk=128 (for append); kernel‑only p95 latency:
H200
Framework | Llama 3-8b-prefill | Llama 3-8b-append | Llama 3-8b-decode | Qwen3-30b-prefill | Qwen3-30b-append | Qwen3-30b-decode | Qwen3-235b-prefill | Qwen3-235b-append | Qwen3-235b-decode |
---|---|---|---|---|---|---|---|---|---|
flash attn | 188.00 | 6.59 | 1.61 | 188.42 | 6.67 | 0.90 | 377.61 | 13.36 | 0.85 |
flashinfer | 139.01 | 7.23 | 1.62 | 141.16 | 7.24 | 1.52 | 278.40 | 12.63 | ERROR |
te | 117.70 | 3.74 | 1.19 | 117.66 | 3.63 | 0.71 | 239.45 | 7.55 | 0.72 |
torch flex | OOM | 141.33 | 30.76 | OOM | 140.86 | 30.30 | OOM | OOM | 58.97 |
torch compile | 217.03 | 14.93 | 6.54 | 213.95 | 13.06 | 6.23 | 432.18 | 25.57 | 9.95 |
torch naive | OOM | 43.07 | 21.34 | OOM | 42.58 | 20.93 | OOM | OOM | 41.19 |
torch sdpa | OOM | 23.52 | 23.38 | OOM | 23.11 | 22.95 | OOM | 45.90 | 45.73 |