Skip to content

Commit f7cfa7b

Browse files
bvandermoonGoogle-ML-Automation
authored andcommitted
PR #3252: Migrate remaining src/MaxText utils to src/maxtext/utils
Imported from GitHub PR #3252 # Description * Move the remaining utils from `src/MaxText` to `src/maxtext/utils` # Tests Successfully ran vllm_decode: ``` python3 -m maxtext.inference.vllm_decode src/maxtext/configs/base.yml \ model_name=qwen3-30b-a3b \ tokenizer_path=Qwen/Qwen3-30B-A3B \ vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \ ici_tensor_parallelism=4 \ hbm_utilization_vllm=0.5 \ prompt="Suggest some famous landmarks in London." \ decode_sampling_temperature=0.0 \ decode_sampling_nucleus_p=1.0 \ decode_sampling_top_k=0.0 \ use_chat_template=True ``` ``` python3 -m maxtext.utils.generate_param_only_checkpoint src/maxtext/configs/base.yml \ base_output_directory=<base_output_directory> \ load_parameters_path=<load_parameters_path> \ run_name=<run_name> \ model_name=gemma-2b \ force_unroll=true ``` # Checklist Before submitting this PR, please make sure (put X in square brackets): - [x] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label. - [x] I have necessary comments in my code, particularly in hard-to-understand areas. - [x] I have run end-to-end tests tests and provided workload links above if applicable. - [x] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation. Copybara import of the project: -- 62f4d58 by Branden Vandermoon <bvandermoon@google.com>: Migrate remaining src/MaxText utils to src/maxtext/utils Merging this change closes #3252 COPYBARA_INTEGRATE_REVIEW=#3252 from AI-Hypercomputer:bvandermoon-repo-restructure 62f4d58 PiperOrigin-RevId: 875884352
1 parent 5a4a9c3 commit f7cfa7b

38 files changed

Lines changed: 76 additions & 82 deletions

src/MaxText/pyconfig_deprecated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
import omegaconf
3030

31-
from MaxText import accelerator_to_spec_map
31+
from maxtext.utils import accelerator_to_spec_map
3232
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR
3333
from maxtext.common.common_types import AttentionType, DecoderBlockType, ShardMode
3434
from maxtext.utils import gcs_utils

src/maxtext/checkpoint_conversion/standalone_scripts/llama_mistral_mixtral_orbax_to_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM, AutoConfig
4949

5050
from MaxText import pyconfig
51-
from MaxText.generate_param_only_checkpoint import _read_train_checkpoint
51+
from maxtext.utils.generate_param_only_checkpoint import _read_train_checkpoint
5252
from maxtext.checkpoint_conversion.standalone_scripts import llama_or_mistral_ckpt
5353
from maxtext.common import checkpointing
5454
from maxtext.utils import max_logging

src/maxtext/configs/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from maxtext.utils import gcs_utils
3232
from maxtext.utils import max_utils
3333
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT
34-
from MaxText import accelerator_to_spec_map
34+
from maxtext.utils import accelerator_to_spec_map
3535
from pydantic.config import ConfigDict
3636
from pydantic.fields import Field
3737
from pydantic.functional_validators import field_validator, model_validator
Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,19 @@
1414
"""
1515
An example script to perform decoding using vLLM via Tunix or via MaxText on vLLM.
1616
17-
Example usage with Tunix:
18-
python3 -m maxtext.vllm_decode maxtext/configs/base.yml \
19-
model_name=llama3.1-8b tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
20-
tokenizer_type=huggingface hf_access_token=<your_hf_token> \
21-
load_parameters_path=<your_checkpoint_path> \
22-
per_device_batch_size=1 run_name=vllm_decode_test max_target_length=64 \
23-
use_chat_template=False prompt="Suggest some famous landmarks in London." \
24-
decode_sampling_temperature=0.0 decode_sampling_nucleus_p=1.0 decode_sampling_top_k=0.0 \
25-
--use_tunix \
26-
27-
Or without Tunix using the MaxText vLLM integration:
28-
python3 -m maxtext.vllm_decode maxtext/configs/base.yml \
29-
model_name=qwen3-30b-a3b \
30-
tokenizer_path=Qwen/Qwen3-30B-A3B \
31-
vllm_hf_config_path=src/MaxText/integration/vllm/maxtext_vllm_adapter \
32-
load_parameters_path=<your_checkpoint_path> \
33-
ici_tensor_parallelism=4 \
34-
hbm_utilization_vllm=0.5 \
35-
prompt="Suggest some famous landmarks in London."
17+
Example usage:
18+
python3 -m maxtext.inference.vllm_decode src/maxtext/configs/base.yml \
19+
model_name=qwen3-30b-a3b \
20+
tokenizer_path=Qwen/Qwen3-30B-A3B \
21+
load_parameters_path=<your_checkpoint_path> \
22+
vllm_hf_overrides='{architectures: ["MaxTextForCausalLM"]}' \
23+
ici_tensor_parallelism=4 \
24+
hbm_utilization_vllm=0.5 \
25+
prompt="Suggest some famous landmarks in London." \
26+
decode_sampling_temperature=0.0 \
27+
decode_sampling_nucleus_p=1.0 \
28+
decode_sampling_top_k=0.0 \
29+
use_chat_template=True
3630
"""
3731

3832
import os

src/maxtext/trainers/pre_train/train_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from jax.experimental.serialize_executable import serialize
3333
from jax.experimental.topologies import get_topology_desc
3434
from jax.sharding import AxisType, Mesh
35-
from MaxText import accelerator_to_spec_map
35+
from maxtext.utils import accelerator_to_spec_map
3636
from MaxText import pyconfig
3737
from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode
3838
from maxtext.layers import quantizations
File renamed without changes.
File renamed without changes.

tests/end_to_end/gpu/a3/test_gemma3_logits.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export MODEL_BUCKET=gs://maxtext-gemma/gemma3
2727

2828
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_gemma3_chkpt --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION}
2929

30-
# Current MaxText.generate_param_only_checkpoint will need to skip on GPU due to cpu process error. reuse the unscanned ckpt generated separately.
30+
# Current maxtext.utils.generate_param_only_checkpoint will need to skip on GPU due to cpu process error. reuse the unscanned ckpt generated separately.
3131

3232
# # Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
3333
# export DATASET_PATH=gs://maxtext-dataset
@@ -38,8 +38,8 @@ python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_gemma3_chkpt
3838
# export RUN_NAME=unscanned_chkpt_${idx}
3939
# export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items
4040
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
41-
# We can do this by running `src/MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`.
42-
#JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_NAME} force_unroll=true
41+
# We can do this by running `src/maxtext/utils/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`.
42+
#JAX_PLATFORMS=cpu python3 -m maxtext.utils.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_NAME} force_unroll=true
4343

4444
export UNSCANNED_CKPT_PATH=gs://runner-maxtext-logs/unscanned_chkpt_2025-04-16-00-01/checkpoints/0/items
4545
export NVTE_FUSED_ATTN=1

0 commit comments

Comments
 (0)