@@ -78,105 +78,45 @@ def make_tf_iterator(
7878 train_iter = multihost_dataloading .MultiHostDataLoadIterator (train_ds , mesh )
7979 return train_iter
8080
81- def make_cached_tfrecord_iterator (
82- dataloading_host_index , dataloading_host_count , mesh , global_batch_size , feature_description , prepare_sample_fn , dataset_path , is_training : bool
83- ):
84- """
85- New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
86- latents, input_ids, prompt_embeds, and text_embeds.
87- """
88-
89- def _parse_tfrecord_fn (example ):
90- return tf .io .parse_single_example (example , feature_description )
91-
92- # This pipeline reads the sharded files and applies the parsing and preparation.
93- filenames = tf .io .gfile .glob (os .path .join (dataset_path , "*" ))
94- ds = tf .data .TFRecordDataset (filenames , num_parallel_reads = AUTOTUNE )
95-
96- # --- PADDING LOGIC FOR EVALUATION ---
97- if not is_training :
98- num_eval_samples = 0
99- for _ in ds :
100- num_eval_samples += 1
101-
102- remainder = num_eval_samples % global_batch_size
103- if remainder != 0 :
104- num_to_pad = global_batch_size - remainder
105- # Create a dataset of padding samples from the beginning
106- padding_ds = ds .take (num_to_pad )
107- # Add the padding samples to the end
108- ds = ds .concatenate (padding_ds )
109- print (f"Padded evaluation dataset with { num_to_pad } samples." )
110-
111- ds = (
112- ds .shard (num_shards = dataloading_host_count , index = dataloading_host_index )
113- .map (_parse_tfrecord_fn , num_parallel_calls = AUTOTUNE )
114- .map (prepare_sample_fn , num_parallel_calls = AUTOTUNE )
115- )
116- if is_training :
117- ds = (
118- ds .shuffle (global_batch_size * 10 )
119- .batch (global_batch_size // dataloading_host_count , drop_remainder = True )
120- .repeat (- 1 )
121- .prefetch (AUTOTUNE )
122- )
123- # For Evaluation
124- else :
125- ds = (
126- ds .batch (global_batch_size // dataloading_host_count , drop_remainder = False )
127- .prefetch (AUTOTUNE )
128- )
129-
130- # This wraps the tf.data.Dataset for use in the multi-host JAX environment.
131- iter = multihost_dataloading .MultiHostDataLoadIterator (ds , mesh )
132- return iter
133-
134-
13581# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
13682def _make_tfrecord_iterator (
137- config , dataloading_host_index , dataloading_host_count , mesh , global_batch_size , feature_description , prepare_sample_fn , dataset_path , is_training : bool
83+ config , dataloading_host_index , dataloading_host_count , mesh , global_batch_size , feature_description_fn , prepare_sample_fn , dataset_path , is_training : bool
13884):
13985 # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
14086 # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
14187 # Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
88+ # if is_training is True, loads the training dataset. If False, loads the evaluation dataset.
14289
14390 # checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked.
144- # if is_training is True, loads the training dataset. If False, loads the evaluation dataset.
14591 is_dataset_dir_valid = "gs://" in config .dataset_save_location or os .path .isdir (config .dataset_save_location )
14692
147- if (
148- config .cache_latents_text_encoder_outputs
149- and is_dataset_dir_valid
150- and "load_tfrecord_cached" in config .get_keys ()
151- and config .load_tfrecord_cached
152- ):
153- return make_cached_tfrecord_iterator (
154- dataloading_host_index ,
155- dataloading_host_count ,
156- mesh ,
157- global_batch_size ,
158- feature_description ,
159- prepare_sample_fn ,
160- dataset_path ,
161- is_training
162- )
93+ # Determine whether to use the "cached" dataset, which requires externally
94+ # provided parsing functions, or the default one with its internal parsing logic.
95+ make_cached_tfrecord_iterator = (
96+ config .cache_latents_text_encoder_outputs
97+ and is_dataset_dir_valid
98+ and "load_tfrecord_cached" in config .get_keys ()
99+ and config .load_tfrecord_cached
100+ )
163101
164102 feature_description = {
165103 "moments" : tf .io .FixedLenFeature ([], tf .string ),
166104 "clip_embeddings" : tf .io .FixedLenFeature ([], tf .string ),
167105 }
168106
107+ used_feature_description = feature_description_fn if make_cached_tfrecord_iterator else feature_description
108+
169109 def _parse_tfrecord_fn (example ):
170- return tf .io .parse_single_example (example , feature_description )
110+ return tf .io .parse_single_example (example , used_feature_description )
171111
172112 def prepare_sample (features ):
173113 moments = tf .io .parse_tensor (tnp .asarray (features ["moments" ]), out_type = tf .float32 )
174114 clip_embeddings = tf .io .parse_tensor (tnp .asarray (features ["clip_embeddings" ]), out_type = tf .float32 )
175115 return {"pixel_values" : moments , "input_ids" : clip_embeddings }
176116
177117 filenames = tf .io .gfile .glob (os .path .join (dataset_path , "*" ))
178-
179118 ds = tf .data .TFRecordDataset (filenames , num_parallel_reads = AUTOTUNE )
119+
180120 # --- PADDING LOGIC FOR EVALUATION ---
181121 if not is_training :
182122 num_eval_samples = 0
@@ -191,11 +131,12 @@ def prepare_sample(features):
191131 # Add the padding samples to the end
192132 ds = ds .concatenate (padding_ds )
193133 print (f"Padded evaluation dataset with { num_to_pad } samples." )
194-
134+
135+ used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
195136 ds = (
196137 ds .shard (num_shards = dataloading_host_count , index = dataloading_host_index )
197138 .map (_parse_tfrecord_fn , num_parallel_calls = AUTOTUNE )
198- .map (prepare_sample , num_parallel_calls = AUTOTUNE )
139+ .map (used_prepare_sample , num_parallel_calls = AUTOTUNE )
199140 )
200141 if is_training :
201142 ds = (
0 commit comments