Skip to content

Commit 4ef205e

Browse files
Merge pull request #3189 from AI-Hypercomputer:bvandermoon-restructure
PiperOrigin-RevId: 872774667
2 parents fdef529 + 3532d64 commit 4ef205e

104 files changed

Lines changed: 794 additions & 746 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.vscode/launch.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
"console": "integratedTerminal",
5252
"justMyCode": false,
5353
"python": "python3",
54-
"module": "MaxText.train",
54+
"module": "maxtext.trainers.pre_train.train",
5555
"args": ["src/maxtext/configs/base.yml",
5656
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
5757
"base_output_directory=gs://test-maxtext-output",

PREFLIGHT.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ Before you run ML workload on Multihost with GCE or GKE, simply apply `bash pref
77

88
Here is an example for GCE:
99
```
10-
bash preflight.sh PLATFORM=GCE && python3 -m MaxText.train src/maxtext/configs/base.yml run_name=$YOUR_JOB_NAME
10+
bash preflight.sh PLATFORM=GCE && python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml run_name=$YOUR_JOB_NAME
1111
```
1212

1313
Here is an example for GKE:
1414
```
15-
bash preflight.sh PLATFORM=GKE && python3 -m MaxText.train src/maxtext/configs/base.yml run_name=$YOUR_JOB_NAME
15+
bash preflight.sh PLATFORM=GKE && python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml run_name=$YOUR_JOB_NAME
1616
```
1717

1818
# Optimization 2: Numa binding (You can only apply this to v4 and v5p)
@@ -22,14 +22,14 @@ For GCE,
2222
[preflight.sh](https://github.com/google/maxtext/blob/main/preflight.sh) will help you install `numactl` dependency, so you can use it directly, here is an example:
2323

2424
```
25-
bash preflight.sh PLATFORM=GCE && numactl --membind 0 --cpunodebind=0 python3 -m MaxText.train src/maxtext/configs/base.yml run_name=$YOUR_JOB_NAME
25+
bash preflight.sh PLATFORM=GCE && numactl --membind 0 --cpunodebind=0 python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml run_name=$YOUR_JOB_NAME
2626
```
2727

2828
For GKE,
2929
`numactl` should be built into your docker image from [maxtext_tpu_dependencies.Dockerfile](https://github.com/google/maxtext/blob/main/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile), so you can use it directly if you built the maxtext docker image. Here is an example
3030

3131
```
32-
bash preflight.sh PLATFORM=GKE && numactl --membind 0 --cpunodebind=0 python3 -m MaxText.train src/maxtext/configs/base.yml run_name=$YOUR_JOB_NAME
32+
bash preflight.sh PLATFORM=GKE && numactl --membind 0 --cpunodebind=0 python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml run_name=$YOUR_JOB_NAME
3333
```
3434

3535
1. `numactl`: This is the command-line tool used for controlling NUMA policy for processes or shared memory. It's particularly useful on multi-socket systems where memory locality can impact performance.

benchmarks/maxtext_xpk_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def build_user_command(
440440
f"export JAX_PLATFORMS={jax_platforms} &&",
441441
"export ENABLE_PJRT_COMPATIBILITY=true &&",
442442
"export MAXTEXT_ASSETS_ROOT=/deps/src/maxtext/assets MAXTEXT_PKG_DIR=/deps/src/MaxText MAXTEXT_REPO_ROOT=/deps &&"
443-
f'{hlo_dump} python3 -m MaxText.train {os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")}',
443+
f'{hlo_dump} python3 -m maxtext.trainers.pre_train.train {os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")}',
444444
f"{config_tuning_params}",
445445
f"steps={wl_config.num_steps}",
446446
f"model_name={wl_config.model.model_type}",

docs/guides/checkpointing_solutions/emergency_checkpointing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,5 +157,5 @@ The flags below would give the user access to the ramdisk in their workload:
157157
--num-slices=${NUM_SLICES} \
158158
--ramdisk-directory=${RAMDISK_DIRECTORY} \
159159
--mtc-enabled \
160-
--command "python3 src/MaxText/train.py src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH steps=120 per_device_batch_size=6 enable_checkpoint_cloud_logger=True checkpoint_period=${CHECKPOINT_PEROID} enable_emergency_checkpoint=True local_checkpoint_period=${LOCAL_CHECKPOINT_PERIOD} local_checkpoint_directory=/${RAMDISK_DIRECTORY}"
160+
--command "python3 src/maxtext/trainers/pre_train/train.py src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH steps=120 per_device_batch_size=6 enable_checkpoint_cloud_logger=True checkpoint_period=${CHECKPOINT_PEROID} enable_emergency_checkpoint=True local_checkpoint_period=${LOCAL_CHECKPOINT_PERIOD} local_checkpoint_directory=/${RAMDISK_DIRECTORY}"
161161
```

docs/guides/checkpointing_solutions/multi_tier_checkpointing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,5 +186,5 @@ The flags below would give the user access to the ramdisk in their workload:
186186
--num-slices=${NUM_SLICES} \
187187
--ramdisk-directory=${RAMDISK_DIRECTORY} \
188188
--mtc-enabled \
189-
--command "python3 src/MaxText/train.py src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH steps=120 per_device_batch_size=6 enable_checkpoint_cloud_logger=True checkpoint_period=${CHECKPOINT_PEROID} enable_multi_tier_checkpointing=True local_checkpoint_period=${LOCAL_CHECKPOINT_PERIOD} local_checkpoint_directory=/${RAMDISK_DIRECTORY} multi_tier_checkpointing_backup_interval_minutes=${MULTI_TIER_CHECKPOINTING_BACKUP_INT_MIN}"
189+
--command "python3 src/maxtext/trainers/pre_train/train.py src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH steps=120 per_device_batch_size=6 enable_checkpoint_cloud_logger=True checkpoint_period=${CHECKPOINT_PEROID} enable_multi_tier_checkpointing=True local_checkpoint_period=${LOCAL_CHECKPOINT_PERIOD} local_checkpoint_directory=/${RAMDISK_DIRECTORY} multi_tier_checkpointing_backup_interval_minutes=${MULTI_TIER_CHECKPOINTING_BACKUP_INT_MIN}"
190190
```

docs/guides/data_input_pipeline/data_input_grain.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pr
112112
bash tools/setup/setup_gcsfuse.sh \
113113
DATASET_GCS_BUCKET=maxtext-dataset \
114114
MOUNT_PATH=/tmp/gcsfuse && \
115-
python3 -m MaxText.train src/maxtext/configs/base.yml \
115+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
116116
run_name=<RUN_NAME> base_output_directory=gs://<MY_BUCKET> \
117117
dataset_type=grain \
118118
grain_file_type=arrayrecord # or parquet \

docs/guides/monitoring_and_debugging/features_and_diagnostics.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ To load the compiled train_step, you just need to pass `compiled_trainstep_file=
8484
```sh
8585
# Run the below on each host of the target hardware, e.g. each host on 2 slices of v5e-256
8686
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
87-
python3 -m MaxText.train src/maxtext/configs/base.yml run_name=example_load_compile \
87+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml run_name=example_load_compile \
8888
compiled_trainstep_file=my_compiled_train.pickle \
8989
global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3 \
9090
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
@@ -122,7 +122,7 @@ To load the compiled `train_step`, you just need to pass `compiled_trainstep_fil
122122
```sh
123123
# Run the below on each of the 4 target A3 hosts.
124124
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
125-
python3 -m MaxText.train src/maxtext/configs/base.yml run_name=example_load_compile \
125+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml run_name=example_load_compile \
126126
compiled_trainstep_file=my_compiled_train.pickle \
127127
attention=dot_product global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3 \
128128
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket

docs/guides/monitoring_and_debugging/ml_workload_diagnostics.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ MaxText has integrated the ML Diagnostics [SDK](https://github.com/AI-Hypercompu
3535
1. Enable ML Diagnostics to just capture Maxtext metrics and configs
3636

3737
```
38-
python3 -m MaxText.train src/maxtext/configs/base.yml \
38+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
3939
run_name=${USER}-tpu-job \
4040
base_output_directory="gs://your-output-bucket/" \
4141
dataset_path="gs://your-dataset-bucket/" \
@@ -47,7 +47,7 @@ MaxText has integrated the ML Diagnostics [SDK](https://github.com/AI-Hypercompu
4747
2. Enable ML Diagnostics to capture Maxtext metrics, configs and singlehost profiles (on the first TPU device)
4848

4949
```
50-
python3 -m MaxText.train src/maxtext/configs/base.yml \
50+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
5151
run_name=${USER}-tpu-job \
5252
base_output_directory="gs://your-output-bucket/" \
5353
dataset_path="gs://your-dataset-bucket/" \
@@ -60,7 +60,7 @@ MaxText has integrated the ML Diagnostics [SDK](https://github.com/AI-Hypercompu
6060
3. Enable ML Diagnostics to capture Maxtext metrics, configs and multihost profiles (on all TPU devices)
6161

6262
```
63-
python3 -m MaxText.train src/maxtext/configs/base.yml \
63+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
6464
run_name=${USER}-tpu-job \
6565
base_output_directory="gs://your-output-bucket/" \
6666
dataset_path="gs://your-dataset-bucket/" \

docs/guides/monitoring_and_debugging/monitor_goodput.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ Please use a unique workload name, unless you intend to monitor cumulative Goodp
8989
MaxText enables Goodput recording and monitoring by default with `enable_goodput_recording=True` and `monitor_goodput=True`. You can configure the goodput upload frequency by setting `goodput_upload_interval_seconds`.
9090

9191
```bash
92-
python3 -m MaxText.train src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH \
92+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH \
9393
dataset_path=$DATA_PATH run_name=goodput-test-run steps=200 goodput_upload_interval_seconds=30
9494
```
9595

@@ -98,7 +98,7 @@ python3 -m MaxText.train src/maxtext/configs/base.yml base_output_directory=$OUT
9898
MaxText enables step time deviation monitoring by default with `monitor_step_time_deviation=True`. You can configure the upload frequency by setting `step_deviation_interval_seconds`.
9999

100100
```bash
101-
python3 -m MaxText.train src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH \
101+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH \
102102
dataset_path=$DATA_PATH run_name=goodput-test-run steps=200 step_deviation_interval_seconds=30
103103
```
104104

@@ -111,7 +111,7 @@ Enabling `enable_pathways_goodput` turns on Goodput measurement for Pathways wor
111111
```
112112

113113
```bash
114-
python3 -m MaxText.train src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH \
114+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH \
115115
run_name=goodput-test-run steps=200 goodput_upload_interval_seconds=30 enable_pathways_goodput=True
116116
```
117117

@@ -168,7 +168,7 @@ and `enable_gcp_step_deviation_metrics` to `False` for disabling step deviation
168168
metrics.
169169

170170
```bash
171-
python3 -m MaxText.train src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH \
171+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH \
172172
run_name=goodput-test-run steps=200 goodput_upload_interval_seconds=30 enable_gcp_goodput_metrics=False \
173173
enable_gcp_step_deviation_metrics=False
174174
```

docs/guides/monitoring_and_debugging/understand_logs_and_metrics.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ When you run a training job, MaxText produces detailed output logs. This guide s
2323
To start, run a simple pretraining job on a single-host TPU. For instance, we can run the following command on TPU v5p-8. The resulting log is used as an example throughout this guide.
2424

2525
```bash
26-
python3 -m MaxText.train src/maxtext/configs/base.yml \
26+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
2727
base_output_directory=gs://runner-maxtext-logs run_name=demo \
2828
model_name=deepseek2-16b \
2929
per_device_batch_size=24 max_target_length=2048 steps=10 dataset_type=synthetic enable_checkpointing=false
@@ -123,7 +123,7 @@ To generate all optional artifacts in one run, you can set the corresponding fla
123123
This command enables tensorboard, profiler, text metrics, config saving, and checkpointing:
124124

125125
```bash
126-
python3 -m MaxText.train src/maxtext/configs/base.yml \
126+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
127127
base_output_directory=gs://runner-maxtext-logs run_name=demo2 \
128128
model_name=deepseek2-16b \
129129
per_device_batch_size=24 max_target_length=2048 steps=10 dataset_type=synthetic \

0 commit comments

Comments
 (0)