Skip to content

Commit 7992261

Browse files
committed
Fallback to regular tfrecord iterator for datasets without all the processed features
Signed-off-by: Kunjan <kunjanp@google.com>
1 parent f418a0c commit 7992261

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import tensorflow.experimental.numpy as tnp
2020
from datasets import load_dataset, load_from_disk
2121

22-
from maxdiffusion import multihost_dataloading
22+
from maxdiffusion import multihost_dataloading, max_logging
2323

2424
AUTOTUNE = tf.data.AUTOTUNE
2525

@@ -134,7 +134,11 @@ def make_tfrecord_iterator(
134134
maxdiffusion/pedagogical_examples/to_tfrecords.py
135135
"""
136136
if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location):
137-
return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size)
137+
try:
138+
return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size)
139+
except Exception as e:
140+
max_logging.log(f"Cached tfrecord dataset doesn't contain required features, making regular tfrecord_iterator {e}")
141+
138142
feature_description = {
139143
"moments": tf.io.FixedLenFeature([], tf.string),
140144
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),

0 commit comments

Comments
 (0)