Skip to content

Commit 5dd5b3c

Browse files
authored
Merge branch 'main' into sanbao/format
2 parents 58c89cf + 77edafe commit 5dd5b3c

76 files changed

Lines changed: 1748 additions & 652 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
* @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan @richjames0 @gagika @shralex @SurbhiJainUSC @hengtaoguo @A9isha @aireenmei @NuojCheng @jiangjy1982 @suexu1025 @NicoGrande @jesselu-google
1+
* @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan @richjames0 @gagika @shralex @SurbhiJainUSC @hengtaoguo @A9isha @aireenmei @NuojCheng @jiangjy1982 @suexu1025 @NicoGrande @jesselu-google @dipannita08 @igorts-git
22

33
# Model bring-up
44
src/MaxText/assets @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande

.github/workflows/pypi_release.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ jobs:
4444

4545
publish_maxtext_package_to_pypi:
4646
name: Publish MaxText package to PyPI
47-
needs: [build_and_test_maxtext_package]
47+
# Temporarily only require release_approval for a one-time upload.
48+
# Immediately revert this to `needs: [build_and_test_maxtext_package]`.
49+
needs: [release_approval]
4850
runs-on: ubuntu-latest
4951
environment: release
5052
steps:

.github/workflows/run_jupyter_notebooks.yml

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,33 +57,18 @@ jobs:
5757
- name: Install MaxText and Dependencies
5858
shell: bash
5959
run: |
60+
# 1. Create virtual environment
6061
python3 -m uv venv --seed
6162
source .venv/bin/activate
62-
63-
# Install MaxText package
6463
maxtext_wheel=$(ls maxtext-*-py3-none-any.whl 2>/dev/null)
65-
uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest
66-
uv pip install -r src/install_maxtext_extra_deps/extra_deps_from_github.txt
6764
68-
# Install dependencies for running notebooks
69-
uv pip install papermill ipykernel ipywidgets
65+
# 2. Install MaxText package and all the post training dependencies
66+
uv pip install ${maxtext_wheel}[tpu-post-train] --resolution=lowest
67+
#TODO: @mazumdera: replace this with the following after release
68+
# uv pip install maxtext[tpu-post-train] --resolution=lowest
69+
install_maxtext_tpu_post_train_extra_deps
7070
.venv/bin/python3 -m ipykernel install --user --name maxtext_venv
71-
72-
# Install Tunix for post-training notebooks
73-
git clone https://github.com/google/tunix
74-
uv pip install ./tunix
7571
76-
# Install vllm for post-training notebooks
77-
git clone https://github.com/vllm-project/vllm.git
78-
VLLM_TARGET_DEVICE="tpu" uv pip install ./vllm
79-
80-
# Install tpu-inference for post-training notebooks
81-
git clone https://github.com/vllm-project/tpu-inference.git
82-
uv pip install ./tpu-inference
83-
84-
uv pip install --no-deps qwix==0.1.4
85-
uv pip install --no-deps protobuf==5.29.5
86-
uv pip install math-verify==0.9.0
8772
python3 -m pip freeze
8873
- name: Run Post-Training Notebooks
8974
shell: bash

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,6 @@ dmypy.json
148148
# Gemini CLI
149149
.gemini/
150150
gha-creds-*.json
151+
152+
# vscode workspace
153+
maxtext.code-workspace

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies
4141

4242
## 🔥 Latest news 🔥
4343

44-
* \[February 27, 2026\] New MaxText structure! MaxText has been restructured according to [RESTRUCTURE.md](https://github.com/AI-Hypercomputer/maxtext/blob/1b9e38aa0a19b6018feb3aed757406126b6953a1/RESTRUCTURE.md). Please feel free to share your thoughts and feedback.
44+
* \[March 5, 2026\] [Qwen3-Next](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md) is now supported.
45+
* \[February 27, 2026\] New MaxText structure! MaxText has been restructured according to [RESTRUCTURE.md](https://github.com/AI-Hypercomputer/maxtext/blob/1b9e38aa0a19b6018feb3aed757406126b6953a1/RESTRUCTURE.md). Please feel free to share your thoughts and feedback.
4546
* \[December 22, 2025\] [Muon optimizer](https://kellerjordan.github.io/posts/muon) is now supported.
4647
* \[December 10, 2025\] DeepSeek V3.1 is now supported. Use existing configs for [DeepSeek V3 671B](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek3-671b.yml) and load in V3.1 checkpoint to use model.
4748
* \[December 9, 2025\] [New RL and SFT Notebook tutorials](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/examples) are available.

benchmarks/globals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os.path
1818

1919
# This is the MaxText root: with "max_utils.py"; &etc. TODO: Replace `os.path.basename` with `os.path.abspath`
20-
MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", "src/MaxText")
20+
MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", "src/maxtext")
2121

2222
# This is the maxtext repo root: with ".git" folder; "README.md"; "pyproject.toml"; &etc.
2323
MAXTEXT_REPO_ROOT = os.environ.get(

benchmarks/maxtext_xpk_runner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import omegaconf
3636

3737
import benchmarks.maxtext_trillium_model_configs as model_configs
38-
from benchmarks.globals import MAXTEXT_CONFIGS_DIR
38+
from benchmarks.globals import MAXTEXT_PKG_DIR
3939
from benchmarks.command_utils import run_command_with_updates
4040
import benchmarks.xla_flags_library as xla_flags
4141
from benchmarks.disruption_management.disruption_handler import DisruptionConfig
@@ -107,7 +107,7 @@ class WorkloadConfig:
107107
generate_metrics_and_upload_to_big_query: bool = True
108108
hardware_id: str = "v6e"
109109
metrics_gcs_file: str = ""
110-
base_config: str = os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")
110+
base_config: str = os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")
111111
topology: str = dataclasses.field(init=False)
112112
num_devices_per_slice: int = dataclasses.field(init=False)
113113
db_project: str = ""
@@ -354,7 +354,7 @@ def _build_args_from_config(wl_config: WorkloadConfig) -> dict:
354354
"xla_flags": f"'{xla_flags_str}'",
355355
"dataset": dataset,
356356
"run_type": "maxtext-xpk",
357-
"config_file": os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml"),
357+
"config_file": os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
358358
"topology": wl_config.topology,
359359
"tuning_params": f"'{tuning_params_str}'",
360360
"db_project": wl_config.db_project,
@@ -439,8 +439,8 @@ def build_user_command(
439439
"export ENABLE_PATHWAYS_PERSISTENCE=1 &&",
440440
f"export JAX_PLATFORMS={jax_platforms} &&",
441441
"export ENABLE_PJRT_COMPATIBILITY=true &&",
442-
"export MAXTEXT_ASSETS_ROOT=/deps/src/maxtext/assets MAXTEXT_PKG_DIR=/deps/src/MaxText MAXTEXT_REPO_ROOT=/deps &&"
443-
f'{hlo_dump} python3 -m maxtext.trainers.pre_train.train {os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")}',
442+
"export MAXTEXT_ASSETS_ROOT=/deps/src/maxtext/assets MAXTEXT_PKG_DIR=/deps/src/maxtext MAXTEXT_REPO_ROOT=/deps &&"
443+
f'{hlo_dump} python3 -m maxtext.trainers.pre_train.train {os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")}',
444444
f"{config_tuning_params}",
445445
f"steps={wl_config.num_steps}",
446446
f"model_name={wl_config.model.model_type}",

dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile

Lines changed: 0 additions & 44 deletions
This file was deleted.

dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ ENV PATH="/usr/local/google-cloud-sdk/bin:/usr/local/bin/python3.12:${PATH}"
2323
ARG MODE
2424
ENV ENV_MODE=$MODE
2525

26+
ARG WORKFLOW
27+
ENV ENV_WORKFLOW=$WORKFLOW
28+
2629
ARG JAX_VERSION
2730
ENV ENV_JAX_VERSION=$JAX_VERSION
2831

@@ -34,7 +37,7 @@ ENV ENV_DEVICE=$DEVICE
3437

3538
ENV MAXTEXT_ASSETS_ROOT=/deps/src/maxtext/assets
3639
ENV MAXTEXT_TEST_ASSETS_ROOT=/deps/tests/assets
37-
ENV MAXTEXT_PKG_DIR=/deps/src/MaxText
40+
ENV MAXTEXT_PKG_DIR=/deps/src/maxtext
3841
ENV MAXTEXT_REPO_ROOT=/deps
3942

4043
# Set the working directory in the container
@@ -43,14 +46,15 @@ WORKDIR /deps
4346
# Copy setup files and dependency files separately for better caching
4447
COPY tools/setup tools/setup/
4548
COPY dependencies/requirements/ dependencies/requirements/
46-
COPY src/install_maxtext_extra_deps/extra_deps_from_github.txt src/install_maxtext_extra_deps/
49+
COPY src/install_maxtext_extra_deps/ src/install_maxtext_extra_deps/
50+
COPY src/maxtext/integration/vllm/ src/maxtext/integration/vllm/
4751

4852
# Copy the custom libtpu.so file if it exists inside maxtext repository
4953
COPY libtpu.so* /root/custom_libtpu/
5054

5155
# Install dependencies - these steps are cached unless the copied files change
52-
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_VERSION=$ENV_LIBTPU_VERSION DEVICE=${ENV_DEVICE}"
53-
RUN --mount=type=cache,target=/root/.cache/pip bash /deps/tools/setup/setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_VERSION=${ENV_LIBTPU_VERSION} DEVICE=${ENV_DEVICE}
56+
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE WORKFLOW=$ENV_WORKFLOW JAX_VERSION=$ENV_JAX_VERSION LIBTPU_VERSION=$ENV_LIBTPU_VERSION DEVICE=${ENV_DEVICE}"
57+
RUN --mount=type=cache,target=/root/.cache/pip bash /deps/tools/setup/setup.sh MODE=${ENV_MODE} WORKFLOW=${ENV_WORKFLOW} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_VERSION=${ENV_LIBTPU_VERSION} DEVICE=${ENV_DEVICE}
5458

5559
# Now copy the remaining code (source files that may change frequently)
5660
COPY . .

dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ antlr4-python3-runtime>=4.9.3
1111
anyio>=4.11.0
1212
aqtp>=0.9.0
1313
array_record>=0.8.3
14+
asttokens>=3.0.1
1415
astor>=0.8.1
1516
astroid>=4.0.2
1617
astunparse>=1.6.3
@@ -24,7 +25,7 @@ botocore>=1.42.56
2425
build>=1.3.0
2526
cachetools>=6.2.2
2627
cbor2>=5.8.0
27-
certifi>=2025.11.12
28+
certifi>=2026.2.25
2829
cffi>=2.0.0
2930
cfgv>=3.5.0
3031
charset-normalizer>=3.4.4
@@ -38,6 +39,7 @@ clu>=0.0.12
3839
cmake>=4.2.1
3940
colorama>=0.4.6
4041
colorful>=0.5.8
42+
comm>=0.2.3
4143
compressed-tensors>=0.13.0
4244
contourpy>=1.3.3
4345
coverage>=7.12.0
@@ -46,6 +48,7 @@ cycler>=0.12.1
4648
dacite>=1.9.2
4749
dataclasses-json>=0.6.7
4850
datasets>=4.6.0
51+
debugpy>=1.8.20
4952
decorator>=5.2.1
5053
depyf>=0.20.0
5154
dill>=0.4.0
@@ -60,13 +63,16 @@ editdistance>=0.8.1
6063
einops>=0.8.1
6164
einshape>=1.0
6265
email-validator>=2.3.0
66+
entrypoints>=0.4
6367
etils>=1.13.0
6468
evaluate>=0.4.6
6569
execnet>=2.1.2
70+
executing>=2.2.1
6671
fastapi>=0.122.0
6772
fastapi-cli>=0.0.24
6873
fastapi-cloud-cli>=0.13.0
6974
fastar>=0.8.0
75+
fastjsonschema>=2.21.2
7076
filelock>=3.20.0
7177
flatbuffers>=25.9.23
7278
flax>=0.12.4
@@ -126,13 +132,18 @@ importlib_metadata>=8.7.0
126132
importlib_resources>=6.5.2
127133
iniconfig>=2.3.0
128134
interegular>=0.3.3
135+
ipykernel>=7.2.0
136+
ipython>=9.10.0
137+
ipython_pygments_lexers>=1.1.1
138+
ipywidgets>=8.1.8
129139
isort>=7.0.0
130140
jaraco.classes>=3.4.0
131141
jaraco.context>=6.1.0
132142
jaraco.functools>=4.3.0
133143
jax>=0.8.3
134144
jaxlib>=0.8.3
135145
jaxtyping>=0.3.3
146+
jedi>=0.19.2
136147
jeepney>=0.9.0
137148
Jinja2>=3.1.6
138149
jiter>=0.13.0
@@ -141,6 +152,9 @@ joblib>=1.5.2
141152
jsonlines>=4.0.0
142153
jsonschema>=4.26.0
143154
jsonschema-specifications>=2025.9.1
155+
jupyter_client>=8.8.0
156+
jupyter_core>=5.9.1
157+
jupyterlab_widgets>=3.0.16
144158
kagglehub>=0.3.13
145159
keras>=3.12.0
146160
keyring>=25.7.0
@@ -162,6 +176,7 @@ MarkupSafe>=3.0.3
162176
marshmallow>=3.26.2
163177
math-verify>=0.9.0
164178
matplotlib>=3.10.7
179+
matplotlib-inline>=0.2.1
165180
mccabe>=0.7.0
166181
mcp>=1.26.0
167182
mdurl>=0.1.2
@@ -178,6 +193,8 @@ multidict>=6.7.0
178193
multiprocess>=0.70.18
179194
mypy_extensions>=1.1.0
180195
namex>=0.1.0
196+
nbclient>=0.10.4
197+
nbformat>=5.10.4
181198
nest-asyncio>=1.6.0
182199
networkx>=3.6
183200
ninja>=1.13.0
@@ -186,7 +203,7 @@ nltk>=3.9.2
186203
nodeenv>=1.9.1
187204
numba>=0.62.1
188205
numpy>=2.2.6
189-
numpy-typing-compat>=20250818.2.0
206+
numpy-typing-compat>=20250818.2.2
190207
nvidia-cublas-cu12>=12.8.4.1
191208
nvidia-cuda-cupti-cu12>=12.8.90
192209
nvidia-cuda-nvrtc-cu12>=12.8.93
@@ -218,7 +235,7 @@ opentelemetry-exporter-prometheus>=0.60b1
218235
opentelemetry-proto>=1.39.1
219236
opentelemetry-sdk>=1.39.1
220237
opentelemetry-semantic-conventions>=0.60b1
221-
opentelemetry-semantic-conventions-ai>=0.4.14
238+
opentelemetry-semantic-conventions-ai>=0.4.15
222239
opt_einsum>=3.4.0
223240
optax>=0.2.6
224241
optree>=0.18.0
@@ -228,23 +245,29 @@ orbax-export>=0.0.8
228245
outlines_core>=0.2.11
229246
packaging>=26.0
230247
pandas>=2.3.3
248+
papermill>=2.7.0
231249
parameterized>=0.9.0
250+
parso>=0.8.6
232251
partial-json-parser>=0.2.1.1.post7
233252
pathspec>=0.12.1
234253
pathwaysutils>=0.1.4
235254
perfetto>=0.16.0
255+
pexpect>=4.9.0
236256
pillow>=12.0.0
237-
platformdirs>=4.5.0
257+
platformdirs>=4.9.2
238258
pluggy>=1.6.0
239259
portpicker>=1.6.0
240260
pre_commit>=4.5.0
241261
prometheus-fastapi-instrumentator>=7.1.0
242262
prometheus_client>=0.23.1
243263
promise>=2.3
264+
prompt_toolkit>=3.0.52
244265
propcache>=0.4.1
245266
proto-plus>=1.26.1
246267
protobuf>=5.29.6
247-
psutil>=7.1.3
268+
psutil>=7.2.2
269+
ptyprocess>=0.7.0
270+
pure_eval>=0.2.3
248271
py-cpuinfo>=9.0.0
249272
py-spy>=0.4.1
250273
pyarrow>=22.0.0
@@ -316,18 +339,19 @@ sniffio>=1.3.1
316339
sortedcontainers>=2.4.0
317340
sse-starlette>=3.2.0
318341
starlette>=0.50.0
342+
stack-data>=0.6.3
319343
supervisor>=4.3.0
320344
sympy>=1.14.0
321345
tabulate>=0.9.0
322-
tenacity>=9.1.2
323-
tensorboard>=2.19.0
346+
tenacity>=9.1.4
347+
tensorboard>=2.20.0
324348
tensorboard-data-server>=0.7.2
325349
tensorboard-plugin-profile>=2.13.0
326350
tensorboardX>=2.6.4
327-
tensorflow>=2.19.1
351+
tensorflow>=2.20.0
328352
tensorflow-datasets>=4.9.9
329353
tensorflow-metadata>=1.17.2
330-
tensorflow-text>=2.19.0
354+
tensorflow-text>=2.20.0
331355
tensorstore>=0.1.79
332356
termcolor>=3.2.0
333357
tiktoken>=0.12.0
@@ -339,8 +363,10 @@ toolz>=1.1.0
339363
torch>=2.9.0
340364
torchax>=0.0.11
341365
torchvision>=0.24.0
366+
tornado>=6.5.4
342367
tpu-info>=0.7.1
343-
tqdm>=4.67.1
368+
tqdm>=4.67.3
369+
traitlets>=5.14.3
344370
transformers>=4.57.1
345371
treescope>=0.1.10
346372
triton>=3.5.0
@@ -358,9 +384,11 @@ uvloop>=0.22.1
358384
virtualenv>=20.35.4
359385
wadler_lindig>=0.1.7
360386
watchfiles>=1.1.1
387+
wcwidth>=0.6.0
361388
websockets>=15.0.1
362389
Werkzeug>=3.1.3
363390
wheel>=0.46.3
391+
widgetsnbextension>=4.0.15
364392
wrapt>=2.0.1
365393
xgrammar>=0.1.29
366394
xprof>=2.21.1

0 commit comments

Comments
 (0)