2626
2727import numpy as np
2828
29- from maxtext .input_pipeline import input_pipeline_utils
30- from maxtext .input_pipeline import instruction_data_processing
31- from maxtext . input_pipeline import multihost_dataloading
29+ from MaxText .input_pipeline import _input_pipeline_utils
30+ from MaxText .input_pipeline import instruction_data_processing
31+ from MaxText import multihost_dataloading
3232
3333
3434def _get_pad_id (tokenizer ):
@@ -61,7 +61,7 @@ def vision_sft_preprocessing_pipeline(
6161 # If multiple image columns are provided, merge them into a single 'images' column.
6262 if isinstance (image_column , list ):
6363 dataset = dataset .map (
64- input_pipeline_utils .merge_image_columns ,
64+ _input_pipeline_utils .merge_image_columns ,
6565 fn_kwargs = {
6666 "image_columns" : image_column ,
6767 "max_num_images_per_example" : config .max_num_images_per_example ,
@@ -75,20 +75,20 @@ def vision_sft_preprocessing_pipeline(
7575 dataset = dataset .rename_column (image_column , "images" )
7676
7777 dataset = dataset .map (
78- input_pipeline_utils .reformat_prompt ,
78+ _input_pipeline_utils .reformat_prompt ,
7979 fn_kwargs = {
8080 "column" : text_columns [0 ],
8181 "image_placeholder" : config .image_placeholder ,
8282 "model_name" : config .model_name ,
8383 },
8484 )
8585 dataset = dataset .map (
86- input_pipeline_utils .reformat_response ,
86+ _input_pipeline_utils .reformat_response ,
8787 fn_kwargs = {"column" : text_columns [1 ], "model_name" : config .model_name },
8888 )
8989
9090 dataset = dataset .map (
91- input_pipeline_utils .pre_process_image_sft ,
91+ _input_pipeline_utils .pre_process_image_sft ,
9292 fn_kwargs = {"image_column" : "images" , "model_name" : config .model_name },
9393 )
9494
@@ -102,7 +102,7 @@ def vision_sft_preprocessing_pipeline(
102102 pad_id = _get_pad_id (tokenizer )
103103
104104 dataset = dataset .map (
105- input_pipeline_utils .tokenization ,
105+ _input_pipeline_utils .tokenization ,
106106 batched = True ,
107107 batch_size = global_batch_size ,
108108 fn_kwargs = {
@@ -113,11 +113,11 @@ def vision_sft_preprocessing_pipeline(
113113 },
114114 )
115115 dataset = dataset .map (
116- input_pipeline_utils .prepare_text_for_image_fusion ,
116+ _input_pipeline_utils .prepare_text_for_image_fusion ,
117117 fn_kwargs = {"column_name" : text_columns [0 ], "model_name" : config .model_name },
118118 )
119119
120- dataset = input_pipeline_utils .HFDataSource (
120+ dataset = _input_pipeline_utils .HFDataSource (
121121 dataset = dataset ,
122122 dataloading_host_index = dataloading_host_index ,
123123 dataloading_host_count = dataloading_host_count ,
@@ -127,7 +127,7 @@ def vision_sft_preprocessing_pipeline(
127127 )
128128 operations = []
129129 operations .append (
130- input_pipeline_utils .SFTPromptMaskingVision (
130+ _input_pipeline_utils .SFTPromptMaskingVision (
131131 query_column = text_columns [0 ],
132132 response_column = text_columns [1 ],
133133 max_target_length = config .max_target_length ,
@@ -136,17 +136,17 @@ def vision_sft_preprocessing_pipeline(
136136 )
137137 # TODO(aireenmei, hengtaoguo): support packing
138138 operations .append (
139- input_pipeline_utils .PadOrTrimToMaxLength (
139+ _input_pipeline_utils .PadOrTrimToMaxLength (
140140 config .max_target_length ,
141141 pad_id ,
142142 model_name = config .model_name ,
143143 max_num_images_per_example = config .max_num_images_per_example ,
144144 )
145145 )
146- operations .append (input_pipeline_utils .ExtractImagesAndMasks ())
146+ operations .append (_input_pipeline_utils .ExtractImagesAndMasks ())
147147 operations .append (grain .Batch (batch_size = batch_size , drop_remainder = True ))
148- operations .append (input_pipeline_utils .FoldImagesIntoBatch (model_name = config .model_name ))
149- operations .append (input_pipeline_utils .ShiftData (ignored_ids = [pad_id ], axis = 1 ))
148+ operations .append (_input_pipeline_utils .FoldImagesIntoBatch (model_name = config .model_name ))
149+ operations .append (_input_pipeline_utils .ShiftData (ignored_ids = [pad_id ], axis = 1 ))
150150 dummy_index_sampler = grain .IndexSampler (
151151 num_records = len (dataset ),
152152 num_epochs = 1 ,
@@ -227,7 +227,7 @@ def preprocessing_pipeline(
227227 dataset = dataset , data_columns = data_column_names , chat_template_path = chat_template_path
228228 )
229229
230- assert input_pipeline_utils .is_conversational (
230+ assert _input_pipeline_utils .is_conversational (
231231 dataset .features , data_column_names
232232 ), "Dataset is not in conversational format."
233233
@@ -237,15 +237,15 @@ def preprocessing_pipeline(
237237 {combined_column_name : [{"content" : datasets .Value (dtype = "string" ), "role" : datasets .Value (dtype = "string" )}]}
238238 )
239239 dataset = dataset .map (
240- input_pipeline_utils .combine_columns ,
240+ _input_pipeline_utils .combine_columns ,
241241 fn_kwargs = {"columns" : data_column_names , "data_column" : combined_column_name },
242242 remove_columns = data_column_names ,
243243 features = dataset_features ,
244244 )
245245
246246 data_column_names = list (dataset .features .keys ())
247247 dataset = dataset .map (
248- input_pipeline_utils .apply_chat_template ,
248+ _input_pipeline_utils .apply_chat_template ,
249249 fn_kwargs = {"tokenizer_model" : tokenizer , "data_column_name" : data_column_names [0 ]},
250250 )
251251 else :
@@ -255,7 +255,7 @@ def preprocessing_pipeline(
255255
256256 if tokenize :
257257 dataset = dataset .map (
258- input_pipeline_utils .tokenization ,
258+ _input_pipeline_utils .tokenization ,
259259 batched = True ,
260260 fn_kwargs = {
261261 "hf_tokenizer" : tokenizer ,
@@ -265,7 +265,7 @@ def preprocessing_pipeline(
265265 },
266266 )
267267
268- dataset = input_pipeline_utils .HFDataSource (
268+ dataset = _input_pipeline_utils .HFDataSource (
269269 dataset ,
270270 dataloading_host_index ,
271271 dataloading_host_count ,
@@ -276,7 +276,7 @@ def preprocessing_pipeline(
276276 operations = []
277277 if use_sft :
278278 operations .append (
279- input_pipeline_utils .SFTPromptMasking (
279+ _input_pipeline_utils .SFTPromptMasking (
280280 text_column_name = data_column_names [0 ],
281281 completion_only = sft_train_on_completion_only ,
282282 max_target_length = max_target_length ,
@@ -293,7 +293,7 @@ def lists2array(x):
293293 operations .append (grain .MapOperation (lists2array ))
294294 else :
295295 assert len (data_column_names ) == 1
296- operations .append (input_pipeline_utils .HFNormalizeFeatures (data_column_names [0 ]))
296+ operations .append (_input_pipeline_utils .HFNormalizeFeatures (data_column_names [0 ]))
297297 data_column_names = ("inputs" , "targets" )
298298
299299 if packing and not use_dpo :
@@ -308,13 +308,13 @@ def lists2array(x):
308308 max_sequences_per_bin = max_segments ,
309309 )
310310 )
311- operations .append (input_pipeline_utils .ReformatPacking (data_column_names ))
311+ operations .append (_input_pipeline_utils .ReformatPacking (data_column_names ))
312312 else :
313- operations .append (input_pipeline_utils .PadOrTrimToMaxLength (max_target_length , pad_id ))
313+ operations .append (_input_pipeline_utils .PadOrTrimToMaxLength (max_target_length , pad_id ))
314314 operations .append (grain .Batch (batch_size = global_batch_size // jax .process_count (), drop_remainder = drop_remainder ))
315315
316316 if shift and not use_dpo :
317- operations .append (input_pipeline_utils .ShiftData (ignored_ids = [pad_id , tokenizer .bos_token_id ], axis = 1 ))
317+ operations .append (_input_pipeline_utils .ShiftData (ignored_ids = [pad_id , tokenizer .bos_token_id ], axis = 1 ))
318318
319319 # Since HuggingFace IterableDataset does not support access through index
320320 # Indexes generated by dummy_index_sampler is not used.
0 commit comments