Attention Guide

This is compact guide to the main attention forms, the kernel families that power them, and benchmarking for how they behave across prefill / append / decode workloads.

Numbers and code referenced come from the accompanying repo.

History (Algorithms & Variants)

Core Attention Forms

Attention started with the multi-head formulation and then subsequently evolved to adapt to enable more efficient work, mostly around long context and throughout / latency. The lineage of the most important forms is as follows:

Notation: \( X\in\mathbb{R}^{L\times d_{\text{model}}} \) (sequence length \( L \), model dim \( d_{\text{model}} \)); heads \( h \); per-head dims \( d_k = d_v = d_{\text{model}}/h \).

$$ \mathrm{Attn}(Q,K,V)=\operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V $$

This is the original form, introduced in the Attention Is All You Need paper and later refined for decoder only. It contains a standard single projection for each Q/K/V head.

Computes:

$$ Q_i = X W_i^{Q}, \quad K_i = X W_i^{K},\quad V_i = X W_i^{V},\quad i=1,\dots,h $$

$$ H_i = \mathrm{Attn}(Q_i,K_i,V_i) $$

$$ Y = \mathrm{Concat}(H_1,\dots,H_h)\,W^{O} $$

KV-cache: is \( h \times T \times d_k \).

This introduces shared K/V projections for all heads. The benefit is cutting KV cache size significantly, which is great for memory-bandwidth-bound decode at batch-1 and longer context workloads. It computes:

$$ Q_i = X W_i^{Q},\quad K_\star = X W^{K},\quad V_\star = X W^{V} $$

$$ H_i = \mathrm{Attn}(Q_i,K_\star,V_\star) $$

$$ Y = \mathrm{Concat}(H_1,\dots,H_h),W^{O} $$

KV-cache: is \( T \times d_k \).

Middle ground: K/V shared per group of Q heads (group size ( g )). The Trade-off between MHA quality and MQA efficiency. It Computes:

$$ Q_i = X W_i^{Q},\quad K_{(j)} = X W^{K}_{(j)},\quad V_{(j)} = X W^{V}_{(j)},\quad j=1,\dots,G $$

$$ H_i = \mathrm{Attn}\big(Q_i,,K_{(j(i))},,V_{(j(i))}\big) $$

$$ Y = \mathrm{Concat}(H_1,\dots,H_h),W^{O} $$

KV-cache: is \( \frac{h}{g} \times T \times d_k \).

$$ Q_i = X W_i^{Q} U_{Q} \in \mathbb{R}^{T\times r},\qquad K_\ell = X U_{K} \in \mathbb{R}^{T\times r},\qquad V_\ell = X U_{V} \in \mathbb{R}^{T\times r} $$

$$ H_i = \operatorname{Attn}(Q_i, K_\ell, V_\ell) \in \mathbb{R}^{T\times r} $$

$$ Y = \operatorname{Concat}(H_1,\dots,H_h),W^{O} $$

Where: \( W_i^{Q}\in\mathbb{R}^{d_{\text{model}}\times d_k} \), \( U_Q\in\mathbb{R}^{d_k\times r} \) (reduces queries to latent dim), \( U_K,\ U_V\in\mathbb{R}^{d_{\text{model}}\times r} \) (project keys/values to latent dim).

KV-cache: \( T \times r \) (or \( \frac{h}{g} \times T \times r \) if latents are shared per group of \( g \) heads).

Here are tightened, more substantial versions of the two sections:

Kernel Variants

Other Variants

Workload Phases

Usually there are 3 workloads types when discussing LLM inference, each one having slightly different computational requitemens:

This flash-infer blogpost does a great job of analysing the requirents of different workloads.

Frameworks

There's quite a lot of ways to compute attention within a single framework like PyTorch, and there are many dedicated framework which evolve quite rapidly.

In our benchmarking code we test the following approaches:

These frameworks evolve very quickly as workloads and state of the art techniques change so it's not trivial to conclude which one is best for a particular use-case and the respective performance ordering might change rapidly when a new data type, library version, or workload is introduced.

Some frameworks like torch SDPA even have a routing mechanism where they try to dispatch operations to the most efficient available backend.

Framework modules can be found in the repo here.

Testing Framework (what we measure)

Right now we test on bfloat16 and we use GQA variants.

We use the following architectures as testing templates: Llama‑3.1‑8B, Qwen3‑30B‑A3B, and Qwen3‑235B‑A3B.

Each experiment is defined with its batch size, sequence lengths, and architecture shapes like this:

@dataclass
class ExperimentSpec:
    workload_type: WorkloadType

    batch_size: int

    q_seq_len: int
    kv_seq_len: int

    num_q_heads: int
    num_kv_heads: int
    head_dim: int

The repository contains a top level run.py script which measures the performance of a given framework on a set of Experiments.

We report latency, throughput, and TFLOPs/s

Results (high‑level takeaways)

For all hardware we use bfloat16

Here are the results for batch_size=256, seq_len=4096, q_chunk=128 (for append); kernel‑only p95 latency:

H200

Framework Llama 3-8b-prefill Llama 3-8b-append Llama 3-8b-decode Qwen3-30b-prefill Qwen3-30b-append Qwen3-30b-decode Qwen3-235b-prefill Qwen3-235b-append Qwen3-235b-decode
flash attn 188.00 6.59 1.61 188.42 6.67 0.90 377.61 13.36 0.85
flashinfer 139.01 7.23 1.62 141.16 7.24 1.52 278.40 12.63 ERROR
te 117.70 3.74 1.19 117.66 3.63 0.71 239.45 7.55 0.72
torch flex OOM 141.33 30.76 OOM 140.86 30.30 OOM OOM 58.97
torch compile 217.03 14.93 6.54 213.95 13.06 6.23 432.18 25.57 9.95
torch naive OOM 43.07 21.34 OOM 42.58 20.93 OOM OOM 41.19
torch sdpa OOM 23.52 23.38 OOM 23.11 22.95 OOM 45.90 45.73