Skip to content

Commit 50ec5a0

Browse files
Merge branch 'AI-Hypercomputer:main' into pr/decoupling-core
2 parents 4c1423a + b279b99 commit 50ec5a0

38 files changed

Lines changed: 1166 additions & 602 deletions

.github/workflows/run_tests_internal.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ jobs:
8181
python3 -m pip install -e . --no-dependencies
8282
[ "${{ inputs.total_workers }}" -gt 1 ] && python3 -m pip install --quiet pytest-split && SPLIT_ARGS="--splits ${{ inputs.total_workers }} --group ${{ inputs.worker_group }}" || SPLIT_ARGS=""
8383
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
84-
python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" --durations=0 $SPLIT_ARGS
84+
python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" --durations=0 --deselect "tests/aot_hlo_identical_test.py::AotHloIdenticalTest::test_default_hlo_match" $SPLIT_ARGS

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
2323
MaxText is a high performance, highly scalable, open-source LLM library and reference implementation written in pure Python/[JAX](https://docs.jax.dev/en/latest/jax-101.html) and targeting Google Cloud TPUs and GPUs for training.
2424

25-
MaxText provides a library of high performance models to choose from, including Gemma, Llama, DeepSeek, Qwen, and Mistral. For each of these models, MaxText supports pre-training (up to tens of thousands of chips) and scalable post-training, with popular techniques like Supervised Fine-Tuning (SFT) and Group Relative Policy Optimization (GRPO, a type of Reinforcement Learning).
25+
MaxText provides a library of high performance models to choose from, including Gemma, Llama, DeepSeek, Qwen, and Mistral. For each of these models, MaxText supports pre-training (up to tens of thousands of chips) and scalable post-training, with popular techniques like Supervised Fine-Tuning (SFT) and Group Relative Policy Optimization (GRPO, a type of Reinforcement Learning) and Group Sequence Policy Optimization (GSPO, a type of Reinforcement Learning).
2626

2727
MaxText achieves high Model FLOPs Utilization (MFU) and tokens/second from single host to very large clusters while staying simple and largely "optimization-free" thanks to the power of JAX and the XLA compiler.
2828

@@ -73,7 +73,7 @@ Our goal is to provide a variety of models (dimension “a”) and techniques (d
7373
Check out these getting started guides:
7474

7575
* [SFT](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama3.1/8b/run_sft.sh) (Supervised Fine Tuning)
76-
* [GRPO](https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html) (Group Relative Policy Optimization)
76+
* [GRPO / GSPO](https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html) (Group Relative & Group Sequence Policy Optimization – pass `loss_algo=gspo-token` to run GSPO)
7777

7878
### Model library
7979

dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ RUN if [ "$DEVICE" = "tpu" ]; then \
5252
python3 -m pip install 'google-tunix>=0.1.2'; \
5353
fi
5454

55+
# Temporarily downgrade to JAX=0.7.2 for GPU images
56+
RUN if [ "$DEVICE" = "gpu" ]; then \
57+
python3 -m pip install -U "jax[cuda12]==0.8.1"; \
58+
python3 -m pip install -U "transformer-engine-cu12" "transformer-engine-jax" "transformer-engine"; \
59+
fi
60+
5561
# Now copy the remaining code (source files that may change frequently)
5662
COPY . .
5763

dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,11 @@ RUN pip install -e /tunix --no-cache-dir
3333

3434

3535
COPY vllm /vllm
36-
RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \
37-
--extra-index-url https://pypi.org/simple/ \
38-
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
39-
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
40-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
41-
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
42-
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
43-
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
44-
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
36+
RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir
4537

4638

4739
COPY tpu-inference /tpu-inference
48-
RUN pip install -e /tpu-inference --no-cache-dir --pre \
49-
--extra-index-url https://pypi.org/simple/ \
50-
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
51-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
40+
RUN pip install -e /tpu-inference --no-cache-dir
5241

5342
RUN pip install --no-deps qwix==0.1.4
5443

dependencies/requirements/generated_requirements/cuda12-requirements.txt

Lines changed: 67 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -4,37 +4,37 @@
44
absl-py>=2.3.1
55
aiofiles>=25.1.0
66
aiohappyeyeballs>=2.6.1
7-
aiohttp>=3.13.1
7+
aiohttp>=3.13.2
88
aiosignal>=1.4.0
9-
annotated-doc>=0.0.3
9+
annotated-doc>=0.0.4
1010
annotated-types>=0.7.0
1111
antlr4-python3-runtime>=4.9.3
1212
anyio>=4.11.0
1313
aqtp>=0.9.0
14-
array-record>=0.8.2
15-
astroid>=4.0.1
14+
array-record>=0.8.3
15+
astroid>=4.0.2
1616
astunparse>=1.6.3
1717
attrs>=25.4.0
18-
auditwheel>=6.4.2
18+
auditwheel>=6.5.0
1919
black>=24.10.0
2020
blobfile>=3.1.0
2121
build>=1.3.0
22-
cachetools>=6.2.1
23-
certifi>=2025.10.5
24-
cfgv>=3.4.0
22+
cachetools>=6.2.2
23+
certifi>=2025.11.12
24+
cfgv>=3.5.0
2525
charset-normalizer>=3.4.4
26-
cheroot>=11.0.0
26+
cheroot>=11.1.2
2727
chex>=0.1.91
28-
click>=8.3.0
28+
click>=8.3.1
2929
cloud-accelerator-diagnostics>=0.1.1
3030
cloud-tpu-diagnostics>=0.1.5
31-
cloudpickle>=3.1.1
31+
cloudpickle>=3.1.2
3232
clu>=0.0.12
3333
colorama>=0.4.6
3434
contourpy>=1.3.3
35-
coverage>=7.11.0
35+
coverage>=7.12.0
3636
cycler>=0.12.1
37-
datasets>=4.3.0
37+
datasets>=4.4.1
3838
decorator>=5.2.1
3939
dill>=0.4.0
4040
distlib>=0.4.0
@@ -45,41 +45,40 @@ einops>=0.8.1
4545
einshape>=1.0
4646
etils>=1.13.0
4747
evaluate>=0.4.6
48-
execnet>=2.1.1
49-
fastapi>=0.120.1
48+
execnet>=2.1.2
49+
fastapi>=0.122.0
5050
filelock>=3.20.0
5151
flatbuffers>=25.9.23
52-
flax>=0.12.0
52+
flax>=0.12.1
5353
fonttools>=4.60.1
5454
frozenlist>=1.8.0
55-
fsspec>=2025.9.0
55+
fsspec>=2025.10.0
5656
gast>=0.6.0
57-
gcsfs>=2025.9.0
58-
google-api-core>=2.28.0
59-
google-api-python-client>=2.185.0
60-
google-auth-httplib2>=0.2.0
57+
gcsfs>=2025.10.0
58+
google-api-core>=2.28.1
59+
google-api-python-client>=2.187.0
60+
google-auth-httplib2>=0.2.1
6161
google-auth-oauthlib>=1.2.2
62-
google-auth>=2.41.1
63-
google-benchmark>=1.9.4
64-
google-cloud-aiplatform>=1.122.0
62+
google-auth>=2.43.0
63+
google-cloud-aiplatform>=1.128.0
6564
google-cloud-appengine-logging>=1.7.0
6665
google-cloud-audit-log>=0.4.0
6766
google-cloud-bigquery>=3.38.0
68-
google-cloud-core>=2.4.3
67+
google-cloud-core>=2.5.0
6968
google-cloud-logging>=3.12.1
7069
google-cloud-monitoring>=2.28.0
7170
google-cloud-resource-manager>=1.15.0
72-
google-cloud-storage>=2.19.0
71+
google-cloud-storage>=3.6.0
7372
google-crc32c>=1.7.1
74-
google-genai>=1.46.0
73+
google-genai>=1.52.0
7574
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
7675
google-pasta>=0.2.0
77-
google-resumable-media>=2.7.2
78-
googleapis-common-protos>=1.71.0
79-
grain>=0.2.13
76+
google-resumable-media>=2.8.0
77+
googleapis-common-protos>=1.72.0
78+
grain>=0.2.15
8079
grpc-google-iam-v1>=0.14.3
8180
grpcio-status>=1.71.2
82-
grpcio>=1.75.1
81+
grpcio>=1.76.0
8382
gviz-api>=1.10.0
8483
h11>=0.16.0
8584
h5py>=3.15.1
@@ -96,43 +95,42 @@ immutabledict>=4.2.2
9695
importlab>=0.8.1
9796
importlib-metadata>=8.7.0
9897
importlib-resources>=6.5.2
99-
iniconfig>=2.1.0
98+
iniconfig>=2.3.0
10099
isort>=7.0.0
101100
jaraco-functools>=4.3.0
102-
jax-cuda12-pjrt>=0.8.0 ; sys_platform == 'linux'
103-
jax-cuda12-plugin>=0.8.0 ; sys_platform == 'linux'
104-
jax-triton>=0.3.0
105-
jax>=0.8.0
106-
jaxlib>=0.8.0
101+
jax-cuda12-pjrt>=0.8.1 ; sys_platform == 'linux'
102+
jax-cuda12-plugin>=0.8.1 ; sys_platform == 'linux'
103+
jax>=0.8.1
104+
jaxlib>=0.8.1
107105
jaxtyping>=0.3.3
108106
jinja2>=3.1.6
109107
joblib>=1.5.2
110108
jsonlines>=4.0.0
111-
keras>=3.11.3
109+
keras>=3.12.0
112110
kiwisolver>=1.4.9
113111
libclang>=18.1.1
114-
libcst>=1.8.5
112+
libcst>=1.8.6
115113
lxml>=6.0.2
116114
markdown-it-py>=4.0.0
117-
markdown>=3.9
115+
markdown>=3.10
118116
markupsafe>=3.0.3
119117
matplotlib>=3.10.7
120118
mccabe>=0.7.0
121119
mdurl>=0.1.2
122120
ml-collections>=1.1.0
123-
ml-dtypes>=0.5.3
121+
ml-dtypes>=0.5.4
124122
ml-goodput-measurement>=0.0.15
125123
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
126124
more-itertools>=10.8.0
127125
mpmath>=1.3.0
128126
msgpack>=1.1.2
129-
msgspec>=0.19.0
127+
msgspec>=0.20.0
130128
multidict>=6.7.0
131-
multiprocess>=0.70.16
129+
multiprocess>=0.70.18
132130
mypy-extensions>=1.1.0
133131
namex>=0.1.0
134132
nest-asyncio>=1.6.0
135-
networkx>=3.5
133+
networkx>=3.6
136134
ninja>=1.13.0
137135
nltk>=3.9.2
138136
nodeenv>=1.9.1
@@ -143,21 +141,21 @@ nvidia-cuda-cupti-cu12>=12.9.79 ; sys_platform == 'linux'
143141
nvidia-cuda-nvcc-cu12>=12.9.86 ; sys_platform == 'linux'
144142
nvidia-cuda-nvrtc-cu12>=12.9.86 ; sys_platform == 'linux'
145143
nvidia-cuda-runtime-cu12>=12.9.79 ; sys_platform == 'linux'
146-
nvidia-cudnn-cu12>=9.14.0.64 ; sys_platform == 'linux'
144+
nvidia-cudnn-cu12>=9.16.0.29 ; sys_platform == 'linux'
147145
nvidia-cufft-cu12>=11.4.1.4 ; sys_platform == 'linux'
148146
nvidia-cusolver-cu12>=11.7.5.82 ; sys_platform == 'linux'
149147
nvidia-cusparse-cu12>=12.5.10.65 ; sys_platform == 'linux'
150-
nvidia-nccl-cu12>=2.28.3 ; sys_platform == 'linux'
148+
nvidia-nccl-cu12>=2.28.9 ; sys_platform == 'linux'
151149
nvidia-nvjitlink-cu12>=12.9.86 ; sys_platform == 'linux'
152150
nvidia-nvshmem-cu12>=3.4.5 ; sys_platform == 'linux'
153151
oauthlib>=3.3.1
154152
omegaconf>=2.3.0
155153
opentelemetry-api>=1.38.0
156154
opt-einsum>=3.4.0
157155
optax>=0.2.6
158-
optree>=0.17.0
156+
optree>=0.18.0
159157
optype>=0.14.0
160-
orbax-checkpoint>=0.11.26
158+
orbax-checkpoint>=0.11.28
161159
packaging>=25.0
162160
pandas>=2.3.3
163161
parameterized>=0.9.0
@@ -167,26 +165,26 @@ pillow>=12.0.0
167165
platformdirs>=4.5.0
168166
pluggy>=1.6.0
169167
portpicker>=1.6.0
170-
pre-commit>=4.3.0
168+
pre-commit>=4.5.0
171169
prometheus-client>=0.23.1
172170
promise>=2.3
173171
propcache>=0.4.1
174172
proto-plus>=1.26.1
175173
protobuf>=5.29.5
176-
psutil>=7.1.0
174+
psutil>=7.1.3
177175
pyarrow>=22.0.0
178176
pyasn1-modules>=0.4.2
179177
pyasn1>=0.6.1
180178
pycnite>=2024.7.31
181179
pycryptodomex>=3.23.0
182-
pydantic-core>=2.41.4
183-
pydantic>=2.12.3
180+
pydantic-core>=2.41.5
181+
pydantic>=2.12.5
184182
pydot>=4.0.1
185183
pyelftools>=0.32
186184
pyglove>=0.4.5
187185
pygments>=2.19.2
188186
pyink>=24.10.1
189-
pylint>=4.0.2
187+
pylint>=4.0.3
190188
pyparsing>=3.2.5
191189
pyproject-hooks>=1.2.0
192190
pytest-xdist>=3.8.0
@@ -195,15 +193,15 @@ python-dateutil>=2.9.0.post0
195193
pytype>=2024.10.11
196194
pytz>=2025.2
197195
pyyaml>=6.0.3
198-
qwix>=0.1.1
199-
regex>=2025.10.23
196+
qwix>=0.1.4
197+
regex>=2025.11.3
200198
requests-oauthlib>=2.0.0
201199
requests>=2.32.5
202200
rich>=14.2.0
203201
rsa>=4.9.1
204-
safetensors>=0.6.2
205-
scipy-stubs>=1.16.2.4
206-
scipy>=1.16.2
202+
safetensors>=0.7.0
203+
scipy-stubs>=1.16.3.0
204+
scipy>=1.16.3
207205
sentencepiece>=0.2.1
208206
seqio>=0.0.20
209207
setuptools>=80.9.0
@@ -214,7 +212,7 @@ simplejson>=3.20.2
214212
six>=1.17.0
215213
sniffio>=1.3.1
216214
sortedcontainers>=2.4.0
217-
starlette>=0.48.0
215+
starlette>=0.50.0
218216
sympy>=1.14.0
219217
tabulate>=0.9.0
220218
tenacity>=9.1.2
@@ -226,35 +224,34 @@ tensorflow-datasets>=4.9.9
226224
tensorflow-metadata>=1.17.2
227225
tensorflow-text>=2.19.0
228226
tensorflow>=2.19.1
229-
tensorstore>=0.1.78
230-
termcolor>=3.1.0
227+
tensorstore>=0.1.79
228+
termcolor>=3.2.0
231229
tiktoken>=0.12.0
232-
tokamax>=0.0.4
230+
tokamax>=0.0.8
233231
tokenizers>=0.22.1
234232
toml>=0.10.2
235233
tomlkit>=0.13.3
236234
toolz>=1.1.0
237235
tqdm>=4.67.1
238-
transformer-engine-cu12>=2.8.0
239-
transformer-engine-jax>=2.8.0
240-
transformer-engine>=2.8.0
241-
transformers>=4.57.1
236+
transformer-engine-cu12>=2.9.0
237+
transformer-engine-jax>=2.9.0
238+
transformer-engine>=2.9.0
239+
transformers>=4.57.3
242240
treescope>=0.1.10
243-
triton>=3.5.0
244241
typeguard>=2.13.3
245242
typing-extensions>=4.15.0
246243
typing-inspection>=0.4.2
247244
tzdata>=2025.2
248245
uritemplate>=4.2.0
249246
urllib3>=2.5.0
250247
uvicorn>=0.38.0
251-
virtualenv>=20.35.3
248+
virtualenv>=20.35.4
252249
wadler-lindig>=0.1.7
253250
websockets>=15.0.1
254251
werkzeug>=3.1.3
255252
wheel>=0.45.1
256-
wrapt>=2.0.0
257-
xprof>=2.20.7
253+
wrapt>=2.0.1
254+
xprof>=2.21.1
258255
xxhash>=3.6.0
259256
yarl>=1.22.0
260257
zipp>=3.23.0

0 commit comments

Comments
 (0)