Skip to content

Commit 33209a8

Browse files
Reverts 885cd56
PiperOrigin-RevId: 864609609
1 parent 885cd56 commit 33209a8

39 files changed

Lines changed: 170 additions & 158 deletions

src/MaxText/experimental/rl/grpo_input_pipeline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434

3535
import grain.python as grain
3636

37-
from maxtext.input_pipeline import input_pipeline_interface
38-
from maxtext.input_pipeline import input_pipeline_utils
37+
from MaxText.input_pipeline import input_pipeline_interface
38+
from MaxText.input_pipeline import _input_pipeline_utils
3939

4040

4141
class SingleHostDataLoader:
@@ -143,7 +143,7 @@ def preprocessing_pipeline(
143143
)
144144

145145
dataset = dataset.map(
146-
input_pipeline_utils.tokenization,
146+
_input_pipeline_utils.tokenization,
147147
batched=True,
148148
fn_kwargs={
149149
"hf_tokenizer": tokenizer,
@@ -153,7 +153,7 @@ def preprocessing_pipeline(
153153
},
154154
)
155155
dataset = dataset.select_columns(data_column_names)
156-
dataset = input_pipeline_utils.HFDataSource(
156+
dataset = _input_pipeline_utils.HFDataSource(
157157
dataset,
158158
dataloading_host_index,
159159
dataloading_host_count,
@@ -168,7 +168,7 @@ def lists2array(x):
168168

169169
operations = [
170170
grain.MapOperation(lists2array),
171-
input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, add_true_length=True),
171+
_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, add_true_length=True),
172172
grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder),
173173
]
174174

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2026 Google LLC
1+
# Copyright 2023–2025 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

src/maxtext/input_pipeline/distillation_data_processing.py renamed to src/MaxText/input_pipeline/_distillation_data_processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import datasets
2727

28-
from maxtext.input_pipeline import input_pipeline_utils
28+
from MaxText.input_pipeline import _input_pipeline_utils
2929
from maxtext.utils import max_logging
3030

3131

@@ -85,7 +85,7 @@ def process_dataset(config, dataset): # pylint: disable=redefined-outer-name
8585
assert any(
8686
set(data_column_names) == set(supported) for supported in supported_columns
8787
), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_column_names}"
88-
assert input_pipeline_utils.is_conversational(
88+
assert _input_pipeline_utils.is_conversational(
8989
dataset.features, data_column_names
9090
), "Dataset is not in conversational format."
9191

src/maxtext/input_pipeline/grain_data_processing.py renamed to src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
from grain.experimental import BestFitPackIterDataset, pick_performance_config
2727
import grain.python as grain
2828

29-
from maxtext.input_pipeline import input_pipeline_utils
30-
from maxtext.input_pipeline import grain_tokenizer
31-
from maxtext.input_pipeline import multihost_dataloading
32-
from maxtext.input_pipeline import tokenizer
29+
from MaxText.input_pipeline import _input_pipeline_utils
30+
from MaxText.input_pipeline import _grain_tokenizer
31+
from MaxText import multihost_dataloading
32+
from MaxText import tokenizer
3333
from maxtext.utils import gcs_utils
3434
from maxtext.utils import max_logging
3535

@@ -199,10 +199,10 @@ def pretrain_preprocessing_pipeline(
199199
):
200200
"""Use grain pipeline to pre-process the dataset and return iterators for pretrain"""
201201
if config.grain_file_type == "arrayrecord":
202-
dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize))
203-
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
202+
dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize))
203+
dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
204204
else:
205-
dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns))
205+
dataset = dataset.map(_input_pipeline_utils.KeepFeatures(feature_names=data_columns))
206206

207207
assert len(data_columns) == 1
208208
text_column = data_columns[0]
@@ -224,13 +224,13 @@ def pretrain_preprocessing_pipeline(
224224

225225
if tokenize:
226226
if config.use_truncation:
227-
dataset = dataset.map(grain_tokenizer.TokenizeAndTrim(text_column, config.max_target_length, tokenizer_model))
227+
dataset = dataset.map(_grain_tokenizer.TokenizeAndTrim(text_column, config.max_target_length, tokenizer_model))
228228
else:
229-
dataset = dataset.apply(grain_tokenizer.TokenizeAndChunk(text_column, config.max_target_length, tokenizer_model))
229+
dataset = dataset.apply(_grain_tokenizer.TokenizeAndChunk(text_column, config.max_target_length, tokenizer_model))
230230

231231
data_columns = ("inputs", "targets")
232232
rekey_dict = {col: text_column for col in data_columns}
233-
dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict))
233+
dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict))
234234

235235
# Pack and Batch examples.
236236
batch_size = config.global_batch_size_to_load // jax.process_count()
@@ -273,15 +273,15 @@ def pretrain_preprocessing_pipeline(
273273
"targets_position": "targets_positions",
274274
"inputs_position": "inputs_positions",
275275
}
276-
dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict))
276+
dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict))
277277
else:
278-
dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
278+
dataset = dataset.map(_input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
279279
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
280280
dataset = dataset.batch(batch_size, batch_fn=batch_fn)
281281

282282
# Shift inputs for teacher-forced training
283283
dataset = dataset.map(
284-
input_pipeline_utils.ShiftData(
284+
_input_pipeline_utils.ShiftData(
285285
ignored_ids=[pad_id],
286286
axis=1,
287287
)
@@ -313,8 +313,8 @@ def dpo_preprocessing_pipeline(
313313
):
314314
"""Use grain to pre-process the dataset and return iterators for dpo fine-tuning"""
315315
if config.grain_file_type == "arrayrecord":
316-
dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize))
317-
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
316+
dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize))
317+
dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
318318
tokenizer_model = tokenizer.build_tokenizer(
319319
config.tokenizer_path,
320320
config.tokenizer_type,
@@ -331,9 +331,9 @@ def dpo_preprocessing_pipeline(
331331
pad_id = -1
332332

333333
if tokenize:
334-
dataset = dataset.map(grain_tokenizer.TokenizeAndTrim(data_columns, config.max_target_length, tokenizer_model))
334+
dataset = dataset.map(_grain_tokenizer.TokenizeAndTrim(data_columns, config.max_target_length, tokenizer_model))
335335

336-
dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
336+
dataset = dataset.map(_input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
337337
batch_size = config.global_batch_size_to_load // jax.process_count()
338338
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
339339
dataset = dataset.batch(batch_size, batch_fn=batch_fn)

src/maxtext/input_pipeline/grain_tokenizer.py renamed to src/MaxText/input_pipeline/_grain_tokenizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Any
2121
import grain.python as grain
2222
import numpy as np
23-
from maxtext.input_pipeline import tokenizer
23+
from MaxText import tokenizer
2424

2525

2626
@dataclasses.dataclass

src/maxtext/input_pipeline/hf_data_processing.py renamed to src/MaxText/input_pipeline/_hf_data_processing.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626

2727
import numpy as np
2828

29-
from maxtext.input_pipeline import input_pipeline_utils
30-
from maxtext.input_pipeline import instruction_data_processing
31-
from maxtext.input_pipeline import multihost_dataloading
29+
from MaxText.input_pipeline import _input_pipeline_utils
30+
from MaxText.input_pipeline import instruction_data_processing
31+
from MaxText import multihost_dataloading
3232

3333

3434
def _get_pad_id(tokenizer):
@@ -61,7 +61,7 @@ def vision_sft_preprocessing_pipeline(
6161
# If multiple image columns are provided, merge them into a single 'images' column.
6262
if isinstance(image_column, list):
6363
dataset = dataset.map(
64-
input_pipeline_utils.merge_image_columns,
64+
_input_pipeline_utils.merge_image_columns,
6565
fn_kwargs={
6666
"image_columns": image_column,
6767
"max_num_images_per_example": config.max_num_images_per_example,
@@ -75,20 +75,20 @@ def vision_sft_preprocessing_pipeline(
7575
dataset = dataset.rename_column(image_column, "images")
7676

7777
dataset = dataset.map(
78-
input_pipeline_utils.reformat_prompt,
78+
_input_pipeline_utils.reformat_prompt,
7979
fn_kwargs={
8080
"column": text_columns[0],
8181
"image_placeholder": config.image_placeholder,
8282
"model_name": config.model_name,
8383
},
8484
)
8585
dataset = dataset.map(
86-
input_pipeline_utils.reformat_response,
86+
_input_pipeline_utils.reformat_response,
8787
fn_kwargs={"column": text_columns[1], "model_name": config.model_name},
8888
)
8989

9090
dataset = dataset.map(
91-
input_pipeline_utils.pre_process_image_sft,
91+
_input_pipeline_utils.pre_process_image_sft,
9292
fn_kwargs={"image_column": "images", "model_name": config.model_name},
9393
)
9494

@@ -102,7 +102,7 @@ def vision_sft_preprocessing_pipeline(
102102
pad_id = _get_pad_id(tokenizer)
103103

104104
dataset = dataset.map(
105-
input_pipeline_utils.tokenization,
105+
_input_pipeline_utils.tokenization,
106106
batched=True,
107107
batch_size=global_batch_size,
108108
fn_kwargs={
@@ -113,11 +113,11 @@ def vision_sft_preprocessing_pipeline(
113113
},
114114
)
115115
dataset = dataset.map(
116-
input_pipeline_utils.prepare_text_for_image_fusion,
116+
_input_pipeline_utils.prepare_text_for_image_fusion,
117117
fn_kwargs={"column_name": text_columns[0], "model_name": config.model_name},
118118
)
119119

120-
dataset = input_pipeline_utils.HFDataSource(
120+
dataset = _input_pipeline_utils.HFDataSource(
121121
dataset=dataset,
122122
dataloading_host_index=dataloading_host_index,
123123
dataloading_host_count=dataloading_host_count,
@@ -127,7 +127,7 @@ def vision_sft_preprocessing_pipeline(
127127
)
128128
operations = []
129129
operations.append(
130-
input_pipeline_utils.SFTPromptMaskingVision(
130+
_input_pipeline_utils.SFTPromptMaskingVision(
131131
query_column=text_columns[0],
132132
response_column=text_columns[1],
133133
max_target_length=config.max_target_length,
@@ -136,17 +136,17 @@ def vision_sft_preprocessing_pipeline(
136136
)
137137
# TODO(aireenmei, hengtaoguo): support packing
138138
operations.append(
139-
input_pipeline_utils.PadOrTrimToMaxLength(
139+
_input_pipeline_utils.PadOrTrimToMaxLength(
140140
config.max_target_length,
141141
pad_id,
142142
model_name=config.model_name,
143143
max_num_images_per_example=config.max_num_images_per_example,
144144
)
145145
)
146-
operations.append(input_pipeline_utils.ExtractImagesAndMasks())
146+
operations.append(_input_pipeline_utils.ExtractImagesAndMasks())
147147
operations.append(grain.Batch(batch_size=batch_size, drop_remainder=True))
148-
operations.append(input_pipeline_utils.FoldImagesIntoBatch(model_name=config.model_name))
149-
operations.append(input_pipeline_utils.ShiftData(ignored_ids=[pad_id], axis=1))
148+
operations.append(_input_pipeline_utils.FoldImagesIntoBatch(model_name=config.model_name))
149+
operations.append(_input_pipeline_utils.ShiftData(ignored_ids=[pad_id], axis=1))
150150
dummy_index_sampler = grain.IndexSampler(
151151
num_records=len(dataset),
152152
num_epochs=1,
@@ -227,7 +227,7 @@ def preprocessing_pipeline(
227227
dataset=dataset, data_columns=data_column_names, chat_template_path=chat_template_path
228228
)
229229

230-
assert input_pipeline_utils.is_conversational(
230+
assert _input_pipeline_utils.is_conversational(
231231
dataset.features, data_column_names
232232
), "Dataset is not in conversational format."
233233

@@ -237,15 +237,15 @@ def preprocessing_pipeline(
237237
{combined_column_name: [{"content": datasets.Value(dtype="string"), "role": datasets.Value(dtype="string")}]}
238238
)
239239
dataset = dataset.map(
240-
input_pipeline_utils.combine_columns,
240+
_input_pipeline_utils.combine_columns,
241241
fn_kwargs={"columns": data_column_names, "data_column": combined_column_name},
242242
remove_columns=data_column_names,
243243
features=dataset_features,
244244
)
245245

246246
data_column_names = list(dataset.features.keys())
247247
dataset = dataset.map(
248-
input_pipeline_utils.apply_chat_template,
248+
_input_pipeline_utils.apply_chat_template,
249249
fn_kwargs={"tokenizer_model": tokenizer, "data_column_name": data_column_names[0]},
250250
)
251251
else:
@@ -255,7 +255,7 @@ def preprocessing_pipeline(
255255

256256
if tokenize:
257257
dataset = dataset.map(
258-
input_pipeline_utils.tokenization,
258+
_input_pipeline_utils.tokenization,
259259
batched=True,
260260
fn_kwargs={
261261
"hf_tokenizer": tokenizer,
@@ -265,7 +265,7 @@ def preprocessing_pipeline(
265265
},
266266
)
267267

268-
dataset = input_pipeline_utils.HFDataSource(
268+
dataset = _input_pipeline_utils.HFDataSource(
269269
dataset,
270270
dataloading_host_index,
271271
dataloading_host_count,
@@ -276,7 +276,7 @@ def preprocessing_pipeline(
276276
operations = []
277277
if use_sft:
278278
operations.append(
279-
input_pipeline_utils.SFTPromptMasking(
279+
_input_pipeline_utils.SFTPromptMasking(
280280
text_column_name=data_column_names[0],
281281
completion_only=sft_train_on_completion_only,
282282
max_target_length=max_target_length,
@@ -293,7 +293,7 @@ def lists2array(x):
293293
operations.append(grain.MapOperation(lists2array))
294294
else:
295295
assert len(data_column_names) == 1
296-
operations.append(input_pipeline_utils.HFNormalizeFeatures(data_column_names[0]))
296+
operations.append(_input_pipeline_utils.HFNormalizeFeatures(data_column_names[0]))
297297
data_column_names = ("inputs", "targets")
298298

299299
if packing and not use_dpo:
@@ -308,13 +308,13 @@ def lists2array(x):
308308
max_sequences_per_bin=max_segments,
309309
)
310310
)
311-
operations.append(input_pipeline_utils.ReformatPacking(data_column_names))
311+
operations.append(_input_pipeline_utils.ReformatPacking(data_column_names))
312312
else:
313-
operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id))
313+
operations.append(_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id))
314314
operations.append(grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder))
315315

316316
if shift and not use_dpo:
317-
operations.append(input_pipeline_utils.ShiftData(ignored_ids=[pad_id, tokenizer.bos_token_id], axis=1))
317+
operations.append(_input_pipeline_utils.ShiftData(ignored_ids=[pad_id, tokenizer.bos_token_id], axis=1))
318318

319319
# Since HuggingFace IterableDataset does not support access through index
320320
# Indexes generated by dummy_index_sampler is not used.

src/maxtext/input_pipeline/input_pipeline_utils.py renamed to src/MaxText/input_pipeline/_input_pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
import grain.python as grain
2424
import numpy as np
2525
import tensorflow as tf
26+
from MaxText import tokenizer
2627
from MaxText import multimodal_utils
27-
from maxtext.input_pipeline import tokenizer
2828
from maxtext.utils import max_logging
2929

3030
Features = dict[str, tf.Tensor]

0 commit comments

Comments
 (0)