3434
3535import grain .python as grain
3636
37- from MaxText .input_pipeline import input_pipeline_interface
38- from MaxText .input_pipeline import _input_pipeline_utils
37+ from maxtext .input_pipeline import input_pipeline_interface
38+ from maxtext .input_pipeline import input_pipeline_utils
3939
4040
4141class SingleHostDataLoader :
@@ -143,7 +143,7 @@ def preprocessing_pipeline(
143143 )
144144
145145 dataset = dataset .map (
146- _input_pipeline_utils .tokenization ,
146+ input_pipeline_utils .tokenization ,
147147 batched = True ,
148148 fn_kwargs = {
149149 "hf_tokenizer" : tokenizer ,
@@ -153,7 +153,7 @@ def preprocessing_pipeline(
153153 },
154154 )
155155 dataset = dataset .select_columns (data_column_names )
156- dataset = _input_pipeline_utils .HFDataSource (
156+ dataset = input_pipeline_utils .HFDataSource (
157157 dataset ,
158158 dataloading_host_index ,
159159 dataloading_host_count ,
@@ -168,7 +168,7 @@ def lists2array(x):
168168
169169 operations = [
170170 grain .MapOperation (lists2array ),
171- _input_pipeline_utils .PadOrTrimToMaxLength (max_target_length , add_true_length = True ),
171+ input_pipeline_utils .PadOrTrimToMaxLength (max_target_length , add_true_length = True ),
172172 grain .Batch (batch_size = global_batch_size // jax .process_count (), drop_remainder = drop_remainder ),
173173 ]
174174
0 commit comments