Skip to content

Commit 535be22

Browse files
committed
Implement Best Fit Packing Algorithm in Grain Pipeline for Reduced Padding
Signed-off-by: bzantium <ryumin93@gmail.com>
1 parent 08216c6 commit 535be22

4 files changed

Lines changed: 41 additions & 5 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,10 +604,10 @@ grain_train_files: ''
604604
grain_eval_files: ''
605605
grain_train_mixture_config_path: '' # Path to a JSON file specifying the mixture weights for Grain training data.
606606
grain_file_type: 'arrayrecord' # arrayrecord or parquet
607-
grain_packing_type: 'first_fit' # 'first_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html
607+
grain_packing_type: 'first_fit' # 'first_fit', 'best_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html
608608
grain_worker_count: 1 # Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html
609609
grain_per_worker_buffer_size: 1
610-
# num_threads and prefetch_buffer_size are per-worker per-dataset.
610+
# num_threads and prefetch_buffer_size are per-worker per-dataset.
611611
# When using array_records, they are used in ReadOptions (https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#per-worker-readoptions)
612612
# The default value matches that in the Grain package. If mixing multiple data sources, consider lowering these values to reduce memory usage.
613613
# When using parquet, grain_num_threads is the number of files to read and interleave in parallel

src/MaxText/configs/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -874,9 +874,9 @@ class DatasetGeneral(BaseModel):
874874
True,
875875
description="Whether to pack multiple short examples into a single sequence.",
876876
)
877-
grain_packing_type: Literal["first_fit", "concat_then_split"] = Field(
877+
grain_packing_type: Literal["first_fit", "best_fit", "concat_then_split"] = Field(
878878
"first_fit",
879-
description="Packing type when using Grain pipeline. 'first_fit' or 'concat_then_split'.",
879+
description="Packing type when using Grain pipeline. 'first_fit', 'best_fit' or 'concat_then_split'.",
880880
)
881881
max_segments_per_seq: int = Field(
882882
32,

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import jax
2525

26-
from grain.experimental import pick_performance_config
26+
from grain.experimental import BestFitPackIterDataset, pick_performance_config
2727
import grain.python as grain
2828

2929
from MaxText.utils import gcs_utils
@@ -246,6 +246,8 @@ def pretrain_preprocessing_pipeline(
246246
dataset = grain.experimental.FirstFitPackIterDataset(
247247
dataset, length_struct=length_struct, num_packing_bins=batch_size
248248
)
249+
elif config.grain_packing_type == "best_fit":
250+
dataset = BestFitPackIterDataset(dataset, length_struct=length_struct, num_packing_bins=batch_size)
249251
elif config.grain_packing_type == "concat_then_split":
250252
if config.add_bos and hasattr(tokenizer_model, "bos_id"):
251253
dataset = grain.experimental.ConcatThenSplitIterDataset(

tests/grain_data_processing_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,40 @@ def test_batch_determinism(self):
227227
super().test_batch_determinism()
228228

229229

230+
class GrainArrayRecordBestFitPackingTest(GrainArrayRecordProcessingTest):
231+
"""Test grain data processing with best_fit packing strategy."""
232+
233+
def setUp(self):
234+
super().setUp()
235+
temp_dir = tempfile.gettempdir()
236+
self.config = pyconfig.initialize(
237+
[sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
238+
per_device_batch_size=1,
239+
run_name="test",
240+
mesh_axes=["data"],
241+
logical_axis_rules=[["batch", "data"]],
242+
data_sharding=["data"],
243+
base_output_directory="gs://max-experiments/",
244+
dataset_type="grain",
245+
grain_train_files=os.path.join(
246+
temp_dir, "gcsfuse", "array-record", "c4", "en", "3.0.1", "c4-train.array_record*"
247+
),
248+
grain_packing_type="best_fit", # Use best_fit packing
249+
tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"),
250+
enable_checkpointing=False,
251+
)
252+
self.mesh_shape_1d = (len(jax.devices()),)
253+
self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes)
254+
self.process_indices = input_pipeline_interface.get_process_loading_real_data(
255+
self.config.data_sharding,
256+
self.config.global_batch_size_to_load,
257+
self.config.global_batch_size_to_train_on,
258+
self.config.max_target_length,
259+
self.mesh,
260+
)
261+
self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)
262+
263+
230264
class GrainParquetProcessingTest(unittest.TestCase):
231265

232266
@classmethod

0 commit comments

Comments
 (0)