Skip to content

Commit a54b374

Browse files
Merge pull request #3088 from AI-Hypercomputer:aireen/hf_epoch
PiperOrigin-RevId: 866183800
2 parents 27eada9 + 8e52df8 commit a54b374

8 files changed

Lines changed: 53 additions & 14 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ repos:
5656
rev: 0.7.22
5757
hooks:
5858
- id: mdformat
59+
args: ['--number']
5960
additional_dependencies: [mdformat-myst, mdformat-ruff]
6061
files: (docs/.)
6162
exclude: docs/guides/checkpointing_solutions.md

docs/guides/data_input_pipeline/data_input_hf.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,3 @@ tokenizer_path: 'google-t5/t5-large' # for using https://huggingface.co/google-
4343

4444
1. Streaming data directly from Hugging Face Hub may be impacted by the traffic of the server. During peak hours you may encounter "504 Server Error: Gateway Time-out". It's recommended to download the Hugging Face dataset to a Cloud Storage bucket or disk for the most stable experience.
4545
2. Streaming data directly from Hugging Face Hub works in multi-host settings with a small number of hosts. With a host number larger than 16, you might encounter a "read time out" error.
46-
3. Only supports `num_epoch=1` at the moment.

docs/guides/optimization/pallas_kernels_performance.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,11 @@ Dispatch a kernel on multiple devices with `jax.shard_map`. It’s usually simpl
214214
## ✅ Putting it all together (checklist)
215215

216216
1. **Profile** the baseline using `named_scope` and `block_until_ready`.
217-
1. **Tile arrays into smaller chunks using BlockSpecs** (virtually always necessary, even for simple kernels).
218-
1. Build a **sweep harness** for block shapes (and optionally scalar prefetch grid choices).
219-
1. **Validate** end-to-end performance in the model, not just microbenchmarks.
220-
1. Consider **maintainability** and guard the new kernel with tests.
221-
1. Consider applying **`jax.vmap`** to a Pallas kernel to simplify implementation; think of it as prepending grid dimensions automatically.
217+
2. **Tile arrays into smaller chunks using BlockSpecs** (virtually always necessary, even for simple kernels).
218+
3. Build a **sweep harness** for block shapes (and optionally scalar prefetch grid choices).
219+
4. **Validate** end-to-end performance in the model, not just microbenchmarks.
220+
5. Consider **maintainability** and guard the new kernel with tests.
221+
6. Consider applying **`jax.vmap`** to a Pallas kernel to simplify implementation; think of it as prepending grid dimensions automatically.
222222

223223
## 📚 References
224224

docs/run_maxtext/run_maxtext_localhost.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Before you can begin a training run, you need to configure your storage environm
1313
You'll need a GCS bucket to store all your training artifacts, such as logs, metrics, and model checkpoints.
1414

1515
1. In your Google Cloud project, create a new storage bucket.
16-
1. Your TPU or GPU VMs require read/write access to this bucket. The simplest way to grant this is by assigning the `Storage Admin` (`roles/storage.admin`) role to the service account associated with your VMs.
16+
2. Your TPU or GPU VMs require read/write access to this bucket. The simplest way to grant this is by assigning the `Storage Admin` (`roles/storage.admin`) role to the service account associated with your VMs.
1717

1818
### Setup MaxText
1919

@@ -36,14 +36,14 @@ Local development on a single host TPU/GPU VM is a convenient way to run MaxText
3636

3737
1. Create and SSH to the single host VM of your choice. You can use any available single host TPU, such as `v5litepod-8`, `v5p-8`, or `v4-8`. For GPUs, you can use `nvidia-h100-mega-80gb`, `nvidia-h200-141gb`, or `nvidia-b200`. For setting up a TPU VM, use the Cloud TPU documentation available at https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm. For a GPU setup, refer to the guide at https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus.
3838

39-
1. Clone MaxText onto that VM.
39+
2. Clone MaxText onto that VM.
4040

4141
```bash
4242
git clone https://github.com/google/maxtext.git
4343
cd maxtext
4444
```
4545

46-
1. Once you have cloned the repository, you have two primary options for setting up the necessary dependencies on your VM: Installing in a Python Environment, or building a Docker container. For single host workloads, we recommend to install dependencies in a python environment, and for multihost workloads we recommend the containerized approach.
46+
3. Once you have cloned the repository, you have two primary options for setting up the necessary dependencies on your VM: Installing in a Python Environment, or building a Docker container. For single host workloads, we recommend to install dependencies in a python environment, and for multihost workloads we recommend the containerized approach.
4747

4848
Within the root directory of the cloned repo, create a virtual environment and install dependencies and the pre-commit hook by running:
4949

src/MaxText/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ train_image_column: 'image'
566566
eval_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected"
567567
eval_image_column: 'image'
568568
packing: True
569-
num_epoch: 1 # only grain and tfds pipeline supports num_epoch > 1
569+
num_epoch: 1
570570
generate_padding_batch_train: False
571571
generate_padding_batch_eval: False
572572
# Maximum number of segments that can be packed into a single sequence

src/MaxText/configs/types.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2289,10 +2289,34 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22892289
):
22902290
logger.warning("`tokenizer_type` is not 'tiktoken' when using llama3 tokenizer. Overriding to 'tiktoken'.")
22912291
self.tokenizer_type = TokenizerType.TIKTOKEN
2292+
# Data input validations
2293+
if self.dataset_type == DatasetType.HF:
2294+
if not self.hf_path:
2295+
raise ValueError("hf_path can't be empty when dataset_type=hf")
2296+
if self.hf_eval_files:
2297+
self.hf_eval_split = "train"
2298+
if self.eval_interval > 0 and not self.hf_eval_split:
2299+
raise ValueError("Please specify hf_eval_split or set eval_interval to <=0.")
2300+
elif self.dataset_type == DatasetType.GRAIN:
2301+
if not self.grain_train_files and not self.grain_train_mixture_config_path:
2302+
raise ValueError("When dataset_type=grain, please set grain_train_files or grain_train_mixture_config_path")
2303+
if self.eval_interval > 0 and not self.grain_eval_files:
2304+
raise ValueError("Please specify grain_eval_files or set eval_interval to <=0.")
2305+
if self.tokenizer_type not in (TokenizerType.SENTENCEPIECE, TokenizerType.HUGGINGFACE):
2306+
raise ValueError(
2307+
f"grain pipeline only supports tokenizer_type: sentencepiece, huggingface, but got {self.tokenizer_type}"
2308+
)
2309+
elif self.dataset_type == DatasetType.TFDS:
2310+
if not self.dataset_name:
2311+
raise ValueError("dataset_name can't be empty when dataset_type=tfds")
2312+
if self.eval_interval > 0 and not self.eval_split:
2313+
raise ValueError("Please specify eval_split or set eval_interval to <=0.")
2314+
2315+
if self.sharding_tolerance > 1.0 or self.sharding_tolerance < 0.0:
2316+
logger.warning("'sharding_tolerance: allowed percentage of non-sharded parameters' should be between 0.0 and 1.0")
2317+
22922318
if self.eval_interval > 0 >= self.eval_steps and self.generate_padding_batch_eval:
22932319
raise ValueError("`eval_steps` must be > 0 when `generate_padding_batch_eval` is True.")
2294-
if self.dataset_type == "hf" and self.num_epoch != 1:
2295-
raise ValueError("HuggingFace pipeline only supports num_epoch=1.")
22962320
if self.rl.loss_algo == "grpo":
22972321
self.use_grpo = True
22982322
else:

src/MaxText/input_pipeline/_hf_data_processing.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,14 @@ def vision_sft_preprocessing_pipeline(
6161
else:
6262
batch_size = global_batch_size // jax.process_count()
6363

64-
if config.enable_data_shuffling:
64+
# for multi-epoch with shuffle, shuffle each epoch with different seeds then concat
65+
if config.enable_data_shuffling and config.num_epoch > 1:
66+
epoch_datasets = [dataset.shuffle(seed=config.data_shuffle_seed + i) for i in range(config.num_epoch)]
67+
dataset = datasets.concatenate_datasets(epoch_datasets)
68+
elif config.enable_data_shuffling:
6569
dataset = dataset.shuffle(seed=config.data_shuffle_seed)
70+
elif config.num_epoch > 1:
71+
dataset = dataset.repeat(config.num_epoch)
6672

6773
# If multiple image columns are provided, merge them into a single 'images' column.
6874
if isinstance(image_column, list):
@@ -206,6 +212,7 @@ def preprocessing_pipeline(
206212
sft_train_on_completion_only=True,
207213
grain_worker_count=1, # only support 0 or 1
208214
max_segments_per_seq=None,
215+
num_epoch=1,
209216
):
210217
"""pipeline for preprocessing HF dataset"""
211218

@@ -217,8 +224,14 @@ def preprocessing_pipeline(
217224
else:
218225
batch_size = global_batch_size // jax.process_count()
219226

220-
if shuffle:
227+
# for multi-epoch with shuffle, shuffle each epoch with different seeds then concat
228+
if shuffle and num_epoch > 1:
229+
epoch_datasets = [dataset.shuffle(seed=data_shuffle_seed + i) for i in range(num_epoch)]
230+
dataset = datasets.concatenate_datasets(epoch_datasets)
231+
elif shuffle:
221232
dataset = dataset.shuffle(seed=data_shuffle_seed)
233+
elif num_epoch > 1:
234+
dataset = dataset.repeat(num_epoch)
222235

223236
tokenizer = transformers.AutoTokenizer.from_pretrained(
224237
tokenizer_path,
@@ -409,6 +422,7 @@ def make_hf_train_iterator(
409422
sft_train_on_completion_only=config.sft_train_on_completion_only,
410423
chat_template_path=config.chat_template_path,
411424
max_segments_per_seq=config.max_segments_per_seq,
425+
num_epoch=config.num_epoch,
412426
)
413427
return train_iter
414428

src/MaxText/pyconfig_deprecated.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def validate_tokamax_usage(keys):
358358
raise ValueError(f"Invalid tokamax's megablox kernel usage for hardware {keys['hardware']}. Only TPU is supported.")
359359

360360

361+
# All data input validations have been migrated to config/types.py
361362
def validate_data_input(keys):
362363
"""validate provided parameters for data input"""
363364
if not keys["hf_access_token"]:

0 commit comments

Comments
 (0)