Skip to content

Commit 3bc185f

Browse files
SurbhiJainUSCGoogle-ML-Automation
authored andcommitted
Move dpo_utils.py, load_and_quantize_checkpoint.py, vllm_decode.py and scratch_code/ to src/maxtext
PiperOrigin-RevId: 863355252
1 parent fb96299 commit 3bc185f

25 files changed

Lines changed: 27 additions & 28 deletions

File tree

codecov.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@ ignore:
3636
- "src/MaxText/configs"
3737
- "src/MaxText/examples"
3838
- "src/MaxText/experimental"
39-
- "src/MaxText/inference"
40-
- "src/MaxText/inference_mlperf"
41-
- "src/MaxText/scratch_code"
39+
- "src/maxtext/inference"
40+
- "src/maxtext/scratch_code"
4241
- "src/MaxText/distillation" # code moved to src/maxtext/trainers/post_train/distillation
4342
- "src/MaxText/sft" # code moved to src/maxtext/trainers/post_train/sft
4443

docs/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ def run_apidoc(_):
175175
# Paths to exclude
176176
os.path.join(MAXTEXT_REPO_ROOT, "tests"),
177177
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "experimental"),
178-
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference_mlperf"),
179-
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "scratch_code"),
178+
os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "inference"),
179+
os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "scratch_code"),
180180
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "utils", "ckpt_conversion"),
181181
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "rl"),
182182
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "multimodal_utils.py"),

docs/guides/optimization/pallas_kernels_performance.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth
6262

6363
- **Serving Attention (Paged & Ragged):** For high-throughput inference, this kernel efficiently fetches non-contiguous "pages" of the KV cache from memory. It is a key optimization for our serving stack and is used for models running on MaxText's inference engine.
6464

65-
- [`src/MaxText/inference/paged_attention.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/inference/paged_attention.py)
66-
- [`src/MaxText/inference/paged_attention_kernel_v2.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/inference/paged_attention_kernel_v2.py)
65+
- [`src/maxtext/inference/paged_attention.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/inference/paged_attention.py)
66+
- [`src/maxtext/inference/paged_attention_kernel_v2.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/inference/paged_attention_kernel_v2.py)
6767

6868
- **MoE Grouped Matmul (Megablox GMM):** Sparse/irregular grouped GEMMs driven by host-built metadata.
6969

src/MaxText/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from MaxText import pyconfig
3333
from MaxText.layers import models
34-
from MaxText import dpo_utils
34+
from maxtext.trainers.post_train.dpo import dpo_utils
3535
from maxtext.utils import maxtext_utils
3636
from maxtext.utils import model_creation_utils
3737
from maxtext.utils.model_creation_utils import from_config

src/MaxText/experimental/agent/integrative_rag_agent/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@
103103
# for converting PyTorch code to JAX
104104
block_for_rag = [
105105
"src/MaxText/layers", # Neural network layers and building blocks
106-
"src/MaxText/inference", # Inference and prediction code
106+
"src/maxtext/inference", # Inference and prediction code
107107
"src/MaxText/common_types.py", # Common data types and structures
108-
"src/MaxText/maxtext_utils.py", # Utility functions and helpers
108+
"src/maxtext/utils/maxtext_utils.py", # Utility functions and helpers
109109
]

src/MaxText/experimental/agent/orchestration_agent/split_python_file.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def visit_Attribute(self, node):
9393
if base_name in self.git_aliases:
9494
# It's an external dependency. We need to format it with the attribute path.
9595
# Example: base_name='page_manager', attr_chain='PageState'
96-
# self.git_dependencies['page_manager'] might be 'src/MaxText/inference/page_manager.py#page_manager'
96+
# self.git_dependencies['page_manager'] might be 'src/maxtext/inference/page_manager.py#page_manager'
9797
path, obj = self.git_dependencies[base_name].split("#", 1)
9898

9999
# As per the user request, we append the attribute access to the object name.
@@ -198,8 +198,8 @@ def convert_package_to_path(self, path):
198198
199199
Example:
200200
"from maxtext.inference import page_manager, utils" ->
201-
{"page_manager": "src/MaxText/inference.py#page_manager",
202-
"utils": "src/MaxText/inference.py#utils"}
201+
{"page_manager": "src/maxtext/inference.py#page_manager",
202+
"utils": "src/maxtext/inference.py#utils"}
203203
204204
Args:
205205
path (str): A normalized absolute import string.
@@ -216,7 +216,7 @@ def convert_package_to_path(self, path):
216216
# The logic in get_absolute_imports should ideally resolve this ambiguity.
217217
# A heuristic could be used here (e.g., checking casing) but we stick to the current logic.
218218
# The user's example `from maxtext.inference import page_manager` creates a path
219-
# `src/MaxText/inference.py#page_manager`, which is what the new visitor expects to correct.
219+
# `src/maxtext/inference.py#page_manager`, which is what the new visitor expects to correct.
220220
import_dict[pkg.strip()] = path_form + ".py#" + pkg.strip()
221221
return import_dict
222222

src/MaxText/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151

5252
from MaxText.gradient_accumulation import gradient_accumulation_loss_and_grad
5353
from MaxText.vocabulary_tiling import vocab_tiling_linen_loss
54-
from MaxText.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn
5554
# pylint: disable=too-many-positional-arguments
5655

5756
from maxtext.common import checkpointing, profiler
@@ -63,6 +62,7 @@
6362
)
6463
from maxtext.common.metric_logger import MetricLogger, record_activation_metrics
6564
from maxtext.common.vertex_tensorboard import VertexTensorboardManager
65+
from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn
6666
from maxtext.utils import exceptions
6767
from maxtext.utils import gcs_utils
6868
from maxtext.utils import max_logging

src/MaxText/load_and_quantize_checkpoint.py renamed to src/maxtext/checkpoint_conversion/load_and_quantize_checkpoint.py

File renamed without changes.

src/maxtext/inference/mlperf/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ cd ~
6464
git clone https://github.com/AI-Hypercomputer/maxtext.git
6565
cd maxtext
6666
bash setup.sh
67-
python3 -m pip install -r src/MaxText/inference_mlperf/requirements.txt
67+
python3 -m pip install -r src/maxtext/inference/mlperf/requirements.txt
6868
```
6969

7070
### Generate quantized checkpoint
@@ -125,7 +125,7 @@ export MODEL_SIZE=llama3.1-405b
125125
export QUANTIZE_TYPE=int8
126126

127127
cd maxtext && \
128-
python3 -m MaxText.load_and_quantize_checkpoint src/MaxText/configs/base.yml tokenizer_path=${TOKENIZER} load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=${MODEL_SIZE} ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1 attention=dot_product quantization=${QUANTIZE_TYPE} save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} async_checkpointing=false
128+
python3 -m maxtext.checkpoint_conversion.load_and_quantize_checkpoint src/MaxText/configs/base.yml tokenizer_path=${TOKENIZER} load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=${MODEL_SIZE} ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1 attention=dot_product quantization=${QUANTIZE_TYPE} save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} async_checkpointing=false
129129
```
130130

131131
The quantized checkpoint is saved at `${SAVE_QUANT_PARAMS_PATH}`
@@ -141,7 +141,7 @@ huggingface-cli login --token $HUGGING_FACE_TOKEN
141141
#### For trillium
142142
#### LLama2-70b:
143143
```
144-
cd ~/maxtext/src/MaxText/inference_mlperf/trillium
144+
cd ~/maxtext/src/maxtext/inference/mlperf/trillium
145145
```
146146

147147
##### Test Run

src/maxtext/inference/mlperf/llama_offline_run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ else
117117
export DATASET_TYPE=full
118118
export DATASET_PATH=${DATA_DISK_DIR}/processed-data.pkl
119119
export TOTAL_SAMPLE_COUNT=24576
120-
export USER_CONFIG=user.conf # NOTE: you may need to change this path(e.g. `src/MaxText/inference_mlperf/user.conf`)
120+
export USER_CONFIG=user.conf # NOTE: you may need to change this path(e.g. `src/maxtext/inference/mlperf/user.conf`)
121121
fi
122122

123123
# LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

0 commit comments

Comments
 (0)