You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/guides/optimization/pallas_kernels_performance.md
+47-39Lines changed: 47 additions & 39 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -26,8 +26,8 @@ This guide explains **when** to consider Pallas, a **workflow** for developing a
26
26
27
27
Think in **roofline** terms ([All About Rooflines](https://jax-ml.github.io/scaling-book/roofline/)) and in terms of **structure the compiler can’t see**:
28
28
29
-
***Roofline framing.** Is your op **compute-limited** (MXU at or near peak) or **bandwidth-limited** (HBM↔on-chip transfers dominate)? Pallas tends to shine when you can reduce bandwidth pressure or avoid wasted work via better tiling and scheduling.
30
-
***Compiler invisibles.** Irregular sparsity, ragged batch shapes, non-contiguous memory access, and domain-specific invariants are all signals that a custom kernel could help.
29
+
-**Roofline framing.** Is your op **compute-limited** (MXU at or near peak) or **bandwidth-limited** (HBM↔on-chip transfers dominate)? Pallas tends to shine when you can reduce bandwidth pressure or avoid wasted work via better tiling and scheduling.
30
+
-**Compiler invisibles.** Irregular sparsity, ragged batch shapes, non-contiguous memory access, and domain-specific invariants are all signals that a custom kernel could help.
31
31
32
32
**Know when XLA is enough.** Before writing a custom kernel, always [profile your baseline](#1-high-level-profiling). If a standard operation (like a dense [`jnp.matmul`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.matmul.html)) is already performing well, the XLA compiler is doing its job. In these cases, a Pallas kernel will increase code complexity and maintenance burden with minimal performance improvement.
33
33
@@ -42,29 +42,34 @@ it is very difficult to automatically infer the dual of the memory pipeline.
42
42
43
43
For dense, regular GEMMs, XLA’s libraries are hard to beat. The exception is **Mixture-of-Experts (MoE)** MLPs with **ragged token→expert layouts** (some tokens routed to different experts; shapes are irregular). Zero-padding to make dense tensors wastes FLOPs; a custom kernel can operate only on the actually-selected tokens.
44
44
45
-
* In MaxText, we use Grouped Matrix Multiplication (GMM) via **Megablox** to compute per-expert matmuls on ragged batches. Precomputed metadata (e.g., token→expert indices and ranges) guides the grouped computation and avoids work on padded regions.
45
+
- In MaxText, we use Grouped Matrix Multiplication (GMM) via **Megablox** to compute per-expert matmuls on ragged batches. Precomputed metadata (e.g., token→expert indices and ranges) guides the grouped computation and avoids work on padded regions.
46
46
47
47
**Note:***Megablox* is an efficient, non-capped MoE implementation in JAX. *Megablocks* refers to the equivalent PyTorch implementation. See [arXiv:2211.15841](https://arxiv.org/abs/2211.15841) for more details.
48
48
49
49
### 2. Memory-Access-Bound work (attention)
50
50
51
-
Attention kernels are classically **bandwidth-limited** if you materialize the full \[L,L\] score matrix. A Pallas kernel can block **Q/K/V** into tiles that fit on-chip and perform **online softmax accumulation**, never storing the massive intermediate.
51
+
Attention kernels are classically **bandwidth-limited** if you materialize the full [L,L] score matrix. A Pallas kernel can block **Q/K/V** into tiles that fit on-chip and perform **online softmax accumulation**, never storing the massive intermediate.
52
52
53
-
* MaxText uses a Pallas attention kernel for training (Flash/Splash-style) and **paged/ragged** attention for inference to efficiently fetch KV cache pages and handle non-contiguous layouts.
53
+
- MaxText uses a Pallas attention kernel for training (Flash/Splash-style) and **paged/ragged** attention for inference to efficiently fetch KV cache pages and handle non-contiguous layouts.
54
54
55
55
## 🛠️ Pallas kernels in MaxText
56
56
57
57
To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth-bound or structurally irregular operations that a general-purpose compiler cannot optimize as effectively. Below are the key kernels we use. **Note**: Examples evolve; treat this list as guidance.
58
58
59
-
***Training Attention (Flash/Splash-style):** This kernel is the default for training Transformer models in MaxText, such as DeepSeek, Gemma and Llama. It avoids creating the large \[L,L\] attention matrix to save memory, processing data in smaller, tiled chunks with online softmax accumulation.
***Serving Attention (Paged & Ragged):** For high-throughput inference, this kernel efficiently fetches non-contiguous "pages" of the KV cache from memory. It is a key optimization for our serving stack and is used for models running on MaxText's inference engine.
-**Training Attention (Flash/Splash-style):** This kernel is the default for training Transformer models in MaxText, such as DeepSeek, Gemma and Llama. It avoids creating the large [L,L] attention matrix to save memory, processing data in smaller, tiled chunks with online softmax accumulation.
65
60
66
-
> This is an efficient computation method for Mixture-of-Experts (MoE) models like DeepSeek, Llama 4, Mixtral and Qwen-MoE. In MoE, each token is processed by only a few "experts," which is inefficient for standard matrix multiplication. Megablox solves this by having the CPU (**host**) first create a routing plan (**metadata**) that assigns tokens to experts. The accelerator (**device**) then uses this plan to perform many small, dense matrix multiplications in parallel (**Grouped Matrix Multiplication**), avoiding wasted work on unused experts.
-**Serving Attention (Paged & Ragged):** For high-throughput inference, this kernel efficiently fetches non-contiguous "pages" of the KV cache from memory. It is a key optimization for our serving stack and is used for models running on MaxText's inference engine.
> This is an efficient computation method for Mixture-of-Experts (MoE) models like DeepSeek, Llama 4, Mixtral and Qwen-MoE. In MoE, each token is processed by only a few "experts," which is inefficient for standard matrix multiplication. Megablox solves this by having the CPU (**host**) first create a routing plan (**metadata**) that assigns tokens to experts. The accelerator (**device**) then uses this plan to perform many small, dense matrix multiplications in parallel (**Grouped Matrix Multiplication**), avoiding wasted work on unused experts.
**Note:** Megablox accelerates the grouped **matmul**; **routing/gating** is separate code ([`src/MaxText/layers/moe.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/moe.py)).
70
75
@@ -74,7 +79,7 @@ To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth
74
79
75
80
Give the kernel a clear name in traces and capture a profile. Always use [`jax.block_until_ready()`](https://docs.jax.dev/en/latest/_autosummary/jax.block_until_ready.html) when timing your operations.
76
81
77
-
```python
82
+
```python
78
83
import jax
79
84
from jax import profiler
80
85
@@ -104,24 +109,24 @@ For a more automated approach, consider using libraries like [tune-jax](https://
104
109
105
110
Pallas exposes the underlying hardware primitives for you to control.
-**VMEM:** On-chip vector SRAM for array tiles; your kernel primarily reads/writes VMEM refs.
114
+
-**SMEM:** On-chip scalar SRAM for control/metadata (e.g., counters, small tables).
115
+
-**Semaphores:** Available for advanced async/barrier patterns in manual pipelines.
116
+
-**MXU:** The Matrix Unit, optimized for large block GEMMs/convolutions.
117
+
-**VPU:** The Vector Processing Unit, used for elementwise/vector work.
113
118
114
119
**Alignment & Constraints:** Respect TPU BlockSpec constraints (divisibility/shape rules for trailing dimensions and supported block shapes). Start with tile shapes that fit in VMEM and meet these requirements, then sweep different sizes to find the optimum. Let profiling guide you; don't assume powers of two are always best.
115
120
116
121
## 🧱 Core Pallas design patterns
117
122
118
123
These are the common techniques used in MaxText's Pallas kernels.
119
124
120
-
***Tiling & Blocking:** Move just a tile that fits on-chip, compute on it, and write it back.
121
-
***Explicit Pipelining:** Overlap HBM↔VMEM loads with compute to hide latency (e.g., double-buffering).
122
-
***Online Accumulation:** Combine partial results as you go; don’t materialize huge intermediate arrays.
123
-
***Auxiliary Metadata:** Precompute control tables (e.g., token-to-expert ranges) and keep them in fast scalar memory.
124
-
***Compute↔Communication Overlap:** In distributed runs, overlap local work with cross-device traffic when possible.
125
+
-**Tiling & Blocking:** Move just a tile that fits on-chip, compute on it, and write it back.
126
+
-**Explicit Pipelining:** Overlap HBM↔VMEM loads with compute to hide latency (e.g., double-buffering).
127
+
-**Online Accumulation:** Combine partial results as you go; don’t materialize huge intermediate arrays.
128
+
-**Auxiliary Metadata:** Precompute control tables (e.g., token-to-expert ranges) and keep them in fast scalar memory.
129
+
-**Compute↔Communication Overlap:** In distributed runs, overlap local work with cross-device traffic when possible.
Prefer `pl.pallas_call` with scratch buffers allocated in the appropriate memory space (VMEM/SMEM) and use multi-buffering to overlap HBM loads with compute. Advanced pipelining to consider: custom prefetch block order via a scalar prefetch grid (for details see [here](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html)), which lets you control block execution order based on runtime values.
194
203
195
-
196
204
## 🌐 Distributed execution
197
205
198
206
Dispatch a kernel on multiple devices with `jax.shard_map`. It’s usually simpler and more maintainable than in-kernel cross-device communication. While Pallas supports low-level comms, `shard_map` is the right first choice for multi-device parallelism, and you can **communicate with `shard_map` collectives** when needed.
199
207
200
208
## 🐞 Debugging tips
201
209
202
-
* Use `interpret=True` in `pallas_call` to run the kernel body in a Python interpreter backend, simulating device execution on CPU without lowering through XLA.
203
-
* Start with a tiny problem size and assert on invariants inside the kernel.
204
-
* Add `jax.named_scope` liberally so kernels are easy to spot in performance traces.
210
+
- Use `interpret=True` in `pallas_call` to run the kernel body in a Python interpreter backend, simulating device execution on CPU without lowering through XLA.
211
+
- Start with a tiny problem size and assert on invariants inside the kernel.
212
+
- Add `jax.named_scope` liberally so kernels are easy to spot in performance traces.
205
213
206
214
## ✅ Putting it all together (checklist)
207
215
208
216
1.**Profile** the baseline using `named_scope` and `block_until_ready`.
209
-
2.**Tile arrays into smaller chunks using BlockSpecs** (virtually always necessary, even for simple kernels).
210
-
3. Build a **sweep harness** for block shapes (and optionally scalar prefetch grid choices).
211
-
4.**Validate** end-to-end performance in the model, not just microbenchmarks.
212
-
5. Consider **maintainability** and guard the new kernel with tests.
213
-
6. Consider applying **`jax.vmap`** to a Pallas kernel to simplify implementation; think of it as prepending grid dimensions automatically.
217
+
1.**Tile arrays into smaller chunks using BlockSpecs** (virtually always necessary, even for simple kernels).
218
+
1. Build a **sweep harness** for block shapes (and optionally scalar prefetch grid choices).
219
+
1.**Validate** end-to-end performance in the model, not just microbenchmarks.
220
+
1. Consider **maintainability** and guard the new kernel with tests.
221
+
1. Consider applying **`jax.vmap`** to a Pallas kernel to simplify implementation; think of it as prepending grid dimensions automatically.
0 commit comments