Skip to content

Commit d5ea751

Browse files
Merge pull request #2814 from AI-Hypercomputer:hengtaoguo-links
PiperOrigin-RevId: 844535808
2 parents 377bf56 + 94a232e commit d5ea751

9 files changed

Lines changed: 16 additions & 17 deletions

File tree

docs/guides/optimization/benchmark_and_performance.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Arithmetic intensity is calculated as the ratio of floating-point operations (FL
2222

2323
This metric helps determine whether a computation is MXU-bound (high arithmetic intensity) or memory-bound/communication-bound (low arithmetic intensity).
2424

25-
[This sharding document](sharding_on_TPUs) illustrates various sharding strategies and their roofline analysis, through AI analysis.
25+
[This sharding document](sharding.md) illustrates various sharding strategies and their roofline analysis, through AI analysis.
2626

2727
## Metrics for benchmark analysis
2828

@@ -74,8 +74,7 @@ See [](quantization).
7474
### Choose sharding strategy
7575

7676
Sharding is crucial for optimizing model performance. MaxText offers various sharding strategies and hybrid options, including FSDP, TP, EP, CP, and PP, which can be configured through your MaxText settings.
77-
[This document](sharding_on_TPUs) illustrates in detail how sharding works in maxtext and chooses the right sharding config for your workload.
78-
[This document](sharding_on_TPUs) illustrates in detail how sharding works in maxtext and chooses the write sharding config for your workload.
77+
[This document](sharding.md) illustrates in detail how sharding works in maxtext and chooses the right sharding config for your workload.
7978

8079
### Performance tuning on custom Pallas call
8180

docs/reference/architecture/architecture_overview.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ The modularity of this design is clearly demonstrated by third-party extensions.
9191

9292
### Data ingestion (`input_pipeline.py`)
9393

94-
[The data ingestion pipeline](data-input-pipeline) is a critical component for performance at scale. In MaxText, the main training loop interfaces with the data pipeline through the create\_data\_iterator function, which is called from train.py. This function acts as a facade, abstracting the specific data loading implementation from the rest of the training logic.
94+
[The data ingestion pipeline](../../guides/data_input_pipeline.md) is a critical component for performance at scale. In MaxText, the main training loop interfaces with the data pipeline through the create\_data\_iterator function, which is called from train.py. This function acts as a facade, abstracting the specific data loading implementation from the rest of the training logic.
9595

9696
MaxText supports three primary data loading backends:
9797

@@ -153,7 +153,7 @@ This logical mesh abstraction enables the implementation of the standard paralle
153153

154154
In MaxText, these strategies are implemented by annotating the model's PyTrees (the nested Python structures of arrays that hold the parameters and state) with sharding specifications. This is done using Flax's partitioning utilities, such as nn\_partitioning. These annotations provide requirements and hints to the compiler, telling it how each tensor should be distributed across the axes of the device mesh. The compiler then generates the appropriate collective communication operations (e.g., all-reduce, all-gather) needed to execute the parallel computation correctly and efficiently.
155155

156-
For more information on sharding see [our sharding documentation](sharding_on_TPUs).
156+
For more information on sharding see [our sharding documentation](../../guides/optimization/sharding.md).
157157

158158
### Hardware abstraction and performance via XLA
159159

docs/reference/models/supported_models_and_architectures.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ MaxText supports a wide range of parallelism strategies for scaling training and
7070

7171
The following summarizes observed runtime efficiency and scaling behaviors of MaxText across different hardware and model types, based on published benchmarks and large-scale runs.
7272

73-
* **High MFU**: MaxText targets high Model FLOPs Utilization across scales; exact numbers vary by model, hardware and config. See [**Performance Metrics → MFU**](performance-metrics) for the definition and how we calculate it.
73+
* **High MFU**: MaxText targets high Model FLOPs Utilization across scales; exact numbers vary by model, hardware and config. See [**Performance Metrics → MFU**](../performance_metrics.md#performance-metrics) for the definition and how we calculate it.
7474
* **Quantization**: MaxText supports quantization via both the AQT and Qwix libraries. Qwix is the recommended approach, providing a non-intrusive way to apply various quantization techniques, including Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ).
7575
* **MoE**: The Mixture-of-Experts implementation features dropless routing with Megablox and `jax.lax.ragged_dot` kernels for enhanced performance.
7676
* **Multi-Token Prediction (MTP)**: This feature improves training efficiency on DeepSeek-style models by adding an auxiliary loss based on predicting multiple future tokens.
@@ -91,6 +91,6 @@ The following summarizes observed runtime efficiency and scaling behaviors of Ma
9191

9292

9393
* **Technical Explanations:**
94-
* [Parallelism & Sharding](sharding_on_TPUs)
95-
* [Quantization Documentation](quantization)
96-
* [AOT Compilation Instructions](aot-compilation)
94+
* [Parallelism & Sharding](../../guides/optimization/sharding.md)
95+
* [Quantization Documentation](../core_concepts/quantization.md)
96+
* [AOT Compilation Instructions](../../guides/monitoring_and_debugging/features_and_diagnostics.md#ahead-of-time-compilation-aot)

docs/tutorials/first_run.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml \
4848
dataset_type=synthetic \
4949
steps=10
5050
```
51-
Optional: If you want to try training on a Hugging Face dataset, see [Data Input Pipeline](data-input-pipeline) for data input options.
51+
Optional: If you want to try training on a Hugging Face dataset, see [Data Input Pipeline](../guides/data_input_pipeline.md) for data input options.
5252

5353
5. To demonstrate model output, run the following command:
5454
```sh
@@ -92,7 +92,7 @@ Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.a
9292

9393
## Multihost development
9494

95-
Google Kubernetes Engine (GKE) is the recommended way to run MaxText on multiple hosts. It provides a managed environment for deploying and scaling containerized applications, including those that require TPUs or GPUs. See [Running Maxtext with XPK](run-xpk) for details.
95+
Google Kubernetes Engine (GKE) is the recommended way to run MaxText on multiple hosts. It provides a managed environment for deploying and scaling containerized applications, including those that require TPUs or GPUs. See [Running Maxtext with XPK](../run_maxtext/run_maxtext_via_xpk.md) for details.
9696

9797
## Next steps: preflight optimizations
9898

docs/tutorials/posttraining/full_finetuning.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ These scripts can provide a reference point for various scripts.
7676

7777
### MaxText checkpoint to Hugging Face
7878

79-
Post finetuning or pre-training, MaxText also provides scripts to convert MaxText format weights back to [Hugging Face](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/llama_mistral_mixtral_orbax_to_hf.py).
79+
Post finetuning or pre-training, MaxText also provides scripts to convert MaxText format weights back to [Hugging Face](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py).
8080

8181
#### Dataset
8282

end_to_end/tpu/gemma/Run_Gemma.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
Following the instructions at [kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText) will let you download Gemma model weights. You will have to consent to license for Gemma using your kaggle account's [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials).
2121

22-
After downloading the weights run [convert_gemma_chkpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/convert_gemma_chkpt.py), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [end_to_end/tpu/gemma](https://github.com/AI-Hypercomputer/maxtext/tree/main/end_to_end/tpu/gemma).
22+
After downloading the weights run [convert_gemma_chkpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_scripts/convert_gemma_chkpt.py), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [end_to_end/tpu/gemma](https://github.com/AI-Hypercomputer/maxtext/tree/main/end_to_end/tpu/gemma).
2323

2424
## MaxText supports pretraining and finetuning with high performance
2525

end_to_end/tpu/gemma3/Run_Gemma3.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml model_name=gemma3-4b base_
2929
```
3030

3131
## Checkpoint Conversion
32-
To obtain the Gemma3 model weights, follow the instructions provided on [Kaggle](https://www.kaggle.com/models/google/gemma-3/flax/). You will need to accept the Gemma3 license through your Kaggle account and utilize your Kaggle [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials) for authentication. Once the weights are downloaded to your GCS bucket, use the [convert_gemma3_chkpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/convert_gemma3_chkpt.py) script to transform the checkpoint into a format compatible with MaxText. This script will also upload the converted checkpoints to a Google Cloud Storage (GCS) bucket.
32+
To obtain the Gemma3 model weights, follow the instructions provided on [Kaggle](https://www.kaggle.com/models/google/gemma-3/flax/). You will need to accept the Gemma3 license through your Kaggle account and utilize your Kaggle [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials) for authentication. Once the weights are downloaded to your GCS bucket, use the [checkpoint conversion utils](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/utils/ckpt_conversion#usage) to transform the checkpoint into a format compatible with MaxText. This script will also upload the converted checkpoints to a Google Cloud Storage (GCS) bucket.
3333

3434
## Fine-tuning
3535
After the conversion, you will have a MaxText compatible checkpoint which allows you to fine-tune it with different datasets. One example command to fine-tune a Gemma3-4B model is as follows:

src/MaxText/checkpointing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def save(
5656
self,
5757
directory: epath.Path,
5858
# `item` is for backwards compatibility with older Orbax API, see
59-
# https://orbax.readthedocs.io/en/latest/api_refactor.html.
59+
# https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html.
6060
item: Optional[Any] = None,
6161
args: Any = None,
6262
):

src/MaxText/configs/base.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ per_device_batch_size: 12.0
537537
# When expansion_factor_real_data is set to > 1, total_hosts//expansion_factor_real_data will load data.
538538
# Each data-loading host will load per_device_batch_size * expansion_factor_real_data.
539539
# When set to between 0 and 1, it's for grain pipeline to use a smaller chip count to read checkpoint from a larger chip count job.
540-
# Details in https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_grain.md#using-grain
540+
# Details in https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline/data_input_grain.md#using-grain
541541
expansion_factor_real_data: -1.0
542542
eval_per_device_batch_size: 0.0
543543
max_corpus_chars: 10_000_000
@@ -578,7 +578,7 @@ use_sft: False
578578
sft_train_on_completion_only: False
579579

580580
# dataset_type must be synthetic, hf, grain, tfds
581-
# details in: https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md
581+
# details in: https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline.md
582582
dataset_type: tfds
583583
# for TFDS input pipeline (dataset_type=tfds)
584584
dataset_path: "" # your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/"

0 commit comments

Comments
 (0)