Skip to content

Commit 0d909c4

Browse files
Merge pull request #3013 from AI-Hypercomputer:jacobplatin/refactor-inference
PiperOrigin-RevId: 862346681
2 parents 15177f2 + 4720def commit 0d909c4

139 files changed

Lines changed: 809 additions & 767 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/CODEOWNERS

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ src/MaxText/elastic_train.py @lukebaumann @shauryagup @richjames0 @shralex
1818
src/MaxText/layers/quantizations.py @khatwanimohit @jshin1394 @liudangyi @richjames0 @shralex
1919

2020
# Inference
21-
src/MaxText/tests/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
22-
src/MaxText/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
23-
src/MaxText/inference_mlperf @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
21+
src/maxtext/tests/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
22+
src/maxtext/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
23+
src/maxtext/inference_mlperf @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
2424

2525
# Dockerfiles and dependencies
2626
*.Dockerfile @bvandermoon @parambole @richjames0 @shralex

.vscode/launch.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"console": "integratedTerminal",
99
"justMyCode": false,
1010
"python": "python3",
11-
"module": "MaxText.decode",
11+
"module": "maxtext.decode",
1212
"args": ["src/MaxText/configs/base.yml",
1313
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
1414
"base_output_directory=gs://test-maxtext-output",
@@ -35,9 +35,9 @@
3535
"console": "integratedTerminal",
3636
"justMyCode": false,
3737
"python": "python3",
38-
"module": "MaxText.decode",
38+
"module": "maxtext.decode",
3939
"args": ["src/MaxText/configs/base.yml",
40-
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
40+
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
4141
"base_output_directory=gs://test-maxtext-output",
4242
"dataset_path=gs://test-maxtext-dataset",
4343
"steps=2",
@@ -53,7 +53,7 @@
5353
"python": "python3",
5454
"module": "MaxText.train",
5555
"args": ["src/MaxText/configs/base.yml",
56-
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
56+
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
5757
"base_output_directory=gs://test-maxtext-output",
5858
"dataset_path=gs://test-maxtext-dataset",
5959
"steps=2",
@@ -66,7 +66,7 @@
6666
"console": "integratedTerminal",
6767
"justMyCode": false,
6868
"python": "python3",
69-
"module": "MaxText.inference_microbenchmark",
69+
"module": "maxtext.inference.inference_microbenchmark",
7070
"args": [
7171
"src/MaxText/configs/base.yml",
7272
"model_name=llama2-7b",
@@ -82,7 +82,7 @@
8282
"inference_microbenchmark_prefill_lengths=32,64,128,256,512,1024",
8383
"inference_microbenchmark_stages=generate",
8484
"inference_microbenchmark_loop_iters=1",
85-
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
85+
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
8686
"base_output_directory=gs://test-maxtext-output",
8787
"prefill_cache_axis_order=0,2,1,3",
8888
"ar_cache_axis_order=0,2,1,3",

codecov.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# During scheduled runs, the 'regular' flag is carried forward from the last PR.
2525

2626
# Exclude non-source code, deprecated and experimental folders from coverage tracking
27-
codecov:
27+
codecov:
2828
token: 35742a22-fb1f-4839-97ff-b54da5588689
2929
# By default file names in the coverage report will have their path in the file system, which in our
3030
# runners would be /__w/maxtext/maxtext/src/MaxText/* but Codecov expects src/MaxText/* so we need to fix the path

docs/guides/optimization/pallas_kernels_performance.md

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ This guide explains **when** to consider Pallas, a **workflow** for developing a
2626

2727
Think in **roofline** terms ([All About Rooflines](https://jax-ml.github.io/scaling-book/roofline/)) and in terms of **structure the compiler can’t see**:
2828

29-
* **Roofline framing.** Is your op **compute-limited** (MXU at or near peak) or **bandwidth-limited** (HBM↔on-chip transfers dominate)? Pallas tends to shine when you can reduce bandwidth pressure or avoid wasted work via better tiling and scheduling.
30-
* **Compiler invisibles.** Irregular sparsity, ragged batch shapes, non-contiguous memory access, and domain-specific invariants are all signals that a custom kernel could help.
29+
- **Roofline framing.** Is your op **compute-limited** (MXU at or near peak) or **bandwidth-limited** (HBM↔on-chip transfers dominate)? Pallas tends to shine when you can reduce bandwidth pressure or avoid wasted work via better tiling and scheduling.
30+
- **Compiler invisibles.** Irregular sparsity, ragged batch shapes, non-contiguous memory access, and domain-specific invariants are all signals that a custom kernel could help.
3131

3232
**Know when XLA is enough.** Before writing a custom kernel, always [profile your baseline](#1-high-level-profiling). If a standard operation (like a dense [`jnp.matmul`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.matmul.html)) is already performing well, the XLA compiler is doing its job. In these cases, a Pallas kernel will increase code complexity and maintenance burden with minimal performance improvement.
3333

@@ -42,29 +42,34 @@ it is very difficult to automatically infer the dual of the memory pipeline.
4242

4343
For dense, regular GEMMs, XLA’s libraries are hard to beat. The exception is **Mixture-of-Experts (MoE)** MLPs with **ragged token→expert layouts** (some tokens routed to different experts; shapes are irregular). Zero-padding to make dense tensors wastes FLOPs; a custom kernel can operate only on the actually-selected tokens.
4444

45-
* In MaxText, we use Grouped Matrix Multiplication (GMM) via **Megablox** to compute per-expert matmuls on ragged batches. Precomputed metadata (e.g., token→expert indices and ranges) guides the grouped computation and avoids work on padded regions.
45+
- In MaxText, we use Grouped Matrix Multiplication (GMM) via **Megablox** to compute per-expert matmuls on ragged batches. Precomputed metadata (e.g., token→expert indices and ranges) guides the grouped computation and avoids work on padded regions.
4646

4747
**Note:** *Megablox* is an efficient, non-capped MoE implementation in JAX. *Megablocks* refers to the equivalent PyTorch implementation. See [arXiv:2211.15841](https://arxiv.org/abs/2211.15841) for more details.
4848

4949
### 2. Memory-Access-Bound work (attention)
5050

51-
Attention kernels are classically **bandwidth-limited** if you materialize the full \[L,L\] score matrix. A Pallas kernel can block **Q/K/V** into tiles that fit on-chip and perform **online softmax accumulation**, never storing the massive intermediate.
51+
Attention kernels are classically **bandwidth-limited** if you materialize the full [L,L] score matrix. A Pallas kernel can block **Q/K/V** into tiles that fit on-chip and perform **online softmax accumulation**, never storing the massive intermediate.
5252

53-
* MaxText uses a Pallas attention kernel for training (Flash/Splash-style) and **paged/ragged** attention for inference to efficiently fetch KV cache pages and handle non-contiguous layouts.
53+
- MaxText uses a Pallas attention kernel for training (Flash/Splash-style) and **paged/ragged** attention for inference to efficiently fetch KV cache pages and handle non-contiguous layouts.
5454

5555
## 🛠️ Pallas kernels in MaxText
5656

5757
To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth-bound or structurally irregular operations that a general-purpose compiler cannot optimize as effectively. Below are the key kernels we use. **Note**: Examples evolve; treat this list as guidance.
5858

59-
* **Training Attention (Flash/Splash-style):** This kernel is the default for training Transformer models in MaxText, such as DeepSeek, Gemma and Llama. It avoids creating the large \[L,L\] attention matrix to save memory, processing data in smaller, tiled chunks with online softmax accumulation.
60-
* [`src/MaxText/kernels/splash_attention_kernel.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/splash_attention_kernel.py)
61-
* **Serving Attention (Paged & Ragged):** For high-throughput inference, this kernel efficiently fetches non-contiguous "pages" of the KV cache from memory. It is a key optimization for our serving stack and is used for models running on MaxText's inference engine.
62-
* [`src/MaxText/inference/paged_attention.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/inference/paged_attention.py)
63-
* [`src/MaxText/inference/paged_attention_kernel_v2.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/inference/paged_attention_kernel_v2.py)
64-
* **MoE Grouped Matmul (Megablox GMM):** Sparse/irregular grouped GEMMs driven by host-built metadata.
59+
- **Training Attention (Flash/Splash-style):** This kernel is the default for training Transformer models in MaxText, such as DeepSeek, Gemma and Llama. It avoids creating the large [L,L] attention matrix to save memory, processing data in smaller, tiled chunks with online softmax accumulation.
6560

66-
> This is an efficient computation method for Mixture-of-Experts (MoE) models like DeepSeek, Llama 4, Mixtral and Qwen-MoE. In MoE, each token is processed by only a few "experts," which is inefficient for standard matrix multiplication. Megablox solves this by having the CPU (**host**) first create a routing plan (**metadata**) that assigns tokens to experts. The accelerator (**device**) then uses this plan to perform many small, dense matrix multiplications in parallel (**Grouped Matrix Multiplication**), avoiding wasted work on unused experts.
67-
* [`src/MaxText/kernels/megablox/gmm.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py)
61+
- [`src/MaxText/kernels/splash_attention_kernel.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/splash_attention_kernel.py)
62+
63+
- **Serving Attention (Paged & Ragged):** For high-throughput inference, this kernel efficiently fetches non-contiguous "pages" of the KV cache from memory. It is a key optimization for our serving stack and is used for models running on MaxText's inference engine.
64+
65+
- [`src/MaxText/inference/paged_attention.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/inference/paged_attention.py)
66+
- [`src/MaxText/inference/paged_attention_kernel_v2.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/inference/paged_attention_kernel_v2.py)
67+
68+
- **MoE Grouped Matmul (Megablox GMM):** Sparse/irregular grouped GEMMs driven by host-built metadata.
69+
70+
> This is an efficient computation method for Mixture-of-Experts (MoE) models like DeepSeek, Llama 4, Mixtral and Qwen-MoE. In MoE, each token is processed by only a few "experts," which is inefficient for standard matrix multiplication. Megablox solves this by having the CPU (**host**) first create a routing plan (**metadata**) that assigns tokens to experts. The accelerator (**device**) then uses this plan to perform many small, dense matrix multiplications in parallel (**Grouped Matrix Multiplication**), avoiding wasted work on unused experts.
71+
72+
- [`src/MaxText/kernels/megablox/gmm.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py)
6873

6974
**Note:** Megablox accelerates the grouped **matmul**; **routing/gating** is separate code ([`src/MaxText/layers/moe.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/moe.py)).
7075

@@ -74,7 +79,7 @@ To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth
7479

7580
Give the kernel a clear name in traces and capture a profile. Always use [`jax.block_until_ready()`](https://docs.jax.dev/en/latest/_autosummary/jax.block_until_ready.html) when timing your operations.
7681

77-
``` python
82+
```python
7883
import jax
7984
from jax import profiler
8085

@@ -104,24 +109,24 @@ For a more automated approach, consider using libraries like [tune-jax](https://
104109

105110
Pallas exposes the underlying hardware primitives for you to control.
106111

107-
* **HBM:** High-Bandwidth Memory (standard device memory).
108-
* **VMEM:** On-chip vector SRAM for array tiles; your kernel primarily reads/writes VMEM refs.
109-
* **SMEM:** On-chip scalar SRAM for control/metadata (e.g., counters, small tables).
110-
* **Semaphores:** Available for advanced async/barrier patterns in manual pipelines.
111-
* **MXU:** The Matrix Unit, optimized for large block GEMMs/convolutions.
112-
* **VPU:** The Vector Processing Unit, used for elementwise/vector work.
112+
- **HBM:** High-Bandwidth Memory (standard device memory).
113+
- **VMEM:** On-chip vector SRAM for array tiles; your kernel primarily reads/writes VMEM refs.
114+
- **SMEM:** On-chip scalar SRAM for control/metadata (e.g., counters, small tables).
115+
- **Semaphores:** Available for advanced async/barrier patterns in manual pipelines.
116+
- **MXU:** The Matrix Unit, optimized for large block GEMMs/convolutions.
117+
- **VPU:** The Vector Processing Unit, used for elementwise/vector work.
113118

114119
**Alignment & Constraints:** Respect TPU BlockSpec constraints (divisibility/shape rules for trailing dimensions and supported block shapes). Start with tile shapes that fit in VMEM and meet these requirements, then sweep different sizes to find the optimum. Let profiling guide you; don't assume powers of two are always best.
115120

116121
## 🧱 Core Pallas design patterns
117122

118123
These are the common techniques used in MaxText's Pallas kernels.
119124

120-
* **Tiling & Blocking:** Move just a tile that fits on-chip, compute on it, and write it back.
121-
* **Explicit Pipelining:** Overlap HBM↔VMEM loads with compute to hide latency (e.g., double-buffering).
122-
* **Online Accumulation:** Combine partial results as you go; don’t materialize huge intermediate arrays.
123-
* **Auxiliary Metadata:** Precompute control tables (e.g., token-to-expert ranges) and keep them in fast scalar memory.
124-
* **Compute↔Communication Overlap:** In distributed runs, overlap local work with cross-device traffic when possible.
125+
- **Tiling & Blocking:** Move just a tile that fits on-chip, compute on it, and write it back.
126+
- **Explicit Pipelining:** Overlap HBM↔VMEM loads with compute to hide latency (e.g., double-buffering).
127+
- **Online Accumulation:** Combine partial results as you go; don’t materialize huge intermediate arrays.
128+
- **Auxiliary Metadata:** Precompute control tables (e.g., token-to-expert ranges) and keep them in fast scalar memory.
129+
- **Compute↔Communication Overlap:** In distributed runs, overlap local work with cross-device traffic when possible.
125130

126131
## ✍️ Writing & integrating a kernel
127132

@@ -136,9 +141,11 @@ import jax
136141
import jax.numpy as jnp
137142
from jax.experimental import pallas as pl
138143

144+
139145
def add_vectors_kernel(x_ref, y_ref, o_ref):
140146
o_ref[:] = x_ref[:] + y_ref[:]
141147

148+
142149
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
143150
assert x.shape == y.shape
144151
return pl.pallas_call(
@@ -156,14 +163,16 @@ import jax
156163
import jax.numpy as jnp
157164
from jax.experimental import pallas as pl
158165

166+
159167
def tile_add_kernel(x_ref, y_ref, o_ref):
160168
# Operate on the tile slices handed in by BlockSpecs (already in VMEM on TPU).
161169
o_ref[:, :] = x_ref[:, :] + y_ref[:, :]
162170

171+
163172
def tile_add(x: jax.Array, y: jax.Array) -> jax.Array:
164173
assert x.shape == y.shape and x.ndim == 2
165174
B0 = min(128, x.shape[0]) # Example choice; tune this with a sweep
166-
B1 = x.shape[1] # Full width tile (for illustration)
175+
B1 = x.shape[1] # Full width tile (for illustration)
167176

168177
# Map program id (tile index) -> tile origin in the full (HBM) array.
169178
# NOTE: The runtime advances origins by `block_shape`, so `i` is already a tile
@@ -192,29 +201,28 @@ def tile_add(x: jax.Array, y: jax.Array) -> jax.Array:
192201

193202
Prefer `pl.pallas_call` with scratch buffers allocated in the appropriate memory space (VMEM/SMEM) and use multi-buffering to overlap HBM loads with compute. Advanced pipelining to consider: custom prefetch block order via a scalar prefetch grid (for details see [here](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html)), which lets you control block execution order based on runtime values.
194203

195-
196204
## 🌐 Distributed execution
197205

198206
Dispatch a kernel on multiple devices with `jax.shard_map`. It’s usually simpler and more maintainable than in-kernel cross-device communication. While Pallas supports low-level comms, `shard_map` is the right first choice for multi-device parallelism, and you can **communicate with `shard_map` collectives** when needed.
199207

200208
## 🐞 Debugging tips
201209

202-
* Use `interpret=True` in `pallas_call` to run the kernel body in a Python interpreter backend, simulating device execution on CPU without lowering through XLA.
203-
* Start with a tiny problem size and assert on invariants inside the kernel.
204-
* Add `jax.named_scope` liberally so kernels are easy to spot in performance traces.
210+
- Use `interpret=True` in `pallas_call` to run the kernel body in a Python interpreter backend, simulating device execution on CPU without lowering through XLA.
211+
- Start with a tiny problem size and assert on invariants inside the kernel.
212+
- Add `jax.named_scope` liberally so kernels are easy to spot in performance traces.
205213

206214
## ✅ Putting it all together (checklist)
207215

208216
1. **Profile** the baseline using `named_scope` and `block_until_ready`.
209-
2. **Tile arrays into smaller chunks using BlockSpecs** (virtually always necessary, even for simple kernels).
210-
3. Build a **sweep harness** for block shapes (and optionally scalar prefetch grid choices).
211-
4. **Validate** end-to-end performance in the model, not just microbenchmarks.
212-
5. Consider **maintainability** and guard the new kernel with tests.
213-
6. Consider applying **`jax.vmap`** to a Pallas kernel to simplify implementation; think of it as prepending grid dimensions automatically.
217+
1. **Tile arrays into smaller chunks using BlockSpecs** (virtually always necessary, even for simple kernels).
218+
1. Build a **sweep harness** for block shapes (and optionally scalar prefetch grid choices).
219+
1. **Validate** end-to-end performance in the model, not just microbenchmarks.
220+
1. Consider **maintainability** and guard the new kernel with tests.
221+
1. Consider applying **`jax.vmap`** to a Pallas kernel to simplify implementation; think of it as prepending grid dimensions automatically.
214222

215223
## 📚 References
216224

217-
* **Pallas Docs & Quickstart:** [docs.jax.dev/en/latest/pallas/index.html](https://docs.jax.dev/en/latest/pallas/index.html)
218-
* **JAX Profiling Guides:** [jax.readthedocs.io/en/latest/profiling.html](https://jax.readthedocs.io/en/latest/profiling.html)
219-
* **Manual Parallelism (shard_map):** [docs.jax.dev/en/latest/notebooks/shard_map.html](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
220-
* **Distributed Pallas on TPU:** [docs.jax.dev/en/latest/pallas/tpu/distributed.html](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html)
225+
- **Pallas Docs & Quickstart:** [docs.jax.dev/en/latest/pallas/index.html](https://docs.jax.dev/en/latest/pallas/index.html)
226+
- **JAX Profiling Guides:** [jax.readthedocs.io/en/latest/profiling.html](https://jax.readthedocs.io/en/latest/profiling.html)
227+
- **Manual Parallelism (shard_map):** [docs.jax.dev/en/latest/notebooks/shard_map.html](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
228+
- **Distributed Pallas on TPU:** [docs.jax.dev/en/latest/pallas/tpu/distributed.html](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html)

0 commit comments

Comments
 (0)