|
19 | 19 | import tensorflow.experimental.numpy as tnp |
20 | 20 | from datasets import load_dataset, load_from_disk |
21 | 21 |
|
22 | | -from maxdiffusion import multihost_dataloading |
| 22 | +from maxdiffusion import multihost_dataloading, max_logging |
23 | 23 |
|
24 | 24 | AUTOTUNE = tf.data.AUTOTUNE |
25 | 25 |
|
@@ -73,7 +73,6 @@ def make_tf_iterator( |
73 | 73 | train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh) |
74 | 74 | return train_iter |
75 | 75 |
|
76 | | - |
77 | 76 | def make_cached_tfrecord_iterator( |
78 | 77 | config, |
79 | 78 | dataloading_host_index, |
@@ -105,6 +104,7 @@ def prepare_sample(features): |
105 | 104 |
|
106 | 105 | # This pipeline reads the sharded files and applies the parsing and preparation. |
107 | 106 | filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*")) |
| 107 | + |
108 | 108 | train_ds = ( |
109 | 109 | tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) |
110 | 110 | .shard(num_shards=dataloading_host_count, index=dataloading_host_index) |
@@ -133,8 +133,13 @@ def make_tfrecord_iterator( |
133 | 133 | check out preparation script |
134 | 134 | maxdiffusion/pedagogical_examples/to_tfrecords.py |
135 | 135 | """ |
136 | | - if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location): |
| 136 | + |
| 137 | + # set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset. |
| 138 | + # pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord. |
| 139 | + # Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator. |
| 140 | + if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location) and config.get("load_tfrecord_cached", False): |
137 | 141 | return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size) |
| 142 | + |
138 | 143 | feature_description = { |
139 | 144 | "moments": tf.io.FixedLenFeature([], tf.string), |
140 | 145 | "clip_embeddings": tf.io.FixedLenFeature([], tf.string), |
|
0 commit comments