Skip to content

Commit 9e36543

Browse files
committed
Support TFrecord in Grain pipeline
1 parent 487bb6f commit 9e36543

7 files changed

Lines changed: 191 additions & 269 deletions

File tree

docs/guides/data_input_pipeline.md

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,34 @@
1515
-->
1616

1717
(data-input-pipeline)=
18+
1819
# Data pipelines
1920

2021
Currently MaxText has three data input pipelines:
2122

22-
| Pipeline | Dataset formats | Features | Limitations |
23-
| -------- | --------------- | -------- | ----------- |
24-
| **[Grain](data_input_pipeline/data_input_grain.md)** (recommended)| [ArrayRecord](https://github.com/google/array_record) (random access, available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview), or [conversion](https://github.com/google/array_record/tree/main/beam))<br>[Parquet](https://arrow.apache.org/docs/python/parquet.html) (sequential access) | With arrayrecord: fully deterministic, resilient to preemption; global shuffle <br>With parquet: performant; fully deterministic, resilient to preemption; hierarchical shuffle | |
25-
| **[Hugging Face](data_input_pipeline/data_input_hf.md)** | datasets in [Hugging Face Hub](https://huggingface.co/datasets)<br>local/Cloud Storage datasets in json, parquet, arrow, csv, txt (sequential access) | no download needed, convenience; <br>multiple formats | limit scalability using the Hugging Face Hub (no limit using Cloud Storage); <br>non-deterministic with preemption<br>(deterministic without preemption)<br> |
26-
| **[TFDS](data_input_pipeline/data_input_tfds.md)** | TFRecord (sequential access), available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview) | performant | only supports TFRecords; <br>non-deterministic with preemption<br>(deterministic without preemption) |
23+
| Pipeline | Dataset formats | Features | Limitations |
24+
| ------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ |
25+
| **[Grain](data_input_pipeline/data_input_grain.md)** (recommended) | [ArrayRecord](https://github.com/google/array_record) (random access, available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview), or [conversion](https://github.com/google/array_record/tree/main/beam))<br>[TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access, available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview))<br>[Parquet](https://arrow.apache.org/docs/python/parquet.html) (sequential access) | With arrayrecord: fully deterministic, resilient to preemption; global shuffle <br>With parquet: performant; fully deterministic, resilient to preemption; hierarchical shuffle | |
26+
| **[Hugging Face](data_input_pipeline/data_input_hf.md)** | datasets in [Hugging Face Hub](https://huggingface.co/datasets)<br>local/Cloud Storage datasets in json, parquet, arrow, csv, txt (sequential access) | no download needed, convenience; <br>multiple formats | limit scalability using the Hugging Face Hub (no limit using Cloud Storage); <br>non-deterministic with preemption<br>(deterministic without preemption)<br> |
27+
| **[TFDS](data_input_pipeline/data_input_tfds.md)** | TFRecord (sequential access), available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview) | performant | only supports TFRecords; <br>non-deterministic with preemption<br>(deterministic without preemption) |
2728

2829
## Multihost dataloading best practice
30+
2931
Training in a multi-host environment presents unique challenges for data input pipelines. An effective data loading strategy must address three key issues:
32+
3033
1. **Concurrent access**: Multiple hosts need to read from the same dataset simultaneously without causing conflicts.
3134
2. **Data uniqueness**: Each host must be fed a unique, non-overlapping subset of the data to ensure the model sees each example correctly.
32-
3. **Uneven completion**: Handling the scenario where some hosts run out of data before others, which can lead to hanging.
33-
The approaches to solve these challenges depend on whether your dataset supports random access or is limited to sequential access.
35+
3. **Uneven completion**: Handling the scenario where some hosts run out of data before others, which can lead to hanging.
36+
The approaches to solve these challenges depend on whether your dataset supports random access or is limited to sequential access.
3437

3538
### Random access dataset (Recommended)
39+
3640
Random-access formats are highly recommended for multi-host training because they allow any part of the file to be read directly by its index.<br>
3741
In MaxText, this is best supported by the ArrayRecord format using the Grain input pipeline. This approach gracefully handles the key challenges:
38-
* **Concurrent access and uniqueness**: Grain assigns a unique set of indices to each host. ArrayRecord allows different hosts to read from different indices in the same file.
3942

40-
* **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag in `src/MaxText/config/base.yml` or through command line arguments. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met.
43+
- **Concurrent access and uniqueness**: Grain assigns a unique set of indices to each host. ArrayRecord allows different hosts to read from different indices in the same file.
44+
45+
- **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag in `src/MaxText/config/base.yml` or through command line arguments. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met.
4146

4247
```{note}
4348
When sequence packing is enabled, the difference in the number of packed examples per host can be larger. The `generate_padding_batch_train`/`generate_padding_batch_eval` flag still solves this.
@@ -48,12 +53,14 @@ If all hosts exhaust their data before the target step count is reached, both `t
4853
```
4954

5055
### Sequential access dataset
51-
* **Concurrent access and uniqueness**: Sequential-access datasets (e.g., Parquet, JSON, TFRecord) cannot be accessed by index, requiring a different strategy -- file-based sharding, where each host is given exclusive access to a specific subset of data files. **Key requirement**: `(Number of data files) % (Number of data-loading hosts) == 0`. If the file count isn't a multiple of the host count, the files will be distributed unevenly. For example, with 10 files and 8 hosts, some hosts will get two files while others get one, significantly worsening the "uneven completion" problem. If you have fewer files than hosts, performance will be severely degraded as all hosts are concurrently accessing all the files.
52-
* **Uneven completion**: Similar to random-access datasets, you can use the `generate_padding_batch_train`/`generate_padding_batch_eval` flag to handle hosts that finish their file shards early.
5356

54-
```{toctree}
55-
:hidden:
57+
- **Concurrent access and uniqueness**: Sequential-access datasets (e.g., Parquet, JSON, TFRecord) cannot be accessed by index, requiring a different strategy -- file-based sharding, where each host is given exclusive access to a specific subset of data files. **Key requirement**: `(Number of data files) % (Number of data-loading hosts) == 0`. If the file count isn't a multiple of the host count, the files will be distributed unevenly. For example, with 10 files and 8 hosts, some hosts will get two files while others get one, significantly worsening the "uneven completion" problem. If you have fewer files than hosts, performance will be severely degraded as all hosts are concurrently accessing all the files.
58+
- **Uneven completion**: Similar to random-access datasets, you can use the `generate_padding_batch_train`/`generate_padding_batch_eval` flag to handle hosts that finish their file shards early.
5659

60+
```{toctree}
61+
---
62+
hidden:
63+
---
5764
data_input_pipeline/data_input_grain
5865
data_input_pipeline/data_input_hf
5966
data_input_pipeline/data_input_tfds

docs/guides/data_input_pipeline/data_input_grain.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ Grain ensures determinism in data input pipelines by saving the pipeline's state
3232

3333
## Using Grain
3434

35-
1. Grain currently supports two data formats: [ArrayRecord](https://github.com/google/array_record) (random access) and [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources.html) class.
35+
1. Grain currently supports three data formats: [ArrayRecord](https://github.com/google/array_record) (random access), [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups) and [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources.html) class.
3636
- **Community Resource**: The MaxText community has created a [ArrayRecord Documentation](https://array-record.readthedocs.io/). Note: we appreciate the contribution from the community, but as of now it has not been verified by the MaxText or ArrayRecord developers yet.
37-
2. When the dataset is hosted on a Cloud Storage bucket, Grain can read it through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh). The script configures some parameters for the mount.
37+
2. If the dataset is hosted on a Cloud Storage bucket, the path `gs://` can be provided directly. However, for the best performance, it's recommended to read the bucket through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). This will significantly improve the perf for the ArrayRecord format as it allows meta data caching to speeds up random access. The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh). The script configures some parameters for the mount.
3838

3939
```sh
4040
bash tools/setup/setup_gcsfuse.sh \
@@ -45,7 +45,7 @@ MOUNT_PATH=${MOUNT_PATH?} \
4545

4646
Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pre-filling the metadata cache (see ["Performance tuning best practices" on the Google Cloud documentation](https://cloud.google.com/storage/docs/cloud-storage-fuse/performance#improve-first-time-reads)).
4747

48-
1. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet}`, `grain_train_files` in `src/maxtext/configs/base.yml` or through command line arguments to match the file pattern on the mounted local path.
48+
1. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet|tfrecord}`, `grain_train_files` in `src/maxtext/configs/base.yml` or through command line arguments to match the file pattern on the mounted local path.
4949

5050
2. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html), [grain_pool.py](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh) to avoid gcsfuse throttling.
5151

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,7 @@ grain_ram_budget_mb: 1024 # RAM budget (MB) for auto-tuning worker count. Only u
679679
grain_num_threads_eval: 16
680680
grain_prefetch_buffer_size_eval: 500
681681
grain_data_source_max_workers: 16 # Max workers for ThreadPoolExecutor when mixing multiple Grain data sources.
682+
grain_shuffle_buffer_size: 100 # shuffle buffer when using sequential access formats such as Parquet, TFRecord.
682683
# for using pathways
683684
colocated_python_data_input: False # experimental feature, under testing
684685

src/maxtext/configs/types.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,17 +1030,13 @@ class GrainDataset(BaseModel):
10301030
"",
10311031
description="Path to a JSON file specifying the mixture weights for Grain training data.",
10321032
)
1033-
grain_file_type: str = Field("arrayrecord", description="File type for Grain data.")
1034-
grain_worker_count: int = Field(1, description="Number of workers for Grain data loading.")
1035-
grain_per_worker_buffer_size: int = Field(
1036-
1,
1037-
description="Buffer size for each worker for Grain data loading during training.",
1033+
grain_file_type: str = Field(
1034+
"arrayrecord", description="File type for Grain data. Supported: arrayrecord, tfrecord, parquet."
10381035
)
1036+
grain_worker_count: int = Field(1, description="Number of workers for Grain data loading.")
1037+
grain_per_worker_buffer_size: int = Field(1, description="Per-worker buffer size for Grain train data loading.")
10391038
grain_worker_count_eval: int = Field(1, description="Number of workers for Grain eval data loading.")
1040-
grain_per_worker_buffer_size_eval: int = Field(
1041-
1,
1042-
description="Buffer size for each worker for Grain data loading during evaluation.",
1043-
)
1039+
grain_per_worker_buffer_size_eval: int = Field(1, description="Per-worker buffer size for Grain eval data loading.")
10441040
grain_ram_budget_mb: int = Field(1024, description="RAM budget (MB) for auto-tuning worker count.")
10451041
grain_num_threads: int = Field(16, description="Number of threads for Grain ReadOptions during training.")
10461042
grain_prefetch_buffer_size: int = Field(500, description="Prefetch buffer size for Grain ReadOptions during training.")
@@ -1052,6 +1048,7 @@ class GrainDataset(BaseModel):
10521048
16,
10531049
description="Max workers for ThreadPoolExecutor when mixing multiple Grain data sources.",
10541050
)
1051+
grain_shuffle_buffer_size: int = Field(100, description="Shuffle buffer size when using Parquet or TFRecord.")
10551052

10561053

10571054
class FineTuning(BaseModel):

0 commit comments

Comments
 (0)