@@ -78,6 +78,53 @@ def make_tf_iterator(
7878 train_iter = multihost_dataloading .MultiHostDataLoadIterator (train_ds , mesh )
7979 return train_iter
8080
81+ def make_cached_tfrecord_iterator (
82+ config ,
83+ dataloading_host_index ,
84+ dataloading_host_count ,
85+ mesh ,
86+ global_batch_size ,
87+ ):
88+ """
89+ New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
90+ latents, input_ids, prompt_embeds, and text_embeds.
91+ """
92+ feature_description = {
93+ "pixel_values" : tf .io .FixedLenFeature ([], tf .string ),
94+ "input_ids" : tf .io .FixedLenFeature ([], tf .string ),
95+ "prompt_embeds" : tf .io .FixedLenFeature ([], tf .string ),
96+ "text_embeds" : tf .io .FixedLenFeature ([], tf .string ),
97+ }
98+
99+ def _parse_tfrecord_fn (example ):
100+ return tf .io .parse_single_example (example , feature_description )
101+
102+ def prepare_sample (features ):
103+ pixel_values = tf .io .parse_tensor (features ["pixel_values" ], out_type = tf .float32 )
104+ input_ids = tf .io .parse_tensor (features ["input_ids" ], out_type = tf .int32 )
105+ prompt_embeds = tf .io .parse_tensor (features ["prompt_embeds" ], out_type = tf .float32 )
106+ text_embeds = tf .io .parse_tensor (features ["text_embeds" ], out_type = tf .float32 )
107+
108+ return {"pixel_values" : pixel_values , "input_ids" : input_ids , "prompt_embeds" : prompt_embeds , "text_embeds" : text_embeds }
109+
110+ # This pipeline reads the sharded files and applies the parsing and preparation.
111+ filenames = tf .io .gfile .glob (os .path .join (config .train_data_dir , "*" ))
112+
113+ train_ds = (
114+ tf .data .TFRecordDataset (filenames , num_parallel_reads = AUTOTUNE )
115+ .shard (num_shards = dataloading_host_count , index = dataloading_host_index )
116+ .map (_parse_tfrecord_fn , num_parallel_calls = AUTOTUNE )
117+ .map (prepare_sample , num_parallel_calls = AUTOTUNE )
118+ .shuffle (global_batch_size * 10 )
119+ .batch (global_batch_size // dataloading_host_count , drop_remainder = True )
120+ .repeat (- 1 )
121+ .prefetch (AUTOTUNE )
122+ )
123+
124+ # This wraps the tf.data.Dataset for use in the multi-host JAX environment.
125+ train_iter = multihost_dataloading .MultiHostDataLoadIterator (train_ds , mesh )
126+ return train_iter
127+
81128
82129def make_cached_tfrecord_iterator (
83130 config , dataloading_host_index , dataloading_host_count , mesh , global_batch_size , feature_description , prepare_sample_fn
@@ -120,22 +167,12 @@ def make_tfrecord_iterator(
120167
121168 # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
122169 # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
123- # Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
124- if (
125- config .cache_latents_text_encoder_outputs
170+ # Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
171+ if (config .cache_latents_text_encoder_outputs
126172 and os .path .isdir (config .dataset_save_location )
127- and "load_tfrecord_cached" in config .get_keys ()
128- and config .load_tfrecord_cached
129- ):
130- return make_cached_tfrecord_iterator (
131- config ,
132- dataloading_host_index ,
133- dataloading_host_count ,
134- mesh ,
135- global_batch_size ,
136- feature_description ,
137- prepare_sample_fn ,
138- )
173+ and 'load_tfrecord_cached' in config .get_keys ()
174+ and config .load_tfrecord_cached ):
175+ return make_cached_tfrecord_iterator (config , dataloading_host_index , dataloading_host_count , mesh , global_batch_size )
139176
140177 feature_description = {
141178 "moments" : tf .io .FixedLenFeature ([], tf .string ),
0 commit comments