3232
3333import grain .python as grain
3434
35- from MaxText .input_pipeline import input_pipeline_interface
36- from MaxText .input_pipeline import _input_pipeline_utils
35+ from maxtext .input_pipeline import input_pipeline_interface
36+ from maxtext .input_pipeline import input_pipeline_utils
3737
3838
3939class SingleHostDataLoader :
@@ -141,7 +141,7 @@ def preprocessing_pipeline(
141141 )
142142
143143 dataset = dataset .map (
144- _input_pipeline_utils .tokenization ,
144+ input_pipeline_utils .tokenization ,
145145 batched = True ,
146146 fn_kwargs = {
147147 "hf_tokenizer" : tokenizer ,
@@ -151,7 +151,7 @@ def preprocessing_pipeline(
151151 },
152152 )
153153 dataset = dataset .select_columns (data_column_names )
154- dataset = _input_pipeline_utils .HFDataSource (
154+ dataset = input_pipeline_utils .HFDataSource (
155155 dataset ,
156156 dataloading_host_index ,
157157 dataloading_host_count ,
@@ -166,7 +166,7 @@ def lists2array(x):
166166
167167 operations = [
168168 grain .MapOperation (lists2array ),
169- _input_pipeline_utils .PadOrTrimToMaxLength (max_target_length , add_true_length = True ),
169+ input_pipeline_utils .PadOrTrimToMaxLength (max_target_length , add_true_length = True ),
170170 grain .Batch (batch_size = global_batch_size // jax .process_count (), drop_remainder = drop_remainder ),
171171 ]
172172
0 commit comments