@@ -78,9 +78,18 @@ def make_tf_iterator(
7878 train_iter = multihost_dataloading .MultiHostDataLoadIterator (train_ds , mesh )
7979 return train_iter
8080
81+
8182# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
8283def _make_tfrecord_iterator (
83- config , dataloading_host_index , dataloading_host_count , mesh , global_batch_size , feature_description_fn , prepare_sample_fn , dataset_path , is_training : bool
84+ config ,
85+ dataloading_host_index ,
86+ dataloading_host_count ,
87+ mesh ,
88+ global_batch_size ,
89+ feature_description_fn ,
90+ prepare_sample_fn ,
91+ dataset_path ,
92+ is_training : bool ,
8493):
8594 # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
8695 # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
@@ -93,10 +102,10 @@ def _make_tfrecord_iterator(
93102 # Determine whether to use the "cached" dataset, which requires externally
94103 # provided parsing functions, or the default one with its internal parsing logic.
95104 make_cached_tfrecord_iterator = (
96- config .cache_latents_text_encoder_outputs
97- and is_dataset_dir_valid
98- and "load_tfrecord_cached" in config .get_keys ()
99- and config .load_tfrecord_cached
105+ config .cache_latents_text_encoder_outputs
106+ and is_dataset_dir_valid
107+ and "load_tfrecord_cached" in config .get_keys ()
108+ and config .load_tfrecord_cached
100109 )
101110
102111 feature_description = {
@@ -121,42 +130,47 @@ def prepare_sample(features):
121130 if not is_training :
122131 num_eval_samples = 0
123132 for _ in ds :
124- num_eval_samples += 1
133+ num_eval_samples += 1
125134
126135 remainder = num_eval_samples % global_batch_size
127136 if remainder != 0 :
128- num_to_pad = global_batch_size - remainder
129- # Create a dataset of padding samples from the beginning
130- padding_ds = ds .take (num_to_pad )
131- # Add the padding samples to the end
132- ds = ds .concatenate (padding_ds )
133- max_logging .log (f"Padded evaluation dataset with { num_to_pad } samples." )
137+ num_to_pad = global_batch_size - remainder
138+ # Create a dataset of padding samples from the beginning
139+ padding_ds = ds .take (num_to_pad )
140+ # Add the padding samples to the end
141+ ds = ds .concatenate (padding_ds )
142+ max_logging .log (f"Padded evaluation dataset with { num_to_pad } samples." )
134143
135144 used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
136145 ds = (
137- ds .shard (num_shards = dataloading_host_count , index = dataloading_host_index )
138- .map (_parse_tfrecord_fn , num_parallel_calls = AUTOTUNE )
139- .map (used_prepare_sample , num_parallel_calls = AUTOTUNE )
146+ ds .shard (num_shards = dataloading_host_count , index = dataloading_host_index )
147+ .map (_parse_tfrecord_fn , num_parallel_calls = AUTOTUNE )
148+ .map (used_prepare_sample , num_parallel_calls = AUTOTUNE )
140149 )
141150 if is_training :
142151 ds = (
143- ds .shuffle (global_batch_size * 10 )
144- .batch (global_batch_size // dataloading_host_count , drop_remainder = True )
145- .repeat (- 1 )
146- .prefetch (AUTOTUNE )
152+ ds .shuffle (global_batch_size * 10 )
153+ .batch (global_batch_size // dataloading_host_count , drop_remainder = True )
154+ .repeat (- 1 )
155+ .prefetch (AUTOTUNE )
147156 )
148157 # For Evaluation
149158 else :
150- ds = (
151- ds .batch (global_batch_size // dataloading_host_count , drop_remainder = False )
152- .prefetch (AUTOTUNE )
153- )
159+ ds = ds .batch (global_batch_size // dataloading_host_count , drop_remainder = False ).prefetch (AUTOTUNE )
154160
155161 iter = multihost_dataloading .MultiHostDataLoadIterator (ds , mesh )
156162 return iter
157163
164+
158165def make_tfrecord_iterator (
159- config , dataloading_host_index , dataloading_host_count , mesh , global_batch_size , feature_description , prepare_sample_fn , is_training
166+ config ,
167+ dataloading_host_index ,
168+ dataloading_host_count ,
169+ mesh ,
170+ global_batch_size ,
171+ feature_description ,
172+ prepare_sample_fn ,
173+ is_training ,
160174):
161175 """Iterator for TFRecord format. For Laion dataset,
162176 check out preparation script
@@ -165,4 +179,14 @@ def make_tfrecord_iterator(
165179 # Currently only support evaluation on tfrecord. To avoid influencing previous reference, judge whether is training dataset.
166180 # TODO: refactor to support evaluation on all dataset format.
167181 dataset_path = config .train_data_dir if is_training else config .eval_data_dir
168- return _make_tfrecord_iterator (config , dataloading_host_index , dataloading_host_count , mesh , global_batch_size , feature_description , prepare_sample_fn , dataset_path , is_training )
182+ return _make_tfrecord_iterator (
183+ config ,
184+ dataloading_host_index ,
185+ dataloading_host_count ,
186+ mesh ,
187+ global_batch_size ,
188+ feature_description ,
189+ prepare_sample_fn ,
190+ dataset_path ,
191+ is_training ,
192+ )
0 commit comments