Skip to content

Commit 80762f9

Browse files
Merge pull request #2865 from ROCm:rocm-main-pr
PiperOrigin-RevId: 863778682
2 parents f44534f + 1f56029 commit 80762f9

139 files changed

Lines changed: 1855 additions & 725 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/run_jupyter_notebooks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ jobs:
9191
HF_TOKEN: ${{ secrets.HF_TOKEN }}
9292
run: |
9393
MAXTEXT_REPO_ROOT=$(pwd)
94-
MAXTEXT_NOTEBOOKS_ROOT="$MAXTEXT_REPO_ROOT/src/MaxText/examples"
94+
MAXTEXT_NOTEBOOKS_ROOT="$MAXTEXT_REPO_ROOT/src/maxtext/examples"
9595
9696
for notebook in "$MAXTEXT_NOTEBOOKS_ROOT"/{sft,rl}*.ipynb; do
9797
filename=$(basename "$notebook")

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ Check out our [Read The Docs site](https://maxtext.readthedocs.io/en/latest/) or
3535
See our installation guide to [install MaxText with pip from PyPI](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-pypi-recommended).
3636

3737
## Decoupled mode
38-
See our guide on running MaxText in decoupled mode, without any GCP dependencies in [Decoupled Mode Guide](https://maxtext.readthedocs.io/en/latest/guides/run_maxtext/decoupled_mode.html).
38+
See our guide on running MaxText in decoupled mode, without any GCP dependencies in [Decoupled Mode Guide](https://maxtext.readthedocs.io/en/latest/run_maxtext/decoupled_mode.html).
3939

4040
<!-- NEWS START -->
4141

4242
## 🔥 Latest news 🔥
4343

4444
* \[December 22, 2025\] [Muon optimizer](https://kellerjordan.github.io/posts/muon) is now supported.
4545
* \[December 10, 2025\] DeepSeek V3.1 is now supported. Use existing configs for [DeepSeek V3 671B](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/models/deepseek3-671b.yml) and load in V3.1 checkpoint to use model.
46-
* \[December 9, 2025\] [New RL and SFT Notebook tutorials](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/examples) are available.
46+
* \[December 9, 2025\] [New RL and SFT Notebook tutorials](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/examples) are available.
4747
* \[December 4, 2025\] The [ReadTheDocs documentation site](https://maxtext.readthedocs.io/en/latest/index.html) has been reorganized.
4848
* \[December 3, 2025\] Multi-host support for GSPO and GRPO is now available via [new RL tutorials](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/rl_on_multi_host.html).
4949
* \[November 20, 2025\] A new guide, [What is Post Training in MaxText?](https://maxtext.readthedocs.io/en/latest/tutorials/post_training_index.html), is now available.

codecov.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ fixes:
3434
ignore:
3535
- "src/maxtext/assets"
3636
- "src/MaxText/configs"
37-
- "src/MaxText/examples"
37+
- "src/maxtext/examples"
3838
- "src/MaxText/experimental"
3939
- "src/maxtext/inference"
4040
- "src/maxtext/scratch_code"
@@ -65,7 +65,7 @@ coverage:
6565
patch:
6666
default:
6767
target: auto
68-
threshold: 5% # fail on 5+ percent degradation
68+
threshold: 10% # fail on 10+ percent degradation
6969
flags:
7070
- regular
7171

dependencies/requirements/requirements_decoupled_jax_0_7.1.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ flax
88
grain>=0.2.12
99
grpcio>=1.75.1
1010
huggingface_hub>=0.35.3
11+
jax==0.7.1
1112
jaxtyping>=0.3.3
1213
jsonlines>=4.0.0
1314
matplotlib>=3.10.3
@@ -19,6 +20,7 @@ omegaconf>=2.3.0
1920
optax>=0.2.6
2021
orbax-checkpoint>=0.11.25
2122
pandas>=2.3.3
23+
parameterized==0.9.0
2224
pathwaysutils>=0.1.3
2325
pillow>=11.3.0
2426
protobuf>=5.29.5
@@ -39,5 +41,4 @@ tiktoken>=0.12.0
3941
tqdm>=4.67.1
4042
transformers>=4.57.0
4143
urllib3>=2.5.0
42-
jax==0.7.1
43-
git+https://github.com/google/tunix.git
44+
git+https://github.com/google/tunix.git

docs/guides/run_python_notebook.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Before proceeding, please verify that the specific notebook you are running work
4343

4444
### Step 1: Choose an Example
4545

46-
1.a. Visit the [MaxText examples directory](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/examples) on Github.
46+
1.a. Visit the [MaxText examples directory](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/examples) on Github.
4747

4848
1.b. Find the notebook you want to run (e.g., `sft_qwen3_demo.ipynb`) and copy its URL.
4949

docs/run_maxtext/decoupled_mode.md

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,38 +14,40 @@
1414
limitations under the License.
1515
-->
1616

17-
1817
# Via Decoupled Mode (No Google Cloud Dependencies)
1918

2019
Set `DECOUPLE_GCLOUD=TRUE` to run MaxText tests and local development without any Google Cloud SDK, `gs://` buckets, JetStream, or Vertex AI integrations.
2120

2221
When enabled:
23-
* Skips external integration tests with markers:
24-
* `external_serving` (`jetstream`, `serving`, `decode_server`)
25-
* `external_training` (`goodput`)
26-
* `decoupled` – Applied by `tests/conftest.py` to tests that are runnable in decoupled mode (i.e. not skipped for TPU or external markers).
27-
* Production / serving entrypoints (`decode.py`, `maxengine_server.py`, `maxengine_config.py`, tokenizer access in `maxengine.py`) **fail fast with a clear RuntimeError** when decoupled. This prevents accidentally running partial serving logic locally when decoupled mode is ON.
28-
* Import-time safety is preserved by lightweight stubs returned from `decouple.py` (so modules import cleanly); only active use of missing functionality raises.
29-
* Conditionally replaces dataset paths in certain tests to point at minimal local datasets.
30-
* Uses a local base output directory (users can override with `LOCAL_BASE_OUTPUT`).
31-
* All tests that previously hard-coded `configs/base.yml` now use the helper `get_test_config_path()` from `tests/utils/test_helper.py`. This helper ensures usage of `decoupled_base_test.yml`.
22+
23+
- Skips external integration tests with markers:
24+
- `external_serving` (`jetstream`, `serving`, `decode_server`)
25+
- `external_training` (`goodput`)
26+
- `decoupled` – Applied by `tests/conftest.py` to tests that are runnable in decoupled mode (i.e. not skipped for TPU or external markers).
27+
- Production / serving entrypoints (`decode.py`, `maxengine_server.py`, `maxengine_config.py`, tokenizer access in `maxengine.py`) **fail fast with a clear RuntimeError** when decoupled. This prevents accidentally running partial serving logic locally when decoupled mode is ON.
28+
- Import-time safety is preserved by lightweight stubs returned from `decouple.py` (so modules import cleanly); only active use of missing functionality raises.
29+
- Conditionally replaces dataset paths in certain tests to point at minimal local datasets.
30+
- Uses a local base output directory (users can override with `LOCAL_BASE_OUTPUT`).
31+
- All tests that previously hard-coded `configs/base.yml` now use the helper `get_test_config_path()` from `tests/utils/test_utils.py`. This helper ensures usage of `decoupled_base_test.yml`.
3232

3333
Minimal datasets included (checked into the repo):
34-
* ArrayRecord shards: generated via `python local_datasets/get_minimal_c4_en_dataset.py`,
34+
35+
- ArrayRecord shards: generated via `python local_datasets/get_minimal_c4_en_dataset.py`,
3536
located in `local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-{train,validation}.array_record-*`
36-
* Parquet (HF style): generated via `python local_datasets/get_minimal_hf_c4_parquet.py`,
37+
- Parquet (HF style): generated via `python local_datasets/get_minimal_hf_c4_parquet.py`,
3738
located in `local_datasets/c4_en_dataset_minimal/hf/c4`
3839

39-
4040
Run a local smoke test fully offline:
41+
4142
```bash
4243
export DECOUPLE_GCLOUD=TRUE
4344
pytest -k train_gpu_smoke_test -q
4445
```
4546

4647
Optional environment variables:
47-
* `LOCAL_GCLOUD_PROJECT` - placeholder project string (default: `local-maxtext-project`).
48-
* `LOCAL_BASE_OUTPUT` - override default local output directory used in tests.
48+
49+
- `LOCAL_GCLOUD_PROJECT` - placeholder project string (default: `local-maxtext-project`).
50+
- `LOCAL_BASE_OUTPUT` - override default local output directory used in tests.
4951

5052
## Centralized Decoupling API (`gcloud_stub.py`)
5153

@@ -55,32 +57,36 @@ MaxText exposes a single module `MaxText.gcloud_stub` to avoid scattering enviro
5557
from MaxText.gcloud_stub import is_decoupled, cloud_diagnostics, jetstream
5658

5759
if is_decoupled():
58-
# Skip optional integrations or use local fallbacks
59-
pass
60+
# Skip optional integrations or use local fallbacks
61+
pass
6062

6163
# Cloud diagnostics (returns diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration)
62-
diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = cloud_diagnostics()
64+
diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = (
65+
cloud_diagnostics()
66+
)
6367

6468
# JetStream (serving) components
6569
config_lib, engine_api, token_utils, tokenizer_api, token_params_ns = jetstream()
6670
TokenizerParameters = getattr(token_params_ns, "TokenizerParameters", object)
6771
```
6872

6973
Behavior when `DECOUPLE_GCLOUD=TRUE`:
70-
* `is_decoupled()` returns True.
71-
* Each helper returns lightweight stubs whose attributes are safe to access; calling methods raises a clear `RuntimeError` only when actually invoked.
72-
* Prevents import-time failures for optional dependencies (JetStream).
74+
75+
- `is_decoupled()` returns True.
76+
- Each helper returns lightweight stubs whose attributes are safe to access; calling methods raises a clear `RuntimeError` only when actually invoked.
77+
- Prevents import-time failures for optional dependencies (JetStream).
7378

7479
## Guidelines:
75-
* Prefer calling `jetstream()` / `cloud_diagnostics()` once at module import and branching on `is_decoupled()` for functionality that truly requires the dependency.
76-
* Use `is_decoupled()` to avoid direct `os.environ["DECOUPLE_GCLOUD"]` checking.
77-
* Use `get_test_config_path()` instead of hard-coded `base.yml`.
78-
* Prefer conditional local fallbacks for cloud buckets and avoid introducing direct `gs://...` paths.
79-
* Please add the appropriate external dependency marker (`external_serving` or `external_training`) for new tests. Prefer the smallest scope instead of module-wide `pytestmark` when only a part of a file needs an external dependency.
80-
* Tests add a `decoupled` marker if DECOUPLE_GCLOUD && not marked with external dependency markers. Run tests with:
80+
81+
- Prefer calling `jetstream()` / `cloud_diagnostics()` once at module import and branching on `is_decoupled()` for functionality that truly requires the dependency.
82+
- Use `is_decoupled()` to avoid direct `os.environ["DECOUPLE_GCLOUD"]` checking.
83+
- Use `get_test_config_path()` instead of hard-coded `base.yml`.
84+
- Prefer conditional local fallbacks for cloud buckets and avoid introducing direct `gs://...` paths.
85+
- Please add the appropriate external dependency marker (`external_serving` or `external_training`) for new tests. Prefer the smallest scope instead of module-wide `pytestmark` when only a part of a file needs an external dependency.
86+
- Tests add a `decoupled` marker if DECOUPLE_GCLOUD && not marked with external dependency markers. Run tests with:
87+
8188
```
8289
pytest -m decoupled -vv tests
8390
```
8491

8592
This centralized approach keeps optional integrations cleanly separated from core MaxText logic, making local development (e.g. on ROCm/NVIDIA GPUs) frictionless.
86-

docs/tutorials/first_run.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ In the same TPU VM where you just installed all the dependencies of MaxText, You
7575

7676
#### Decoding in MaxText via notebook
7777

78-
You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint.
78+
You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint.
7979

8080
### Run MaxText on NVIDIA GPUs
8181

docs/tutorials/posttraining/multimodal.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ This document provides a guide to use the multimodal functionalities in MaxText
66
- **Multimodal Decode**: Inference with text+images as input.
77
- **Supervised Fine-Tuning (SFT)**: Apply SFT to the model using a visual-question-answering dataset.
88

9-
We also provide a [colab](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/multimodal_gemma3_demo.ipynb) for multimodal features demonstration. The following table provides a list of models and modalities we currently support:
9+
We also provide a [colab](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/multimodal_gemma3_demo.ipynb) for multimodal features demonstration. The following table provides a list of models and modalities we currently support:
1010

1111
| Models | Input Modalities | Output Modalities |
1212
| :--------------------------------------------- | :--------------- | :---------------- |

src/MaxText/configs/decoupled_base_test.yml

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Decoupled base test config: used when DECOUPLE_GCLOUD=TRUE for tests that previously relied on base.yml.
2-
# Inherit all model defaults from base.yml but override any cloud-coupled paths and disable optional cloud features.
3-
base_config: base.yml
2+
# Inherit all model defaults (PyDantic already does this) but override any cloud-coupled paths and disable
3+
# optional cloud features.
44

55
# Output goes to a local relative directory so tests do not require GCS.
6-
base_output_directory: ./maxtext_local_output
6+
base_output_directory: ./maxtext_local_output/gcloud_decoupled_test_logs
77
run_name: test_decoupled
88

99
# Disable checkpointing by default for speed unless a test explicitly enables it.
@@ -23,7 +23,9 @@ profile_periodically_period: 0
2323
profiler_steps: 0
2424

2525
# Leave dataset-related keys to be overridden by individual tests.
26-
dataset_type: ""
26+
dataset_path: "tests/assets/local_datasets/c4_en_dataset_minimal/"
27+
dataset_name: 'c4/en:3.1.0'
28+
eval_dataset_name: 'c4/en:3.1.0'
2729

2830
# Use dot_product attention to avoid GPU Pallas shared memory limits on AMD GPUs
2931
attention: "dot_product"
@@ -44,6 +46,8 @@ ici_tensor_sequence_parallelism: 1
4446
ici_autoregressive_parallelism: 1
4547
ici_fsdp_parallelism: 1
4648
ici_fsdp_transpose_parallelism: 1
49+
# Allow higher unsharded parameter percentage for small device count
50+
sharding_tolerance: 0.3
4751

4852
# DCN dimensions to 1 (no multi-slice expectation locally).
4953
dcn_data_parallelism: 1
@@ -68,12 +72,4 @@ goodput_upload_interval_seconds: 0
6872
enable_pathways_goodput: false
6973
enable_gcp_goodput_metrics: false
7074

71-
# Disable any cloud logging / BigQuery or external metric uploads.
72-
enable_cloud_logging: false
73-
upload_metrics_to_bigquery: false
74-
bigquery_project: ""
75-
bigquery_dataset: ""
76-
bigquery_table: ""
77-
78-
# Force local-only behavior for tests: avoid accidental env pickup.
79-
tensorboard_dir: "./maxtext_local_output/tensorboard"
75+
tensorboard_dir: "./maxtext_local_output/gcloud_decoupled_test_logs/tensorboard"

src/MaxText/configs/rl.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ reasoning_start_token: '<reasoning>'
171171
reasoning_end_token: '</reasoning>'
172172
solution_start_token: '<answer>'
173173
solution_end_token: '</answer>'
174-
chat_template_path: 'src/MaxText/examples/chat_templates/gsm8k_rl.json'
174+
chat_template_path: 'src/maxtext/examples/chat_templates/gsm8k_rl.json'
175175
skip_jax_distributed_system: True
176176

177177
# # TODO(@mazumdera): fix this

0 commit comments

Comments
 (0)