Skip to content

Commit 1c12384

Browse files
authored
Merge branch 'main' into landing-page
2 parents b8dbf70 + b279b99 commit 1c12384

78 files changed

Lines changed: 2259 additions & 988 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ src/MaxText/inference_mlperf @vipannalla @mitalisi @gpolovets1 @mailvijayasingh
3333
.github/workflows @gobbleturk @khatwanimohit @shralex @parambole @bvandermoon @richjames0
3434

3535
# Benchmarking/Recipes
36-
benchmarks @SujeethJinesh @bvandermoon @richjames0 @shralex @vipannalla @mitalisi @RissyRan @shauryagup @NuojCheng @gobbleturk @khatwanimohit @Obliviour
36+
benchmarks @SujeethJinesh @bvandermoon @richjames0 @shralex @vipannalla @mitalisi @RissyRan @shauryagup @NuojCheng @gobbleturk @khatwanimohit @Obliviour @notabee

.github/workflows/run_pathways_tests_internal.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ jobs:
7575
python3 -m pip install -e . --no-dependencies &&
7676
python3 -m pip uninstall -y libtpu &&
7777
# TODO(b/454659463): Enable test_default_hlo_match after volume mount is supported.
78-
python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" -k "not AotHloIdenticalTest" --durations=0
78+
python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" -k "not AotHloIdenticalTest and not CompileThenLoad" --durations=0
7979
8080
services:
8181
resource_manager:

.github/workflows/run_tests_internal.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ jobs:
8181
python3 -m pip install -e . --no-dependencies
8282
[ "${{ inputs.total_workers }}" -gt 1 ] && python3 -m pip install --quiet pytest-split && SPLIT_ARGS="--splits ${{ inputs.total_workers }} --group ${{ inputs.worker_group }}" || SPLIT_ARGS=""
8383
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
84-
python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" --durations=0 $SPLIT_ARGS
84+
python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" --durations=0 --deselect "tests/aot_hlo_identical_test.py::AotHloIdenticalTest::test_default_hlo_match" $SPLIT_ARGS

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
2323
MaxText is a high performance, highly scalable, open-source LLM library and reference implementation written in pure Python/[JAX](https://docs.jax.dev/en/latest/jax-101.html) and targeting Google Cloud TPUs and GPUs for training.
2424

25-
MaxText provides a library of high performance models to choose from, including Gemma, Llama, DeepSeek, Qwen, and Mistral. For each of these models, MaxText supports pre-training (up to tens of thousands of chips) and scalable post-training, with popular techniques like Supervised Fine-Tuning (SFT) and Group Relative Policy Optimization (GRPO, a type of Reinforcement Learning).
25+
MaxText provides a library of high performance models to choose from, including Gemma, Llama, DeepSeek, Qwen, and Mistral. For each of these models, MaxText supports pre-training (up to tens of thousands of chips) and scalable post-training, with popular techniques like Supervised Fine-Tuning (SFT) and Group Relative Policy Optimization (GRPO, a type of Reinforcement Learning) and Group Sequence Policy Optimization (GSPO, a type of Reinforcement Learning).
2626

2727
MaxText achieves high Model FLOPs Utilization (MFU) and tokens/second from single host to very large clusters while staying simple and largely "optimization-free" thanks to the power of JAX and the XLA compiler.
2828

@@ -74,9 +74,10 @@ Check out these getting started guides:
7474
* Supervised Fine Tuning (SFT)
7575
* [SFT on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/sft.html)
7676
* [SFT on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/sft_on_multi_host.html)
77-
* Group Relative Policy Optimization (GRPO)
77+
* Group Relative & Group Sequence Policy Optimization (GRPO & GSPO)
7878
* [GRPO on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html)
79-
* [GRPO on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/grpo_with_pathways.html)
79+
* [GRPO on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/grpo_with_pathways.html)
80+
* [GSPO](https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html#run-gspo) (pass `loss_algo=gspo-token` to run GSPO)
8081

8182
### Model library
8283

benchmarks/maxtest/getting_started.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,21 @@ EXIT_CODE=0
4444

4545
- maxtest.sh will generate a YAML file in the directory that is passed to kubectl. This file can be modified and reused by running `kubectl apply -f maxtest.yaml`
4646

47+
### Passing custom libtpu or XLA flags ###
48+
49+
If we want to pass custom flags this is also possible by specifying
50+
`--libtpu_args`.
51+
52+
53+
#### Setting flags for SDC checking ####
54+
55+
Useful checking for the existence of SDC on TPU hardware.
56+
57+
```
58+
bash maxtest.sh --project $TPU_PROJECT --cluster $CLUSTER --region $REGION --nodepool $NODEPOOL_NAME --num_workers $NUM_WORKERS --libtpu_args '--xla_tpu_enable_sdc_checker'
59+
```
60+
61+
4762
### Debugging common job errors ###
4863

4964
If the job does not exit with `EXIT_CODE=0`, there is a failure among one of

benchmarks/maxtest/maxtest.sh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#!/bin/bash
1+
#!bin/bash
22

33
function usage() {
44
echo "error: $1"
@@ -15,6 +15,7 @@ while [[ "$#" > 0 ]]; do case $1 in
1515
-r|--region) GKE_REGION="$2";shift;shift;;
1616
--nodepool) NODEPOOL="$2";shift;shift;;
1717
--num_workers) NUM_WORKERS="$2";shift;shift;;
18+
--libtpu_args) LIBTPU_ARGS="$2";shift;shift;;
1819
*) usage "Unknown parameter passed: $1"; shift; shift;;
1920
esac; done
2021

@@ -32,19 +33,20 @@ if [ -z "$TPU_ACCELERATOR" ]; then exit; fi;
3233

3334
UUID=$(uuidgen)
3435
export JOB_NAME="${UUID:0:5}-maxtest"
35-
export DOCKER_IMAGE="gcr.io/cloud-tpu-images-public/tpu/healthscan"
36+
export DOCKER_IMAGE="us-docker.pkg.dev/cloud-tpu-images-public/tpu/healthscan:latest"
3637
export NODEPOOL
3738
export TPU_TOPOLOGY
3839
export TPU_ACCELERATOR
3940
export GKE_PROJECT
4041
export GKE_REGION
4142
export GKE_CLUSTER
43+
export LIBTPU_ARGS
4244

4345
export MEMORY_PER_HOST="407Gi"
4446
export TPU_CHIPS_PER_HOST=4
4547
export COMPLETIONS=$NUM_WORKERS # Number of VMs in the nodepool (v6e -> 2 VMs for v6e-8, v5p -> 1 VM for a v5p-8)
4648

47-
YAML_VARS='$JOB_NAME $DOCKER_IMAGE $NODEPOOL $TPU_TOPOLOGY $TPU_ACCELERATOR $COMPLETIONS $MEMORY_PER_HOST $TPU_CHIPS_PER_HOST $GKE_PROJECT $GKE_REGION $GKE_CLUSTER'
49+
YAML_VARS='$JOB_NAME $DOCKER_IMAGE $NODEPOOL $TPU_TOPOLOGY $TPU_ACCELERATOR $COMPLETIONS $MEMORY_PER_HOST $TPU_CHIPS_PER_HOST $GKE_PROJECT $GKE_REGION $GKE_CLUSTER $LIBTPU_ARGS'
4850

4951
envsubst "${YAML_VARS}" < maxtest.yaml.template > maxtest.yaml
5052

benchmarks/maxtest/maxtest.yaml.template

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ spec:
4242
_sigterm() (kill -SIGTERM $! 2>/dev/null;);
4343
trap _sigterm SIGTERM;
4444

45-
(export TPU_STDERR_LOG_LEVEL=0 && export TPU_MIN_LOG_LEVEL=0 && export TF_CPP_MIN_LOG_LEVEL=0 && python3 -m benchmarks.benchmark_runner healthscan --device_type=$TPU_ACCELERATOR_TYPE --base_output_directory=gke-healthscan-output --num_steps=5) & PID=$1;
45+
(export TPU_STDERR_LOG_LEVEL=0 && export TPU_MIN_LOG_LEVEL=0 && export TF_CPP_MIN_LOG_LEVEL=0 && echo LIBTPU_INIT_ARGS='$LIBTPU_ARGS' && export LIBTPU_INIT_ARGS='$LIBTPU_ARGS' && python3 -m benchmarks.benchmark_runner healthscan --device_type=$TPU_ACCELERATOR_TYPE --base_output_directory=gke-healthscan-output --num_steps=5) & PID=$1;
4646

4747
while kill -0 $PID 2>/dev/null;
4848
do sleep 5;

dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ RUN if [ "$DEVICE" = "tpu" ]; then \
5252
python3 -m pip install 'google-tunix>=0.1.2'; \
5353
fi
5454

55+
# Temporarily downgrade to JAX=0.7.2 for GPU images
56+
RUN if [ "$DEVICE" = "gpu" ]; then \
57+
python3 -m pip install -U "jax[cuda12]==0.8.1"; \
58+
python3 -m pip install -U "transformer-engine-cu12" "transformer-engine-jax" "transformer-engine"; \
59+
fi
60+
5561
# Now copy the remaining code (source files that may change frequently)
5662
COPY . .
5763

dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ RUN pip install numba==0.61.2
3030
# Install vLLM for Jax and TPUs
3131
RUN pip install vllm-tpu
3232

33+
RUN pip install --no-deps qwix==0.1.4
34+
3335
RUN if [ "$MODE" = "post-training-experimental" ]; then \
3436
pip uninstall -y jax jaxlib libtpu && \
3537
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \

dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,18 @@ RUN pip install keyring keyrings.google-artifactregistry-auth
2828
RUN pip install numba==0.61.2
2929

3030
COPY tunix /tunix
31+
RUN pip uninstall -y google-tunix
3132
RUN pip install -e /tunix --no-cache-dir
3233

3334

3435
COPY vllm /vllm
35-
RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \
36-
--extra-index-url https://pypi.org/simple/ \
37-
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
38-
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
39-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
40-
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
41-
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
42-
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
43-
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
36+
RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir
4437

4538

4639
COPY tpu-inference /tpu-inference
47-
RUN pip install -e /tpu-inference --no-cache-dir --pre \
48-
--extra-index-url https://pypi.org/simple/ \
49-
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
50-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
40+
RUN pip install -e /tpu-inference --no-cache-dir
5141

42+
RUN pip install --no-deps qwix==0.1.4
5243

5344
RUN if [ "$MODE" = "post-training-experimental" ]; then \
5445
echo "MODE=post-training-experimental: Re-installing JAX/libtpu"; \

0 commit comments

Comments
 (0)