kallyaleksiev

This is compact guide to the main attention forms, the kernel families that power them, and benchmarking for how they behave across prefill / append / decode workloads.

Numbers and code referenced come from the accompanying repo.

History (Algorithms & Variants)

Core Attention Forms

Attention started with the multi-head formulation and then subsequently evolved to adapt to enable more efficient work, mostly around long context and throughout / latency. The lineage of the most important forms is as follows:

Notation: \( X\in\mathbb{R}^{L\times d_{\text{model}}} \) (sequence length \( L \), model dim \( d_{\text{model}} \)); heads \( h \); per-head dims \( d_k = d_v = d_{\text{model}}/h \).

$$ \mathrm{Attn}(Q,K,V)=\operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V $$

  • MHA (Multi-Head Attention)

This is the original form, introduced in the Attention Is All You Need paper and later refined for decoder only. It contains a standard single projection for each Q/K/V head.

Computes:

$$ Q_i = X W_i^{Q}, \quad K_i = X W_i^{K},\quad V_i = X W_i^{V},\quad i=1,\dots,h $$

$$ H_i = \mathrm{Attn}(Q_i,K_i,V_i) $$

$$ Y = \mathrm{Concat}(H_1,\dots,H_h)\,W^{O} $$

KV-cache: is \( h \times T \times d_k \).

  • MQA (Multi-Query Attention)

This introduces shared K/V projections for all heads. The benefit is cutting KV cache size significantly, which is great for memory-bandwidth-bound decode at batch-1 and longer context workloads. It computes:

$$ Q_i = X W_i^{Q},\quad K_\star = X W^{K},\quad V_\star = X W^{V} $$

$$ H_i = \mathrm{Attn}(Q_i,K_\star,V_\star) $$

$$ Y = \mathrm{Concat}(H_1,\dots,H_h),W^{O} $$

KV-cache: is \( T \times d_k \).

  • GQA (Group-Query Attention)

Middle ground: K/V shared per group of Q heads (group size ( g )). The Trade-off between MHA quality and MQA efficiency. It Computes:

$$ Q_i = X W_i^{Q},\quad K_{(j)} = X W^{K}_{(j)},\quad V_{(j)} = X W^{V}_{(j)},\quad j=1,\dots,G $$

$$ H_i = \mathrm{Attn}\big(Q_i,,K_{(j(i))},,V_{(j(i))}\big) $$

$$ Y = \mathrm{Concat}(H_1,\dots,H_h),W^{O} $$

KV-cache: is \( \frac{h}{g} \times T \times d_k \).

  • MLA (Multi‑Latent Attention) This form was introduced by DeepSeek for their V2 architecture. Its benefit is that it reduces KV-cache size even further while preserving query expressivity. The assumption that it rests upon is that the key-value heads are redundant in the sense that all the semantic information can be stored in lower dimensionality. It computes:

$$ Q_i = X W_i^{Q} U_{Q} \in \mathbb{R}^{T\times r},\qquad K_\ell = X U_{K} \in \mathbb{R}^{T\times r},\qquad V_\ell = X U_{V} \in \mathbb{R}^{T\times r} $$

$$ H_i = \operatorname{Attn}(Q_i, K_\ell, V_\ell) \in \mathbb{R}^{T\times r} $$

$$ Y = \operatorname{Concat}(H_1,\dots,H_h),W^{O} $$

Where: \( W_i^{Q}\in\mathbb{R}^{d_{\text{model}}\times d_k} \), \( U_Q\in\mathbb{R}^{d_k\times r} \) (reduces queries to latent dim), \( U_K,\ U_V\in\mathbb{R}^{d_{\text{model}}\times r} \) (project keys/values to latent dim).

KV-cache: \( T \times r \) (or \( \frac{h}{g} \times T \times r \) if latents are shared per group of \( g \) heads).

Here are tightened, more substantial versions of the two sections:

Kernel Variants

  • FlashAttention (v1→v3)

    • Flash Attention v1 introduced a new approach of doing an exact attention computation without materialising an entire $L \times L$ matrix in memory. At its core it defines smaller I/O aware tiles for materialisation and does the computation on the tiles online, using up roughly linear memory.
    • v2 of Flash Attention introduces mechanical improvements to the original algorithm, such as removing repeated work for a considerable speedup
    • v3 of FlashAttention makes use of Hopper-specific features for better distribution of work — it achieved ~75% of max FLOPs utilisation for Hopper (compared to 35% for Flash Attention v2)
  • Flash Decoding

    • Flash Decoding optimizes attention for the Q≪K regime (decode / small-q_chunk append) by splitting work along KV length, maintaining online (log-sum-exp) softmax across splits, and reducing per-token kernel-launch overhead with fused, memory-efficient reads; designed for memory-bound decoding to boost batch-1/small-batch throughput.
  • Paged Attention

    • Paged Attention first introduced in vLLM is the de-facto standard for inference engines with high throughput requirements
    • It's a computational technique which stores and allocates KV cache memory for each request in a set of fixed size pages similar to how the operating system in a computer allocates memory for processes. This means that each request is dynamically allocated memory as it requires it
    • Before paged attention, the standard used to be that you allocate memory equal to the maximum amount (context length) that could be needed and each request may or may not fill that mount
    • This leads to external fragmentation because there is no guarantee for where the memory region for each request will be and also leads to internal fragmentation because each request is likely to not use up all of the context
    • Paged attention eliminates external fragmentation entirely because memory regions are allocated as fixed sized pages where each page normally is enough to hold several slots (i.e. the KV-cache requirements for a certain number of tokens) and minimises internal fragmentation, upper bounding it to the product of KV_cache size per token times the number of slots in a page times the number of active requests

Other Variants

  • Sliding-Window Attention (SWA) — Restricts each query to attend only the most recent W tokens (banded causal mask), dropping cost from O(L²) to O(L·W) for long contexts while keeping KV writes linear in L — proposed for better long context capabilities

  • Ring Attention — Shards sequence/KV across devices and circulates partial results around a device ring/mesh, trading HBM for interconnect bandwidth to extend effective context or fit larger models.

Workload Phases

Usually there are 3 workloads types when discussing LLM inference, each one having slightly different computational requitemens:

  • Prefill

    • initial phase, create KV cache for a prompt (Q=KV=L)
    • High arithmetic intensity, so compute bound
  • Append

    • Forward pass for q_chunk tokens at a time (1≪q≪L).
  • Decode

    • One token at a time autoregressive generation (Q=1)
    • Memory-bandwidth bound, KV reads dominate

This flash-infer blogpost does a great job of analysing the requirents of different workloads.

Frameworks

There's quite a lot of ways to compute attention within a single framework like PyTorch, and there are many dedicated framework which evolve quite rapidly.

In our benchmarking code we test the following approaches:

  • Barebones PyTorch (naïve) – use non-fused torch operators and compute the attention op by op
  • torch.compile – uses the torch.compile feature which in this case helps with producing fused eager code and better underlying kernels
  • PyTorch SDPA – dedicated torch scaled-dot-product attention op
  • PyTorch FlexAttention – pattern‑based attention API
  • Transformer Engine (TE)NVIDIA kernels which were introduced to leverage the Hopper architecture to the maximum (e.g. efficient use of tensor cores)
  • flash‑attn – the og FlashAttention library
  • flash‑infer – a really cool optimised attention library

These frameworks evolve very quickly as workloads and state of the art techniques change so it's not trivial to conclude which one is best for a particular use-case and the respective performance ordering might change rapidly when a new data type, library version, or workload is introduced.

Some frameworks like torch SDPA even have a routing mechanism where they try to dispatch operations to the most efficient available backend.

Framework modules can be found in the repo here.

Testing Framework (what we measure)

Right now we test on bfloat16 and we use GQA variants.

We use the following architectures as testing templates: Llama‑3.1‑8B, Qwen3‑30B‑A3B, and Qwen3‑235B‑A3B.

Each experiment is defined with its batch size, sequence lengths, and architecture shapes like this:

@dataclass
class ExperimentSpec:
    workload_type: WorkloadType

    batch_size: int

    q_seq_len: int
    kv_seq_len: int

    num_q_heads: int
    num_kv_heads: int
    head_dim: int

The repository contains a top level run.py script which measures the performance of a given framework on a set of Experiments.

We report latency, throughput, and TFLOPs/s

Results (high‑level takeaways)

For all hardware we use bfloat16

Here are the results for batch_size=256, seq_len=4096, q_chunk=128 (for append); kernel‑only p95 latency:

H200

Framework Llama 3-8b-prefill Llama 3-8b-append Llama 3-8b-decode Qwen3-30b-prefill Qwen3-30b-append Qwen3-30b-decode Qwen3-235b-prefill Qwen3-235b-append Qwen3-235b-decode
flash attn 188.00 6.59 1.61 188.42 6.67 0.90 377.61 13.36 0.85
flashinfer 139.01 7.23 1.62 141.16 7.24 1.52 278.40 12.63 ERROR
te 117.70 3.74 1.19 117.66 3.63 0.71 239.45 7.55 0.72
torch flex OOM 141.33 30.76 OOM 140.86 30.30 OOM OOM 58.97
torch compile 217.03 14.93 6.54 213.95 13.06 6.23 432.18 25.57 9.95
torch naive OOM 43.07 21.34 OOM 42.58 20.93 OOM OOM 41.19
torch sdpa OOM 23.52 23.38 OOM 23.11 22.95 OOM 45.90 45.73