@@ -73,7 +73,6 @@ def make_tf_iterator(
7373 train_iter = multihost_dataloading .MultiHostDataLoadIterator (train_ds , mesh )
7474 return train_iter
7575
76-
7776def make_cached_tfrecord_iterator (
7877 config ,
7978 dataloading_host_index ,
@@ -105,6 +104,7 @@ def prepare_sample(features):
105104
106105 # This pipeline reads the sharded files and applies the parsing and preparation.
107106 filenames = tf .io .gfile .glob (os .path .join (config .train_data_dir , "*" ))
107+
108108 train_ds = (
109109 tf .data .TFRecordDataset (filenames , num_parallel_reads = AUTOTUNE )
110110 .shard (num_shards = dataloading_host_count , index = dataloading_host_index )
@@ -133,8 +133,16 @@ def make_tfrecord_iterator(
133133 check out preparation script
134134 maxdiffusion/pedagogical_examples/to_tfrecords.py
135135 """
136- if config .cache_latents_text_encoder_outputs and os .path .isdir (config .dataset_save_location ):
136+
137+ # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
138+ # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
139+ # Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
140+ if (config .cache_latents_text_encoder_outputs and
141+ os .path .isdir (config .dataset_save_location ) and
142+ hasattr (config , 'load_tfrecord_cached' ) and
143+ config .load_tfrecord_cached ):
137144 return make_cached_tfrecord_iterator (config , dataloading_host_index , dataloading_host_count , mesh , global_batch_size )
145+
138146 feature_description = {
139147 "moments" : tf .io .FixedLenFeature ([], tf .string ),
140148 "clip_embeddings" : tf .io .FixedLenFeature ([], tf .string ),
0 commit comments